diff --git a/crypto/devicelist.go b/crypto/devicelist.go index bbe06aae..f5c07cd3 100644 --- a/crypto/devicelist.go +++ b/crypto/devicelist.go @@ -27,8 +27,16 @@ var ( InvalidKeySignature = errors.New("invalid signature on device keys") ) -func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) map[id.DeviceID]*id.Device { - return mach.fetchKeys(ctx, []id.UserID{user}, "", true)[user] +func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys map[id.DeviceID]*id.Device) { + log := zerolog.Ctx(ctx) + + if keys, err := mach.FetchKeys(ctx, []id.UserID{user}, true); err != nil { + log.Err(err).Msg("Failed to load devices") + } else if keys != nil { + return keys[user] + } + + return nil } func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id.UserID, deviceID id.DeviceID, resp *mautrix.RespQueryKeys) { @@ -85,19 +93,16 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id } } -func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceToken string, includeUntracked bool) (data map[id.UserID]map[id.DeviceID]*id.Device) { - // TODO this function should probably return errors +func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includeUntracked bool) (data map[id.UserID]map[id.DeviceID]*id.Device, err error) { req := &mautrix.ReqQueryKeys{ DeviceKeys: mautrix.DeviceKeysRequest{}, Timeout: 10 * 1000, - Token: sinceToken, } log := mach.machOrContextLog(ctx) if !includeUntracked { - var err error users, err = mach.CryptoStore.FilterTrackedUsers(ctx, users) if err != nil { - log.Warn().Err(err).Msg("Failed to filter tracked user list") + return nil, fmt.Errorf("failed to filter tracked user list: %w", err) } } if len(users) == 0 { @@ -109,8 +114,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT log.Debug().Strs("users", strishArray(users)).Msg("Querying keys for users") resp, err := mach.Client.QueryKeys(ctx, req) if err != nil { - log.Error().Err(err).Msg("Failed to query keys") - return + return nil, fmt.Errorf("failed to query keys: %w", err) } for server, err := range resp.Failures { log.Warn().Interface("query_error", err).Str("server", server).Msg("Query keys failure for server") @@ -189,7 +193,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT mach.storeCrossSigningKeys(ctx, resp.SelfSigningKeys, resp.DeviceKeys) mach.storeCrossSigningKeys(ctx, resp.UserSigningKeys, resp.DeviceKeys) - return data + return data, nil } // OnDevicesChanged finds all shared rooms with the given user and invalidates outbound sessions in those rooms. diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index 1eee2fec..dcd36dc1 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -229,12 +229,16 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, if len(fetchKeys) > 0 { log.Debug().Strs("users", strishArray(fetchKeys)).Msg("Fetching missing keys") - for userID, devices := range mach.fetchKeys(ctx, fetchKeys, "", true) { - log.Debug(). - Int("device_count", len(devices)). - Str("target_user_id", userID.String()). - Msg("Got device keys for user") - missingSessions[userID] = devices + if keys, err := mach.FetchKeys(ctx, fetchKeys, true); err != nil { + log.Err(err).Strs("users", strishArray(fetchKeys)).Msg("Failed to fetch missing keys") + } else if keys != nil { + for userID, devices := range keys { + log.Debug(). + Int("device_count", len(devices)). + Str("target_user_id", userID.String()). + Msg("Got device keys for user") + missingSessions[userID] = devices + } } } diff --git a/crypto/machine.go b/crypto/machine.go index fc0f1742..fa0c50dc 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -33,6 +33,9 @@ type OlmMachine struct { PlaintextMentions bool + // Never ask the server for keys automatically as a side effect. + DisableKeyFetching bool + SendKeysMinTrust id.TrustState ShareKeysMinTrust id.TrustState @@ -224,7 +227,11 @@ func (mach *OlmMachine) HandleDeviceLists(dl *mautrix.DeviceLists, since string) Str("trace_id", traceID). Interface("changes", dl.Changed). Msg("Device list changes in /sync") - mach.fetchKeys(context.TODO(), dl.Changed, since, false) + if mach.DisableKeyFetching { + mach.CryptoStore.MarkTrackedUsersOutdated(context.TODO(), dl.Changed) + } else { + mach.FetchKeys(context.TODO(), dl.Changed, false) + } mach.Log.Debug().Str("trace_id", traceID).Msg("Finished handling device list changes") } } @@ -413,11 +420,12 @@ func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID, device, err := mach.CryptoStore.GetDevice(ctx, userID, deviceID) if err != nil { return nil, fmt.Errorf("failed to get sender device from store: %w", err) - } else if device != nil { + } else if device != nil || mach.DisableKeyFetching { return device, nil } - usersToDevices := mach.fetchKeys(ctx, []id.UserID{userID}, "", true) - if devices, ok := usersToDevices[userID]; ok { + if usersToDevices, err := mach.FetchKeys(ctx, []id.UserID{userID}, true); err != nil { + return nil, fmt.Errorf("failed to fetch keys: %w", err) + } else if devices, ok := usersToDevices[userID]; ok { if device, ok = devices[deviceID]; ok { return device, nil } @@ -431,7 +439,7 @@ func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID, // the given identity key. func (mach *OlmMachine) GetOrFetchDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) { deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(ctx, userID, identityKey) - if err != nil || deviceIdentity != nil { + if err != nil || deviceIdentity != nil || mach.DisableKeyFetching { return deviceIdentity, err } mach.machOrContextLog(ctx).Debug(). diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 8c85f6de..99a94f0e 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -665,12 +665,19 @@ func (store *SQLCryptoStore) PutDevice(ctx context.Context, userID id.UserID, de return err } +const trackedUserUpsertQuery = ` +INSERT INTO crypto_tracked_user (user_id, devices_outdated) +VALUES ($1, false) +ON CONFLICT (user_id) DO UPDATE + SET devices_outdated = EXCLUDED.devices_outdated +` + // PutDevices stores the device identity information for the given user ID. func (store *SQLCryptoStore) PutDevices(ctx context.Context, userID id.UserID, devices map[id.DeviceID]*id.Device) error { return store.DB.DoTxn(ctx, nil, func(ctx context.Context) error { - _, err := store.DB.Exec(ctx, "INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID) + _, err := store.DB.Exec(ctx, trackedUserUpsertQuery, userID) if err != nil { - return fmt.Errorf("failed to add user to tracked users list: %w", err) + return fmt.Errorf("failed to upsert user to tracked users list: %w", err) } _, err = store.DB.Exec(ctx, "UPDATE crypto_device SET deleted=true WHERE user_id=$1", userID) @@ -734,6 +741,30 @@ func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id. return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.UserID]).AsList() } +// MarkTrackedUsersOutdated flags that the device list for given users are outdated. +func (store *SQLCryptoStore) MarkTrackedUsersOutdated(ctx context.Context, users []id.UserID) error { + return store.DB.DoTxn(ctx, nil, func(ctx context.Context) error { + // TODO refactor to use a single query + for _, userID := range users { + _, err := store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id = $1", userID) + if err != nil { + return fmt.Errorf("failed to update user in the tracked users list: %w", err) + } + } + + return nil + }) +} + +// GetOutdatedTrackerUsers gets all tracked users whose devices need to be updated. +func (store *SQLCryptoStore) GetOutdatedTrackedUsers(ctx context.Context) ([]id.UserID, error) { + rows, err := store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE devices_outdated = TRUE") + if err != nil { + return nil, err + } + return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.UserID]).AsList() +} + // PutCrossSigningKey stores a cross-signing key of some user along with its usage. func (store *SQLCryptoStore) PutCrossSigningKey(ctx context.Context, userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error { _, err := store.DB.Exec(ctx, ` diff --git a/crypto/sql_store_upgrade/00-latest-revision.sql b/crypto/sql_store_upgrade/00-latest-revision.sql index bd8f7942..90d7d31c 100644 --- a/crypto/sql_store_upgrade/00-latest-revision.sql +++ b/crypto/sql_store_upgrade/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v10: Latest revision +-- v0 -> v11: Latest revision CREATE TABLE IF NOT EXISTS crypto_account ( account_id TEXT PRIMARY KEY, device_id TEXT NOT NULL, @@ -17,7 +17,8 @@ CREATE TABLE IF NOT EXISTS crypto_message_index ( ); CREATE TABLE IF NOT EXISTS crypto_tracked_user ( - user_id TEXT PRIMARY KEY + user_id TEXT PRIMARY KEY, + devices_outdated BOOLEAN NOT NULL DEFAULT FALSE ); CREATE TABLE IF NOT EXISTS crypto_device ( diff --git a/crypto/sql_store_upgrade/11-outdated-devices.sql b/crypto/sql_store_upgrade/11-outdated-devices.sql new file mode 100644 index 00000000..f0f0ba5b --- /dev/null +++ b/crypto/sql_store_upgrade/11-outdated-devices.sql @@ -0,0 +1,2 @@ +-- v11: Add devices_outdated field to crypto_tracked_user +ALTER TABLE crypto_tracked_user ADD COLUMN devices_outdated BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/crypto/store.go b/crypto/store.go index 09393a51..fb3d5b96 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -108,6 +108,10 @@ type Store interface { // FilterTrackedUsers returns a filtered version of the given list that only includes user IDs whose device lists // have been stored with PutDevices. A user is considered tracked even if the PutDevices list was empty. FilterTrackedUsers(context.Context, []id.UserID) ([]id.UserID, error) + // MarkTrackedUsersOutdated flags that the device list for given users are outdated. + MarkTrackedUsersOutdated(context.Context, []id.UserID) error + // GetOutdatedTrackerUsers gets all tracked users whose devices need to be updated. + GetOutdatedTrackedUsers(context.Context) ([]id.UserID, error) // PutCrossSigningKey stores a cross-signing key of some user along with its usage. PutCrossSigningKey(context.Context, id.UserID, id.CrossSigningUsage, id.Ed25519) error @@ -148,6 +152,7 @@ type MemoryStore struct { Devices map[id.UserID]map[id.DeviceID]*id.Device CrossSigningKeys map[id.UserID]map[id.CrossSigningUsage]id.CrossSigningKey KeySignatures map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string + OutdatedUsers map[id.UserID]struct{} } var _ Store = (*MemoryStore)(nil) @@ -167,6 +172,7 @@ func NewMemoryStore(saveCallback func() error) *MemoryStore { Devices: make(map[id.UserID]map[id.DeviceID]*id.Device), CrossSigningKeys: make(map[id.UserID]map[id.CrossSigningUsage]id.CrossSigningKey), KeySignatures: make(map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string), + OutdatedUsers: make(map[id.UserID]struct{}), } } @@ -499,6 +505,9 @@ func (gs *MemoryStore) PutDevices(_ context.Context, userID id.UserID, devices m gs.lock.Lock() gs.Devices[userID] = devices err := gs.save() + if err == nil { + delete(gs.OutdatedUsers, userID) + } gs.lock.Unlock() return err } @@ -517,6 +526,27 @@ func (gs *MemoryStore) FilterTrackedUsers(_ context.Context, users []id.UserID) return users[:ptr], nil } +func (gs *MemoryStore) MarkTrackedUsersOutdated(_ context.Context, users []id.UserID) error { + gs.lock.Lock() + for _, userID := range users { + if _, ok := gs.Devices[userID]; ok { + gs.OutdatedUsers[userID] = struct{}{} + } + } + gs.lock.Unlock() + return nil +} + +func (gs *MemoryStore) GetOutdatedTrackedUsers(_ context.Context) ([]id.UserID, error) { + gs.lock.RLock() + users := make([]id.UserID, 0, len(gs.OutdatedUsers)) + for userID := range gs.OutdatedUsers { + users = append(users, userID) + } + gs.lock.RUnlock() + return users, nil +} + func (gs *MemoryStore) PutCrossSigningKey(_ context.Context, userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error { gs.lock.RLock() userKeys, ok := gs.CrossSigningKeys[userID] diff --git a/crypto/store_test.go b/crypto/store_test.go index 665e3ef9..bbadef28 100644 --- a/crypto/store_test.go +++ b/crypto/store_test.go @@ -221,6 +221,13 @@ func TestStoreDevices(t *testing.T) { stores := getCryptoStores(t) for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { + outdated, err := store.GetOutdatedTrackedUsers(context.TODO()) + if err != nil { + t.Errorf("Error filtering tracked users: %v", err) + } + if len(outdated) > 0 { + t.Errorf("Got %d outdated tracked users when expected none", len(outdated)) + } deviceMap := make(map[id.DeviceID]*id.Device) for i := 0; i < 17; i++ { iStr := strconv.Itoa(i) @@ -232,9 +239,9 @@ func TestStoreDevices(t *testing.T) { SigningKey: acc.SigningKey(), } } - err := store.PutDevices(context.TODO(), "user1", deviceMap) + err = store.PutDevices(context.TODO(), "user1", deviceMap) if err != nil { - t.Errorf("Error string devices: %v", err) + t.Errorf("Error storing devices: %v", err) } devs, err := store.GetDevices(context.TODO(), "user1") if err != nil { @@ -256,6 +263,36 @@ func TestStoreDevices(t *testing.T) { } else if len(filtered) != 1 || filtered[0] != "user1" { t.Errorf("Expected to get 'user1' from filter, got %v", filtered) } + + outdated, err = store.GetOutdatedTrackedUsers(context.TODO()) + if err != nil { + t.Errorf("Error filtering tracked users: %v", err) + } + if len(outdated) > 0 { + t.Errorf("Got %d outdated tracked users when expected none", len(outdated)) + } + err = store.MarkTrackedUsersOutdated(context.TODO(), []id.UserID{"user0", "user1"}) + if err != nil { + t.Errorf("Error marking tracked users outdated: %v", err) + } + outdated, err = store.GetOutdatedTrackedUsers(context.TODO()) + if err != nil { + t.Errorf("Error filtering tracked users: %v", err) + } + if len(outdated) != 1 || outdated[0] != id.UserID("user1") { + t.Errorf("Got outdated tracked users %v when expected 'user1'", outdated) + } + err = store.PutDevices(context.TODO(), "user1", deviceMap) + if err != nil { + t.Errorf("Error storing devices: %v", err) + } + outdated, err = store.GetOutdatedTrackedUsers(context.TODO()) + if err != nil { + t.Errorf("Error filtering tracked users: %v", err) + } + if len(outdated) > 0 { + t.Errorf("Got outdated tracked users %v when expected none", outdated) + } }) } }