Include room ID and first known index in SessionReceived callback

This commit is contained in:
Tulir Asokan 2024-05-25 23:03:26 +03:00
commit a2169274da
4 changed files with 13 additions and 10 deletions

View file

@ -160,7 +160,8 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.
maxMessages = config.RotationPeriodMessages
}
if firstKnownIndex := igsInternal.FirstKnownIndex(); firstKnownIndex > 0 {
firstKnownIndex := igsInternal.FirstKnownIndex()
if firstKnownIndex > 0 {
log.Warn().Uint32("first_known_index", firstKnownIndex).Msg("Importing partial session")
}
@ -181,6 +182,6 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.
if err != nil {
return fmt.Errorf("failed to store new inbound group session: %w", err)
}
mach.markSessionReceived(ctx, sessionID)
mach.markSessionReceived(ctx, roomID, sessionID, firstKnownIndex)
return nil
}

View file

@ -114,7 +114,8 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor
ReceivedAt: time.Now().UTC(),
}
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID())
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() {
firstKnownIndex := igs.Internal.FirstKnownIndex()
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= firstKnownIndex {
// We already have an equivalent or better session in the store, so don't override it.
return false, nil
}
@ -122,7 +123,7 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor
if err != nil {
return false, fmt.Errorf("failed to store imported session: %w", err)
}
mach.markSessionReceived(ctx, igs.ID())
mach.markSessionReceived(ctx, session.RoomID, igs.ID(), firstKnownIndex)
return true, nil
}

View file

@ -168,7 +168,8 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
if content.MaxMessages != 0 {
maxMessages = content.MaxMessages
}
if firstKnownIndex := igsInternal.FirstKnownIndex(); firstKnownIndex > 0 {
firstKnownIndex := igsInternal.FirstKnownIndex()
if firstKnownIndex > 0 {
log.Warn().Uint32("first_known_index", firstKnownIndex).Msg("Importing partial session")
}
igs := &InboundGroupSession{
@ -194,7 +195,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
log.Error().Err(err).Msg("Failed to store new inbound group session")
return false
}
mach.markSessionReceived(ctx, content.SessionID)
mach.markSessionReceived(ctx, content.RoomID, content.SessionID, firstKnownIndex)
log.Debug().Msg("Received forwarded inbound group session")
return true
}

View file

@ -53,7 +53,7 @@ type OlmMachine struct {
keyWaitersLock sync.Mutex
// Optional callback which is called when we save a session to store
SessionReceived func(context.Context, id.SessionID)
SessionReceived func(context.Context, id.RoomID, id.SessionID, uint32)
devicesToUnwedge map[id.IdentityKey]bool
devicesToUnwedgeLock sync.Mutex
@ -523,7 +523,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen
log.Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session")
return fmt.Errorf("failed to store new inbound group session: %w", err)
}
mach.markSessionReceived(ctx, sessionID)
mach.markSessionReceived(ctx, roomID, sessionID, igs.Internal.FirstKnownIndex())
log.Debug().
Str("session_id", sessionID.String()).
Str("sender_key", senderKey.String()).
@ -534,9 +534,9 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen
return nil
}
func (mach *OlmMachine) markSessionReceived(ctx context.Context, id id.SessionID) {
func (mach *OlmMachine) markSessionReceived(ctx context.Context, roomID id.RoomID, id id.SessionID, firstKnownIndex uint32) {
if mach.SessionReceived != nil {
mach.SessionReceived(ctx, id)
mach.SessionReceived(ctx, roomID, id, firstKnownIndex)
}
mach.keyWaitersLock.Lock()