mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
crypto: save source of megolm sessions
This commit is contained in:
parent
67d30e054c
commit
bc79822eab
8 changed files with 58 additions and 19 deletions
|
|
@ -200,13 +200,14 @@ func (mach *OlmMachine) ImportRoomKeyFromBackupWithoutSaving(
|
||||||
SigningKey: keyBackupData.SenderClaimedKeys.Ed25519,
|
SigningKey: keyBackupData.SenderClaimedKeys.Ed25519,
|
||||||
SenderKey: keyBackupData.SenderKey,
|
SenderKey: keyBackupData.SenderKey,
|
||||||
RoomID: roomID,
|
RoomID: roomID,
|
||||||
ForwardingChains: append(keyBackupData.ForwardingKeyChain, keyBackupData.SenderKey.String()),
|
ForwardingChains: keyBackupData.ForwardingKeyChain,
|
||||||
id: sessionID,
|
id: sessionID,
|
||||||
|
|
||||||
ReceivedAt: time.Now().UTC(),
|
ReceivedAt: time.Now().UTC(),
|
||||||
MaxAge: maxAge.Milliseconds(),
|
MaxAge: maxAge.Milliseconds(),
|
||||||
MaxMessages: maxMessages,
|
MaxMessages: maxMessages,
|
||||||
KeyBackupVersion: version,
|
KeyBackupVersion: version,
|
||||||
|
KeySource: id.KeySourceBackup,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -108,14 +108,13 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor
|
||||||
return false, ErrMismatchingExportedSessionID
|
return false, ErrMismatchingExportedSessionID
|
||||||
}
|
}
|
||||||
igs := &InboundGroupSession{
|
igs := &InboundGroupSession{
|
||||||
Internal: igsInternal,
|
Internal: igsInternal,
|
||||||
SigningKey: session.SenderClaimedKeys.Ed25519,
|
SigningKey: session.SenderClaimedKeys.Ed25519,
|
||||||
SenderKey: session.SenderKey,
|
SenderKey: session.SenderKey,
|
||||||
RoomID: session.RoomID,
|
RoomID: session.RoomID,
|
||||||
// TODO should we add something here to mark the signing key as unverified like key requests do?
|
|
||||||
ForwardingChains: session.ForwardingChains,
|
ForwardingChains: session.ForwardingChains,
|
||||||
|
KeySource: id.KeySourceImport,
|
||||||
ReceivedAt: time.Now().UTC(),
|
ReceivedAt: time.Now().UTC(),
|
||||||
}
|
}
|
||||||
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID())
|
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID())
|
||||||
firstKnownIndex := igs.Internal.FirstKnownIndex()
|
firstKnownIndex := igs.Internal.FirstKnownIndex()
|
||||||
|
|
|
||||||
|
|
@ -189,6 +189,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
|
||||||
MaxAge: maxAge.Milliseconds(),
|
MaxAge: maxAge.Milliseconds(),
|
||||||
MaxMessages: maxMessages,
|
MaxMessages: maxMessages,
|
||||||
IsScheduled: content.IsScheduled,
|
IsScheduled: content.IsScheduled,
|
||||||
|
KeySource: id.KeySourceForward,
|
||||||
}
|
}
|
||||||
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID())
|
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID())
|
||||||
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() {
|
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() {
|
||||||
|
|
|
||||||
|
|
@ -117,6 +117,7 @@ type InboundGroupSession struct {
|
||||||
MaxMessages int
|
MaxMessages int
|
||||||
IsScheduled bool
|
IsScheduled bool
|
||||||
KeyBackupVersion id.KeyBackupVersion
|
KeyBackupVersion id.KeyBackupVersion
|
||||||
|
KeySource id.KeySource
|
||||||
|
|
||||||
id id.SessionID
|
id id.SessionID
|
||||||
}
|
}
|
||||||
|
|
@ -136,6 +137,7 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI
|
||||||
MaxAge: maxAge.Milliseconds(),
|
MaxAge: maxAge.Milliseconds(),
|
||||||
MaxMessages: maxMessages,
|
MaxMessages: maxMessages,
|
||||||
IsScheduled: isScheduled,
|
IsScheduled: isScheduled,
|
||||||
|
KeySource: id.KeySourceDirect,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -346,22 +346,23 @@ func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, session *Inbou
|
||||||
Int("max_messages", session.MaxMessages).
|
Int("max_messages", session.MaxMessages).
|
||||||
Bool("is_scheduled", session.IsScheduled).
|
Bool("is_scheduled", session.IsScheduled).
|
||||||
Stringer("key_backup_version", session.KeyBackupVersion).
|
Stringer("key_backup_version", session.KeyBackupVersion).
|
||||||
|
Stringer("key_source", session.KeySource).
|
||||||
Msg("Upserting megolm inbound group session")
|
Msg("Upserting megolm inbound group session")
|
||||||
_, err = store.DB.Exec(ctx, `
|
_, err = store.DB.Exec(ctx, `
|
||||||
INSERT INTO crypto_megolm_inbound_session (
|
INSERT INTO crypto_megolm_inbound_session (
|
||||||
session_id, sender_key, signing_key, room_id, session, forwarding_chains,
|
session_id, sender_key, signing_key, room_id, session, forwarding_chains,
|
||||||
ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, account_id
|
ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source, account_id
|
||||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
|
||||||
ON CONFLICT (session_id, account_id) DO UPDATE
|
ON CONFLICT (session_id, account_id) DO UPDATE
|
||||||
SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key,
|
SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key,
|
||||||
room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains,
|
room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains,
|
||||||
ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at,
|
ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at,
|
||||||
max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled,
|
max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled,
|
||||||
key_backup_version=excluded.key_backup_version
|
key_backup_version=excluded.key_backup_version, key_source=excluded.key_source
|
||||||
`,
|
`,
|
||||||
session.ID(), session.SenderKey, session.SigningKey, session.RoomID, sessionBytes, forwardingChains,
|
session.ID(), session.SenderKey, session.SigningKey, session.RoomID, sessionBytes, forwardingChains,
|
||||||
ratchetSafety, datePtr(session.ReceivedAt), dbutil.NumPtr(session.MaxAge), dbutil.NumPtr(session.MaxMessages),
|
ratchetSafety, datePtr(session.ReceivedAt), dbutil.NumPtr(session.MaxAge), dbutil.NumPtr(session.MaxMessages),
|
||||||
session.IsScheduled, session.KeyBackupVersion, store.AccountID,
|
session.IsScheduled, session.KeyBackupVersion, session.KeySource, store.AccountID,
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -374,12 +375,13 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room
|
||||||
var maxAge, maxMessages sql.NullInt64
|
var maxAge, maxMessages sql.NullInt64
|
||||||
var isScheduled bool
|
var isScheduled bool
|
||||||
var version id.KeyBackupVersion
|
var version id.KeyBackupVersion
|
||||||
|
var keySource id.KeySource
|
||||||
err := store.DB.QueryRow(ctx, `
|
err := store.DB.QueryRow(ctx, `
|
||||||
SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
|
SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source
|
||||||
FROM crypto_megolm_inbound_session
|
FROM crypto_megolm_inbound_session
|
||||||
WHERE room_id=$1 AND session_id=$2 AND account_id=$3`,
|
WHERE room_id=$1 AND session_id=$2 AND account_id=$3`,
|
||||||
roomID, sessionID, store.AccountID,
|
roomID, sessionID, store.AccountID,
|
||||||
).Scan(&senderKey, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version)
|
).Scan(&senderKey, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version, &keySource)
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
|
|
@ -410,6 +412,7 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room
|
||||||
MaxMessages: int(maxMessages.Int64),
|
MaxMessages: int(maxMessages.Int64),
|
||||||
IsScheduled: isScheduled,
|
IsScheduled: isScheduled,
|
||||||
KeyBackupVersion: version,
|
KeyBackupVersion: version,
|
||||||
|
KeySource: keySource,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -534,7 +537,8 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In
|
||||||
var maxAge, maxMessages sql.NullInt64
|
var maxAge, maxMessages sql.NullInt64
|
||||||
var isScheduled bool
|
var isScheduled bool
|
||||||
var version id.KeyBackupVersion
|
var version id.KeyBackupVersion
|
||||||
err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version)
|
var keySource id.KeySource
|
||||||
|
err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version, &keySource)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -554,12 +558,13 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In
|
||||||
MaxMessages: int(maxMessages.Int64),
|
MaxMessages: int(maxMessages.Int64),
|
||||||
IsScheduled: isScheduled,
|
IsScheduled: isScheduled,
|
||||||
KeyBackupVersion: version,
|
KeyBackupVersion: version,
|
||||||
|
KeySource: keySource,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) dbutil.RowIter[*InboundGroupSession] {
|
func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) dbutil.RowIter[*InboundGroupSession] {
|
||||||
rows, err := store.DB.Query(ctx, `
|
rows, err := store.DB.Query(ctx, `
|
||||||
SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
|
SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source
|
||||||
FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`,
|
FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`,
|
||||||
roomID, store.AccountID,
|
roomID, store.AccountID,
|
||||||
)
|
)
|
||||||
|
|
@ -568,7 +573,7 @@ func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID
|
||||||
|
|
||||||
func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.RowIter[*InboundGroupSession] {
|
func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.RowIter[*InboundGroupSession] {
|
||||||
rows, err := store.DB.Query(ctx, `
|
rows, err := store.DB.Query(ctx, `
|
||||||
SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
|
SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source
|
||||||
FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL`,
|
FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL`,
|
||||||
store.AccountID,
|
store.AccountID,
|
||||||
)
|
)
|
||||||
|
|
@ -577,7 +582,7 @@ func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.Row
|
||||||
|
|
||||||
func (store *SQLCryptoStore) GetGroupSessionsWithoutKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] {
|
func (store *SQLCryptoStore) GetGroupSessionsWithoutKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] {
|
||||||
rows, err := store.DB.Query(ctx, `
|
rows, err := store.DB.Query(ctx, `
|
||||||
SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
|
SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source
|
||||||
FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL AND key_backup_version != $2`,
|
FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL AND key_backup_version != $2`,
|
||||||
store.AccountID, version,
|
store.AccountID, version,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
-- v0 -> v18 (compatible with v15+): Latest revision
|
-- v0 -> v19 (compatible with v15+): Latest revision
|
||||||
CREATE TABLE IF NOT EXISTS crypto_account (
|
CREATE TABLE IF NOT EXISTS crypto_account (
|
||||||
account_id TEXT PRIMARY KEY,
|
account_id TEXT PRIMARY KEY,
|
||||||
device_id TEXT NOT NULL,
|
device_id TEXT NOT NULL,
|
||||||
|
|
@ -71,6 +71,7 @@ CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
|
||||||
max_messages INTEGER,
|
max_messages INTEGER,
|
||||||
is_scheduled BOOLEAN NOT NULL DEFAULT false,
|
is_scheduled BOOLEAN NOT NULL DEFAULT false,
|
||||||
key_backup_version TEXT NOT NULL DEFAULT '',
|
key_backup_version TEXT NOT NULL DEFAULT '',
|
||||||
|
key_source TEXT NOT NULL DEFAULT '',
|
||||||
PRIMARY KEY (account_id, session_id)
|
PRIMARY KEY (account_id, session_id)
|
||||||
);
|
);
|
||||||
-- Useful index to find keys that need backing up
|
-- Useful index to find keys that need backing up
|
||||||
|
|
|
||||||
2
crypto/sql_store_upgrade/19-megolm-session-source.sql
Normal file
2
crypto/sql_store_upgrade/19-megolm-session-source.sql
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
-- v19 (compatible with v15+): Store megolm session source
|
||||||
|
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN key_source TEXT NOT NULL DEFAULT '';
|
||||||
28
id/crypto.go
28
id/crypto.go
|
|
@ -53,6 +53,34 @@ const (
|
||||||
KeyBackupAlgorithmMegolmBackupV1 KeyBackupAlgorithm = "m.megolm_backup.v1.curve25519-aes-sha2"
|
KeyBackupAlgorithmMegolmBackupV1 KeyBackupAlgorithm = "m.megolm_backup.v1.curve25519-aes-sha2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type KeySource string
|
||||||
|
|
||||||
|
func (source KeySource) String() string {
|
||||||
|
return string(source)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (source KeySource) Int() int {
|
||||||
|
switch source {
|
||||||
|
case KeySourceDirect:
|
||||||
|
return 100
|
||||||
|
case KeySourceBackup:
|
||||||
|
return 90
|
||||||
|
case KeySourceImport:
|
||||||
|
return 80
|
||||||
|
case KeySourceForward:
|
||||||
|
return 50
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
KeySourceDirect KeySource = "direct"
|
||||||
|
KeySourceBackup KeySource = "backup"
|
||||||
|
KeySourceImport KeySource = "import"
|
||||||
|
KeySourceForward KeySource = "forward"
|
||||||
|
)
|
||||||
|
|
||||||
// BackupVersion is an arbitrary string that identifies a server side key backup.
|
// BackupVersion is an arbitrary string that identifies a server side key backup.
|
||||||
type KeyBackupVersion string
|
type KeyBackupVersion string
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue