Database level support for key backup versioning

This doesn't plumb anything in yet but adds the columns and types for an
external implementation.

Key backup version is now typed.
This commit is contained in:
Toni Spets 2024-01-30 14:48:52 +02:00 committed by Toni Spets
commit 11c2907f2e
14 changed files with 160 additions and 105 deletions

View file

@ -1948,9 +1948,9 @@ func (cli *Client) GetKeyChanges(ctx context.Context, from, to string) (resp *Re
// GetKeyBackup retrieves the keys from the backup.
//
// See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keyskeys
func (cli *Client) GetKeyBackup(ctx context.Context, version string) (resp *RespRoomKeys[backup.EncryptedSessionData[backup.MegolmSessionData]], err error) {
func (cli *Client) GetKeyBackup(ctx context.Context, version id.KeyBackupVersion) (resp *RespRoomKeys[backup.EncryptedSessionData[backup.MegolmSessionData]], err error) {
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys"}, map[string]string{
"version": version,
"version": string(version),
})
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return
@ -1959,9 +1959,9 @@ func (cli *Client) GetKeyBackup(ctx context.Context, version string) (resp *Resp
// PutKeysInBackup stores several keys in the backup.
//
// See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keyskeys
func (cli *Client) PutKeysInBackup(ctx context.Context, version string, req *ReqKeyBackup) (resp *RespRoomKeysUpdate, err error) {
func (cli *Client) PutKeysInBackup(ctx context.Context, version id.KeyBackupVersion, req *ReqKeyBackup) (resp *RespRoomKeysUpdate, err error) {
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys"}, map[string]string{
"version": version,
"version": string(version),
})
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req, &resp)
return
@ -1970,9 +1970,9 @@ func (cli *Client) PutKeysInBackup(ctx context.Context, version string, req *Req
// DeleteKeyBackup deletes all keys from the backup.
//
// See: https://spec.matrix.org/v1.9/client-server-api/#delete_matrixclientv3room_keyskeys
func (cli *Client) DeleteKeyBackup(ctx context.Context, version string) (resp *RespRoomKeysUpdate, err error) {
func (cli *Client) DeleteKeyBackup(ctx context.Context, version id.KeyBackupVersion) (resp *RespRoomKeysUpdate, err error) {
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys"}, map[string]string{
"version": version,
"version": string(version),
})
_, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, &resp)
return
@ -1982,10 +1982,10 @@ func (cli *Client) DeleteKeyBackup(ctx context.Context, version string) (resp *R
//
// See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keyskeysroomid
func (cli *Client) GetKeyBackupForRoom(
ctx context.Context, version string, roomID id.RoomID,
ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID,
) (resp *RespRoomKeyBackup[backup.EncryptedSessionData[backup.MegolmSessionData]], err error) {
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String()}, map[string]string{
"version": version,
"version": string(version),
})
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return
@ -1994,9 +1994,9 @@ func (cli *Client) GetKeyBackupForRoom(
// PutKeysInBackupForRoom stores several keys in the backup for the given room.
//
// See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keyskeysroomid
func (cli *Client) PutKeysInBackupForRoom(ctx context.Context, version string, roomID id.RoomID, req *ReqRoomKeyBackup) (resp *RespRoomKeysUpdate, err error) {
func (cli *Client) PutKeysInBackupForRoom(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, req *ReqRoomKeyBackup) (resp *RespRoomKeysUpdate, err error) {
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String()}, map[string]string{
"version": version,
"version": string(version),
})
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req, &resp)
return
@ -2006,9 +2006,9 @@ func (cli *Client) PutKeysInBackupForRoom(ctx context.Context, version string, r
// room.
//
// See: https://spec.matrix.org/v1.9/client-server-api/#delete_matrixclientv3room_keyskeysroomid
func (cli *Client) DeleteKeysFromBackupForRoom(ctx context.Context, version string, roomID id.RoomID) (resp *RespRoomKeysUpdate, err error) {
func (cli *Client) DeleteKeysFromBackupForRoom(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID) (resp *RespRoomKeysUpdate, err error) {
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String()}, map[string]string{
"version": version,
"version": string(version),
})
_, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, &resp)
return
@ -2018,10 +2018,10 @@ func (cli *Client) DeleteKeysFromBackupForRoom(ctx context.Context, version stri
//
// See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keyskeysroomidsessionid
func (cli *Client) GetKeyBackupForRoomAndSession(
ctx context.Context, version string, roomID id.RoomID, sessionID id.SessionID,
ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID,
) (resp *RespKeyBackupData[backup.EncryptedSessionData[backup.MegolmSessionData]], err error) {
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String(), sessionID.String()}, map[string]string{
"version": version,
"version": string(version),
})
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return
@ -2030,9 +2030,9 @@ func (cli *Client) GetKeyBackupForRoomAndSession(
// PutKeysInBackupForRoomAndSession stores a key in the backup.
//
// See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keyskeysroomidsessionid
func (cli *Client) PutKeysInBackupForRoomAndSession(ctx context.Context, version string, roomID id.RoomID, sessionID id.SessionID, req *ReqKeyBackupData) (resp *RespRoomKeysUpdate, err error) {
func (cli *Client) PutKeysInBackupForRoomAndSession(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, req *ReqKeyBackupData) (resp *RespRoomKeysUpdate, err error) {
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String(), sessionID.String()}, map[string]string{
"version": version,
"version": string(version),
})
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req, &resp)
return
@ -2041,9 +2041,9 @@ func (cli *Client) PutKeysInBackupForRoomAndSession(ctx context.Context, version
// DeleteKeysInBackupForRoomAndSession deletes a key from the backup.
//
// See: https://spec.matrix.org/v1.9/client-server-api/#delete_matrixclientv3room_keyskeysroomidsessionid
func (cli *Client) DeleteKeysInBackupForRoomAndSession(ctx context.Context, version string, roomID id.RoomID, sessionID id.SessionID) (resp *RespRoomKeysUpdate, err error) {
func (cli *Client) DeleteKeysInBackupForRoomAndSession(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID) (resp *RespRoomKeysUpdate, err error) {
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String(), sessionID.String()}, map[string]string{
"version": version,
"version": string(version),
})
_, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, &resp)
return
@ -2070,7 +2070,7 @@ func (cli *Client) CreateKeyBackupVersion(ctx context.Context, req *ReqRoomKeysV
// GetKeyBackupVersion returns information about an existing key backup.
//
// See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keysversionversion
func (cli *Client) GetKeyBackupVersion(ctx context.Context, version string) (resp *RespRoomKeysVersion[backup.MegolmAuthData], err error) {
func (cli *Client) GetKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) (resp *RespRoomKeysVersion[backup.MegolmAuthData], err error) {
urlPath := cli.BuildClientURL("v3", "room_keys", "version", version)
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return
@ -2080,7 +2080,7 @@ func (cli *Client) GetKeyBackupVersion(ctx context.Context, version string) (res
// the auth_data can be modified.
//
// See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keysversionversion
func (cli *Client) UpdateKeyBackupVersion(ctx context.Context, version string, req *ReqRoomKeysVersionUpdate) error {
func (cli *Client) UpdateKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion, req *ReqRoomKeysVersionUpdate) error {
urlPath := cli.BuildClientURL("v3", "room_keys", "version", version)
_, err := cli.MakeRequest(ctx, http.MethodPut, urlPath, nil, nil)
return err
@ -2091,7 +2091,7 @@ func (cli *Client) UpdateKeyBackupVersion(ctx context.Context, version string, r
// deleted.
//
// See: https://spec.matrix.org/v1.1/client-server-api/#delete_matrixclientv3room_keysversionversion
func (cli *Client) DeleteKeyBackupVersion(ctx context.Context, version string) error {
func (cli *Client) DeleteKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) error {
urlPath := cli.BuildClientURL("v3", "room_keys", "version", version)
_, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, nil)
return err

View file

@ -14,10 +14,11 @@ import (
)
type OlmAccount struct {
Internal olm.Account
signingKey id.SigningKey
identityKey id.IdentityKey
Shared bool
Internal olm.Account
signingKey id.SigningKey
identityKey id.IdentityKey
Shared bool
KeyBackupVersion id.KeyBackupVersion
}
func NewOlmAccount() *OlmAccount {

View file

@ -15,7 +15,7 @@ import (
"maunium.net/go/mautrix/id"
)
func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, megolmBackupKey *backup.MegolmBackupKey) error {
func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, megolmBackupKey *backup.MegolmBackupKey) (id.KeyBackupVersion, error) {
log := mach.machOrContextLog(ctx).With().
Str("action", "download and store latest key backup").
Logger()
@ -24,12 +24,13 @@ func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, meg
versionInfo, err := mach.GetAndVerifyLatestKeyBackupVersion(ctx)
if err != nil {
return err
return "", err
} else if versionInfo == nil {
return nil
return "", nil
}
return mach.GetAndStoreKeyBackup(ctx, versionInfo.Version, megolmBackupKey)
err = mach.GetAndStoreKeyBackup(ctx, versionInfo.Version, megolmBackupKey)
return versionInfo.Version, err
}
func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) (*mautrix.RespRoomKeysVersion[backup.MegolmAuthData], error) {
@ -45,7 +46,7 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context)
log := mach.machOrContextLog(ctx).With().
Int("count", versionInfo.Count).
Str("etag", versionInfo.ETag).
Str("key_backup_version", versionInfo.Version).
Stringer("key_backup_version", versionInfo.Version).
Logger()
userSignatures, ok := versionInfo.AuthData.Signatures[mach.Client.UserID]
@ -93,7 +94,7 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context)
return versionInfo, nil
}
func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version string, megolmBackupKey *backup.MegolmBackupKey) error {
func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version id.KeyBackupVersion, megolmBackupKey *backup.MegolmBackupKey) error {
keys, err := mach.Client.GetKeyBackup(ctx, version)
if err != nil {
return err
@ -112,7 +113,7 @@ func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version string
continue
}
err = mach.ImportRoomKeyFromBackup(ctx, roomID, sessionID, sessionData)
err = mach.ImportRoomKeyFromBackup(ctx, version, roomID, sessionID, sessionData)
if err != nil {
log.Warn().Err(err).Msg("Failed to import room key from backup")
failedCount++
@ -130,7 +131,7 @@ func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version string
return nil
}
func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) error {
func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) error {
log := zerolog.Ctx(ctx).With().
Str("room_id", roomID.String()).
Str("session_id", sessionID.String()).
@ -166,9 +167,10 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, roomID id.R
ForwardingChains: append(keyBackupData.ForwardingKeyChain, keyBackupData.SenderKey.String()),
id: sessionID,
ReceivedAt: time.Now().UTC(),
MaxAge: maxAge.Milliseconds(),
MaxMessages: maxMessages,
ReceivedAt: time.Now().UTC(),
MaxAge: maxAge.Milliseconds(),
MaxMessages: maxMessages,
KeyBackupVersion: version,
}
err = mach.CryptoStore.PutGroupSession(ctx, roomID, keyBackupData.SenderKey, sessionID, igs)
if err != nil {

View file

@ -152,11 +152,21 @@ func (mach *OlmMachine) Load(ctx context.Context) (err error) {
return nil
}
func (mach *OlmMachine) saveAccount(ctx context.Context) {
func (mach *OlmMachine) saveAccount(ctx context.Context) error {
err := mach.CryptoStore.PutAccount(ctx, mach.account)
if err != nil {
mach.Log.Error().Err(err).Msg("Failed to save account")
}
return err
}
func (mach *OlmMachine) KeyBackupVersion() id.KeyBackupVersion {
return mach.account.KeyBackupVersion
}
func (mach *OlmMachine) SetKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) error {
mach.account.KeyBackupVersion = version
return mach.saveAccount(ctx)
}
// FlushStore calls the Flush method of the CryptoStore.
@ -698,8 +708,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro
}
mach.lastOTKUpload = time.Now()
mach.account.Shared = true
mach.saveAccount(ctx)
return nil
return mach.saveAccount(ctx)
}
func (mach *OlmMachine) ExpiredKeyDeleteLoop(ctx context.Context) {

View file

@ -105,10 +105,11 @@ type InboundGroupSession struct {
ForwardingChains []string
RatchetSafety RatchetSafety
ReceivedAt time.Time
MaxAge int64
MaxMessages int
IsScheduled bool
ReceivedAt time.Time
MaxAge int64
MaxMessages int
IsScheduled bool
KeyBackupVersion id.KeyBackupVersion
id id.SessionID
}

View file

@ -125,20 +125,21 @@ func (store *SQLCryptoStore) PutAccount(ctx context.Context, account *OlmAccount
store.Account = account
bytes := account.Internal.Pickle(store.PickleKey)
_, err := store.DB.Exec(ctx, `
INSERT INTO crypto_account (device_id, shared, sync_token, account, account_id) VALUES ($1, $2, $3, $4, $5)
INSERT INTO crypto_account (device_id, shared, sync_token, account, account_id, key_backup_version) VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (account_id) DO UPDATE SET shared=excluded.shared, sync_token=excluded.sync_token,
account=excluded.account, account_id=excluded.account_id
`, store.DeviceID, account.Shared, store.SyncToken, bytes, store.AccountID)
account=excluded.account, account_id=excluded.account_id,
key_backup_version=excluded.key_backup_version
`, store.DeviceID, account.Shared, store.SyncToken, bytes, store.AccountID, account.KeyBackupVersion)
return err
}
// GetAccount retrieves an OlmAccount from the database.
func (store *SQLCryptoStore) GetAccount(ctx context.Context) (*OlmAccount, error) {
if store.Account == nil {
row := store.DB.QueryRow(ctx, "SELECT shared, sync_token, account FROM crypto_account WHERE account_id=$1", store.AccountID)
row := store.DB.QueryRow(ctx, "SELECT shared, sync_token, account, key_backup_version FROM crypto_account WHERE account_id=$1", store.AccountID)
acc := &OlmAccount{Internal: *olm.NewBlankAccount()}
var accountBytes []byte
err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes)
err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes, &acc.KeyBackupVersion)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
@ -285,17 +286,18 @@ func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, roomID id.Room
_, err = store.DB.Exec(ctx, `
INSERT INTO crypto_megolm_inbound_session (
session_id, sender_key, signing_key, room_id, session, forwarding_chains,
ratchet_safety, received_at, max_age, max_messages, is_scheduled, account_id
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, account_id
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
ON CONFLICT (session_id, account_id) DO UPDATE
SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key,
room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains,
ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at,
max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled
max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled,
key_backup_version=excluded.key_backup_version
`,
sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains,
ratchetSafety, datePtr(session.ReceivedAt), intishPtr(session.MaxAge), intishPtr(session.MaxMessages),
session.IsScheduled, store.AccountID,
session.IsScheduled, session.KeyBackupVersion, store.AccountID,
)
return err
}
@ -307,12 +309,13 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room
var receivedAt sql.NullTime
var maxAge, maxMessages sql.NullInt64
var isScheduled bool
var version id.KeyBackupVersion
err := store.DB.QueryRow(ctx, `
SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled
SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
FROM crypto_megolm_inbound_session
WHERE room_id=$1 AND (sender_key=$2 OR $2 = '') AND session_id=$3 AND account_id=$4`,
roomID, senderKey, sessionID, store.AccountID,
).Scan(&senderKeyDB, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
).Scan(&senderKeyDB, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
@ -342,6 +345,7 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room
MaxAge: maxAge.Int64,
MaxMessages: int(maxMessages.Int64),
IsScheduled: isScheduled,
KeyBackupVersion: version,
}, nil
}
@ -469,7 +473,8 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In
var receivedAt sql.NullTime
var maxAge, maxMessages sql.NullInt64
var isScheduled bool
err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
var version id.KeyBackupVersion
err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version)
if err != nil {
return nil, err
}
@ -485,31 +490,35 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In
MaxAge: maxAge.Int64,
MaxMessages: int(maxMessages.Int64),
IsScheduled: isScheduled,
KeyBackupVersion: version,
}, nil
}
func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) ([]*InboundGroupSession, error) {
func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) dbutil.RowIter[*InboundGroupSession] {
rows, err := store.DB.Query(ctx, `
SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled
SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`,
roomID, store.AccountID,
)
if err != nil {
return nil, err
}
return dbutil.NewRowIter(rows, store.scanInboundGroupSession).AsList()
return dbutil.NewRowIterWithError(rows, store.scanInboundGroupSession, err)
}
func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) ([]*InboundGroupSession, error) {
func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.RowIter[*InboundGroupSession] {
rows, err := store.DB.Query(ctx, `
SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled
FROM crypto_megolm_inbound_session WHERE account_id=$2 AND session IS NOT NULL`,
SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL`,
store.AccountID,
)
if err != nil {
return nil, err
}
return dbutil.NewRowIter(rows, store.scanInboundGroupSession).AsList()
return dbutil.NewRowIterWithError(rows, store.scanInboundGroupSession, err)
}
func (store *SQLCryptoStore) GetGroupSessionsWithoutKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] {
rows, err := store.DB.Query(ctx, `
SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL AND key_backup_version != $2`,
store.AccountID, version,
)
return dbutil.NewRowIterWithError(rows, store.scanInboundGroupSession, err)
}
// AddOutboundGroupSession stores an outbound Megolm session, along with the information about the room and involved devices.

View file

@ -1,10 +1,11 @@
-- v0 -> v13 (compatible with v9+): Latest revision
-- v0 -> v14 (compatible with v9+): Latest revision
CREATE TABLE IF NOT EXISTS crypto_account (
account_id TEXT PRIMARY KEY,
device_id TEXT NOT NULL,
shared BOOLEAN NOT NULL,
sync_token TEXT NOT NULL,
account bytea NOT NULL
account_id TEXT PRIMARY KEY,
device_id TEXT NOT NULL,
shared BOOLEAN NOT NULL,
sync_token TEXT NOT NULL,
account bytea NOT NULL,
key_backup_version TEXT NOT NULL DEFAULT ''
);
CREATE TABLE IF NOT EXISTS crypto_message_index (
@ -44,20 +45,21 @@ CREATE TABLE IF NOT EXISTS crypto_olm_session (
);
CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
account_id TEXT,
session_id CHAR(43),
sender_key CHAR(43) NOT NULL,
signing_key CHAR(43),
room_id TEXT NOT NULL,
session bytea,
forwarding_chains bytea,
withheld_code TEXT,
withheld_reason TEXT,
ratchet_safety jsonb,
received_at timestamp,
max_age BIGINT,
max_messages INTEGER,
is_scheduled BOOLEAN NOT NULL DEFAULT false,
account_id TEXT,
session_id CHAR(43),
sender_key CHAR(43) NOT NULL,
signing_key CHAR(43),
room_id TEXT NOT NULL,
session bytea,
forwarding_chains bytea,
withheld_code TEXT,
withheld_reason TEXT,
ratchet_safety jsonb,
received_at timestamp,
max_age BIGINT,
max_messages INTEGER,
is_scheduled BOOLEAN NOT NULL DEFAULT false,
key_backup_version TEXT NOT NULL DEFAULT '',
PRIMARY KEY (account_id, session_id)
);

View file

@ -0,0 +1,4 @@
-- v14 (compatible with v9+): Add key_backup_version column to account and igs
ALTER TABLE crypto_account ADD COLUMN key_backup_version TEXT NOT NULL DEFAULT '';
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN key_backup_version TEXT NOT NULL DEFAULT '';

View file

@ -12,6 +12,8 @@ import (
"sort"
"sync"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
@ -68,10 +70,12 @@ type Store interface {
// GetGroupSessionsForRoom gets all the inbound Megolm sessions for a specific room. This is used for creating key
// export files. Unlike GetGroupSession, this should not return any errors about withheld keys.
GetGroupSessionsForRoom(context.Context, id.RoomID) ([]*InboundGroupSession, error)
GetGroupSessionsForRoom(context.Context, id.RoomID) dbutil.RowIter[*InboundGroupSession]
// GetAllGroupSessions gets all the inbound Megolm sessions in the store. This is used for creating key export
// files. Unlike GetGroupSession, this should not return any errors about withheld keys.
GetAllGroupSessions(context.Context) ([]*InboundGroupSession, error)
GetAllGroupSessions(context.Context) dbutil.RowIter[*InboundGroupSession]
// GetGroupSessionsWithoutKeyBackupVersion gets all the inbound Megolm sessions in the store that do not match given key backup version.
GetGroupSessionsWithoutKeyBackupVersion(context.Context, id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession]
// AddOutboundGroupSession inserts the given outbound Megolm session into the store.
//
@ -376,12 +380,12 @@ func (gs *MemoryStore) GetWithheldGroupSession(_ context.Context, roomID id.Room
return session, nil
}
func (gs *MemoryStore) GetGroupSessionsForRoom(_ context.Context, roomID id.RoomID) ([]*InboundGroupSession, error) {
func (gs *MemoryStore) GetGroupSessionsForRoom(_ context.Context, roomID id.RoomID) dbutil.RowIter[*InboundGroupSession] {
gs.lock.Lock()
defer gs.lock.Unlock()
room, ok := gs.GroupSessions[roomID]
if !ok {
return []*InboundGroupSession{}, nil
return nil
}
var result []*InboundGroupSession
for _, sessions := range room {
@ -389,10 +393,10 @@ func (gs *MemoryStore) GetGroupSessionsForRoom(_ context.Context, roomID id.Room
result = append(result, session)
}
}
return result, nil
return dbutil.NewSliceIter[*InboundGroupSession](result)
}
func (gs *MemoryStore) GetAllGroupSessions(_ context.Context) ([]*InboundGroupSession, error) {
func (gs *MemoryStore) GetAllGroupSessions(_ context.Context) dbutil.RowIter[*InboundGroupSession] {
gs.lock.Lock()
var result []*InboundGroupSession
for _, room := range gs.GroupSessions {
@ -403,7 +407,23 @@ func (gs *MemoryStore) GetAllGroupSessions(_ context.Context) ([]*InboundGroupSe
}
}
gs.lock.Unlock()
return result, nil
return dbutil.NewSliceIter[*InboundGroupSession](result)
}
func (gs *MemoryStore) GetGroupSessionsWithoutKeyBackupVersion(_ context.Context, version id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] {
gs.lock.Lock()
var result []*InboundGroupSession
for _, room := range gs.GroupSessions {
for _, sessions := range room {
for _, session := range sessions {
if session.KeyBackupVersion != version {
result = append(result, session)
}
}
}
}
gs.lock.Unlock()
return dbutil.NewSliceIter[*InboundGroupSession](result)
}
func (gs *MemoryStore) AddOutboundGroupSession(_ context.Context, session *OutboundGroupSession) error {

2
go.mod
View file

@ -12,7 +12,7 @@ require (
github.com/tidwall/gjson v1.17.0
github.com/tidwall/sjson v1.2.5
github.com/yuin/goldmark v1.6.0
go.mau.fi/util v0.3.0
go.mau.fi/util v0.3.1-0.20240131162106-bcac615a2941
go.mau.fi/zeroconfig v0.1.2
golang.org/x/crypto v0.18.0
golang.org/x/exp v0.0.0-20240112132812-db7319d0e0e3

4
go.sum
View file

@ -36,8 +36,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68=
github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.mau.fi/util v0.3.0 h1:Lt3lbRXP6ZBqTINK0EieRWor3zEwwwrDT14Z5N8RUCs=
go.mau.fi/util v0.3.0/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs=
go.mau.fi/util v0.3.1-0.20240131162106-bcac615a2941 h1:F9ySn0OM0uFqcGDQM2WUqlFJh4UCBYNfeSxzJd0kknM=
go.mau.fi/util v0.3.1-0.20240131162106-bcac615a2941/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs=
go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto=
go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70=
golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc=

View file

@ -50,6 +50,13 @@ const (
KeyBackupAlgorithmMegolmBackupV1 KeyBackupAlgorithm = "m.megolm_backup.v1.curve25519-aes-sha2"
)
// BackupVersion is an arbitrary string that identifies a server side key backup.
type KeyBackupVersion string
func (version KeyBackupVersion) String() string {
return string(version)
}
// A SessionID is an arbitrary string that identifies an Olm or Megolm session.
type SessionID string

View file

@ -429,14 +429,14 @@ type ReqBeeperSplitRoom struct {
}
type ReqRoomKeysVersionCreate struct {
Algorithm string `json:"algorithm"`
AuthData json.RawMessage `json:"auth_data"`
Algorithm id.KeyBackupAlgorithm `json:"algorithm"`
AuthData json.RawMessage `json:"auth_data"`
}
type ReqRoomKeysVersionUpdate struct {
Algorithm string `json:"algorithm"`
AuthData json.RawMessage `json:"auth_data"`
Version string `json:"version,omitempty"`
Algorithm id.KeyBackupAlgorithm `json:"algorithm"`
AuthData json.RawMessage `json:"auth_data"`
Version id.KeyBackupVersion `json:"version,omitempty"`
}
type ReqKeyBackup struct {

View file

@ -593,7 +593,7 @@ type RespTimestampToEvent struct {
}
type RespRoomKeysVersionCreate struct {
Version string `json:"version"`
Version id.KeyBackupVersion `json:"version"`
}
type RespRoomKeysVersion[A any] struct {
@ -601,7 +601,7 @@ type RespRoomKeysVersion[A any] struct {
AuthData A `json:"auth_data"`
Count int `json:"count"`
ETag string `json:"etag"`
Version string `json:"version"`
Version id.KeyBackupVersion `json:"version"`
}
type RespRoomKeys[S any] struct {