From bb6c88faf3cea0c65c6f3671d13dc2e72256298f Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Wed, 6 Mar 2024 08:00:06 +0200 Subject: [PATCH] Add callback on megolm session receive --- crypto/keybackup.go | 2 +- crypto/keyimport.go | 2 +- crypto/keysharing.go | 2 +- crypto/machine.go | 11 +++++++++-- 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/crypto/keybackup.go b/crypto/keybackup.go index cf5e747f..d3701e93 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -179,6 +179,6 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id. if err != nil { return fmt.Errorf("failed to store new inbound group session: %w", err) } - mach.markSessionReceived(sessionID) + mach.markSessionReceived(ctx, sessionID) return nil } diff --git a/crypto/keyimport.go b/crypto/keyimport.go index 2d9f3486..da51774f 100644 --- a/crypto/keyimport.go +++ b/crypto/keyimport.go @@ -122,7 +122,7 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor if err != nil { return false, fmt.Errorf("failed to store imported session: %w", err) } - mach.markSessionReceived(igs.ID()) + mach.markSessionReceived(ctx, igs.ID()) return true, nil } diff --git a/crypto/keysharing.go b/crypto/keysharing.go index fa422ca5..05e7f894 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -189,7 +189,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt log.Error().Err(err).Msg("Failed to store new inbound group session") return false } - mach.markSessionReceived(content.SessionID) + mach.markSessionReceived(ctx, content.SessionID) log.Debug().Msg("Received forwarded inbound group session") return true } diff --git a/crypto/machine.go b/crypto/machine.go index 4a691166..4417faf3 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -52,6 +52,9 @@ type OlmMachine struct { keyWaiters map[id.SessionID]chan struct{} keyWaitersLock sync.Mutex + // Optional callback which is called when we save a session to store + SessionReceived func(context.Context, id.SessionID) + devicesToUnwedge map[id.IdentityKey]bool devicesToUnwedgeLock sync.Mutex recentlyUnwedged map[id.IdentityKey]time.Time @@ -520,7 +523,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen log.Error().Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session") return } - mach.markSessionReceived(sessionID) + mach.markSessionReceived(ctx, sessionID) log.Debug(). Str("session_id", sessionID.String()). Str("sender_key", senderKey.String()). @@ -530,7 +533,11 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen Msg("Received inbound group session") } -func (mach *OlmMachine) markSessionReceived(id id.SessionID) { +func (mach *OlmMachine) markSessionReceived(ctx context.Context, id id.SessionID) { + if mach.SessionReceived != nil { + mach.SessionReceived(ctx, id) + } + mach.keyWaitersLock.Lock() ch, ok := mach.keyWaiters[id] if ok {