mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
Add more contexts everywhere
This commit is contained in:
parent
0a302c753d
commit
25bc36bc7a
37 changed files with 879 additions and 833 deletions
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
|
@ -80,11 +80,11 @@ type OlmMachine struct {
|
|||
// StateStore is used by OlmMachine to get room state information that's needed for encryption.
|
||||
type StateStore interface {
|
||||
// IsEncrypted returns whether a room is encrypted.
|
||||
IsEncrypted(id.RoomID) bool
|
||||
IsEncrypted(context.Context, id.RoomID) (bool, error)
|
||||
// GetEncryptionEvent returns the encryption event's content for an encrypted room.
|
||||
GetEncryptionEvent(id.RoomID) *event.EncryptionEventContent
|
||||
GetEncryptionEvent(context.Context, id.RoomID) (*event.EncryptionEventContent, error)
|
||||
// FindSharedRooms returns the encrypted rooms that another user is also in for a user ID.
|
||||
FindSharedRooms(id.UserID) []id.RoomID
|
||||
FindSharedRooms(context.Context, id.UserID) ([]id.RoomID, error)
|
||||
}
|
||||
|
||||
// NewOlmMachine creates an OlmMachine with the given client, logger and stores.
|
||||
|
|
@ -131,8 +131,8 @@ func (mach *OlmMachine) machOrContextLog(ctx context.Context) *zerolog.Logger {
|
|||
|
||||
// Load loads the Olm account information from the crypto store. If there's no olm account, a new one is created.
|
||||
// This must be called before using the machine.
|
||||
func (mach *OlmMachine) Load() (err error) {
|
||||
mach.account, err = mach.CryptoStore.GetAccount()
|
||||
func (mach *OlmMachine) Load(ctx context.Context) (err error) {
|
||||
mach.account, err = mach.CryptoStore.GetAccount(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -143,15 +143,15 @@ func (mach *OlmMachine) Load() (err error) {
|
|||
}
|
||||
|
||||
func (mach *OlmMachine) saveAccount() {
|
||||
err := mach.CryptoStore.PutAccount(mach.account)
|
||||
err := mach.CryptoStore.PutAccount(context.TODO(), mach.account)
|
||||
if err != nil {
|
||||
mach.Log.Error().Err(err).Msg("Failed to save account")
|
||||
}
|
||||
}
|
||||
|
||||
// FlushStore calls the Flush method of the CryptoStore.
|
||||
func (mach *OlmMachine) FlushStore() error {
|
||||
return mach.CryptoStore.Flush()
|
||||
func (mach *OlmMachine) FlushStore(ctx context.Context) error {
|
||||
return mach.CryptoStore.Flush(ctx)
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) timeTrace(ctx context.Context, thing string, expectedDuration time.Duration) func() {
|
||||
|
|
@ -284,7 +284,12 @@ func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string
|
|||
//
|
||||
// client.Syncer.(mautrix.ExtensibleSyncer).OnEventType(event.StateMember, c.crypto.HandleMemberEvent)
|
||||
func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Event) {
|
||||
if !mach.StateStore.IsEncrypted(evt.RoomID) {
|
||||
ctx := context.TODO()
|
||||
if isEncrypted, err := mach.StateStore.IsEncrypted(ctx, evt.RoomID); err != nil {
|
||||
mach.machOrContextLog(ctx).Err(err).Stringer("room_id", evt.RoomID).
|
||||
Msg("Failed to check if room is encrypted to handle member event")
|
||||
return
|
||||
} else if !isEncrypted {
|
||||
return
|
||||
}
|
||||
content := evt.Content.AsMember()
|
||||
|
|
@ -311,7 +316,7 @@ func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Even
|
|||
Str("prev_membership", string(prevContent.Membership)).
|
||||
Str("new_membership", string(content.Membership)).
|
||||
Msg("Got membership state change, invalidating group session in room")
|
||||
err := mach.CryptoStore.RemoveOutboundGroupSession(evt.RoomID)
|
||||
err := mach.CryptoStore.RemoveOutboundGroupSession(ctx, evt.RoomID)
|
||||
if err != nil {
|
||||
mach.Log.Warn().Str("room_id", evt.RoomID.String()).Msg("Failed to invalidate outbound group session")
|
||||
}
|
||||
|
|
@ -405,7 +410,7 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
|
|||
// GetOrFetchDevice attempts to retrieve the device identity for the given device from the store
|
||||
// and if it's not found it asks the server for it.
|
||||
func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
|
||||
device, err := mach.CryptoStore.GetDevice(userID, deviceID)
|
||||
device, err := mach.CryptoStore.GetDevice(ctx, userID, deviceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get sender device from store: %w", err)
|
||||
} else if device != nil {
|
||||
|
|
@ -425,7 +430,7 @@ func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID,
|
|||
// store and if it's not found it asks the server for it. This returns nil if the server doesn't return a device with
|
||||
// the given identity key.
|
||||
func (mach *OlmMachine) GetOrFetchDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
|
||||
deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(userID, identityKey)
|
||||
deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(ctx, userID, identityKey)
|
||||
if err != nil || deviceIdentity != nil {
|
||||
return deviceIdentity, err
|
||||
}
|
||||
|
|
@ -455,7 +460,7 @@ func (mach *OlmMachine) SendEncryptedToDevice(ctx context.Context, device *id.De
|
|||
mach.olmLock.Lock()
|
||||
defer mach.olmLock.Unlock()
|
||||
|
||||
olmSess, err := mach.CryptoStore.GetLatestSession(device.IdentityKey)
|
||||
olmSess, err := mach.CryptoStore.GetLatestSession(ctx, device.IdentityKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -499,7 +504,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen
|
|||
Msg("Mismatched session ID while creating inbound group session")
|
||||
return
|
||||
}
|
||||
err = mach.CryptoStore.PutGroupSession(roomID, senderKey, sessionID, igs)
|
||||
err = mach.CryptoStore.PutGroupSession(ctx, roomID, senderKey, sessionID, igs)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session")
|
||||
return
|
||||
|
|
@ -525,7 +530,7 @@ func (mach *OlmMachine) markSessionReceived(id id.SessionID) {
|
|||
}
|
||||
|
||||
// WaitForSession waits for the given Megolm session to arrive.
|
||||
func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
|
||||
func (mach *OlmMachine) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
|
||||
mach.keyWaitersLock.Lock()
|
||||
ch, ok := mach.keyWaiters[sessionID]
|
||||
if !ok {
|
||||
|
|
@ -534,7 +539,7 @@ func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey,
|
|||
}
|
||||
mach.keyWaitersLock.Unlock()
|
||||
// Handle race conditions where a session appears between the failed decryption and WaitForSession call.
|
||||
sess, err := mach.CryptoStore.GetGroupSession(roomID, senderKey, sessionID)
|
||||
sess, err := mach.CryptoStore.GetGroupSession(ctx, roomID, senderKey, sessionID)
|
||||
if sess != nil || errors.Is(err, ErrGroupSessionWithheld) {
|
||||
return true
|
||||
}
|
||||
|
|
@ -542,10 +547,12 @@ func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey,
|
|||
case <-ch:
|
||||
return true
|
||||
case <-time.After(timeout):
|
||||
sess, err = mach.CryptoStore.GetGroupSession(roomID, senderKey, sessionID)
|
||||
sess, err = mach.CryptoStore.GetGroupSession(ctx, roomID, senderKey, sessionID)
|
||||
// Check if the session somehow appeared in the store without telling us
|
||||
// We accept withheld sessions as received, as then the decryption attempt will show the error.
|
||||
return sess != nil || errors.Is(err, ErrGroupSessionWithheld)
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -568,7 +575,10 @@ func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEve
|
|||
return
|
||||
}
|
||||
|
||||
config := mach.StateStore.GetEncryptionEvent(content.RoomID)
|
||||
config, err := mach.StateStore.GetEncryptionEvent(ctx, content.RoomID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get encryption event for room")
|
||||
}
|
||||
var maxAge time.Duration
|
||||
var maxMessages int
|
||||
if config != nil {
|
||||
|
|
@ -589,7 +599,7 @@ func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEve
|
|||
}
|
||||
if mach.DeletePreviousKeysOnReceive && !content.IsScheduled {
|
||||
log.Debug().Msg("Redacting previous megolm sessions from sender in room")
|
||||
sessionIDs, err := mach.CryptoStore.RedactGroupSessions(content.RoomID, evt.SenderKey, "received new key from device")
|
||||
sessionIDs, err := mach.CryptoStore.RedactGroupSessions(ctx, content.RoomID, evt.SenderKey, "received new key from device")
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to redact previous megolm sessions")
|
||||
} else {
|
||||
|
|
@ -606,7 +616,7 @@ func (mach *OlmMachine) handleRoomKeyWithheld(ctx context.Context, content *even
|
|||
zerolog.Ctx(ctx).Debug().Interface("content", content).Msg("Non-megolm room key withheld event")
|
||||
return
|
||||
}
|
||||
err := mach.CryptoStore.PutWithheldGroupSession(*content)
|
||||
err := mach.CryptoStore.PutWithheldGroupSession(ctx, *content)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to save room key withheld event")
|
||||
}
|
||||
|
|
@ -662,7 +672,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro
|
|||
func (mach *OlmMachine) ExpiredKeyDeleteLoop(ctx context.Context) {
|
||||
log := mach.Log.With().Str("action", "redact expired sessions").Logger()
|
||||
for {
|
||||
sessionIDs, err := mach.CryptoStore.RedactExpiredGroupSessions()
|
||||
sessionIDs, err := mach.CryptoStore.RedactExpiredGroupSessions(ctx)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to redact expired megolm sessions")
|
||||
} else if len(sessionIDs) > 0 {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue