Simplify outbound session setup

Don't send inbound session to self and claim and send all keys at once.
This commit is contained in:
Nicolas Werner 2020-09-20 23:04:14 +02:00
parent 8eb74daf76
commit 54db9c89ed
9 changed files with 209 additions and 214 deletions

View File

@ -342,7 +342,7 @@ if(USE_BUNDLED_MTXCLIENT)
FetchContent_Declare( FetchContent_Declare(
MatrixClient MatrixClient
GIT_REPOSITORY https://github.com/Nheko-Reborn/mtxclient.git GIT_REPOSITORY https://github.com/Nheko-Reborn/mtxclient.git
GIT_TAG 0665c8baf4af0ce192adb8ca97761b63b681d569 GIT_TAG f84611f129b46746a4b586acaba54fc31a303bc6
) )
FetchContent_MakeAvailable(MatrixClient) FetchContent_MakeAvailable(MatrixClient)
else() else()

View File

@ -146,7 +146,7 @@
"name": "mtxclient", "name": "mtxclient",
"sources": [ "sources": [
{ {
"commit": "0665c8baf4af0ce192adb8ca97761b63b681d569", "commit": "f84611f129b46746a4b586acaba54fc31a303bc6",
"type": "git", "type": "git",
"url": "https://github.com/Nheko-Reborn/mtxclient.git" "url": "https://github.com/Nheko-Reborn/mtxclient.git"
} }

View File

@ -139,24 +139,26 @@ Cache::Cache(const QString &userId, QObject *parent)
, localUserId_{userId} , localUserId_{userId}
{ {
setup(); setup();
connect(this, connect(
&Cache::updateUserCacheFlag, this,
this, &Cache::updateUserCacheFlag,
[this](const std::string &user_id) { this,
std::optional<UserCache> cache_ = getUserCache(user_id); [this](const std::string &user_id) {
if (cache_.has_value()) { std::optional<UserCache> cache_ = getUserCache(user_id);
cache_.value().isUpdated = false; if (cache_.has_value()) {
setUserCache(user_id, cache_.value()); cache_.value().isUpdated = false;
} else { setUserCache(user_id, cache_.value());
setUserCache(user_id, UserCache{}); } else {
} setUserCache(user_id, UserCache{});
}, }
Qt::QueuedConnection); },
connect(this, Qt::QueuedConnection);
&Cache::deleteLeftUsers, connect(
this, this,
[this](const std::string &user_id) { deleteUserCache(user_id); }, &Cache::deleteLeftUsers,
Qt::QueuedConnection); this,
[this](const std::string &user_id) { deleteUserCache(user_id); },
Qt::QueuedConnection);
} }
void void

View File

@ -606,11 +606,12 @@ ChatPage::ChatPage(QSharedPointer<UserSettings> userSettings, QWidget *parent)
connect( connect(
this, &ChatPage::tryInitialSyncCb, this, &ChatPage::tryInitialSync, Qt::QueuedConnection); this, &ChatPage::tryInitialSyncCb, this, &ChatPage::tryInitialSync, Qt::QueuedConnection);
connect(this, &ChatPage::trySyncCb, this, &ChatPage::trySync, Qt::QueuedConnection); connect(this, &ChatPage::trySyncCb, this, &ChatPage::trySync, Qt::QueuedConnection);
connect(this, connect(
&ChatPage::tryDelayedSyncCb, this,
this, &ChatPage::tryDelayedSyncCb,
[this]() { QTimer::singleShot(RETRY_TIMEOUT, this, &ChatPage::trySync); }, this,
Qt::QueuedConnection); [this]() { QTimer::singleShot(RETRY_TIMEOUT, this, &ChatPage::trySync); },
Qt::QueuedConnection);
connect(this, connect(this,
&ChatPage::newSyncResponse, &ChatPage::newSyncResponse,

View File

@ -581,9 +581,11 @@ send_megolm_key_to_device(const std::string &user_id,
->create_room_key_event(UserId(user_id), pks.ed25519, payload) ->create_room_key_event(UserId(user_id), pks.ed25519, payload)
.dump(); .dump();
mtx::requests::ClaimKeys claim_keys;
claim_keys.one_time_keys[user_id][device_id] = mtx::crypto::SIGNED_CURVE25519;
http::client()->claim_keys( http::client()->claim_keys(
user_id, claim_keys,
{device_id},
[room_key, user_id, device_id, pks](const mtx::responses::ClaimKeys &res, [room_key, user_id, device_id, pks](const mtx::responses::ClaimKeys &res,
mtx::http::RequestErr err) { mtx::http::RequestErr err) {
if (err) { if (err) {

Binary file not shown.

View File

@ -32,38 +32,40 @@ EventStore::EventStore(std::string room_id, QObject *)
this->last = range->last; this->last = range->last;
} }
connect(this, connect(
&EventStore::eventFetched, this,
this, &EventStore::eventFetched,
[this](std::string id, this,
std::string relatedTo, [this](std::string id,
mtx::events::collections::TimelineEvents timeline) { std::string relatedTo,
cache::client()->storeEvent(room_id_, id, {timeline}); mtx::events::collections::TimelineEvents timeline) {
cache::client()->storeEvent(room_id_, id, {timeline});
if (!relatedTo.empty()) { if (!relatedTo.empty()) {
auto idx = idToIndex(relatedTo); auto idx = idToIndex(relatedTo);
if (idx) if (idx)
emit dataChanged(*idx, *idx); emit dataChanged(*idx, *idx);
} }
}, },
Qt::QueuedConnection); Qt::QueuedConnection);
connect(this, connect(
&EventStore::oldMessagesRetrieved, this,
this, &EventStore::oldMessagesRetrieved,
[this](const mtx::responses::Messages &res) { this,
uint64_t newFirst = cache::client()->saveOldMessages(room_id_, res); [this](const mtx::responses::Messages &res) {
if (newFirst == first) uint64_t newFirst = cache::client()->saveOldMessages(room_id_, res);
fetchMore(); if (newFirst == first)
else { fetchMore();
emit beginInsertRows(toExternalIdx(newFirst), else {
toExternalIdx(this->first - 1)); emit beginInsertRows(toExternalIdx(newFirst),
this->first = newFirst; toExternalIdx(this->first - 1));
emit endInsertRows(); this->first = newFirst;
emit fetchedMore(); emit endInsertRows();
} emit fetchedMore();
}, }
Qt::QueuedConnection); },
Qt::QueuedConnection);
connect(this, &EventStore::processPending, this, [this]() { connect(this, &EventStore::processPending, this, [this]() {
if (!current_txn.empty()) { if (!current_txn.empty()) {
@ -128,46 +130,48 @@ EventStore::EventStore(std::string room_id, QObject *)
event->data); event->data);
}); });
connect(this, connect(
&EventStore::messageFailed, this,
this, &EventStore::messageFailed,
[this](std::string txn_id) { this,
if (current_txn == txn_id) { [this](std::string txn_id) {
current_txn_error_count++; if (current_txn == txn_id) {
if (current_txn_error_count > 10) { current_txn_error_count++;
nhlog::ui()->debug("failing txn id '{}'", txn_id); if (current_txn_error_count > 10) {
cache::client()->removePendingStatus(room_id_, txn_id); nhlog::ui()->debug("failing txn id '{}'", txn_id);
current_txn_error_count = 0; cache::client()->removePendingStatus(room_id_, txn_id);
} current_txn_error_count = 0;
} }
QTimer::singleShot(1000, this, [this]() { }
nhlog::ui()->debug("timeout"); QTimer::singleShot(1000, this, [this]() {
this->current_txn = ""; nhlog::ui()->debug("timeout");
emit processPending(); this->current_txn = "";
}); emit processPending();
}, });
Qt::QueuedConnection); },
Qt::QueuedConnection);
connect(this, connect(
&EventStore::messageSent, this,
this, &EventStore::messageSent,
[this](std::string txn_id, std::string event_id) { this,
nhlog::ui()->debug("sent {}", txn_id); [this](std::string txn_id, std::string event_id) {
nhlog::ui()->debug("sent {}", txn_id);
http::client()->read_event( http::client()->read_event(
room_id_, event_id, [this, event_id](mtx::http::RequestErr err) { room_id_, event_id, [this, event_id](mtx::http::RequestErr err) {
if (err) { if (err) {
nhlog::net()->warn( nhlog::net()->warn(
"failed to read_event ({}, {})", room_id_, event_id); "failed to read_event ({}, {})", room_id_, event_id);
} }
}); });
cache::client()->removePendingStatus(room_id_, txn_id); cache::client()->removePendingStatus(room_id_, txn_id);
this->current_txn = ""; this->current_txn = "";
this->current_txn_error_count = 0; this->current_txn_error_count = 0;
emit processPending(); emit processPending();
}, },
Qt::QueuedConnection); Qt::QueuedConnection);
} }
void void

View File

@ -204,11 +204,12 @@ TimelineModel::TimelineModel(TimelineViewManager *manager, QString room_id, QObj
, room_id_(room_id) , room_id_(room_id)
, manager_(manager) , manager_(manager)
{ {
connect(this, connect(
&TimelineModel::redactionFailed, this,
this, &TimelineModel::redactionFailed,
[](const QString &msg) { emit ChatPage::instance()->showNotification(msg); }, this,
Qt::QueuedConnection); [](const QString &msg) { emit ChatPage::instance()->showNotification(msg); },
Qt::QueuedConnection);
connect(this, connect(this,
&TimelineModel::newMessageToSend, &TimelineModel::newMessageToSend,
@ -217,17 +218,17 @@ TimelineModel::TimelineModel(TimelineViewManager *manager, QString room_id, QObj
Qt::QueuedConnection); Qt::QueuedConnection);
connect(this, &TimelineModel::addPendingMessageToStore, &events, &EventStore::addPending); connect(this, &TimelineModel::addPendingMessageToStore, &events, &EventStore::addPending);
connect(&events, connect(
&EventStore::dataChanged, &events,
this, &EventStore::dataChanged,
[this](int from, int to) { this,
nhlog::ui()->debug("data changed {} to {}", [this](int from, int to) {
events.size() - to - 1, nhlog::ui()->debug(
events.size() - from - 1); "data changed {} to {}", events.size() - to - 1, events.size() - from - 1);
emit dataChanged(index(events.size() - to - 1, 0), emit dataChanged(index(events.size() - to - 1, 0),
index(events.size() - from - 1, 0)); index(events.size() - from - 1, 0));
}, },
Qt::QueuedConnection); Qt::QueuedConnection);
connect(&events, &EventStore::beginInsertRows, this, [this](int from, int to) { connect(&events, &EventStore::beginInsertRows, this, [this](int from, int to) {
int first = events.size() - to; int first = events.size() - to;
@ -916,10 +917,20 @@ TimelineModel::sendEncryptedMessage(mtx::events::RoomEvent<T> msg, mtx::events::
OutboundGroupSessionData session_data; OutboundGroupSessionData session_data;
session_data.session_id = session_id; session_data.session_id = session_id;
session_data.session_key = session_key; session_data.session_key = session_key;
session_data.message_index = 0; // TODO Update me session_data.message_index = 0;
cache::saveOutboundMegolmSession( cache::saveOutboundMegolmSession(
room_id, session_data, std::move(outbound_session)); room_id, session_data, std::move(outbound_session));
{
MegolmSessionIndex index;
index.room_id = room_id;
index.session_id = session_id;
index.sender_key = olm::client()->identity_keys().curve25519;
auto megolm_session =
olm::client()->init_inbound_group_session(session_key);
cache::saveInboundMegolmSession(index, std::move(megolm_session));
}
const auto members = cache::roomMembers(room_id); const auto members = cache::roomMembers(room_id);
nhlog::ui()->info("retrieved {} members for {}", members.size(), room_id); nhlog::ui()->info("retrieved {} members for {}", members.size(), room_id);
@ -961,19 +972,23 @@ TimelineModel::sendEncryptedMessage(mtx::events::RoomEvent<T> msg, mtx::events::
return; return;
} }
mtx::requests::ClaimKeys claim_keys;
// Mapping from user id to a device_id with valid identity keys to the
// generated room_key event used for sharing the megolm session.
std::map<std::string, std::map<std::string, std::string>> room_key_msgs;
std::map<std::string, std::map<std::string, DevicePublicKeys>> deviceKeys;
for (const auto &user : res.device_keys) { for (const auto &user : res.device_keys) {
// Mapping from a device_id with valid identity keys to the
// generated room_key event used for sharing the megolm session.
std::map<std::string, std::string> room_key_msgs;
std::map<std::string, DevicePublicKeys> deviceKeys;
room_key_msgs.clear();
deviceKeys.clear();
for (const auto &dev : user.second) { for (const auto &dev : user.second) {
const auto user_id = ::UserId(dev.second.user_id); const auto user_id = ::UserId(dev.second.user_id);
const auto device_id = DeviceId(dev.second.device_id); const auto device_id = DeviceId(dev.second.device_id);
if (user_id.get() ==
http::client()->user_id().to_string() &&
device_id.get() == http::client()->device_id())
continue;
const auto device_keys = dev.second.keys; const auto device_keys = dev.second.keys;
const auto curveKey = "curve25519:" + device_id.get(); const auto curveKey = "curve25519:" + device_id.get();
const auto edKey = "ed25519:" + device_id.get(); const auto edKey = "ed25519:" + device_id.get();
@ -1015,42 +1030,25 @@ TimelineModel::sendEncryptedMessage(mtx::events::RoomEvent<T> msg, mtx::events::
user_id, pks.ed25519, megolm_payload) user_id, pks.ed25519, megolm_payload)
.dump(); .dump();
room_key_msgs.emplace(device_id, room_key); room_key_msgs[user_id].emplace(device_id, room_key);
deviceKeys.emplace(device_id, pks); deviceKeys[user_id].emplace(device_id, pks);
claim_keys.one_time_keys[user.first][device_id] =
mtx::crypto::SIGNED_CURVE25519;
nhlog::net()->info("{}", device_id.get());
nhlog::net()->info(" curve25519 {}", pks.curve25519);
nhlog::net()->info(" ed25519 {}", pks.ed25519);
} }
std::vector<std::string> valid_devices;
valid_devices.reserve(room_key_msgs.size());
for (auto const &d : room_key_msgs) {
valid_devices.push_back(d.first);
nhlog::net()->info("{}", d.first);
nhlog::net()->info(" curve25519 {}",
deviceKeys.at(d.first).curve25519);
nhlog::net()->info(" ed25519 {}",
deviceKeys.at(d.first).ed25519);
}
nhlog::net()->info(
"sending claim request for user {} with {} devices",
user.first,
valid_devices.size());
http::client()->claim_keys(
user.first,
valid_devices,
std::bind(&TimelineModel::handleClaimedKeys,
this,
keeper,
room_key_msgs,
deviceKeys,
user.first,
std::placeholders::_1,
std::placeholders::_2));
// TODO: Wait before sending the next batch of requests.
std::this_thread::sleep_for(std::chrono::milliseconds(500));
} }
http::client()->claim_keys(claim_keys,
std::bind(&TimelineModel::handleClaimedKeys,
this,
keeper,
room_key_msgs,
deviceKeys,
std::placeholders::_1,
std::placeholders::_2));
}); });
// TODO: Let the user know about the errors. // TODO: Let the user know about the errors.
@ -1068,12 +1066,12 @@ TimelineModel::sendEncryptedMessage(mtx::events::RoomEvent<T> msg, mtx::events::
} }
void void
TimelineModel::handleClaimedKeys(std::shared_ptr<StateKeeper> keeper, TimelineModel::handleClaimedKeys(
const std::map<std::string, std::string> &room_keys, std::shared_ptr<StateKeeper> keeper,
const std::map<std::string, DevicePublicKeys> &pks, const std::map<std::string, std::map<std::string, std::string>> &room_keys,
const std::string &user_id, const std::map<std::string, std::map<std::string, DevicePublicKeys>> &pks,
const mtx::responses::ClaimKeys &res, const mtx::responses::ClaimKeys &res,
mtx::http::RequestErr err) mtx::http::RequestErr err)
{ {
if (err) { if (err) {
nhlog::net()->warn("claim keys error: {} {} {}", nhlog::net()->warn("claim keys error: {} {} {}",
@ -1083,65 +1081,53 @@ TimelineModel::handleClaimedKeys(std::shared_ptr<StateKeeper> keeper,
return; return;
} }
nhlog::net()->debug("claimed keys for {}", user_id);
if (res.one_time_keys.size() == 0) {
nhlog::net()->debug("no one-time keys found for user_id: {}", user_id);
return;
}
if (res.one_time_keys.find(user_id) == res.one_time_keys.end()) {
nhlog::net()->debug("no one-time keys found for user_id: {}", user_id);
return;
}
auto retrieved_devices = res.one_time_keys.at(user_id);
// Payload with all the to_device message to be sent. // Payload with all the to_device message to be sent.
json body; nlohmann::json body;
body["messages"][user_id] = json::object();
for (const auto &rd : retrieved_devices) { for (const auto &[user_id, retrieved_devices] : res.one_time_keys) {
const auto device_id = rd.first; nhlog::net()->debug("claimed keys for {}", user_id);
nhlog::net()->debug("{} : \n {}", device_id, rd.second.dump(2)); if (retrieved_devices.size() == 0) {
nhlog::net()->debug("no one-time keys found for user_id: {}", user_id);
// TODO: Verify signatures return;
auto otk = rd.second.begin()->at("key");
if (pks.find(device_id) == pks.end()) {
nhlog::net()->critical("couldn't find public key for device: {}",
device_id);
continue;
} }
auto id_key = pks.at(device_id).curve25519; for (const auto &rd : retrieved_devices) {
auto s = olm::client()->create_outbound_session(id_key, otk); const auto device_id = rd.first;
if (room_keys.find(device_id) == room_keys.end()) { nhlog::net()->debug("{} : \n {}", device_id, rd.second.dump(2));
nhlog::net()->critical("couldn't find m.room_key for device: {}",
device_id); // TODO: Verify signatures
continue; auto otk = rd.second.begin()->at("key");
auto id_key = pks.at(user_id).at(device_id).curve25519;
auto s = olm::client()->create_outbound_session(id_key, otk);
auto device_msg = olm::client()->create_olm_encrypted_content(
s.get(),
room_keys.at(user_id).at(device_id),
pks.at(user_id).at(device_id).curve25519);
try {
cache::saveOlmSession(id_key, std::move(s));
} catch (const lmdb::error &e) {
nhlog::db()->critical("failed to save outbound olm session: {}",
e.what());
} catch (const mtx::crypto::olm_exception &e) {
nhlog::crypto()->critical(
"failed to pickle outbound olm session: {}", e.what());
}
body["messages"][user_id][device_id] = device_msg;
} }
auto device_msg = olm::client()->create_olm_encrypted_content( nhlog::net()->info("send_to_device: {}", user_id);
s.get(), room_keys.at(device_id), pks.at(device_id).curve25519);
try {
cache::saveOlmSession(id_key, std::move(s));
} catch (const lmdb::error &e) {
nhlog::db()->critical("failed to save outbound olm session: {}", e.what());
} catch (const mtx::crypto::olm_exception &e) {
nhlog::crypto()->critical("failed to pickle outbound olm session: {}",
e.what());
}
body["messages"][user_id][device_id] = device_msg;
} }
nhlog::net()->info("send_to_device: {}", user_id);
http::client()->send_to_device( http::client()->send_to_device(
"m.room.encrypted", body, [keeper](mtx::http::RequestErr err) { mtx::events::to_string(mtx::events::EventType::RoomEncrypted),
http::client()->generate_txn_id(),
body,
[keeper](mtx::http::RequestErr err) {
if (err) { if (err) {
nhlog::net()->warn("failed to send " nhlog::net()->warn("failed to send "
"send_to_device " "send_to_device "

View File

@ -285,12 +285,12 @@ signals:
private: private:
template<typename T> template<typename T>
void sendEncryptedMessage(mtx::events::RoomEvent<T> msg, mtx::events::EventType eventType); void sendEncryptedMessage(mtx::events::RoomEvent<T> msg, mtx::events::EventType eventType);
void handleClaimedKeys(std::shared_ptr<StateKeeper> keeper, void handleClaimedKeys(
const std::map<std::string, std::string> &room_key, std::shared_ptr<StateKeeper> keeper,
const std::map<std::string, DevicePublicKeys> &pks, const std::map<std::string, std::map<std::string, std::string>> &room_keys,
const std::string &user_id, const std::map<std::string, std::map<std::string, DevicePublicKeys>> &pks,
const mtx::responses::ClaimKeys &res, const mtx::responses::ClaimKeys &res,
mtx::http::RequestErr err); mtx::http::RequestErr err);
void readEvent(const std::string &id); void readEvent(const std::string &id);
void setPaginationInProgress(const bool paginationInProgress); void setPaginationInProgress(const bool paginationInProgress);