mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
Allow disabling automatic key fetching for Olm machine
Many crypto operations in the Olm machine have a possible side effect of fetching keys from the server if they are missing. This may be undesired in some special cases. To tracking which users need key fetching, CryptoStore now exposes APIs to mark and query the status.
This commit is contained in:
parent
b3910eb699
commit
a3883fcf6f
8 changed files with 144 additions and 27 deletions
|
|
@ -27,8 +27,16 @@ var (
|
|||
InvalidKeySignature = errors.New("invalid signature on device keys")
|
||||
)
|
||||
|
||||
func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) map[id.DeviceID]*id.Device {
|
||||
return mach.fetchKeys(ctx, []id.UserID{user}, "", true)[user]
|
||||
func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys map[id.DeviceID]*id.Device) {
|
||||
log := zerolog.Ctx(ctx)
|
||||
|
||||
if keys, err := mach.FetchKeys(ctx, []id.UserID{user}, true); err != nil {
|
||||
log.Err(err).Msg("Failed to load devices")
|
||||
} else if keys != nil {
|
||||
return keys[user]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id.UserID, deviceID id.DeviceID, resp *mautrix.RespQueryKeys) {
|
||||
|
|
@ -85,19 +93,16 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id
|
|||
}
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceToken string, includeUntracked bool) (data map[id.UserID]map[id.DeviceID]*id.Device) {
|
||||
// TODO this function should probably return errors
|
||||
func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includeUntracked bool) (data map[id.UserID]map[id.DeviceID]*id.Device, err error) {
|
||||
req := &mautrix.ReqQueryKeys{
|
||||
DeviceKeys: mautrix.DeviceKeysRequest{},
|
||||
Timeout: 10 * 1000,
|
||||
Token: sinceToken,
|
||||
}
|
||||
log := mach.machOrContextLog(ctx)
|
||||
if !includeUntracked {
|
||||
var err error
|
||||
users, err = mach.CryptoStore.FilterTrackedUsers(ctx, users)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to filter tracked user list")
|
||||
return nil, fmt.Errorf("failed to filter tracked user list: %w", err)
|
||||
}
|
||||
}
|
||||
if len(users) == 0 {
|
||||
|
|
@ -109,8 +114,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
|
|||
log.Debug().Strs("users", strishArray(users)).Msg("Querying keys for users")
|
||||
resp, err := mach.Client.QueryKeys(ctx, req)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to query keys")
|
||||
return
|
||||
return nil, fmt.Errorf("failed to query keys: %w", err)
|
||||
}
|
||||
for server, err := range resp.Failures {
|
||||
log.Warn().Interface("query_error", err).Str("server", server).Msg("Query keys failure for server")
|
||||
|
|
@ -189,7 +193,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
|
|||
mach.storeCrossSigningKeys(ctx, resp.SelfSigningKeys, resp.DeviceKeys)
|
||||
mach.storeCrossSigningKeys(ctx, resp.UserSigningKeys, resp.DeviceKeys)
|
||||
|
||||
return data
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// OnDevicesChanged finds all shared rooms with the given user and invalidates outbound sessions in those rooms.
|
||||
|
|
|
|||
|
|
@ -229,12 +229,16 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID,
|
|||
|
||||
if len(fetchKeys) > 0 {
|
||||
log.Debug().Strs("users", strishArray(fetchKeys)).Msg("Fetching missing keys")
|
||||
for userID, devices := range mach.fetchKeys(ctx, fetchKeys, "", true) {
|
||||
log.Debug().
|
||||
Int("device_count", len(devices)).
|
||||
Str("target_user_id", userID.String()).
|
||||
Msg("Got device keys for user")
|
||||
missingSessions[userID] = devices
|
||||
if keys, err := mach.FetchKeys(ctx, fetchKeys, true); err != nil {
|
||||
log.Err(err).Strs("users", strishArray(fetchKeys)).Msg("Failed to fetch missing keys")
|
||||
} else if keys != nil {
|
||||
for userID, devices := range keys {
|
||||
log.Debug().
|
||||
Int("device_count", len(devices)).
|
||||
Str("target_user_id", userID.String()).
|
||||
Msg("Got device keys for user")
|
||||
missingSessions[userID] = devices
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -33,6 +33,9 @@ type OlmMachine struct {
|
|||
|
||||
PlaintextMentions bool
|
||||
|
||||
// Never ask the server for keys automatically as a side effect.
|
||||
DisableKeyFetching bool
|
||||
|
||||
SendKeysMinTrust id.TrustState
|
||||
ShareKeysMinTrust id.TrustState
|
||||
|
||||
|
|
@ -224,7 +227,11 @@ func (mach *OlmMachine) HandleDeviceLists(dl *mautrix.DeviceLists, since string)
|
|||
Str("trace_id", traceID).
|
||||
Interface("changes", dl.Changed).
|
||||
Msg("Device list changes in /sync")
|
||||
mach.fetchKeys(context.TODO(), dl.Changed, since, false)
|
||||
if mach.DisableKeyFetching {
|
||||
mach.CryptoStore.MarkTrackedUsersOutdated(context.TODO(), dl.Changed)
|
||||
} else {
|
||||
mach.FetchKeys(context.TODO(), dl.Changed, false)
|
||||
}
|
||||
mach.Log.Debug().Str("trace_id", traceID).Msg("Finished handling device list changes")
|
||||
}
|
||||
}
|
||||
|
|
@ -413,11 +420,12 @@ func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID,
|
|||
device, err := mach.CryptoStore.GetDevice(ctx, userID, deviceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get sender device from store: %w", err)
|
||||
} else if device != nil {
|
||||
} else if device != nil || mach.DisableKeyFetching {
|
||||
return device, nil
|
||||
}
|
||||
usersToDevices := mach.fetchKeys(ctx, []id.UserID{userID}, "", true)
|
||||
if devices, ok := usersToDevices[userID]; ok {
|
||||
if usersToDevices, err := mach.FetchKeys(ctx, []id.UserID{userID}, true); err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch keys: %w", err)
|
||||
} else if devices, ok := usersToDevices[userID]; ok {
|
||||
if device, ok = devices[deviceID]; ok {
|
||||
return device, nil
|
||||
}
|
||||
|
|
@ -431,7 +439,7 @@ func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID,
|
|||
// the given identity key.
|
||||
func (mach *OlmMachine) GetOrFetchDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
|
||||
deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(ctx, userID, identityKey)
|
||||
if err != nil || deviceIdentity != nil {
|
||||
if err != nil || deviceIdentity != nil || mach.DisableKeyFetching {
|
||||
return deviceIdentity, err
|
||||
}
|
||||
mach.machOrContextLog(ctx).Debug().
|
||||
|
|
|
|||
|
|
@ -665,12 +665,19 @@ func (store *SQLCryptoStore) PutDevice(ctx context.Context, userID id.UserID, de
|
|||
return err
|
||||
}
|
||||
|
||||
const trackedUserUpsertQuery = `
|
||||
INSERT INTO crypto_tracked_user (user_id, devices_outdated)
|
||||
VALUES ($1, false)
|
||||
ON CONFLICT (user_id) DO UPDATE
|
||||
SET devices_outdated = EXCLUDED.devices_outdated
|
||||
`
|
||||
|
||||
// PutDevices stores the device identity information for the given user ID.
|
||||
func (store *SQLCryptoStore) PutDevices(ctx context.Context, userID id.UserID, devices map[id.DeviceID]*id.Device) error {
|
||||
return store.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
|
||||
_, err := store.DB.Exec(ctx, "INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
|
||||
_, err := store.DB.Exec(ctx, trackedUserUpsertQuery, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add user to tracked users list: %w", err)
|
||||
return fmt.Errorf("failed to upsert user to tracked users list: %w", err)
|
||||
}
|
||||
|
||||
_, err = store.DB.Exec(ctx, "UPDATE crypto_device SET deleted=true WHERE user_id=$1", userID)
|
||||
|
|
@ -734,6 +741,30 @@ func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id.
|
|||
return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.UserID]).AsList()
|
||||
}
|
||||
|
||||
// MarkTrackedUsersOutdated flags that the device list for given users are outdated.
|
||||
func (store *SQLCryptoStore) MarkTrackedUsersOutdated(ctx context.Context, users []id.UserID) error {
|
||||
return store.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
|
||||
// TODO refactor to use a single query
|
||||
for _, userID := range users {
|
||||
_, err := store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id = $1", userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update user in the tracked users list: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetOutdatedTrackerUsers gets all tracked users whose devices need to be updated.
|
||||
func (store *SQLCryptoStore) GetOutdatedTrackedUsers(ctx context.Context) ([]id.UserID, error) {
|
||||
rows, err := store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE devices_outdated = TRUE")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.UserID]).AsList()
|
||||
}
|
||||
|
||||
// PutCrossSigningKey stores a cross-signing key of some user along with its usage.
|
||||
func (store *SQLCryptoStore) PutCrossSigningKey(ctx context.Context, userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
|
||||
_, err := store.DB.Exec(ctx, `
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
-- v0 -> v10: Latest revision
|
||||
-- v0 -> v11: Latest revision
|
||||
CREATE TABLE IF NOT EXISTS crypto_account (
|
||||
account_id TEXT PRIMARY KEY,
|
||||
device_id TEXT NOT NULL,
|
||||
|
|
@ -17,7 +17,8 @@ CREATE TABLE IF NOT EXISTS crypto_message_index (
|
|||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS crypto_tracked_user (
|
||||
user_id TEXT PRIMARY KEY
|
||||
user_id TEXT PRIMARY KEY,
|
||||
devices_outdated BOOLEAN NOT NULL DEFAULT FALSE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS crypto_device (
|
||||
|
|
|
|||
2
crypto/sql_store_upgrade/11-outdated-devices.sql
Normal file
2
crypto/sql_store_upgrade/11-outdated-devices.sql
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
-- v11: Add devices_outdated field to crypto_tracked_user
|
||||
ALTER TABLE crypto_tracked_user ADD COLUMN devices_outdated BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
|
|
@ -108,6 +108,10 @@ type Store interface {
|
|||
// FilterTrackedUsers returns a filtered version of the given list that only includes user IDs whose device lists
|
||||
// have been stored with PutDevices. A user is considered tracked even if the PutDevices list was empty.
|
||||
FilterTrackedUsers(context.Context, []id.UserID) ([]id.UserID, error)
|
||||
// MarkTrackedUsersOutdated flags that the device list for given users are outdated.
|
||||
MarkTrackedUsersOutdated(context.Context, []id.UserID) error
|
||||
// GetOutdatedTrackerUsers gets all tracked users whose devices need to be updated.
|
||||
GetOutdatedTrackedUsers(context.Context) ([]id.UserID, error)
|
||||
|
||||
// PutCrossSigningKey stores a cross-signing key of some user along with its usage.
|
||||
PutCrossSigningKey(context.Context, id.UserID, id.CrossSigningUsage, id.Ed25519) error
|
||||
|
|
@ -148,6 +152,7 @@ type MemoryStore struct {
|
|||
Devices map[id.UserID]map[id.DeviceID]*id.Device
|
||||
CrossSigningKeys map[id.UserID]map[id.CrossSigningUsage]id.CrossSigningKey
|
||||
KeySignatures map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string
|
||||
OutdatedUsers map[id.UserID]struct{}
|
||||
}
|
||||
|
||||
var _ Store = (*MemoryStore)(nil)
|
||||
|
|
@ -167,6 +172,7 @@ func NewMemoryStore(saveCallback func() error) *MemoryStore {
|
|||
Devices: make(map[id.UserID]map[id.DeviceID]*id.Device),
|
||||
CrossSigningKeys: make(map[id.UserID]map[id.CrossSigningUsage]id.CrossSigningKey),
|
||||
KeySignatures: make(map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string),
|
||||
OutdatedUsers: make(map[id.UserID]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -499,6 +505,9 @@ func (gs *MemoryStore) PutDevices(_ context.Context, userID id.UserID, devices m
|
|||
gs.lock.Lock()
|
||||
gs.Devices[userID] = devices
|
||||
err := gs.save()
|
||||
if err == nil {
|
||||
delete(gs.OutdatedUsers, userID)
|
||||
}
|
||||
gs.lock.Unlock()
|
||||
return err
|
||||
}
|
||||
|
|
@ -517,6 +526,27 @@ func (gs *MemoryStore) FilterTrackedUsers(_ context.Context, users []id.UserID)
|
|||
return users[:ptr], nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) MarkTrackedUsersOutdated(_ context.Context, users []id.UserID) error {
|
||||
gs.lock.Lock()
|
||||
for _, userID := range users {
|
||||
if _, ok := gs.Devices[userID]; ok {
|
||||
gs.OutdatedUsers[userID] = struct{}{}
|
||||
}
|
||||
}
|
||||
gs.lock.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetOutdatedTrackedUsers(_ context.Context) ([]id.UserID, error) {
|
||||
gs.lock.RLock()
|
||||
users := make([]id.UserID, 0, len(gs.OutdatedUsers))
|
||||
for userID := range gs.OutdatedUsers {
|
||||
users = append(users, userID)
|
||||
}
|
||||
gs.lock.RUnlock()
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) PutCrossSigningKey(_ context.Context, userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
|
||||
gs.lock.RLock()
|
||||
userKeys, ok := gs.CrossSigningKeys[userID]
|
||||
|
|
|
|||
|
|
@ -221,6 +221,13 @@ func TestStoreDevices(t *testing.T) {
|
|||
stores := getCryptoStores(t)
|
||||
for storeName, store := range stores {
|
||||
t.Run(storeName, func(t *testing.T) {
|
||||
outdated, err := store.GetOutdatedTrackedUsers(context.TODO())
|
||||
if err != nil {
|
||||
t.Errorf("Error filtering tracked users: %v", err)
|
||||
}
|
||||
if len(outdated) > 0 {
|
||||
t.Errorf("Got %d outdated tracked users when expected none", len(outdated))
|
||||
}
|
||||
deviceMap := make(map[id.DeviceID]*id.Device)
|
||||
for i := 0; i < 17; i++ {
|
||||
iStr := strconv.Itoa(i)
|
||||
|
|
@ -232,9 +239,9 @@ func TestStoreDevices(t *testing.T) {
|
|||
SigningKey: acc.SigningKey(),
|
||||
}
|
||||
}
|
||||
err := store.PutDevices(context.TODO(), "user1", deviceMap)
|
||||
err = store.PutDevices(context.TODO(), "user1", deviceMap)
|
||||
if err != nil {
|
||||
t.Errorf("Error string devices: %v", err)
|
||||
t.Errorf("Error storing devices: %v", err)
|
||||
}
|
||||
devs, err := store.GetDevices(context.TODO(), "user1")
|
||||
if err != nil {
|
||||
|
|
@ -256,6 +263,36 @@ func TestStoreDevices(t *testing.T) {
|
|||
} else if len(filtered) != 1 || filtered[0] != "user1" {
|
||||
t.Errorf("Expected to get 'user1' from filter, got %v", filtered)
|
||||
}
|
||||
|
||||
outdated, err = store.GetOutdatedTrackedUsers(context.TODO())
|
||||
if err != nil {
|
||||
t.Errorf("Error filtering tracked users: %v", err)
|
||||
}
|
||||
if len(outdated) > 0 {
|
||||
t.Errorf("Got %d outdated tracked users when expected none", len(outdated))
|
||||
}
|
||||
err = store.MarkTrackedUsersOutdated(context.TODO(), []id.UserID{"user0", "user1"})
|
||||
if err != nil {
|
||||
t.Errorf("Error marking tracked users outdated: %v", err)
|
||||
}
|
||||
outdated, err = store.GetOutdatedTrackedUsers(context.TODO())
|
||||
if err != nil {
|
||||
t.Errorf("Error filtering tracked users: %v", err)
|
||||
}
|
||||
if len(outdated) != 1 || outdated[0] != id.UserID("user1") {
|
||||
t.Errorf("Got outdated tracked users %v when expected 'user1'", outdated)
|
||||
}
|
||||
err = store.PutDevices(context.TODO(), "user1", deviceMap)
|
||||
if err != nil {
|
||||
t.Errorf("Error storing devices: %v", err)
|
||||
}
|
||||
outdated, err = store.GetOutdatedTrackedUsers(context.TODO())
|
||||
if err != nil {
|
||||
t.Errorf("Error filtering tracked users: %v", err)
|
||||
}
|
||||
if len(outdated) > 0 {
|
||||
t.Errorf("Got outdated tracked users %v when expected none", outdated)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue