From 11c2907f2e5ae8600714ec1dec3b01a5087cfc3a Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Tue, 30 Jan 2024 14:48:52 +0200 Subject: [PATCH] Database level support for key backup versioning This doesn't plumb anything in yet but adds the columns and types for an external implementation. Key backup version is now typed. --- client.go | 42 ++++++------- crypto/account.go | 9 +-- crypto/keybackup.go | 24 ++++---- crypto/machine.go | 15 ++++- crypto/sessions.go | 9 +-- crypto/sql_store.go | 59 +++++++++++-------- .../sql_store_upgrade/00-latest-revision.sql | 42 ++++++------- .../14-account-key-backup-version.sql | 4 ++ crypto/store.go | 34 ++++++++--- go.mod | 2 +- go.sum | 4 +- id/crypto.go | 7 +++ requests.go | 10 ++-- responses.go | 4 +- 14 files changed, 160 insertions(+), 105 deletions(-) create mode 100644 crypto/sql_store_upgrade/14-account-key-backup-version.sql diff --git a/client.go b/client.go index 18f2c019..2712789b 100644 --- a/client.go +++ b/client.go @@ -1948,9 +1948,9 @@ func (cli *Client) GetKeyChanges(ctx context.Context, from, to string) (resp *Re // GetKeyBackup retrieves the keys from the backup. // // See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keyskeys -func (cli *Client) GetKeyBackup(ctx context.Context, version string) (resp *RespRoomKeys[backup.EncryptedSessionData[backup.MegolmSessionData]], err error) { +func (cli *Client) GetKeyBackup(ctx context.Context, version id.KeyBackupVersion) (resp *RespRoomKeys[backup.EncryptedSessionData[backup.MegolmSessionData]], err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys"}, map[string]string{ - "version": version, + "version": string(version), }) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return @@ -1959,9 +1959,9 @@ func (cli *Client) GetKeyBackup(ctx context.Context, version string) (resp *Resp // PutKeysInBackup stores several keys in the backup. // // See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keyskeys -func (cli *Client) PutKeysInBackup(ctx context.Context, version string, req *ReqKeyBackup) (resp *RespRoomKeysUpdate, err error) { +func (cli *Client) PutKeysInBackup(ctx context.Context, version id.KeyBackupVersion, req *ReqKeyBackup) (resp *RespRoomKeysUpdate, err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys"}, map[string]string{ - "version": version, + "version": string(version), }) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req, &resp) return @@ -1970,9 +1970,9 @@ func (cli *Client) PutKeysInBackup(ctx context.Context, version string, req *Req // DeleteKeyBackup deletes all keys from the backup. // // See: https://spec.matrix.org/v1.9/client-server-api/#delete_matrixclientv3room_keyskeys -func (cli *Client) DeleteKeyBackup(ctx context.Context, version string) (resp *RespRoomKeysUpdate, err error) { +func (cli *Client) DeleteKeyBackup(ctx context.Context, version id.KeyBackupVersion) (resp *RespRoomKeysUpdate, err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys"}, map[string]string{ - "version": version, + "version": string(version), }) _, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, &resp) return @@ -1982,10 +1982,10 @@ func (cli *Client) DeleteKeyBackup(ctx context.Context, version string) (resp *R // // See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keyskeysroomid func (cli *Client) GetKeyBackupForRoom( - ctx context.Context, version string, roomID id.RoomID, + ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, ) (resp *RespRoomKeyBackup[backup.EncryptedSessionData[backup.MegolmSessionData]], err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String()}, map[string]string{ - "version": version, + "version": string(version), }) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return @@ -1994,9 +1994,9 @@ func (cli *Client) GetKeyBackupForRoom( // PutKeysInBackupForRoom stores several keys in the backup for the given room. // // See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keyskeysroomid -func (cli *Client) PutKeysInBackupForRoom(ctx context.Context, version string, roomID id.RoomID, req *ReqRoomKeyBackup) (resp *RespRoomKeysUpdate, err error) { +func (cli *Client) PutKeysInBackupForRoom(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, req *ReqRoomKeyBackup) (resp *RespRoomKeysUpdate, err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String()}, map[string]string{ - "version": version, + "version": string(version), }) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req, &resp) return @@ -2006,9 +2006,9 @@ func (cli *Client) PutKeysInBackupForRoom(ctx context.Context, version string, r // room. // // See: https://spec.matrix.org/v1.9/client-server-api/#delete_matrixclientv3room_keyskeysroomid -func (cli *Client) DeleteKeysFromBackupForRoom(ctx context.Context, version string, roomID id.RoomID) (resp *RespRoomKeysUpdate, err error) { +func (cli *Client) DeleteKeysFromBackupForRoom(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID) (resp *RespRoomKeysUpdate, err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String()}, map[string]string{ - "version": version, + "version": string(version), }) _, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, &resp) return @@ -2018,10 +2018,10 @@ func (cli *Client) DeleteKeysFromBackupForRoom(ctx context.Context, version stri // // See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keyskeysroomidsessionid func (cli *Client) GetKeyBackupForRoomAndSession( - ctx context.Context, version string, roomID id.RoomID, sessionID id.SessionID, + ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, ) (resp *RespKeyBackupData[backup.EncryptedSessionData[backup.MegolmSessionData]], err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String(), sessionID.String()}, map[string]string{ - "version": version, + "version": string(version), }) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return @@ -2030,9 +2030,9 @@ func (cli *Client) GetKeyBackupForRoomAndSession( // PutKeysInBackupForRoomAndSession stores a key in the backup. // // See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keyskeysroomidsessionid -func (cli *Client) PutKeysInBackupForRoomAndSession(ctx context.Context, version string, roomID id.RoomID, sessionID id.SessionID, req *ReqKeyBackupData) (resp *RespRoomKeysUpdate, err error) { +func (cli *Client) PutKeysInBackupForRoomAndSession(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, req *ReqKeyBackupData) (resp *RespRoomKeysUpdate, err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String(), sessionID.String()}, map[string]string{ - "version": version, + "version": string(version), }) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req, &resp) return @@ -2041,9 +2041,9 @@ func (cli *Client) PutKeysInBackupForRoomAndSession(ctx context.Context, version // DeleteKeysInBackupForRoomAndSession deletes a key from the backup. // // See: https://spec.matrix.org/v1.9/client-server-api/#delete_matrixclientv3room_keyskeysroomidsessionid -func (cli *Client) DeleteKeysInBackupForRoomAndSession(ctx context.Context, version string, roomID id.RoomID, sessionID id.SessionID) (resp *RespRoomKeysUpdate, err error) { +func (cli *Client) DeleteKeysInBackupForRoomAndSession(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID) (resp *RespRoomKeysUpdate, err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String(), sessionID.String()}, map[string]string{ - "version": version, + "version": string(version), }) _, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, &resp) return @@ -2070,7 +2070,7 @@ func (cli *Client) CreateKeyBackupVersion(ctx context.Context, req *ReqRoomKeysV // GetKeyBackupVersion returns information about an existing key backup. // // See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keysversionversion -func (cli *Client) GetKeyBackupVersion(ctx context.Context, version string) (resp *RespRoomKeysVersion[backup.MegolmAuthData], err error) { +func (cli *Client) GetKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) (resp *RespRoomKeysVersion[backup.MegolmAuthData], err error) { urlPath := cli.BuildClientURL("v3", "room_keys", "version", version) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return @@ -2080,7 +2080,7 @@ func (cli *Client) GetKeyBackupVersion(ctx context.Context, version string) (res // the auth_data can be modified. // // See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keysversionversion -func (cli *Client) UpdateKeyBackupVersion(ctx context.Context, version string, req *ReqRoomKeysVersionUpdate) error { +func (cli *Client) UpdateKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion, req *ReqRoomKeysVersionUpdate) error { urlPath := cli.BuildClientURL("v3", "room_keys", "version", version) _, err := cli.MakeRequest(ctx, http.MethodPut, urlPath, nil, nil) return err @@ -2091,7 +2091,7 @@ func (cli *Client) UpdateKeyBackupVersion(ctx context.Context, version string, r // deleted. // // See: https://spec.matrix.org/v1.1/client-server-api/#delete_matrixclientv3room_keysversionversion -func (cli *Client) DeleteKeyBackupVersion(ctx context.Context, version string) error { +func (cli *Client) DeleteKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) error { urlPath := cli.BuildClientURL("v3", "room_keys", "version", version) _, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, nil) return err diff --git a/crypto/account.go b/crypto/account.go index 78fbfa5f..d242df6f 100644 --- a/crypto/account.go +++ b/crypto/account.go @@ -14,10 +14,11 @@ import ( ) type OlmAccount struct { - Internal olm.Account - signingKey id.SigningKey - identityKey id.IdentityKey - Shared bool + Internal olm.Account + signingKey id.SigningKey + identityKey id.IdentityKey + Shared bool + KeyBackupVersion id.KeyBackupVersion } func NewOlmAccount() *OlmAccount { diff --git a/crypto/keybackup.go b/crypto/keybackup.go index 328309dc..9090e76c 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -15,7 +15,7 @@ import ( "maunium.net/go/mautrix/id" ) -func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, megolmBackupKey *backup.MegolmBackupKey) error { +func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, megolmBackupKey *backup.MegolmBackupKey) (id.KeyBackupVersion, error) { log := mach.machOrContextLog(ctx).With(). Str("action", "download and store latest key backup"). Logger() @@ -24,12 +24,13 @@ func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, meg versionInfo, err := mach.GetAndVerifyLatestKeyBackupVersion(ctx) if err != nil { - return err + return "", err } else if versionInfo == nil { - return nil + return "", nil } - return mach.GetAndStoreKeyBackup(ctx, versionInfo.Version, megolmBackupKey) + err = mach.GetAndStoreKeyBackup(ctx, versionInfo.Version, megolmBackupKey) + return versionInfo.Version, err } func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) (*mautrix.RespRoomKeysVersion[backup.MegolmAuthData], error) { @@ -45,7 +46,7 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) log := mach.machOrContextLog(ctx).With(). Int("count", versionInfo.Count). Str("etag", versionInfo.ETag). - Str("key_backup_version", versionInfo.Version). + Stringer("key_backup_version", versionInfo.Version). Logger() userSignatures, ok := versionInfo.AuthData.Signatures[mach.Client.UserID] @@ -93,7 +94,7 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) return versionInfo, nil } -func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version string, megolmBackupKey *backup.MegolmBackupKey) error { +func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version id.KeyBackupVersion, megolmBackupKey *backup.MegolmBackupKey) error { keys, err := mach.Client.GetKeyBackup(ctx, version) if err != nil { return err @@ -112,7 +113,7 @@ func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version string continue } - err = mach.ImportRoomKeyFromBackup(ctx, roomID, sessionID, sessionData) + err = mach.ImportRoomKeyFromBackup(ctx, version, roomID, sessionID, sessionData) if err != nil { log.Warn().Err(err).Msg("Failed to import room key from backup") failedCount++ @@ -130,7 +131,7 @@ func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version string return nil } -func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) error { +func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) error { log := zerolog.Ctx(ctx).With(). Str("room_id", roomID.String()). Str("session_id", sessionID.String()). @@ -166,9 +167,10 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, roomID id.R ForwardingChains: append(keyBackupData.ForwardingKeyChain, keyBackupData.SenderKey.String()), id: sessionID, - ReceivedAt: time.Now().UTC(), - MaxAge: maxAge.Milliseconds(), - MaxMessages: maxMessages, + ReceivedAt: time.Now().UTC(), + MaxAge: maxAge.Milliseconds(), + MaxMessages: maxMessages, + KeyBackupVersion: version, } err = mach.CryptoStore.PutGroupSession(ctx, roomID, keyBackupData.SenderKey, sessionID, igs) if err != nil { diff --git a/crypto/machine.go b/crypto/machine.go index 180e05f0..e5058ed8 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -152,11 +152,21 @@ func (mach *OlmMachine) Load(ctx context.Context) (err error) { return nil } -func (mach *OlmMachine) saveAccount(ctx context.Context) { +func (mach *OlmMachine) saveAccount(ctx context.Context) error { err := mach.CryptoStore.PutAccount(ctx, mach.account) if err != nil { mach.Log.Error().Err(err).Msg("Failed to save account") } + return err +} + +func (mach *OlmMachine) KeyBackupVersion() id.KeyBackupVersion { + return mach.account.KeyBackupVersion +} + +func (mach *OlmMachine) SetKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) error { + mach.account.KeyBackupVersion = version + return mach.saveAccount(ctx) } // FlushStore calls the Flush method of the CryptoStore. @@ -698,8 +708,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro } mach.lastOTKUpload = time.Now() mach.account.Shared = true - mach.saveAccount(ctx) - return nil + return mach.saveAccount(ctx) } func (mach *OlmMachine) ExpiredKeyDeleteLoop(ctx context.Context) { diff --git a/crypto/sessions.go b/crypto/sessions.go index ad8c2ae8..045af933 100644 --- a/crypto/sessions.go +++ b/crypto/sessions.go @@ -105,10 +105,11 @@ type InboundGroupSession struct { ForwardingChains []string RatchetSafety RatchetSafety - ReceivedAt time.Time - MaxAge int64 - MaxMessages int - IsScheduled bool + ReceivedAt time.Time + MaxAge int64 + MaxMessages int + IsScheduled bool + KeyBackupVersion id.KeyBackupVersion id id.SessionID } diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 1d86fec9..ef1be25b 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -125,20 +125,21 @@ func (store *SQLCryptoStore) PutAccount(ctx context.Context, account *OlmAccount store.Account = account bytes := account.Internal.Pickle(store.PickleKey) _, err := store.DB.Exec(ctx, ` - INSERT INTO crypto_account (device_id, shared, sync_token, account, account_id) VALUES ($1, $2, $3, $4, $5) + INSERT INTO crypto_account (device_id, shared, sync_token, account, account_id, key_backup_version) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (account_id) DO UPDATE SET shared=excluded.shared, sync_token=excluded.sync_token, - account=excluded.account, account_id=excluded.account_id - `, store.DeviceID, account.Shared, store.SyncToken, bytes, store.AccountID) + account=excluded.account, account_id=excluded.account_id, + key_backup_version=excluded.key_backup_version + `, store.DeviceID, account.Shared, store.SyncToken, bytes, store.AccountID, account.KeyBackupVersion) return err } // GetAccount retrieves an OlmAccount from the database. func (store *SQLCryptoStore) GetAccount(ctx context.Context) (*OlmAccount, error) { if store.Account == nil { - row := store.DB.QueryRow(ctx, "SELECT shared, sync_token, account FROM crypto_account WHERE account_id=$1", store.AccountID) + row := store.DB.QueryRow(ctx, "SELECT shared, sync_token, account, key_backup_version FROM crypto_account WHERE account_id=$1", store.AccountID) acc := &OlmAccount{Internal: *olm.NewBlankAccount()} var accountBytes []byte - err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes) + err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes, &acc.KeyBackupVersion) if err == sql.ErrNoRows { return nil, nil } else if err != nil { @@ -285,17 +286,18 @@ func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, roomID id.Room _, err = store.DB.Exec(ctx, ` INSERT INTO crypto_megolm_inbound_session ( session_id, sender_key, signing_key, room_id, session, forwarding_chains, - ratchet_safety, received_at, max_age, max_messages, is_scheduled, account_id - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, account_id + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) ON CONFLICT (session_id, account_id) DO UPDATE SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key, room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains, ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at, - max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled + max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled, + key_backup_version=excluded.key_backup_version `, sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains, ratchetSafety, datePtr(session.ReceivedAt), intishPtr(session.MaxAge), intishPtr(session.MaxMessages), - session.IsScheduled, store.AccountID, + session.IsScheduled, session.KeyBackupVersion, store.AccountID, ) return err } @@ -307,12 +309,13 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room var receivedAt sql.NullTime var maxAge, maxMessages sql.NullInt64 var isScheduled bool + var version id.KeyBackupVersion err := store.DB.QueryRow(ctx, ` - SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled + SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version FROM crypto_megolm_inbound_session WHERE room_id=$1 AND (sender_key=$2 OR $2 = '') AND session_id=$3 AND account_id=$4`, roomID, senderKey, sessionID, store.AccountID, - ).Scan(&senderKeyDB, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled) + ).Scan(&senderKeyDB, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version) if errors.Is(err, sql.ErrNoRows) { return nil, nil } else if err != nil { @@ -342,6 +345,7 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room MaxAge: maxAge.Int64, MaxMessages: int(maxMessages.Int64), IsScheduled: isScheduled, + KeyBackupVersion: version, }, nil } @@ -469,7 +473,8 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In var receivedAt sql.NullTime var maxAge, maxMessages sql.NullInt64 var isScheduled bool - err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled) + var version id.KeyBackupVersion + err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version) if err != nil { return nil, err } @@ -485,31 +490,35 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In MaxAge: maxAge.Int64, MaxMessages: int(maxMessages.Int64), IsScheduled: isScheduled, + KeyBackupVersion: version, }, nil } -func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) ([]*InboundGroupSession, error) { +func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled + SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`, roomID, store.AccountID, ) - if err != nil { - return nil, err - } - return dbutil.NewRowIter(rows, store.scanInboundGroupSession).AsList() + return dbutil.NewRowIterWithError(rows, store.scanInboundGroupSession, err) } -func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) ([]*InboundGroupSession, error) { +func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled - FROM crypto_megolm_inbound_session WHERE account_id=$2 AND session IS NOT NULL`, + SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version + FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL`, store.AccountID, ) - if err != nil { - return nil, err - } - return dbutil.NewRowIter(rows, store.scanInboundGroupSession).AsList() + return dbutil.NewRowIterWithError(rows, store.scanInboundGroupSession, err) +} + +func (store *SQLCryptoStore) GetGroupSessionsWithoutKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] { + rows, err := store.DB.Query(ctx, ` + SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version + FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL AND key_backup_version != $2`, + store.AccountID, version, + ) + return dbutil.NewRowIterWithError(rows, store.scanInboundGroupSession, err) } // AddOutboundGroupSession stores an outbound Megolm session, along with the information about the room and involved devices. diff --git a/crypto/sql_store_upgrade/00-latest-revision.sql b/crypto/sql_store_upgrade/00-latest-revision.sql index a8c31153..06aea750 100644 --- a/crypto/sql_store_upgrade/00-latest-revision.sql +++ b/crypto/sql_store_upgrade/00-latest-revision.sql @@ -1,10 +1,11 @@ --- v0 -> v13 (compatible with v9+): Latest revision +-- v0 -> v14 (compatible with v9+): Latest revision CREATE TABLE IF NOT EXISTS crypto_account ( - account_id TEXT PRIMARY KEY, - device_id TEXT NOT NULL, - shared BOOLEAN NOT NULL, - sync_token TEXT NOT NULL, - account bytea NOT NULL + account_id TEXT PRIMARY KEY, + device_id TEXT NOT NULL, + shared BOOLEAN NOT NULL, + sync_token TEXT NOT NULL, + account bytea NOT NULL, + key_backup_version TEXT NOT NULL DEFAULT '' ); CREATE TABLE IF NOT EXISTS crypto_message_index ( @@ -44,20 +45,21 @@ CREATE TABLE IF NOT EXISTS crypto_olm_session ( ); CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session ( - account_id TEXT, - session_id CHAR(43), - sender_key CHAR(43) NOT NULL, - signing_key CHAR(43), - room_id TEXT NOT NULL, - session bytea, - forwarding_chains bytea, - withheld_code TEXT, - withheld_reason TEXT, - ratchet_safety jsonb, - received_at timestamp, - max_age BIGINT, - max_messages INTEGER, - is_scheduled BOOLEAN NOT NULL DEFAULT false, + account_id TEXT, + session_id CHAR(43), + sender_key CHAR(43) NOT NULL, + signing_key CHAR(43), + room_id TEXT NOT NULL, + session bytea, + forwarding_chains bytea, + withheld_code TEXT, + withheld_reason TEXT, + ratchet_safety jsonb, + received_at timestamp, + max_age BIGINT, + max_messages INTEGER, + is_scheduled BOOLEAN NOT NULL DEFAULT false, + key_backup_version TEXT NOT NULL DEFAULT '', PRIMARY KEY (account_id, session_id) ); diff --git a/crypto/sql_store_upgrade/14-account-key-backup-version.sql b/crypto/sql_store_upgrade/14-account-key-backup-version.sql new file mode 100644 index 00000000..e5236b62 --- /dev/null +++ b/crypto/sql_store_upgrade/14-account-key-backup-version.sql @@ -0,0 +1,4 @@ +-- v14 (compatible with v9+): Add key_backup_version column to account and igs + +ALTER TABLE crypto_account ADD COLUMN key_backup_version TEXT NOT NULL DEFAULT ''; +ALTER TABLE crypto_megolm_inbound_session ADD COLUMN key_backup_version TEXT NOT NULL DEFAULT ''; diff --git a/crypto/store.go b/crypto/store.go index f900a3fa..3b6e6564 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -12,6 +12,8 @@ import ( "sort" "sync" + "go.mau.fi/util/dbutil" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -68,10 +70,12 @@ type Store interface { // GetGroupSessionsForRoom gets all the inbound Megolm sessions for a specific room. This is used for creating key // export files. Unlike GetGroupSession, this should not return any errors about withheld keys. - GetGroupSessionsForRoom(context.Context, id.RoomID) ([]*InboundGroupSession, error) + GetGroupSessionsForRoom(context.Context, id.RoomID) dbutil.RowIter[*InboundGroupSession] // GetAllGroupSessions gets all the inbound Megolm sessions in the store. This is used for creating key export // files. Unlike GetGroupSession, this should not return any errors about withheld keys. - GetAllGroupSessions(context.Context) ([]*InboundGroupSession, error) + GetAllGroupSessions(context.Context) dbutil.RowIter[*InboundGroupSession] + // GetGroupSessionsWithoutKeyBackupVersion gets all the inbound Megolm sessions in the store that do not match given key backup version. + GetGroupSessionsWithoutKeyBackupVersion(context.Context, id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] // AddOutboundGroupSession inserts the given outbound Megolm session into the store. // @@ -376,12 +380,12 @@ func (gs *MemoryStore) GetWithheldGroupSession(_ context.Context, roomID id.Room return session, nil } -func (gs *MemoryStore) GetGroupSessionsForRoom(_ context.Context, roomID id.RoomID) ([]*InboundGroupSession, error) { +func (gs *MemoryStore) GetGroupSessionsForRoom(_ context.Context, roomID id.RoomID) dbutil.RowIter[*InboundGroupSession] { gs.lock.Lock() defer gs.lock.Unlock() room, ok := gs.GroupSessions[roomID] if !ok { - return []*InboundGroupSession{}, nil + return nil } var result []*InboundGroupSession for _, sessions := range room { @@ -389,10 +393,10 @@ func (gs *MemoryStore) GetGroupSessionsForRoom(_ context.Context, roomID id.Room result = append(result, session) } } - return result, nil + return dbutil.NewSliceIter[*InboundGroupSession](result) } -func (gs *MemoryStore) GetAllGroupSessions(_ context.Context) ([]*InboundGroupSession, error) { +func (gs *MemoryStore) GetAllGroupSessions(_ context.Context) dbutil.RowIter[*InboundGroupSession] { gs.lock.Lock() var result []*InboundGroupSession for _, room := range gs.GroupSessions { @@ -403,7 +407,23 @@ func (gs *MemoryStore) GetAllGroupSessions(_ context.Context) ([]*InboundGroupSe } } gs.lock.Unlock() - return result, nil + return dbutil.NewSliceIter[*InboundGroupSession](result) +} + +func (gs *MemoryStore) GetGroupSessionsWithoutKeyBackupVersion(_ context.Context, version id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] { + gs.lock.Lock() + var result []*InboundGroupSession + for _, room := range gs.GroupSessions { + for _, sessions := range room { + for _, session := range sessions { + if session.KeyBackupVersion != version { + result = append(result, session) + } + } + } + } + gs.lock.Unlock() + return dbutil.NewSliceIter[*InboundGroupSession](result) } func (gs *MemoryStore) AddOutboundGroupSession(_ context.Context, session *OutboundGroupSession) error { diff --git a/go.mod b/go.mod index 48ff59e0..08eb341b 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/tidwall/gjson v1.17.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.6.0 - go.mau.fi/util v0.3.0 + go.mau.fi/util v0.3.1-0.20240131162106-bcac615a2941 go.mau.fi/zeroconfig v0.1.2 golang.org/x/crypto v0.18.0 golang.org/x/exp v0.0.0-20240112132812-db7319d0e0e3 diff --git a/go.sum b/go.sum index 9061a651..64186d85 100644 --- a/go.sum +++ b/go.sum @@ -36,8 +36,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68= github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.mau.fi/util v0.3.0 h1:Lt3lbRXP6ZBqTINK0EieRWor3zEwwwrDT14Z5N8RUCs= -go.mau.fi/util v0.3.0/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs= +go.mau.fi/util v0.3.1-0.20240131162106-bcac615a2941 h1:F9ySn0OM0uFqcGDQM2WUqlFJh4UCBYNfeSxzJd0kknM= +go.mau.fi/util v0.3.1-0.20240131162106-bcac615a2941/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= diff --git a/id/crypto.go b/id/crypto.go index f28e3d88..9334198e 100644 --- a/id/crypto.go +++ b/id/crypto.go @@ -50,6 +50,13 @@ const ( KeyBackupAlgorithmMegolmBackupV1 KeyBackupAlgorithm = "m.megolm_backup.v1.curve25519-aes-sha2" ) +// BackupVersion is an arbitrary string that identifies a server side key backup. +type KeyBackupVersion string + +func (version KeyBackupVersion) String() string { + return string(version) +} + // A SessionID is an arbitrary string that identifies an Olm or Megolm session. type SessionID string diff --git a/requests.go b/requests.go index 1551e63b..61fe8a55 100644 --- a/requests.go +++ b/requests.go @@ -429,14 +429,14 @@ type ReqBeeperSplitRoom struct { } type ReqRoomKeysVersionCreate struct { - Algorithm string `json:"algorithm"` - AuthData json.RawMessage `json:"auth_data"` + Algorithm id.KeyBackupAlgorithm `json:"algorithm"` + AuthData json.RawMessage `json:"auth_data"` } type ReqRoomKeysVersionUpdate struct { - Algorithm string `json:"algorithm"` - AuthData json.RawMessage `json:"auth_data"` - Version string `json:"version,omitempty"` + Algorithm id.KeyBackupAlgorithm `json:"algorithm"` + AuthData json.RawMessage `json:"auth_data"` + Version id.KeyBackupVersion `json:"version,omitempty"` } type ReqKeyBackup struct { diff --git a/responses.go b/responses.go index b8552b58..e182a722 100644 --- a/responses.go +++ b/responses.go @@ -593,7 +593,7 @@ type RespTimestampToEvent struct { } type RespRoomKeysVersionCreate struct { - Version string `json:"version"` + Version id.KeyBackupVersion `json:"version"` } type RespRoomKeysVersion[A any] struct { @@ -601,7 +601,7 @@ type RespRoomKeysVersion[A any] struct { AuthData A `json:"auth_data"` Count int `json:"count"` ETag string `json:"etag"` - Version string `json:"version"` + Version id.KeyBackupVersion `json:"version"` } type RespRoomKeys[S any] struct {