diff --git a/src/Cache.cpp b/src/Cache.cpp index e090e40d..2784cf50 100644 --- a/src/Cache.cpp +++ b/src/Cache.cpp @@ -912,6 +912,29 @@ Cache::getMegolmSessionData(const MegolmSessionIndex &index) // OLM sessions. // +void +Cache::saveOlmSessions(std::vector> sessions, + uint64_t timestamp) +{ + using namespace mtx::crypto; + + auto txn = lmdb::txn::begin(env_); + for (const auto &[curve25519, session] : sessions) { + auto db = getOlmSessionsDb(txn, curve25519); + + const auto pickled = pickle(session.get(), pickle_secret_); + const auto session_id = mtx::crypto::session_id(session.get()); + + StoredOlmSession stored_session; + stored_session.pickled_session = pickled; + stored_session.last_message_ts = timestamp; + + db.put(txn, session_id, nlohmann::json(stored_session).dump()); + } + + txn.commit(); +} + void Cache::saveOlmSession(const std::string &curve25519, mtx::crypto::OlmSessionPtr session, diff --git a/src/Cache_p.h b/src/Cache_p.h index 1694adb7..742e4aab 100644 --- a/src/Cache_p.h +++ b/src/Cache_p.h @@ -277,6 +277,8 @@ public: void saveOlmSession(const std::string &curve25519, mtx::crypto::OlmSessionPtr session, uint64_t timestamp); + void saveOlmSessions(std::vector> sessions, + uint64_t timestamp); std::vector getOlmSessions(const std::string &curve25519); std::optional getOlmSession(const std::string &curve25519, const std::string &session_id); diff --git a/src/encryption/Olm.cpp b/src/encryption/Olm.cpp index 7ada2f92..a9d5b1c2 100644 --- a/src/encryption/Olm.cpp +++ b/src/encryption/Olm.cpp @@ -1299,78 +1299,83 @@ send_encrypted_to_device_messages(const std::mapidentity_keys().curve25519; - for (const auto &[user, devices] : targets) { - auto deviceKeys = cache::client()->userKeys(user); + { + auto currentTime = QDateTime::currentSecsSinceEpoch(); + std::vector> sessionsToPersist; - // no keys for user, query them - if (!deviceKeys) { - keysToQuery[user] = devices; - continue; - } + for (const auto &[user, devices] : targets) { + auto deviceKeys = cache::client()->userKeys(user); - auto deviceTargets = devices; - if (devices.empty()) { - deviceTargets.clear(); - deviceTargets.reserve(deviceKeys->device_keys.size()); - for (const auto &[device, keys] : deviceKeys->device_keys) { - (void)keys; - deviceTargets.push_back(device); - } - } - - for (const auto &device : deviceTargets) { - if (!deviceKeys->device_keys.count(device)) { - keysToQuery[user] = {}; - break; - } - - auto d = deviceKeys->device_keys.at(device); - - if (!d.keys.count("curve25519:" + device) || !d.keys.count("ed25519:" + device)) { - nhlog::crypto()->warn("Skipping device {} since it has no keys!", device); + // no keys for user, query them + if (!deviceKeys) { + keysToQuery[user] = devices; continue; } - auto device_curve = d.keys.at("curve25519:" + device); - if (device_curve == our_curve) { - nhlog::crypto()->warn("Skipping our own device, since sending " - "ourselves olm messages makes no sense."); - continue; - } - - auto session = cache::getLatestOlmSession(device_curve); - if (!session || force_new_session) { - auto currentTime = QDateTime::currentSecsSinceEpoch(); - if (rateLimit.value(QPair(user, device)) + 60 * 60 * 10 < currentTime) { - claims.one_time_keys[user][device] = mtx::crypto::SIGNED_CURVE25519; - pks[user][device].ed25519 = d.keys.at("ed25519:" + device); - pks[user][device].curve25519 = d.keys.at("curve25519:" + device); - - rateLimit.insert(QPair(user, device), currentTime); - } else { - nhlog::crypto()->warn("Not creating new session with {}:{} " - "because of rate limit", - user, - device); + auto deviceTargets = devices; + if (devices.empty()) { + deviceTargets.clear(); + deviceTargets.reserve(deviceKeys->device_keys.size()); + for (const auto &[device, keys] : deviceKeys->device_keys) { + (void)keys; + deviceTargets.push_back(device); } - continue; } - messages[mtx::identifiers::parse(user)][device] = - olm::client() - ->create_olm_encrypted_content(session->get(), - ev_json, - UserId(user), - d.keys.at("ed25519:" + device), - device_curve) - .get(); + for (const auto &device : deviceTargets) { + if (!deviceKeys->device_keys.count(device)) { + keysToQuery[user] = {}; + break; + } + const auto &d = deviceKeys->device_keys.at(device); + + if (!d.keys.count("curve25519:" + device) || !d.keys.count("ed25519:" + device)) { + nhlog::crypto()->warn("Skipping device {} since it has no keys!", device); + continue; + } + + auto device_curve = d.keys.at("curve25519:" + device); + if (device_curve == our_curve) { + nhlog::crypto()->warn("Skipping our own device, since sending " + "ourselves olm messages makes no sense."); + continue; + } + + auto session = cache::getLatestOlmSession(device_curve); + if (!session || force_new_session) { + if (rateLimit.value(QPair(user, device)) + 60 * 60 * 10 < currentTime) { + claims.one_time_keys[user][device] = mtx::crypto::SIGNED_CURVE25519; + pks[user][device].ed25519 = d.keys.at("ed25519:" + device); + pks[user][device].curve25519 = d.keys.at("curve25519:" + device); + + rateLimit.insert(QPair(user, device), currentTime); + } else { + nhlog::crypto()->warn("Not creating new session with {}:{} " + "because of rate limit", + user, + device); + } + continue; + } + + messages[mtx::identifiers::parse(user)][device] = + olm::client() + ->create_olm_encrypted_content(session->get(), + ev_json, + UserId(user), + d.keys.at("ed25519:" + device), + device_curve) + .get(); + sessionsToPersist.emplace_back(d.keys.at("curve25519:" + device), + std::move(*session)); + } + } + + if (!sessionsToPersist.empty()) { try { - nhlog::crypto()->debug("Updated olm session: {}", - mtx::crypto::session_id(session->get())); - cache::saveOlmSession(d.keys.at("curve25519:" + device), - std::move(*session), - QDateTime::currentMSecsSinceEpoch()); + nhlog::crypto()->debug("Updated olm sessions: {}", sessionsToPersist.size()); + cache::client()->saveOlmSessions(std::move(sessionsToPersist), currentTime); } catch (const lmdb::error &e) { nhlog::db()->critical("failed to save outbound olm session: {}", e.what()); } catch (const mtx::crypto::olm_exception &e) { @@ -1395,6 +1400,9 @@ send_encrypted_to_device_messages(const std::map> messages; + auto currentTime = QDateTime::currentSecsSinceEpoch(); + std::vector> sessionsToPersist; + for (const auto &[user_id, retrieved_devices] : res.one_time_keys) { nhlog::net()->debug("claimed keys for {}", user_id); if (retrieved_devices.size() == 0) { @@ -1440,21 +1448,24 @@ send_encrypted_to_device_messages(const std::map(); - try { - nhlog::crypto()->debug("Updated olm session: {}", - mtx::crypto::session_id(session.get())); - cache::saveOlmSession( - id_key, std::move(session), QDateTime::currentMSecsSinceEpoch()); - } catch (const lmdb::error &e) { - nhlog::db()->critical("failed to save outbound olm session: {}", e.what()); - } catch (const mtx::crypto::olm_exception &e) { - nhlog::crypto()->critical("failed to pickle outbound olm session: {}", - e.what()); - } + sessionsToPersist.emplace_back(id_key, std::move(session)); } nhlog::net()->info("send_to_device: {}", user_id); } + if (!sessionsToPersist.empty()) { + try { + nhlog::crypto()->debug("Updated (new) olm sessions: {}", + sessionsToPersist.size()); + cache::client()->saveOlmSessions(std::move(sessionsToPersist), currentTime); + } catch (const lmdb::error &e) { + nhlog::db()->critical("failed to save outbound olm session: {}", e.what()); + } catch (const mtx::crypto::olm_exception &e) { + nhlog::crypto()->critical("failed to pickle outbound olm session: {}", + e.what()); + } + } + if (!messages.empty()) http::client()->send_to_device( http::client()->generate_txn_id(), messages, [](mtx::http::RequestErr err) {