Add more contexts everywhere

This commit is contained in:
Tulir Asokan 2024-01-07 22:44:06 +02:00
commit 25bc36bc7a
37 changed files with 879 additions and 833 deletions

View file

@ -1,4 +1,4 @@
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -80,11 +80,11 @@ type OlmMachine struct {
// StateStore is used by OlmMachine to get room state information that's needed for encryption.
type StateStore interface {
// IsEncrypted returns whether a room is encrypted.
IsEncrypted(id.RoomID) bool
IsEncrypted(context.Context, id.RoomID) (bool, error)
// GetEncryptionEvent returns the encryption event's content for an encrypted room.
GetEncryptionEvent(id.RoomID) *event.EncryptionEventContent
GetEncryptionEvent(context.Context, id.RoomID) (*event.EncryptionEventContent, error)
// FindSharedRooms returns the encrypted rooms that another user is also in for a user ID.
FindSharedRooms(id.UserID) []id.RoomID
FindSharedRooms(context.Context, id.UserID) ([]id.RoomID, error)
}
// NewOlmMachine creates an OlmMachine with the given client, logger and stores.
@ -131,8 +131,8 @@ func (mach *OlmMachine) machOrContextLog(ctx context.Context) *zerolog.Logger {
// Load loads the Olm account information from the crypto store. If there's no olm account, a new one is created.
// This must be called before using the machine.
func (mach *OlmMachine) Load() (err error) {
mach.account, err = mach.CryptoStore.GetAccount()
func (mach *OlmMachine) Load(ctx context.Context) (err error) {
mach.account, err = mach.CryptoStore.GetAccount(ctx)
if err != nil {
return
}
@ -143,15 +143,15 @@ func (mach *OlmMachine) Load() (err error) {
}
func (mach *OlmMachine) saveAccount() {
err := mach.CryptoStore.PutAccount(mach.account)
err := mach.CryptoStore.PutAccount(context.TODO(), mach.account)
if err != nil {
mach.Log.Error().Err(err).Msg("Failed to save account")
}
}
// FlushStore calls the Flush method of the CryptoStore.
func (mach *OlmMachine) FlushStore() error {
return mach.CryptoStore.Flush()
func (mach *OlmMachine) FlushStore(ctx context.Context) error {
return mach.CryptoStore.Flush(ctx)
}
func (mach *OlmMachine) timeTrace(ctx context.Context, thing string, expectedDuration time.Duration) func() {
@ -284,7 +284,12 @@ func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string
//
// client.Syncer.(mautrix.ExtensibleSyncer).OnEventType(event.StateMember, c.crypto.HandleMemberEvent)
func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Event) {
if !mach.StateStore.IsEncrypted(evt.RoomID) {
ctx := context.TODO()
if isEncrypted, err := mach.StateStore.IsEncrypted(ctx, evt.RoomID); err != nil {
mach.machOrContextLog(ctx).Err(err).Stringer("room_id", evt.RoomID).
Msg("Failed to check if room is encrypted to handle member event")
return
} else if !isEncrypted {
return
}
content := evt.Content.AsMember()
@ -311,7 +316,7 @@ func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Even
Str("prev_membership", string(prevContent.Membership)).
Str("new_membership", string(content.Membership)).
Msg("Got membership state change, invalidating group session in room")
err := mach.CryptoStore.RemoveOutboundGroupSession(evt.RoomID)
err := mach.CryptoStore.RemoveOutboundGroupSession(ctx, evt.RoomID)
if err != nil {
mach.Log.Warn().Str("room_id", evt.RoomID.String()).Msg("Failed to invalidate outbound group session")
}
@ -405,7 +410,7 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
// GetOrFetchDevice attempts to retrieve the device identity for the given device from the store
// and if it's not found it asks the server for it.
func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
device, err := mach.CryptoStore.GetDevice(userID, deviceID)
device, err := mach.CryptoStore.GetDevice(ctx, userID, deviceID)
if err != nil {
return nil, fmt.Errorf("failed to get sender device from store: %w", err)
} else if device != nil {
@ -425,7 +430,7 @@ func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID,
// store and if it's not found it asks the server for it. This returns nil if the server doesn't return a device with
// the given identity key.
func (mach *OlmMachine) GetOrFetchDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(userID, identityKey)
deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(ctx, userID, identityKey)
if err != nil || deviceIdentity != nil {
return deviceIdentity, err
}
@ -455,7 +460,7 @@ func (mach *OlmMachine) SendEncryptedToDevice(ctx context.Context, device *id.De
mach.olmLock.Lock()
defer mach.olmLock.Unlock()
olmSess, err := mach.CryptoStore.GetLatestSession(device.IdentityKey)
olmSess, err := mach.CryptoStore.GetLatestSession(ctx, device.IdentityKey)
if err != nil {
return err
}
@ -499,7 +504,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen
Msg("Mismatched session ID while creating inbound group session")
return
}
err = mach.CryptoStore.PutGroupSession(roomID, senderKey, sessionID, igs)
err = mach.CryptoStore.PutGroupSession(ctx, roomID, senderKey, sessionID, igs)
if err != nil {
log.Error().Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session")
return
@ -525,7 +530,7 @@ func (mach *OlmMachine) markSessionReceived(id id.SessionID) {
}
// WaitForSession waits for the given Megolm session to arrive.
func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
func (mach *OlmMachine) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
mach.keyWaitersLock.Lock()
ch, ok := mach.keyWaiters[sessionID]
if !ok {
@ -534,7 +539,7 @@ func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey,
}
mach.keyWaitersLock.Unlock()
// Handle race conditions where a session appears between the failed decryption and WaitForSession call.
sess, err := mach.CryptoStore.GetGroupSession(roomID, senderKey, sessionID)
sess, err := mach.CryptoStore.GetGroupSession(ctx, roomID, senderKey, sessionID)
if sess != nil || errors.Is(err, ErrGroupSessionWithheld) {
return true
}
@ -542,10 +547,12 @@ func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey,
case <-ch:
return true
case <-time.After(timeout):
sess, err = mach.CryptoStore.GetGroupSession(roomID, senderKey, sessionID)
sess, err = mach.CryptoStore.GetGroupSession(ctx, roomID, senderKey, sessionID)
// Check if the session somehow appeared in the store without telling us
// We accept withheld sessions as received, as then the decryption attempt will show the error.
return sess != nil || errors.Is(err, ErrGroupSessionWithheld)
case <-ctx.Done():
return false
}
}
@ -568,7 +575,10 @@ func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEve
return
}
config := mach.StateStore.GetEncryptionEvent(content.RoomID)
config, err := mach.StateStore.GetEncryptionEvent(ctx, content.RoomID)
if err != nil {
log.Error().Err(err).Msg("Failed to get encryption event for room")
}
var maxAge time.Duration
var maxMessages int
if config != nil {
@ -589,7 +599,7 @@ func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEve
}
if mach.DeletePreviousKeysOnReceive && !content.IsScheduled {
log.Debug().Msg("Redacting previous megolm sessions from sender in room")
sessionIDs, err := mach.CryptoStore.RedactGroupSessions(content.RoomID, evt.SenderKey, "received new key from device")
sessionIDs, err := mach.CryptoStore.RedactGroupSessions(ctx, content.RoomID, evt.SenderKey, "received new key from device")
if err != nil {
log.Err(err).Msg("Failed to redact previous megolm sessions")
} else {
@ -606,7 +616,7 @@ func (mach *OlmMachine) handleRoomKeyWithheld(ctx context.Context, content *even
zerolog.Ctx(ctx).Debug().Interface("content", content).Msg("Non-megolm room key withheld event")
return
}
err := mach.CryptoStore.PutWithheldGroupSession(*content)
err := mach.CryptoStore.PutWithheldGroupSession(ctx, *content)
if err != nil {
zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to save room key withheld event")
}
@ -662,7 +672,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro
func (mach *OlmMachine) ExpiredKeyDeleteLoop(ctx context.Context) {
log := mach.Log.With().Str("action", "redact expired sessions").Logger()
for {
sessionIDs, err := mach.CryptoStore.RedactExpiredGroupSessions()
sessionIDs, err := mach.CryptoStore.RedactExpiredGroupSessions(ctx)
if err != nil {
log.Err(err).Msg("Failed to redact expired megolm sessions")
} else if len(sessionIDs) > 0 {