crypto/keybackup: allow importing room keys without saving
Some checks failed
Go / Lint (latest) (push) Has been cancelled
Go / Build (old, libolm) (push) Has been cancelled
Go / Build (latest, libolm) (push) Has been cancelled
Go / Build (old, goolm) (push) Has been cancelled
Go / Build (latest, goolm) (push) Has been cancelled

This commit is contained in:
Tulir Asokan 2025-05-04 14:09:06 +03:00
commit 6eb4c7b17f
4 changed files with 43 additions and 21 deletions

View file

@ -13,6 +13,7 @@ import (
"maunium.net/go/mautrix/crypto/backup"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/crypto/signatures"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
@ -161,11 +162,15 @@ var (
ErrFailedToStoreNewInboundGroupSessionFromBackup = errors.New("failed to store new inbound group session from key backup")
)
func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) (*InboundGroupSession, error) {
log := zerolog.Ctx(ctx).With().
Str("room_id", roomID.String()).
Str("session_id", sessionID.String()).
Logger()
func (mach *OlmMachine) ImportRoomKeyFromBackupWithoutSaving(
ctx context.Context,
version id.KeyBackupVersion,
roomID id.RoomID,
config *event.EncryptionEventContent,
sessionID id.SessionID,
keyBackupData *backup.MegolmSessionData,
) (*InboundGroupSession, error) {
log := zerolog.Ctx(ctx)
if keyBackupData.Algorithm != id.AlgorithmMegolmV1 {
return nil, fmt.Errorf("%w %s", ErrUnknownAlgorithmInKeyBackup, keyBackupData.Algorithm)
}
@ -175,6 +180,8 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.
return nil, fmt.Errorf("failed to import inbound group session: %w", err)
} else if igsInternal.ID() != sessionID {
log.Warn().
Stringer("room_id", roomID).
Stringer("session_id", sessionID).
Stringer("actual_session_id", igsInternal.ID()).
Msg("Mismatched session ID while creating inbound group session from key backup")
return nil, ErrMismatchingSessionIDInKeyBackup
@ -182,19 +189,12 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.
var maxAge time.Duration
var maxMessages int
if config, err := mach.StateStore.GetEncryptionEvent(ctx, roomID); err != nil {
log.Error().Err(err).Msg("Failed to get encryption event for room")
} else if config != nil {
if config != nil {
maxAge = time.Duration(config.RotationPeriodMillis) * time.Millisecond
maxMessages = config.RotationPeriodMessages
}
firstKnownIndex := igsInternal.FirstKnownIndex()
if firstKnownIndex > 0 {
log.Warn().Uint32("first_known_index", firstKnownIndex).Msg("Importing partial session")
}
igs := &InboundGroupSession{
return &InboundGroupSession{
Internal: igsInternal,
SigningKey: keyBackupData.SenderClaimedKeys.Ed25519,
SenderKey: keyBackupData.SenderKey,
@ -206,11 +206,33 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.
MaxAge: maxAge.Milliseconds(),
MaxMessages: maxMessages,
KeyBackupVersion: version,
}, nil
}
func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) (*InboundGroupSession, error) {
config, err := mach.StateStore.GetEncryptionEvent(ctx, roomID)
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("room_id", roomID).
Stringer("session_id", sessionID).
Msg("Failed to get encryption event for room")
}
err = mach.CryptoStore.PutGroupSession(ctx, igs)
imported, err := mach.ImportRoomKeyFromBackupWithoutSaving(ctx, version, roomID, config, sessionID, keyBackupData)
if err != nil {
return nil, err
}
firstKnownIndex := imported.Internal.FirstKnownIndex()
if firstKnownIndex > 0 {
zerolog.Ctx(ctx).Warn().
Stringer("room_id", roomID).
Stringer("session_id", sessionID).
Uint32("first_known_index", firstKnownIndex).
Msg("Importing partial session")
}
err = mach.CryptoStore.PutGroupSession(ctx, imported)
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrFailedToStoreNewInboundGroupSessionFromBackup, err)
}
mach.markSessionReceived(ctx, roomID, sessionID, firstKnownIndex)
return igs, nil
mach.MarkSessionReceived(ctx, roomID, sessionID, firstKnownIndex)
return imported, nil
}

View file

@ -127,7 +127,7 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor
if err != nil {
return false, fmt.Errorf("failed to store imported session: %w", err)
}
mach.markSessionReceived(ctx, session.RoomID, igs.ID(), firstKnownIndex)
mach.MarkSessionReceived(ctx, session.RoomID, igs.ID(), firstKnownIndex)
return true, nil
}

View file

@ -200,7 +200,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
log.Error().Err(err).Msg("Failed to store new inbound group session")
return false
}
mach.markSessionReceived(ctx, content.RoomID, content.SessionID, firstKnownIndex)
mach.MarkSessionReceived(ctx, content.RoomID, content.SessionID, firstKnownIndex)
log.Debug().Msg("Received forwarded inbound group session")
return true
}

View file

@ -584,7 +584,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen
log.Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session")
return fmt.Errorf("failed to store new inbound group session: %w", err)
}
mach.markSessionReceived(ctx, roomID, sessionID, igs.Internal.FirstKnownIndex())
mach.MarkSessionReceived(ctx, roomID, sessionID, igs.Internal.FirstKnownIndex())
log.Debug().
Str("session_id", sessionID.String()).
Str("sender_key", senderKey.String()).
@ -595,7 +595,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen
return nil
}
func (mach *OlmMachine) markSessionReceived(ctx context.Context, roomID id.RoomID, id id.SessionID, firstKnownIndex uint32) {
func (mach *OlmMachine) MarkSessionReceived(ctx context.Context, roomID id.RoomID, id id.SessionID, firstKnownIndex uint32) {
if mach.SessionReceived != nil {
mach.SessionReceived(ctx, roomID, id, firstKnownIndex)
}