From 5b400770339ab3b75e23255bc4f26c0ca2d40224 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 6 May 2020 19:49:53 +0300 Subject: [PATCH] Don't query changed devices if device list isn't tracked --- crypto/devicelist.go | 8 +++++++- crypto/encryptmegolm.go | 2 +- crypto/machine.go | 2 +- crypto/store.go | 15 +++++++++++++++ 4 files changed, 24 insertions(+), 3 deletions(-) diff --git a/crypto/devicelist.go b/crypto/devicelist.go index e5d9d746..04b0ec50 100644 --- a/crypto/devicelist.go +++ b/crypto/devicelist.go @@ -23,12 +23,18 @@ var ( InvalidKeySignature = errors.New("invalid signature on device keys") ) -func (mach *OlmMachine) fetchKeys(users []id.UserID, sinceToken string) (data map[id.UserID]map[id.DeviceID]*DeviceIdentity) { +func (mach *OlmMachine) fetchKeys(users []id.UserID, sinceToken string, includeUntracked bool) (data map[id.UserID]map[id.DeviceID]*DeviceIdentity) { req := &mautrix.ReqQueryKeys{ DeviceKeys: mautrix.DeviceKeysRequest{}, Timeout: 10 * 1000, Token: sinceToken, } + if !includeUntracked { + users = mach.CryptoStore.FilterTrackedUsers(users) + } + if len(users) == 0 { + return + } for _, userID := range users { req.DeviceKeys[userID] = mautrix.DeviceIDList{} } diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index 4c9c28bc..95d237f0 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -103,7 +103,7 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e if len(fetchKeys) > 0 { mach.Log.Trace("Fetching missing keys for %v", fetchKeys) - for userID, devices := range mach.fetchKeys(fetchKeys, "") { + for userID, devices := range mach.fetchKeys(fetchKeys, "", true) { mach.Log.Trace("Got %d device keys for %s", len(devices), userID) missingSessions[userID] = devices } diff --git a/crypto/machine.go b/crypto/machine.go index e295a3fe..791804ee 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -72,7 +72,7 @@ func (mach *OlmMachine) FlushStore() error { func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string) { if len(resp.DeviceLists.Changed) > 0 { mach.Log.Trace("Device list changes in /sync: %v", resp.DeviceLists.Changed) - mach.fetchKeys(resp.DeviceLists.Changed, since) + mach.fetchKeys(resp.DeviceLists.Changed, since, false) } for _, evt := range resp.ToDevice.Events { diff --git a/crypto/store.go b/crypto/store.go index 266b68e1..62e7e4cd 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -60,6 +60,7 @@ type Store interface { GetDevices(id.UserID) (map[id.DeviceID]*DeviceIdentity, error) PutDevices(id.UserID, map[id.DeviceID]*DeviceIdentity) error + FilterTrackedUsers([]id.UserID) []id.UserID } type messageIndexKey struct { @@ -291,3 +292,17 @@ func (gs *GobStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*Device gs.lock.Unlock() return err } + +func (gs *GobStore) FilterTrackedUsers(users []id.UserID) []id.UserID { + gs.lock.RLock() + var ptr int + for _, userID := range users { + _, ok := gs.Devices[userID] + if ok { + users[ptr] = userID + ptr++ + } + } + gs.lock.RUnlock() + return users[:ptr] +}