From 308e3583b06f03da67da38a5ff4d711cd5fa02d1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 13 Jan 2024 18:56:12 +0200 Subject: [PATCH] Add contexts to event handlers --- CHANGELOG.md | 1 + appservice/eventprocessor.go | 55 +++++------ bridge/bridge.go | 8 +- bridge/crypto.go | 8 +- bridge/matrix.go | 35 ++++--- client.go | 7 +- crypto/cryptohelper/cryptohelper.go | 18 ++-- crypto/machine.go | 32 +++---- event/events.go | 2 + event/eventsource.go | 72 ++++++++++++++ statestore.go | 4 +- sync.go | 140 +++++++--------------------- 12 files changed, 196 insertions(+), 186 deletions(-) create mode 100644 event/eventsource.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 7abbe587..a04fbff4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ functions. * **Breaking change *(everything)*** Added context parameters to all functions (started by [@recht] in [#144]). +* *(client)* Moved `EventSource` to `event.Source`. * *(crypto)* Added experimental pure Go Olm implementation to replace libolm (thanks to [@DerLukas15] in [#106]). * You can use the `goolm` build tag to the new implementation. diff --git a/appservice/eventprocessor.go b/appservice/eventprocessor.go index 376a4fc4..4cd2ce4e 100644 --- a/appservice/eventprocessor.go +++ b/appservice/eventprocessor.go @@ -7,6 +7,7 @@ package appservice import ( + "context" "encoding/json" "runtime/debug" "time" @@ -25,9 +26,9 @@ const ( Sync ) -type EventHandler = func(evt *event.Event) -type OTKHandler = func(otk *mautrix.OTKCount) -type DeviceListHandler = func(lists *mautrix.DeviceLists, since string) +type EventHandler = func(ctx context.Context, evt *event.Event) +type OTKHandler = func(ctx context.Context, otk *mautrix.OTKCount) +type DeviceListHandler = func(ctx context.Context, lists *mautrix.DeviceLists, since string) type EventProcessor struct { ExecMode ExecMode @@ -97,34 +98,34 @@ func (ep *EventProcessor) recoverFunc(data interface{}) { } } -func (ep *EventProcessor) callHandler(handler EventHandler, evt *event.Event) { +func (ep *EventProcessor) callHandler(ctx context.Context, handler EventHandler, evt *event.Event) { defer ep.recoverFunc(evt) - handler(evt) + handler(ctx, evt) } -func (ep *EventProcessor) callOTKHandler(handler OTKHandler, otk *mautrix.OTKCount) { +func (ep *EventProcessor) callOTKHandler(ctx context.Context, handler OTKHandler, otk *mautrix.OTKCount) { defer ep.recoverFunc(otk) - handler(otk) + handler(ctx, otk) } -func (ep *EventProcessor) callDeviceListHandler(handler DeviceListHandler, dl *mautrix.DeviceLists) { +func (ep *EventProcessor) callDeviceListHandler(ctx context.Context, handler DeviceListHandler, dl *mautrix.DeviceLists) { defer ep.recoverFunc(dl) - handler(dl, "") + handler(ctx, dl, "") } -func (ep *EventProcessor) DispatchOTK(otk *mautrix.OTKCount) { +func (ep *EventProcessor) DispatchOTK(ctx context.Context, otk *mautrix.OTKCount) { for _, handler := range ep.otkHandlers { - go ep.callOTKHandler(handler, otk) + go ep.callOTKHandler(ctx, handler, otk) } } -func (ep *EventProcessor) DispatchDeviceList(dl *mautrix.DeviceLists) { +func (ep *EventProcessor) DispatchDeviceList(ctx context.Context, dl *mautrix.DeviceLists) { for _, handler := range ep.deviceListHandlers { - go ep.callDeviceListHandler(handler, dl) + go ep.callDeviceListHandler(ctx, handler, dl) } } -func (ep *EventProcessor) Dispatch(evt *event.Event) { +func (ep *EventProcessor) Dispatch(ctx context.Context, evt *event.Event) { handlers, ok := ep.handlers[evt.Type] if !ok { return @@ -132,25 +133,25 @@ func (ep *EventProcessor) Dispatch(evt *event.Event) { switch ep.ExecMode { case AsyncHandlers: for _, handler := range handlers { - go ep.callHandler(handler, evt) + go ep.callHandler(ctx, handler, evt) } case AsyncLoop: go func() { for _, handler := range handlers { - ep.callHandler(handler, evt) + ep.callHandler(ctx, handler, evt) } }() case Sync: if ep.ExecSyncWarnTime == 0 && ep.ExecSyncTimeout == 0 { for _, handler := range handlers { - ep.callHandler(handler, evt) + ep.callHandler(ctx, handler, evt) } return } doneChan := make(chan struct{}) go func() { for _, handler := range handlers { - ep.callHandler(handler, evt) + ep.callHandler(ctx, handler, evt) } close(doneChan) }() @@ -172,35 +173,35 @@ func (ep *EventProcessor) Dispatch(evt *event.Event) { } } } -func (ep *EventProcessor) startEvents() { +func (ep *EventProcessor) startEvents(ctx context.Context) { for { select { case evt := <-ep.as.Events: - ep.Dispatch(evt) + ep.Dispatch(ctx, evt) case <-ep.stop: return } } } -func (ep *EventProcessor) startEncryption() { +func (ep *EventProcessor) startEncryption(ctx context.Context) { for { select { case evt := <-ep.as.ToDeviceEvents: - ep.Dispatch(evt) + ep.Dispatch(ctx, evt) case otk := <-ep.as.OTKCounts: - ep.DispatchOTK(otk) + ep.DispatchOTK(ctx, otk) case dl := <-ep.as.DeviceLists: - ep.DispatchDeviceList(dl) + ep.DispatchDeviceList(ctx, dl) case <-ep.stop: return } } } -func (ep *EventProcessor) Start() { - go ep.startEvents() - go ep.startEncryption() +func (ep *EventProcessor) Start(ctx context.Context) { + go ep.startEvents(ctx) + go ep.startEncryption(ctx) } func (ep *EventProcessor) Stop() { diff --git a/bridge/bridge.go b/bridge/bridge.go index 6ad19720..7d5333ce 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -214,7 +214,7 @@ type Bridge struct { } type Crypto interface { - HandleMemberEvent(*event.Event) + HandleMemberEvent(context.Context, *event.Event) Decrypt(context.Context, *event.Event) (*event.Event, error) Encrypt(context.Context, id.RoomID, event.Type, *event.Content) error WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool @@ -321,7 +321,7 @@ func (br *Bridge) ensureConnection(ctx context.Context) { if errors.Is(err, mautrix.MUnknownToken) { br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was not accepted. Is the registration file installed in your homeserver correctly?") } else if errors.Is(err, mautrix.MExclusive) { - br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was accepted, but the /register request was not. Are the homeserver domain and username template in the config correct, and do they match the values in the registration?") + br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was accepted, but the /register request was not. Are the homeserver domain, bot username and username template in the config correct, and do they match the values in the registration?") } else { br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("/whoami request failed with unknown error") } @@ -674,7 +674,7 @@ func (br *Bridge) start() { } br.ZLog.Debug().Msg("Checking connection to homeserver") - ctx := context.Background() + ctx := br.ZLog.WithContext(context.Background()) br.ensureConnection(ctx) go br.fetchMediaConfig(ctx) @@ -687,7 +687,7 @@ func (br *Bridge) start() { } br.ZLog.Debug().Msg("Starting event processor") - br.EventProcessor.Start() + br.EventProcessor.Start(ctx) go br.UpdateBotProfile(ctx) if br.Crypto != nil { diff --git a/bridge/crypto.go b/bridge/crypto.go index 872bf8a6..f0b90056 100644 --- a/bridge/crypto.go +++ b/bridge/crypto.go @@ -425,10 +425,10 @@ func (helper *CryptoHelper) ResetSession(ctx context.Context, roomID id.RoomID) } } -func (helper *CryptoHelper) HandleMemberEvent(evt *event.Event) { +func (helper *CryptoHelper) HandleMemberEvent(ctx context.Context, evt *event.Event) { helper.lock.RLock() defer helper.lock.RUnlock() - helper.mach.HandleMemberEvent(0, evt) + helper.mach.HandleMemberEvent(ctx, evt) } // ShareKeys uploads the given number of one-time-keys to the server. @@ -440,7 +440,7 @@ type cryptoSyncer struct { *crypto.OlmMachine } -func (syncer *cryptoSyncer) ProcessResponse(resp *mautrix.RespSync, since string) error { +func (syncer *cryptoSyncer) ProcessResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { done := make(chan struct{}) go func() { defer func() { @@ -454,7 +454,7 @@ func (syncer *cryptoSyncer) ProcessResponse(resp *mautrix.RespSync, since string done <- struct{}{} }() syncer.Log.Trace().Str("since", since).Msg("Starting sync response handling") - syncer.ProcessSyncResponse(resp, since) + syncer.ProcessSyncResponse(ctx, resp, since) syncer.Log.Trace().Str("since", since).Msg("Successfully handled sync response") }() select { diff --git a/bridge/matrix.go b/bridge/matrix.go index 00994dd2..5aa457fa 100644 --- a/bridge/matrix.go +++ b/bridge/matrix.go @@ -68,13 +68,13 @@ func NewMatrixHandler(br *Bridge) *MatrixHandler { return handler } -func (mx *MatrixHandler) sendBridgeCheckpoint(evt *event.Event) { +func (mx *MatrixHandler) sendBridgeCheckpoint(_ context.Context, evt *event.Event) { if !evt.Mautrix.CheckpointSent { go mx.bridge.SendMessageSuccessCheckpoint(evt, status.MsgStepBridge, 0) } } -func (mx *MatrixHandler) HandleEncryption(evt *event.Event) { +func (mx *MatrixHandler) HandleEncryption(ctx context.Context, evt *event.Event) { defer mx.TrackEventDuration(evt.Type)() if evt.Content.AsEncryption().Algorithm != id.AlgorithmMegolmV1 { return @@ -87,7 +87,7 @@ func (mx *MatrixHandler) HandleEncryption(evt *event.Event) { Msg("Encryption was enabled in room") portal.MarkEncrypted() if portal.IsPrivateChat() { - err := mx.as.BotIntent().EnsureJoined(context.TODO(), evt.RoomID, appservice.EnsureJoinedParams{BotOverride: portal.MainIntent().Client}) + err := mx.as.BotIntent().EnsureJoined(ctx, evt.RoomID, appservice.EnsureJoinedParams{BotOverride: portal.MainIntent().Client}) if err != nil { mx.log.Err(err). Str("room_id", evt.RoomID.String()). @@ -232,15 +232,14 @@ func (mx *MatrixHandler) HandleGhostInvite(ctx context.Context, evt *event.Event } } -func (mx *MatrixHandler) HandleMembership(evt *event.Event) { +func (mx *MatrixHandler) HandleMembership(ctx context.Context, evt *event.Event) { if evt.Sender == mx.bridge.Bot.UserID || mx.bridge.Child.IsGhost(evt.Sender) { return } defer mx.TrackEventDuration(evt.Type)() - ctx := context.TODO() if mx.bridge.Crypto != nil { - mx.bridge.Crypto.HandleMemberEvent(evt) + mx.bridge.Crypto.HandleMemberEvent(ctx, evt) } log := mx.log.With(). @@ -300,7 +299,7 @@ func (mx *MatrixHandler) HandleMembership(evt *event.Event) { // TODO kicking/inviting non-ghost users users } -func (mx *MatrixHandler) HandleRoomMetadata(evt *event.Event) { +func (mx *MatrixHandler) HandleRoomMetadata(ctx context.Context, evt *event.Event) { defer mx.TrackEventDuration(evt.Type)() if mx.shouldIgnoreEvent(evt) { return @@ -469,20 +468,20 @@ func (mx *MatrixHandler) postDecrypt(ctx context.Context, original, decrypted *e mx.bridge.SendMessageSuccessCheckpoint(decrypted, status.MsgStepDecrypted, retryCount) decrypted.Mautrix.CheckpointSent = true decrypted.Mautrix.DecryptionDuration = duration - mx.bridge.EventProcessor.Dispatch(decrypted) + decrypted.Mautrix.EventSource |= event.SourceDecrypted + mx.bridge.EventProcessor.Dispatch(ctx, decrypted) if errorEventID != "" { _, _ = mx.bridge.Bot.RedactEvent(ctx, decrypted.RoomID, errorEventID) } } -func (mx *MatrixHandler) HandleEncrypted(evt *event.Event) { +func (mx *MatrixHandler) HandleEncrypted(ctx context.Context, evt *event.Event) { defer mx.TrackEventDuration(evt.Type)() if mx.shouldIgnoreEvent(evt) { return } content := evt.Content.AsEncrypted() - ctx := context.TODO() - log := mx.log.With(). + log := zerolog.Ctx(ctx).With(). Str("event_id", evt.ID.String()). Str("session_id", content.SessionID.String()). Logger() @@ -546,14 +545,14 @@ func (mx *MatrixHandler) waitLongerForSession(ctx context.Context, evt *event.Ev mx.postDecrypt(ctx, evt, decrypted, 2, errorEventID, time.Since(decryptionStart)) } -func (mx *MatrixHandler) HandleMessage(evt *event.Event) { +func (mx *MatrixHandler) HandleMessage(ctx context.Context, evt *event.Event) { defer mx.TrackEventDuration(evt.Type)() - log := mx.log.With(). + log := zerolog.Ctx(ctx).With(). Str("event_id", evt.ID.String()). Str("room_id", evt.RoomID.String()). Str("sender", evt.Sender.String()). Logger() - ctx := log.WithContext(context.TODO()) + ctx = log.WithContext(ctx) if mx.shouldIgnoreEvent(evt) { return } else if !evt.Mautrix.WasEncrypted && mx.bridge.Config.Bridge.GetEncryptionConfig().Require { @@ -604,7 +603,7 @@ func (mx *MatrixHandler) HandleMessage(evt *event.Event) { } } -func (mx *MatrixHandler) HandleReaction(evt *event.Event) { +func (mx *MatrixHandler) HandleReaction(_ context.Context, evt *event.Event) { defer mx.TrackEventDuration(evt.Type)() if mx.shouldIgnoreEvent(evt) { return @@ -623,7 +622,7 @@ func (mx *MatrixHandler) HandleReaction(evt *event.Event) { } } -func (mx *MatrixHandler) HandleRedaction(evt *event.Event) { +func (mx *MatrixHandler) HandleRedaction(_ context.Context, evt *event.Event) { defer mx.TrackEventDuration(evt.Type)() if mx.shouldIgnoreEvent(evt) { return @@ -642,7 +641,7 @@ func (mx *MatrixHandler) HandleRedaction(evt *event.Event) { } } -func (mx *MatrixHandler) HandleReceipt(evt *event.Event) { +func (mx *MatrixHandler) HandleReceipt(_ context.Context, evt *event.Event) { portal := mx.bridge.Child.GetIPortal(evt.RoomID) if portal == nil { return @@ -676,7 +675,7 @@ func (mx *MatrixHandler) HandleReceipt(evt *event.Event) { } } -func (mx *MatrixHandler) HandleTyping(evt *event.Event) { +func (mx *MatrixHandler) HandleTyping(_ context.Context, evt *event.Event) { portal := mx.bridge.Child.GetIPortal(evt.RoomID) if portal == nil { return diff --git a/client.go b/client.go index d1a6d8f0..dfef7231 100644 --- a/client.go +++ b/client.go @@ -236,8 +236,11 @@ func (cli *Client) SyncWithContext(ctx context.Context) error { // Save the token now *before* processing it. This means it's possible // to not process some events, but it means that we won't get constantly stuck processing // a malformed/buggy event which keeps making us panic. - cli.Store.SaveNextBatch(ctx, cli.UserID, resSync.NextBatch) - if err = cli.Syncer.ProcessResponse(resSync, nextBatch); err != nil { + err = cli.Store.SaveNextBatch(ctx, cli.UserID, resSync.NextBatch) + if err != nil { + return err + } + if err = cli.Syncer.ProcessResponse(ctx, resSync, nextBatch); err != nil { return err } diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index eb7d7a77..a0065012 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -245,17 +245,18 @@ var NoSessionFound = crypto.NoSessionFound const initialSessionWaitTimeout = 3 * time.Second const extendedSessionWaitTimeout = 22 * time.Second -func (helper *CryptoHelper) HandleEncrypted(src mautrix.EventSource, evt *event.Event) { +func (helper *CryptoHelper) HandleEncrypted(ctx context.Context, evt *event.Event) { if helper == nil { return } content := evt.Content.AsEncrypted() + // TODO use context log instead of helper? log := helper.log.With(). Str("event_id", evt.ID.String()). Str("session_id", content.SessionID.String()). Logger() log.Debug().Msg("Decrypting received event") - ctx := log.WithContext(context.TODO()) + ctx = log.WithContext(ctx) decrypted, err := helper.Decrypt(ctx, evt) if errors.Is(err, NoSessionFound) { @@ -266,7 +267,7 @@ func (helper *CryptoHelper) HandleEncrypted(src mautrix.EventSource, evt *event. log.Debug().Msg("Got keys after waiting, trying to decrypt event again") decrypted, err = helper.Decrypt(ctx, evt) } else { - go helper.waitLongerForSession(ctx, log, src, evt) + go helper.waitLongerForSession(ctx, log, evt) return } } @@ -275,11 +276,12 @@ func (helper *CryptoHelper) HandleEncrypted(src mautrix.EventSource, evt *event. helper.DecryptErrorCallback(evt, err) return } - helper.postDecrypt(src, decrypted) + helper.postDecrypt(ctx, decrypted) } -func (helper *CryptoHelper) postDecrypt(src mautrix.EventSource, decrypted *event.Event) { - helper.client.Syncer.(mautrix.DispatchableSyncer).Dispatch(src|mautrix.EventSourceDecrypted, decrypted) +func (helper *CryptoHelper) postDecrypt(ctx context.Context, decrypted *event.Event) { + decrypted.Mautrix.EventSource |= event.SourceDecrypted + helper.client.Syncer.(mautrix.DispatchableSyncer).Dispatch(ctx, decrypted) } func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { @@ -309,7 +311,7 @@ func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID } } -func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, log zerolog.Logger, src mautrix.EventSource, evt *event.Event) { +func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, log zerolog.Logger, evt *event.Event) { content := evt.Content.AsEncrypted() log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...") @@ -329,7 +331,7 @@ func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, log zerolo return } - helper.postDecrypt(src, decrypted) + helper.postDecrypt(ctx, decrypted) } func (helper *CryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { diff --git a/crypto/machine.go b/crypto/machine.go index fa0c50dc..9892536a 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -197,9 +197,9 @@ func (mach *OlmMachine) OwnIdentity() *id.Device { } type asEventProcessor interface { - On(evtType event.Type, handler func(evt *event.Event)) - OnOTK(func(otk *mautrix.OTKCount)) - OnDeviceList(func(lists *mautrix.DeviceLists, since string)) + On(evtType event.Type, handler func(ctx context.Context, evt *event.Event)) + OnOTK(func(ctx context.Context, otk *mautrix.OTKCount)) + OnDeviceList(func(ctx context.Context, lists *mautrix.DeviceLists, since string)) } func (mach *OlmMachine) AddAppserviceListener(ep asEventProcessor) { @@ -220,7 +220,7 @@ func (mach *OlmMachine) AddAppserviceListener(ep asEventProcessor) { mach.Log.Debug().Msg("Added listeners for encryption data coming from appservice transactions") } -func (mach *OlmMachine) HandleDeviceLists(dl *mautrix.DeviceLists, since string) { +func (mach *OlmMachine) HandleDeviceLists(ctx context.Context, dl *mautrix.DeviceLists, since string) { if len(dl.Changed) > 0 { traceID := time.Now().Format("15:04:05.000000") mach.Log.Debug(). @@ -228,15 +228,15 @@ func (mach *OlmMachine) HandleDeviceLists(dl *mautrix.DeviceLists, since string) Interface("changes", dl.Changed). Msg("Device list changes in /sync") if mach.DisableKeyFetching { - mach.CryptoStore.MarkTrackedUsersOutdated(context.TODO(), dl.Changed) + mach.CryptoStore.MarkTrackedUsersOutdated(ctx, dl.Changed) } else { - mach.FetchKeys(context.TODO(), dl.Changed, false) + mach.FetchKeys(ctx, dl.Changed, false) } mach.Log.Debug().Str("trace_id", traceID).Msg("Finished handling device list changes") } } -func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) { +func (mach *OlmMachine) HandleOTKCounts(ctx context.Context, otkCount *mautrix.OTKCount) { if (len(otkCount.UserID) > 0 && otkCount.UserID != mach.Client.UserID) || (len(otkCount.DeviceID) > 0 && otkCount.DeviceID != mach.Client.DeviceID) { // TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions mach.Log.Warn(). @@ -250,7 +250,7 @@ func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) { if otkCount.SignedCurve25519 < int(minCount) { traceID := time.Now().Format("15:04:05.000000") log := mach.Log.With().Str("trace_id", traceID).Logger() - ctx := log.WithContext(context.TODO()) + ctx = log.WithContext(ctx) log.Debug(). Int("keys_left", otkCount.Curve25519). Msg("Sync response said we have less than 50 signed curve25519 keys left, sharing new ones...") @@ -268,8 +268,8 @@ func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) { // This can be easily registered into a mautrix client using .OnSync(): // // client.Syncer.(mautrix.ExtensibleSyncer).OnSync(c.crypto.ProcessSyncResponse) -func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string) bool { - mach.HandleDeviceLists(&resp.DeviceLists, since) +func (mach *OlmMachine) ProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) bool { + mach.HandleDeviceLists(ctx, &resp.DeviceLists, since) for _, evt := range resp.ToDevice.Events { evt.Type.Class = event.ToDeviceEventType @@ -278,10 +278,10 @@ func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string mach.Log.Warn().Str("event_type", evt.Type.Type).Err(err).Msg("Failed to parse to-device event") continue } - mach.HandleToDeviceEvent(evt) + mach.HandleToDeviceEvent(ctx, evt) } - mach.HandleOTKCounts(&resp.DeviceOTKCount) + mach.HandleOTKCounts(ctx, &resp.DeviceOTKCount) return true } @@ -290,8 +290,7 @@ func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string // Currently this is not automatically called, so you must add a listener yourself: // // client.Syncer.(mautrix.ExtensibleSyncer).OnEventType(event.StateMember, c.crypto.HandleMemberEvent) -func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Event) { - ctx := context.TODO() +func (mach *OlmMachine) HandleMemberEvent(ctx context.Context, evt *event.Event) { if isEncrypted, err := mach.StateStore.IsEncrypted(ctx, evt.RoomID); err != nil { mach.machOrContextLog(ctx).Err(err).Stringer("room_id", evt.RoomID). Msg("Failed to check if room is encrypted to handle member event") @@ -331,7 +330,7 @@ func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Even // HandleToDeviceEvent handles a single to-device event. This is automatically called by ProcessSyncResponse, so you // don't need to add any custom handlers if you use that method. -func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) { +func (mach *OlmMachine) HandleToDeviceEvent(ctx context.Context, evt *event.Event) { if len(evt.ToUserID) > 0 && (evt.ToUserID != mach.Client.UserID || evt.ToDeviceID != mach.Client.DeviceID) { // TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions mach.Log.Debug(). @@ -341,12 +340,13 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) { return } traceID := time.Now().Format("15:04:05.000000") + // TODO use context log? log := mach.Log.With(). Str("trace_id", traceID). Str("sender", evt.Sender.String()). Str("type", evt.Type.Type). Logger() - ctx := log.WithContext(context.TODO()) + ctx = log.WithContext(ctx) if evt.Type != event.ToDeviceEncrypted { log.Debug().Msg("Starting handling to-device event") } diff --git a/event/events.go b/event/events.go index 57611221..f7b4d4d6 100644 --- a/event/events.go +++ b/event/events.go @@ -105,6 +105,8 @@ func (evt *Event) MarshalJSON() ([]byte, error) { } type MautrixInfo struct { + EventSource Source + TrustState id.TrustState ForwardedKeys bool WasEncrypted bool diff --git a/event/eventsource.go b/event/eventsource.go new file mode 100644 index 00000000..86c1cebe --- /dev/null +++ b/event/eventsource.go @@ -0,0 +1,72 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package event + +import ( + "fmt" +) + +// Source represents the part of the sync response that an event came from. +type Source int + +const ( + SourcePresence Source = 1 << iota + SourceJoin + SourceInvite + SourceLeave + SourceAccountData + SourceTimeline + SourceState + SourceEphemeral + SourceToDevice + SourceDecrypted +) + +const primaryTypes = SourcePresence | SourceAccountData | SourceToDevice | SourceTimeline | SourceState +const roomSections = SourceJoin | SourceInvite | SourceLeave +const roomableTypes = SourceAccountData | SourceTimeline | SourceState +const encryptableTypes = roomableTypes | SourceToDevice + +func (es Source) String() string { + var typeName string + switch es & primaryTypes { + case SourcePresence: + typeName = "presence" + case SourceAccountData: + typeName = "account data" + case SourceToDevice: + typeName = "to-device" + case SourceTimeline: + typeName = "timeline" + case SourceState: + typeName = "state" + default: + return fmt.Sprintf("unknown (%d)", es) + } + if es&roomableTypes != 0 { + switch es & roomSections { + case SourceJoin: + typeName = "joined room " + typeName + case SourceInvite: + typeName = "invited room " + typeName + case SourceLeave: + typeName = "left room " + typeName + default: + return fmt.Sprintf("unknown (%s+%d)", typeName, es) + } + es &^= roomSections + } + if es&encryptableTypes != 0 && es&SourceDecrypted != 0 { + typeName += " (decrypted)" + es &^= SourceDecrypted + } + es &^= primaryTypes + if es != 0 { + return fmt.Sprintf("unknown (%s+%d)", typeName, es) + } + return typeName +} diff --git a/statestore.go b/statestore.go index 63a5bfb4..8fe5f8b3 100644 --- a/statestore.go +++ b/statestore.go @@ -67,8 +67,8 @@ func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) { // client.Syncer.(mautrix.ExtensibleSyncer).OnEvent(client.StateStoreSyncHandler) // // DefaultSyncer.ParseEventContent must also be true for this to work (which it is by default). -func (cli *Client) StateStoreSyncHandler(_ EventSource, evt *event.Event) { - UpdateStateStore(cli.Log.WithContext(context.TODO()), cli.StateStore, evt) +func (cli *Client) StateStoreSyncHandler(ctx context.Context, evt *event.Event) { + UpdateStateStore(ctx, cli.StateStore, evt) } type MemoryStateStore struct { diff --git a/sync.go b/sync.go index f05e9b5f..d4208404 100644 --- a/sync.go +++ b/sync.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -7,6 +7,7 @@ package mautrix import ( + "context" "errors" "fmt" "runtime/debug" @@ -16,78 +17,17 @@ import ( "maunium.net/go/mautrix/id" ) -// EventSource represents the part of the sync response that an event came from. -type EventSource int - -const ( - EventSourcePresence EventSource = 1 << iota - EventSourceJoin - EventSourceInvite - EventSourceLeave - EventSourceAccountData - EventSourceTimeline - EventSourceState - EventSourceEphemeral - EventSourceToDevice - EventSourceDecrypted -) - -const primaryTypes = EventSourcePresence | EventSourceAccountData | EventSourceToDevice | EventSourceTimeline | EventSourceState -const roomSections = EventSourceJoin | EventSourceInvite | EventSourceLeave -const roomableTypes = EventSourceAccountData | EventSourceTimeline | EventSourceState -const encryptableTypes = roomableTypes | EventSourceToDevice - -func (es EventSource) String() string { - var typeName string - switch es & primaryTypes { - case EventSourcePresence: - typeName = "presence" - case EventSourceAccountData: - typeName = "account data" - case EventSourceToDevice: - typeName = "to-device" - case EventSourceTimeline: - typeName = "timeline" - case EventSourceState: - typeName = "state" - default: - return fmt.Sprintf("unknown (%d)", es) - } - if es&roomableTypes != 0 { - switch es & roomSections { - case EventSourceJoin: - typeName = "joined room " + typeName - case EventSourceInvite: - typeName = "invited room " + typeName - case EventSourceLeave: - typeName = "left room " + typeName - default: - return fmt.Sprintf("unknown (%s+%d)", typeName, es) - } - es &^= roomSections - } - if es&encryptableTypes != 0 && es&EventSourceDecrypted != 0 { - typeName += " (decrypted)" - es &^= EventSourceDecrypted - } - es &^= primaryTypes - if es != 0 { - return fmt.Sprintf("unknown (%s+%d)", typeName, es) - } - return typeName -} - // EventHandler handles a single event from a sync response. -type EventHandler func(source EventSource, evt *event.Event) +type EventHandler func(ctx context.Context, evt *event.Event) // SyncHandler handles a whole sync response. If the return value is false, handling will be stopped completely. -type SyncHandler func(resp *RespSync, since string) bool +type SyncHandler func(ctx context.Context, resp *RespSync, since string) bool // Syncer is an interface that must be satisfied in order to do /sync requests on a client. type Syncer interface { // ProcessResponse processes the /sync response. The since parameter is the since= value that was used to produce the response. // This is useful for detecting the very first sync (since=""). If an error is return, Syncing will be stopped permanently. - ProcessResponse(resp *RespSync, since string) error + ProcessResponse(ctx context.Context, resp *RespSync, since string) error // OnFailedSync returns either the time to wait before retrying or an error to stop syncing permanently. OnFailedSync(res *RespSync, err error) (time.Duration, error) // GetFilterJSON for the given user ID. NOT the filter ID. @@ -101,7 +41,7 @@ type ExtensibleSyncer interface { } type DispatchableSyncer interface { - Dispatch(source EventSource, evt *event.Event) + Dispatch(ctx context.Context, evt *event.Event) } // DefaultSyncer is the default syncing implementation. You can either write your own syncer, or selectively @@ -144,7 +84,7 @@ func NewDefaultSyncer() *DefaultSyncer { // ProcessResponse processes the /sync response in a way suitable for bots. "Suitable for bots" means a stream of // unrepeating events. Returns a fatal error if a listener panics. -func (s *DefaultSyncer) ProcessResponse(res *RespSync, since string) (err error) { +func (s *DefaultSyncer) ProcessResponse(ctx context.Context, res *RespSync, since string) (err error) { defer func() { if r := recover(); r != nil { err = fmt.Errorf("ProcessResponse panicked! since=%s panic=%s\n%s", since, r, debug.Stack()) @@ -152,38 +92,38 @@ func (s *DefaultSyncer) ProcessResponse(res *RespSync, since string) (err error) }() for _, listener := range s.syncListeners { - if !listener(res, since) { + if !listener(ctx, res, since) { return } } - s.processSyncEvents("", res.ToDevice.Events, EventSourceToDevice) - s.processSyncEvents("", res.Presence.Events, EventSourcePresence) - s.processSyncEvents("", res.AccountData.Events, EventSourceAccountData) + s.processSyncEvents(ctx, "", res.ToDevice.Events, event.SourceToDevice) + s.processSyncEvents(ctx, "", res.Presence.Events, event.SourcePresence) + s.processSyncEvents(ctx, "", res.AccountData.Events, event.SourceAccountData) for roomID, roomData := range res.Rooms.Join { - s.processSyncEvents(roomID, roomData.State.Events, EventSourceJoin|EventSourceState) - s.processSyncEvents(roomID, roomData.Timeline.Events, EventSourceJoin|EventSourceTimeline) - s.processSyncEvents(roomID, roomData.Ephemeral.Events, EventSourceJoin|EventSourceEphemeral) - s.processSyncEvents(roomID, roomData.AccountData.Events, EventSourceJoin|EventSourceAccountData) + s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceJoin|event.SourceState) + s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline) + s.processSyncEvents(ctx, roomID, roomData.Ephemeral.Events, event.SourceJoin|event.SourceEphemeral) + s.processSyncEvents(ctx, roomID, roomData.AccountData.Events, event.SourceJoin|event.SourceAccountData) } for roomID, roomData := range res.Rooms.Invite { - s.processSyncEvents(roomID, roomData.State.Events, EventSourceInvite|EventSourceState) + s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceInvite|event.SourceState) } for roomID, roomData := range res.Rooms.Leave { - s.processSyncEvents(roomID, roomData.State.Events, EventSourceLeave|EventSourceState) - s.processSyncEvents(roomID, roomData.Timeline.Events, EventSourceLeave|EventSourceTimeline) + s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceLeave|event.SourceState) + s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceLeave|event.SourceTimeline) } return } -func (s *DefaultSyncer) processSyncEvents(roomID id.RoomID, events []*event.Event, source EventSource) { +func (s *DefaultSyncer) processSyncEvents(ctx context.Context, roomID id.RoomID, events []*event.Event, source event.Source) { for _, evt := range events { - s.processSyncEvent(roomID, evt, source) + s.processSyncEvent(ctx, roomID, evt, source) } } -func (s *DefaultSyncer) processSyncEvent(roomID id.RoomID, evt *event.Event, source EventSource) { +func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID, evt *event.Event, source event.Source) { evt.RoomID = roomID // Ensure the type class is correct. It's safe to mutate the class since the event type is not a pointer. @@ -191,11 +131,11 @@ func (s *DefaultSyncer) processSyncEvent(roomID id.RoomID, evt *event.Event, sou switch { case evt.StateKey != nil: evt.Type.Class = event.StateEventType - case source == EventSourcePresence, source&EventSourceEphemeral != 0: + case source == event.SourcePresence, source&event.SourceEphemeral != 0: evt.Type.Class = event.EphemeralEventType - case source&EventSourceAccountData != 0: + case source&event.SourceAccountData != 0: evt.Type.Class = event.AccountDataEventType - case source == EventSourceToDevice: + case source == event.SourceToDevice: evt.Type.Class = event.ToDeviceEventType default: evt.Type.Class = event.MessageEventType @@ -208,17 +148,18 @@ func (s *DefaultSyncer) processSyncEvent(roomID id.RoomID, evt *event.Event, sou } } - s.Dispatch(source, evt) + evt.Mautrix.EventSource = source + s.Dispatch(ctx, evt) } -func (s *DefaultSyncer) Dispatch(source EventSource, evt *event.Event) { +func (s *DefaultSyncer) Dispatch(ctx context.Context, evt *event.Event) { for _, fn := range s.globalListeners { - fn(source, evt) + fn(ctx, evt) } listeners, exists := s.listeners[evt.Type] if exists { for _, fn := range listeners { - fn(source, evt) + fn(ctx, evt) } } } @@ -266,31 +207,18 @@ func (s *DefaultSyncer) GetFilterJSON(userID id.UserID) *Filter { return s.FilterJSON } -// OldEventIgnorer is a utility struct for bots to ignore events from before the bot joined the room. -// -// Deprecated: Use Client.DontProcessOldEvents instead. -type OldEventIgnorer struct { - UserID id.UserID -} - -func (oei *OldEventIgnorer) Register(syncer ExtensibleSyncer) { - syncer.OnSync(oei.DontProcessOldEvents) -} - -func (oei *OldEventIgnorer) DontProcessOldEvents(resp *RespSync, since string) bool { - return dontProcessOldEvents(oei.UserID, resp, since) -} - // DontProcessOldEvents is a sync handler that removes rooms that the user just joined. // It's meant for bots to ignore events from before the bot joined the room. // // To use it, register it with your Syncer, e.g.: // // cli.Syncer.(mautrix.ExtensibleSyncer).OnSync(cli.DontProcessOldEvents) -func (cli *Client) DontProcessOldEvents(resp *RespSync, since string) bool { +func (cli *Client) DontProcessOldEvents(_ context.Context, resp *RespSync, since string) bool { return dontProcessOldEvents(cli.UserID, resp, since) } +var _ SyncHandler = (*Client)(nil).DontProcessOldEvents + func dontProcessOldEvents(userID id.UserID, resp *RespSync, since string) bool { if since == "" { return false @@ -327,7 +255,7 @@ func dontProcessOldEvents(userID id.UserID, resp *RespSync, since string) bool { // To use it, register it with your Syncer, e.g.: // // cli.Syncer.(mautrix.ExtensibleSyncer).OnSync(cli.MoveInviteState) -func (cli *Client) MoveInviteState(resp *RespSync, _ string) bool { +func (cli *Client) MoveInviteState(ctx context.Context, resp *RespSync, _ string) bool { for _, meta := range resp.Rooms.Invite { var inviteState []event.StrippedState var inviteEvt *event.Event @@ -352,3 +280,5 @@ func (cli *Client) MoveInviteState(resp *RespSync, _ string) bool { } return true } + +var _ SyncHandler = (*Client)(nil).MoveInviteState