From a2169274da2999d08239532b4ff7fb8136ec1fc5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 25 May 2024 23:03:26 +0300 Subject: [PATCH] Include room ID and first known index in SessionReceived callback --- crypto/keybackup.go | 5 +++-- crypto/keyimport.go | 5 +++-- crypto/keysharing.go | 5 +++-- crypto/machine.go | 8 ++++---- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/crypto/keybackup.go b/crypto/keybackup.go index 820f3114..7d8148f6 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -160,7 +160,8 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id. maxMessages = config.RotationPeriodMessages } - if firstKnownIndex := igsInternal.FirstKnownIndex(); firstKnownIndex > 0 { + firstKnownIndex := igsInternal.FirstKnownIndex() + if firstKnownIndex > 0 { log.Warn().Uint32("first_known_index", firstKnownIndex).Msg("Importing partial session") } @@ -181,6 +182,6 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id. if err != nil { return fmt.Errorf("failed to store new inbound group session: %w", err) } - mach.markSessionReceived(ctx, sessionID) + mach.markSessionReceived(ctx, roomID, sessionID, firstKnownIndex) return nil } diff --git a/crypto/keyimport.go b/crypto/keyimport.go index 6c320f43..693ff6b8 100644 --- a/crypto/keyimport.go +++ b/crypto/keyimport.go @@ -114,7 +114,8 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor ReceivedAt: time.Now().UTC(), } existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID()) - if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() { + firstKnownIndex := igs.Internal.FirstKnownIndex() + if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= firstKnownIndex { // We already have an equivalent or better session in the store, so don't override it. return false, nil } @@ -122,7 +123,7 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor if err != nil { return false, fmt.Errorf("failed to store imported session: %w", err) } - mach.markSessionReceived(ctx, igs.ID()) + mach.markSessionReceived(ctx, session.RoomID, igs.ID(), firstKnownIndex) return true, nil } diff --git a/crypto/keysharing.go b/crypto/keysharing.go index ad0011e5..362dee81 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -168,7 +168,8 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt if content.MaxMessages != 0 { maxMessages = content.MaxMessages } - if firstKnownIndex := igsInternal.FirstKnownIndex(); firstKnownIndex > 0 { + firstKnownIndex := igsInternal.FirstKnownIndex() + if firstKnownIndex > 0 { log.Warn().Uint32("first_known_index", firstKnownIndex).Msg("Importing partial session") } igs := &InboundGroupSession{ @@ -194,7 +195,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt log.Error().Err(err).Msg("Failed to store new inbound group session") return false } - mach.markSessionReceived(ctx, content.SessionID) + mach.markSessionReceived(ctx, content.RoomID, content.SessionID, firstKnownIndex) log.Debug().Msg("Received forwarded inbound group session") return true } diff --git a/crypto/machine.go b/crypto/machine.go index abb8d540..c9c06c3b 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -53,7 +53,7 @@ type OlmMachine struct { keyWaitersLock sync.Mutex // Optional callback which is called when we save a session to store - SessionReceived func(context.Context, id.SessionID) + SessionReceived func(context.Context, id.RoomID, id.SessionID, uint32) devicesToUnwedge map[id.IdentityKey]bool devicesToUnwedgeLock sync.Mutex @@ -523,7 +523,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen log.Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session") return fmt.Errorf("failed to store new inbound group session: %w", err) } - mach.markSessionReceived(ctx, sessionID) + mach.markSessionReceived(ctx, roomID, sessionID, igs.Internal.FirstKnownIndex()) log.Debug(). Str("session_id", sessionID.String()). Str("sender_key", senderKey.String()). @@ -534,9 +534,9 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen return nil } -func (mach *OlmMachine) markSessionReceived(ctx context.Context, id id.SessionID) { +func (mach *OlmMachine) markSessionReceived(ctx context.Context, roomID id.RoomID, id id.SessionID, firstKnownIndex uint32) { if mach.SessionReceived != nil { - mach.SessionReceived(ctx, id) + mach.SessionReceived(ctx, roomID, id, firstKnownIndex) } mach.keyWaitersLock.Lock()