mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
Database level support for key backup versioning
This doesn't plumb anything in yet but adds the columns and types for an external implementation. Key backup version is now typed.
This commit is contained in:
parent
e08ed23845
commit
11c2907f2e
14 changed files with 160 additions and 105 deletions
42
client.go
42
client.go
|
|
@ -1948,9 +1948,9 @@ func (cli *Client) GetKeyChanges(ctx context.Context, from, to string) (resp *Re
|
|||
// GetKeyBackup retrieves the keys from the backup.
|
||||
//
|
||||
// See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keyskeys
|
||||
func (cli *Client) GetKeyBackup(ctx context.Context, version string) (resp *RespRoomKeys[backup.EncryptedSessionData[backup.MegolmSessionData]], err error) {
|
||||
func (cli *Client) GetKeyBackup(ctx context.Context, version id.KeyBackupVersion) (resp *RespRoomKeys[backup.EncryptedSessionData[backup.MegolmSessionData]], err error) {
|
||||
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys"}, map[string]string{
|
||||
"version": version,
|
||||
"version": string(version),
|
||||
})
|
||||
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
|
||||
return
|
||||
|
|
@ -1959,9 +1959,9 @@ func (cli *Client) GetKeyBackup(ctx context.Context, version string) (resp *Resp
|
|||
// PutKeysInBackup stores several keys in the backup.
|
||||
//
|
||||
// See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keyskeys
|
||||
func (cli *Client) PutKeysInBackup(ctx context.Context, version string, req *ReqKeyBackup) (resp *RespRoomKeysUpdate, err error) {
|
||||
func (cli *Client) PutKeysInBackup(ctx context.Context, version id.KeyBackupVersion, req *ReqKeyBackup) (resp *RespRoomKeysUpdate, err error) {
|
||||
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys"}, map[string]string{
|
||||
"version": version,
|
||||
"version": string(version),
|
||||
})
|
||||
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req, &resp)
|
||||
return
|
||||
|
|
@ -1970,9 +1970,9 @@ func (cli *Client) PutKeysInBackup(ctx context.Context, version string, req *Req
|
|||
// DeleteKeyBackup deletes all keys from the backup.
|
||||
//
|
||||
// See: https://spec.matrix.org/v1.9/client-server-api/#delete_matrixclientv3room_keyskeys
|
||||
func (cli *Client) DeleteKeyBackup(ctx context.Context, version string) (resp *RespRoomKeysUpdate, err error) {
|
||||
func (cli *Client) DeleteKeyBackup(ctx context.Context, version id.KeyBackupVersion) (resp *RespRoomKeysUpdate, err error) {
|
||||
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys"}, map[string]string{
|
||||
"version": version,
|
||||
"version": string(version),
|
||||
})
|
||||
_, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, &resp)
|
||||
return
|
||||
|
|
@ -1982,10 +1982,10 @@ func (cli *Client) DeleteKeyBackup(ctx context.Context, version string) (resp *R
|
|||
//
|
||||
// See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keyskeysroomid
|
||||
func (cli *Client) GetKeyBackupForRoom(
|
||||
ctx context.Context, version string, roomID id.RoomID,
|
||||
ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID,
|
||||
) (resp *RespRoomKeyBackup[backup.EncryptedSessionData[backup.MegolmSessionData]], err error) {
|
||||
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String()}, map[string]string{
|
||||
"version": version,
|
||||
"version": string(version),
|
||||
})
|
||||
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
|
||||
return
|
||||
|
|
@ -1994,9 +1994,9 @@ func (cli *Client) GetKeyBackupForRoom(
|
|||
// PutKeysInBackupForRoom stores several keys in the backup for the given room.
|
||||
//
|
||||
// See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keyskeysroomid
|
||||
func (cli *Client) PutKeysInBackupForRoom(ctx context.Context, version string, roomID id.RoomID, req *ReqRoomKeyBackup) (resp *RespRoomKeysUpdate, err error) {
|
||||
func (cli *Client) PutKeysInBackupForRoom(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, req *ReqRoomKeyBackup) (resp *RespRoomKeysUpdate, err error) {
|
||||
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String()}, map[string]string{
|
||||
"version": version,
|
||||
"version": string(version),
|
||||
})
|
||||
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req, &resp)
|
||||
return
|
||||
|
|
@ -2006,9 +2006,9 @@ func (cli *Client) PutKeysInBackupForRoom(ctx context.Context, version string, r
|
|||
// room.
|
||||
//
|
||||
// See: https://spec.matrix.org/v1.9/client-server-api/#delete_matrixclientv3room_keyskeysroomid
|
||||
func (cli *Client) DeleteKeysFromBackupForRoom(ctx context.Context, version string, roomID id.RoomID) (resp *RespRoomKeysUpdate, err error) {
|
||||
func (cli *Client) DeleteKeysFromBackupForRoom(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID) (resp *RespRoomKeysUpdate, err error) {
|
||||
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String()}, map[string]string{
|
||||
"version": version,
|
||||
"version": string(version),
|
||||
})
|
||||
_, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, &resp)
|
||||
return
|
||||
|
|
@ -2018,10 +2018,10 @@ func (cli *Client) DeleteKeysFromBackupForRoom(ctx context.Context, version stri
|
|||
//
|
||||
// See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keyskeysroomidsessionid
|
||||
func (cli *Client) GetKeyBackupForRoomAndSession(
|
||||
ctx context.Context, version string, roomID id.RoomID, sessionID id.SessionID,
|
||||
ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID,
|
||||
) (resp *RespKeyBackupData[backup.EncryptedSessionData[backup.MegolmSessionData]], err error) {
|
||||
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String(), sessionID.String()}, map[string]string{
|
||||
"version": version,
|
||||
"version": string(version),
|
||||
})
|
||||
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
|
||||
return
|
||||
|
|
@ -2030,9 +2030,9 @@ func (cli *Client) GetKeyBackupForRoomAndSession(
|
|||
// PutKeysInBackupForRoomAndSession stores a key in the backup.
|
||||
//
|
||||
// See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keyskeysroomidsessionid
|
||||
func (cli *Client) PutKeysInBackupForRoomAndSession(ctx context.Context, version string, roomID id.RoomID, sessionID id.SessionID, req *ReqKeyBackupData) (resp *RespRoomKeysUpdate, err error) {
|
||||
func (cli *Client) PutKeysInBackupForRoomAndSession(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, req *ReqKeyBackupData) (resp *RespRoomKeysUpdate, err error) {
|
||||
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String(), sessionID.String()}, map[string]string{
|
||||
"version": version,
|
||||
"version": string(version),
|
||||
})
|
||||
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req, &resp)
|
||||
return
|
||||
|
|
@ -2041,9 +2041,9 @@ func (cli *Client) PutKeysInBackupForRoomAndSession(ctx context.Context, version
|
|||
// DeleteKeysInBackupForRoomAndSession deletes a key from the backup.
|
||||
//
|
||||
// See: https://spec.matrix.org/v1.9/client-server-api/#delete_matrixclientv3room_keyskeysroomidsessionid
|
||||
func (cli *Client) DeleteKeysInBackupForRoomAndSession(ctx context.Context, version string, roomID id.RoomID, sessionID id.SessionID) (resp *RespRoomKeysUpdate, err error) {
|
||||
func (cli *Client) DeleteKeysInBackupForRoomAndSession(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID) (resp *RespRoomKeysUpdate, err error) {
|
||||
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String(), sessionID.String()}, map[string]string{
|
||||
"version": version,
|
||||
"version": string(version),
|
||||
})
|
||||
_, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, &resp)
|
||||
return
|
||||
|
|
@ -2070,7 +2070,7 @@ func (cli *Client) CreateKeyBackupVersion(ctx context.Context, req *ReqRoomKeysV
|
|||
// GetKeyBackupVersion returns information about an existing key backup.
|
||||
//
|
||||
// See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keysversionversion
|
||||
func (cli *Client) GetKeyBackupVersion(ctx context.Context, version string) (resp *RespRoomKeysVersion[backup.MegolmAuthData], err error) {
|
||||
func (cli *Client) GetKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) (resp *RespRoomKeysVersion[backup.MegolmAuthData], err error) {
|
||||
urlPath := cli.BuildClientURL("v3", "room_keys", "version", version)
|
||||
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
|
||||
return
|
||||
|
|
@ -2080,7 +2080,7 @@ func (cli *Client) GetKeyBackupVersion(ctx context.Context, version string) (res
|
|||
// the auth_data can be modified.
|
||||
//
|
||||
// See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keysversionversion
|
||||
func (cli *Client) UpdateKeyBackupVersion(ctx context.Context, version string, req *ReqRoomKeysVersionUpdate) error {
|
||||
func (cli *Client) UpdateKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion, req *ReqRoomKeysVersionUpdate) error {
|
||||
urlPath := cli.BuildClientURL("v3", "room_keys", "version", version)
|
||||
_, err := cli.MakeRequest(ctx, http.MethodPut, urlPath, nil, nil)
|
||||
return err
|
||||
|
|
@ -2091,7 +2091,7 @@ func (cli *Client) UpdateKeyBackupVersion(ctx context.Context, version string, r
|
|||
// deleted.
|
||||
//
|
||||
// See: https://spec.matrix.org/v1.1/client-server-api/#delete_matrixclientv3room_keysversionversion
|
||||
func (cli *Client) DeleteKeyBackupVersion(ctx context.Context, version string) error {
|
||||
func (cli *Client) DeleteKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) error {
|
||||
urlPath := cli.BuildClientURL("v3", "room_keys", "version", version)
|
||||
_, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, nil)
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -14,10 +14,11 @@ import (
|
|||
)
|
||||
|
||||
type OlmAccount struct {
|
||||
Internal olm.Account
|
||||
signingKey id.SigningKey
|
||||
identityKey id.IdentityKey
|
||||
Shared bool
|
||||
Internal olm.Account
|
||||
signingKey id.SigningKey
|
||||
identityKey id.IdentityKey
|
||||
Shared bool
|
||||
KeyBackupVersion id.KeyBackupVersion
|
||||
}
|
||||
|
||||
func NewOlmAccount() *OlmAccount {
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import (
|
|||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, megolmBackupKey *backup.MegolmBackupKey) error {
|
||||
func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, megolmBackupKey *backup.MegolmBackupKey) (id.KeyBackupVersion, error) {
|
||||
log := mach.machOrContextLog(ctx).With().
|
||||
Str("action", "download and store latest key backup").
|
||||
Logger()
|
||||
|
|
@ -24,12 +24,13 @@ func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, meg
|
|||
|
||||
versionInfo, err := mach.GetAndVerifyLatestKeyBackupVersion(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
return "", err
|
||||
} else if versionInfo == nil {
|
||||
return nil
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return mach.GetAndStoreKeyBackup(ctx, versionInfo.Version, megolmBackupKey)
|
||||
err = mach.GetAndStoreKeyBackup(ctx, versionInfo.Version, megolmBackupKey)
|
||||
return versionInfo.Version, err
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) (*mautrix.RespRoomKeysVersion[backup.MegolmAuthData], error) {
|
||||
|
|
@ -45,7 +46,7 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context)
|
|||
log := mach.machOrContextLog(ctx).With().
|
||||
Int("count", versionInfo.Count).
|
||||
Str("etag", versionInfo.ETag).
|
||||
Str("key_backup_version", versionInfo.Version).
|
||||
Stringer("key_backup_version", versionInfo.Version).
|
||||
Logger()
|
||||
|
||||
userSignatures, ok := versionInfo.AuthData.Signatures[mach.Client.UserID]
|
||||
|
|
@ -93,7 +94,7 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context)
|
|||
return versionInfo, nil
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version string, megolmBackupKey *backup.MegolmBackupKey) error {
|
||||
func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version id.KeyBackupVersion, megolmBackupKey *backup.MegolmBackupKey) error {
|
||||
keys, err := mach.Client.GetKeyBackup(ctx, version)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -112,7 +113,7 @@ func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version string
|
|||
continue
|
||||
}
|
||||
|
||||
err = mach.ImportRoomKeyFromBackup(ctx, roomID, sessionID, sessionData)
|
||||
err = mach.ImportRoomKeyFromBackup(ctx, version, roomID, sessionID, sessionData)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to import room key from backup")
|
||||
failedCount++
|
||||
|
|
@ -130,7 +131,7 @@ func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version string
|
|||
return nil
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) error {
|
||||
func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) error {
|
||||
log := zerolog.Ctx(ctx).With().
|
||||
Str("room_id", roomID.String()).
|
||||
Str("session_id", sessionID.String()).
|
||||
|
|
@ -166,9 +167,10 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, roomID id.R
|
|||
ForwardingChains: append(keyBackupData.ForwardingKeyChain, keyBackupData.SenderKey.String()),
|
||||
id: sessionID,
|
||||
|
||||
ReceivedAt: time.Now().UTC(),
|
||||
MaxAge: maxAge.Milliseconds(),
|
||||
MaxMessages: maxMessages,
|
||||
ReceivedAt: time.Now().UTC(),
|
||||
MaxAge: maxAge.Milliseconds(),
|
||||
MaxMessages: maxMessages,
|
||||
KeyBackupVersion: version,
|
||||
}
|
||||
err = mach.CryptoStore.PutGroupSession(ctx, roomID, keyBackupData.SenderKey, sessionID, igs)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -152,11 +152,21 @@ func (mach *OlmMachine) Load(ctx context.Context) (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) saveAccount(ctx context.Context) {
|
||||
func (mach *OlmMachine) saveAccount(ctx context.Context) error {
|
||||
err := mach.CryptoStore.PutAccount(ctx, mach.account)
|
||||
if err != nil {
|
||||
mach.Log.Error().Err(err).Msg("Failed to save account")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) KeyBackupVersion() id.KeyBackupVersion {
|
||||
return mach.account.KeyBackupVersion
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) SetKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) error {
|
||||
mach.account.KeyBackupVersion = version
|
||||
return mach.saveAccount(ctx)
|
||||
}
|
||||
|
||||
// FlushStore calls the Flush method of the CryptoStore.
|
||||
|
|
@ -698,8 +708,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro
|
|||
}
|
||||
mach.lastOTKUpload = time.Now()
|
||||
mach.account.Shared = true
|
||||
mach.saveAccount(ctx)
|
||||
return nil
|
||||
return mach.saveAccount(ctx)
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) ExpiredKeyDeleteLoop(ctx context.Context) {
|
||||
|
|
|
|||
|
|
@ -105,10 +105,11 @@ type InboundGroupSession struct {
|
|||
ForwardingChains []string
|
||||
RatchetSafety RatchetSafety
|
||||
|
||||
ReceivedAt time.Time
|
||||
MaxAge int64
|
||||
MaxMessages int
|
||||
IsScheduled bool
|
||||
ReceivedAt time.Time
|
||||
MaxAge int64
|
||||
MaxMessages int
|
||||
IsScheduled bool
|
||||
KeyBackupVersion id.KeyBackupVersion
|
||||
|
||||
id id.SessionID
|
||||
}
|
||||
|
|
|
|||
|
|
@ -125,20 +125,21 @@ func (store *SQLCryptoStore) PutAccount(ctx context.Context, account *OlmAccount
|
|||
store.Account = account
|
||||
bytes := account.Internal.Pickle(store.PickleKey)
|
||||
_, err := store.DB.Exec(ctx, `
|
||||
INSERT INTO crypto_account (device_id, shared, sync_token, account, account_id) VALUES ($1, $2, $3, $4, $5)
|
||||
INSERT INTO crypto_account (device_id, shared, sync_token, account, account_id, key_backup_version) VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (account_id) DO UPDATE SET shared=excluded.shared, sync_token=excluded.sync_token,
|
||||
account=excluded.account, account_id=excluded.account_id
|
||||
`, store.DeviceID, account.Shared, store.SyncToken, bytes, store.AccountID)
|
||||
account=excluded.account, account_id=excluded.account_id,
|
||||
key_backup_version=excluded.key_backup_version
|
||||
`, store.DeviceID, account.Shared, store.SyncToken, bytes, store.AccountID, account.KeyBackupVersion)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetAccount retrieves an OlmAccount from the database.
|
||||
func (store *SQLCryptoStore) GetAccount(ctx context.Context) (*OlmAccount, error) {
|
||||
if store.Account == nil {
|
||||
row := store.DB.QueryRow(ctx, "SELECT shared, sync_token, account FROM crypto_account WHERE account_id=$1", store.AccountID)
|
||||
row := store.DB.QueryRow(ctx, "SELECT shared, sync_token, account, key_backup_version FROM crypto_account WHERE account_id=$1", store.AccountID)
|
||||
acc := &OlmAccount{Internal: *olm.NewBlankAccount()}
|
||||
var accountBytes []byte
|
||||
err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes)
|
||||
err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes, &acc.KeyBackupVersion)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
|
|
@ -285,17 +286,18 @@ func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, roomID id.Room
|
|||
_, err = store.DB.Exec(ctx, `
|
||||
INSERT INTO crypto_megolm_inbound_session (
|
||||
session_id, sender_key, signing_key, room_id, session, forwarding_chains,
|
||||
ratchet_safety, received_at, max_age, max_messages, is_scheduled, account_id
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||
ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, account_id
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
||||
ON CONFLICT (session_id, account_id) DO UPDATE
|
||||
SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key,
|
||||
room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains,
|
||||
ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at,
|
||||
max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled
|
||||
max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled,
|
||||
key_backup_version=excluded.key_backup_version
|
||||
`,
|
||||
sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains,
|
||||
ratchetSafety, datePtr(session.ReceivedAt), intishPtr(session.MaxAge), intishPtr(session.MaxMessages),
|
||||
session.IsScheduled, store.AccountID,
|
||||
session.IsScheduled, session.KeyBackupVersion, store.AccountID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
|
@ -307,12 +309,13 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room
|
|||
var receivedAt sql.NullTime
|
||||
var maxAge, maxMessages sql.NullInt64
|
||||
var isScheduled bool
|
||||
var version id.KeyBackupVersion
|
||||
err := store.DB.QueryRow(ctx, `
|
||||
SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled
|
||||
SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
|
||||
FROM crypto_megolm_inbound_session
|
||||
WHERE room_id=$1 AND (sender_key=$2 OR $2 = '') AND session_id=$3 AND account_id=$4`,
|
||||
roomID, senderKey, sessionID, store.AccountID,
|
||||
).Scan(&senderKeyDB, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
|
||||
).Scan(&senderKeyDB, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
|
|
@ -342,6 +345,7 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room
|
|||
MaxAge: maxAge.Int64,
|
||||
MaxMessages: int(maxMessages.Int64),
|
||||
IsScheduled: isScheduled,
|
||||
KeyBackupVersion: version,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -469,7 +473,8 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In
|
|||
var receivedAt sql.NullTime
|
||||
var maxAge, maxMessages sql.NullInt64
|
||||
var isScheduled bool
|
||||
err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
|
||||
var version id.KeyBackupVersion
|
||||
err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -485,31 +490,35 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In
|
|||
MaxAge: maxAge.Int64,
|
||||
MaxMessages: int(maxMessages.Int64),
|
||||
IsScheduled: isScheduled,
|
||||
KeyBackupVersion: version,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) ([]*InboundGroupSession, error) {
|
||||
func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) dbutil.RowIter[*InboundGroupSession] {
|
||||
rows, err := store.DB.Query(ctx, `
|
||||
SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled
|
||||
SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
|
||||
FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`,
|
||||
roomID, store.AccountID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dbutil.NewRowIter(rows, store.scanInboundGroupSession).AsList()
|
||||
return dbutil.NewRowIterWithError(rows, store.scanInboundGroupSession, err)
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) ([]*InboundGroupSession, error) {
|
||||
func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.RowIter[*InboundGroupSession] {
|
||||
rows, err := store.DB.Query(ctx, `
|
||||
SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled
|
||||
FROM crypto_megolm_inbound_session WHERE account_id=$2 AND session IS NOT NULL`,
|
||||
SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
|
||||
FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL`,
|
||||
store.AccountID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dbutil.NewRowIter(rows, store.scanInboundGroupSession).AsList()
|
||||
return dbutil.NewRowIterWithError(rows, store.scanInboundGroupSession, err)
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) GetGroupSessionsWithoutKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] {
|
||||
rows, err := store.DB.Query(ctx, `
|
||||
SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
|
||||
FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL AND key_backup_version != $2`,
|
||||
store.AccountID, version,
|
||||
)
|
||||
return dbutil.NewRowIterWithError(rows, store.scanInboundGroupSession, err)
|
||||
}
|
||||
|
||||
// AddOutboundGroupSession stores an outbound Megolm session, along with the information about the room and involved devices.
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
-- v0 -> v13 (compatible with v9+): Latest revision
|
||||
-- v0 -> v14 (compatible with v9+): Latest revision
|
||||
CREATE TABLE IF NOT EXISTS crypto_account (
|
||||
account_id TEXT PRIMARY KEY,
|
||||
device_id TEXT NOT NULL,
|
||||
shared BOOLEAN NOT NULL,
|
||||
sync_token TEXT NOT NULL,
|
||||
account bytea NOT NULL
|
||||
account_id TEXT PRIMARY KEY,
|
||||
device_id TEXT NOT NULL,
|
||||
shared BOOLEAN NOT NULL,
|
||||
sync_token TEXT NOT NULL,
|
||||
account bytea NOT NULL,
|
||||
key_backup_version TEXT NOT NULL DEFAULT ''
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS crypto_message_index (
|
||||
|
|
@ -44,20 +45,21 @@ CREATE TABLE IF NOT EXISTS crypto_olm_session (
|
|||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
|
||||
account_id TEXT,
|
||||
session_id CHAR(43),
|
||||
sender_key CHAR(43) NOT NULL,
|
||||
signing_key CHAR(43),
|
||||
room_id TEXT NOT NULL,
|
||||
session bytea,
|
||||
forwarding_chains bytea,
|
||||
withheld_code TEXT,
|
||||
withheld_reason TEXT,
|
||||
ratchet_safety jsonb,
|
||||
received_at timestamp,
|
||||
max_age BIGINT,
|
||||
max_messages INTEGER,
|
||||
is_scheduled BOOLEAN NOT NULL DEFAULT false,
|
||||
account_id TEXT,
|
||||
session_id CHAR(43),
|
||||
sender_key CHAR(43) NOT NULL,
|
||||
signing_key CHAR(43),
|
||||
room_id TEXT NOT NULL,
|
||||
session bytea,
|
||||
forwarding_chains bytea,
|
||||
withheld_code TEXT,
|
||||
withheld_reason TEXT,
|
||||
ratchet_safety jsonb,
|
||||
received_at timestamp,
|
||||
max_age BIGINT,
|
||||
max_messages INTEGER,
|
||||
is_scheduled BOOLEAN NOT NULL DEFAULT false,
|
||||
key_backup_version TEXT NOT NULL DEFAULT '',
|
||||
PRIMARY KEY (account_id, session_id)
|
||||
);
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,4 @@
|
|||
-- v14 (compatible with v9+): Add key_backup_version column to account and igs
|
||||
|
||||
ALTER TABLE crypto_account ADD COLUMN key_backup_version TEXT NOT NULL DEFAULT '';
|
||||
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN key_backup_version TEXT NOT NULL DEFAULT '';
|
||||
|
|
@ -12,6 +12,8 @@ import (
|
|||
"sort"
|
||||
"sync"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
|
@ -68,10 +70,12 @@ type Store interface {
|
|||
|
||||
// GetGroupSessionsForRoom gets all the inbound Megolm sessions for a specific room. This is used for creating key
|
||||
// export files. Unlike GetGroupSession, this should not return any errors about withheld keys.
|
||||
GetGroupSessionsForRoom(context.Context, id.RoomID) ([]*InboundGroupSession, error)
|
||||
GetGroupSessionsForRoom(context.Context, id.RoomID) dbutil.RowIter[*InboundGroupSession]
|
||||
// GetAllGroupSessions gets all the inbound Megolm sessions in the store. This is used for creating key export
|
||||
// files. Unlike GetGroupSession, this should not return any errors about withheld keys.
|
||||
GetAllGroupSessions(context.Context) ([]*InboundGroupSession, error)
|
||||
GetAllGroupSessions(context.Context) dbutil.RowIter[*InboundGroupSession]
|
||||
// GetGroupSessionsWithoutKeyBackupVersion gets all the inbound Megolm sessions in the store that do not match given key backup version.
|
||||
GetGroupSessionsWithoutKeyBackupVersion(context.Context, id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession]
|
||||
|
||||
// AddOutboundGroupSession inserts the given outbound Megolm session into the store.
|
||||
//
|
||||
|
|
@ -376,12 +380,12 @@ func (gs *MemoryStore) GetWithheldGroupSession(_ context.Context, roomID id.Room
|
|||
return session, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetGroupSessionsForRoom(_ context.Context, roomID id.RoomID) ([]*InboundGroupSession, error) {
|
||||
func (gs *MemoryStore) GetGroupSessionsForRoom(_ context.Context, roomID id.RoomID) dbutil.RowIter[*InboundGroupSession] {
|
||||
gs.lock.Lock()
|
||||
defer gs.lock.Unlock()
|
||||
room, ok := gs.GroupSessions[roomID]
|
||||
if !ok {
|
||||
return []*InboundGroupSession{}, nil
|
||||
return nil
|
||||
}
|
||||
var result []*InboundGroupSession
|
||||
for _, sessions := range room {
|
||||
|
|
@ -389,10 +393,10 @@ func (gs *MemoryStore) GetGroupSessionsForRoom(_ context.Context, roomID id.Room
|
|||
result = append(result, session)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
return dbutil.NewSliceIter[*InboundGroupSession](result)
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetAllGroupSessions(_ context.Context) ([]*InboundGroupSession, error) {
|
||||
func (gs *MemoryStore) GetAllGroupSessions(_ context.Context) dbutil.RowIter[*InboundGroupSession] {
|
||||
gs.lock.Lock()
|
||||
var result []*InboundGroupSession
|
||||
for _, room := range gs.GroupSessions {
|
||||
|
|
@ -403,7 +407,23 @@ func (gs *MemoryStore) GetAllGroupSessions(_ context.Context) ([]*InboundGroupSe
|
|||
}
|
||||
}
|
||||
gs.lock.Unlock()
|
||||
return result, nil
|
||||
return dbutil.NewSliceIter[*InboundGroupSession](result)
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetGroupSessionsWithoutKeyBackupVersion(_ context.Context, version id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] {
|
||||
gs.lock.Lock()
|
||||
var result []*InboundGroupSession
|
||||
for _, room := range gs.GroupSessions {
|
||||
for _, sessions := range room {
|
||||
for _, session := range sessions {
|
||||
if session.KeyBackupVersion != version {
|
||||
result = append(result, session)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
gs.lock.Unlock()
|
||||
return dbutil.NewSliceIter[*InboundGroupSession](result)
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) AddOutboundGroupSession(_ context.Context, session *OutboundGroupSession) error {
|
||||
|
|
|
|||
2
go.mod
2
go.mod
|
|
@ -12,7 +12,7 @@ require (
|
|||
github.com/tidwall/gjson v1.17.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/yuin/goldmark v1.6.0
|
||||
go.mau.fi/util v0.3.0
|
||||
go.mau.fi/util v0.3.1-0.20240131162106-bcac615a2941
|
||||
go.mau.fi/zeroconfig v0.1.2
|
||||
golang.org/x/crypto v0.18.0
|
||||
golang.org/x/exp v0.0.0-20240112132812-db7319d0e0e3
|
||||
|
|
|
|||
4
go.sum
4
go.sum
|
|
@ -36,8 +36,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
|||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68=
|
||||
github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
go.mau.fi/util v0.3.0 h1:Lt3lbRXP6ZBqTINK0EieRWor3zEwwwrDT14Z5N8RUCs=
|
||||
go.mau.fi/util v0.3.0/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs=
|
||||
go.mau.fi/util v0.3.1-0.20240131162106-bcac615a2941 h1:F9ySn0OM0uFqcGDQM2WUqlFJh4UCBYNfeSxzJd0kknM=
|
||||
go.mau.fi/util v0.3.1-0.20240131162106-bcac615a2941/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs=
|
||||
go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto=
|
||||
go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70=
|
||||
golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc=
|
||||
|
|
|
|||
|
|
@ -50,6 +50,13 @@ const (
|
|||
KeyBackupAlgorithmMegolmBackupV1 KeyBackupAlgorithm = "m.megolm_backup.v1.curve25519-aes-sha2"
|
||||
)
|
||||
|
||||
// BackupVersion is an arbitrary string that identifies a server side key backup.
|
||||
type KeyBackupVersion string
|
||||
|
||||
func (version KeyBackupVersion) String() string {
|
||||
return string(version)
|
||||
}
|
||||
|
||||
// A SessionID is an arbitrary string that identifies an Olm or Megolm session.
|
||||
type SessionID string
|
||||
|
||||
|
|
|
|||
10
requests.go
10
requests.go
|
|
@ -429,14 +429,14 @@ type ReqBeeperSplitRoom struct {
|
|||
}
|
||||
|
||||
type ReqRoomKeysVersionCreate struct {
|
||||
Algorithm string `json:"algorithm"`
|
||||
AuthData json.RawMessage `json:"auth_data"`
|
||||
Algorithm id.KeyBackupAlgorithm `json:"algorithm"`
|
||||
AuthData json.RawMessage `json:"auth_data"`
|
||||
}
|
||||
|
||||
type ReqRoomKeysVersionUpdate struct {
|
||||
Algorithm string `json:"algorithm"`
|
||||
AuthData json.RawMessage `json:"auth_data"`
|
||||
Version string `json:"version,omitempty"`
|
||||
Algorithm id.KeyBackupAlgorithm `json:"algorithm"`
|
||||
AuthData json.RawMessage `json:"auth_data"`
|
||||
Version id.KeyBackupVersion `json:"version,omitempty"`
|
||||
}
|
||||
|
||||
type ReqKeyBackup struct {
|
||||
|
|
|
|||
|
|
@ -593,7 +593,7 @@ type RespTimestampToEvent struct {
|
|||
}
|
||||
|
||||
type RespRoomKeysVersionCreate struct {
|
||||
Version string `json:"version"`
|
||||
Version id.KeyBackupVersion `json:"version"`
|
||||
}
|
||||
|
||||
type RespRoomKeysVersion[A any] struct {
|
||||
|
|
@ -601,7 +601,7 @@ type RespRoomKeysVersion[A any] struct {
|
|||
AuthData A `json:"auth_data"`
|
||||
Count int `json:"count"`
|
||||
ETag string `json:"etag"`
|
||||
Version string `json:"version"`
|
||||
Version id.KeyBackupVersion `json:"version"`
|
||||
}
|
||||
|
||||
type RespRoomKeys[S any] struct {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue