diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index c49dbf1c..d2969eda 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -18,6 +18,8 @@ import ( ) var ( + ErrIgnoringRemoteEvent error = errors.New("ignoring remote event") + ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage() ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false) ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false) diff --git a/bridgev2/networkid/bridgeid.go b/bridgev2/networkid/bridgeid.go index 65d34609..46f82155 100644 --- a/bridgev2/networkid/bridgeid.go +++ b/bridgev2/networkid/bridgeid.go @@ -86,6 +86,11 @@ type UserLoginID string // Message IDs must be unique across rooms and consistent across users (i.e. globally unique within the bridge). type MessageID string +// TransactionID is a client-generated identifier for a message send operation on the remote network. +// +// Transaction IDs must be unique across users in a room, but don't need to be unique across different rooms. +type TransactionID string + // PartID is the ID of a message part on the remote network (e.g. index of image in album). // // Part IDs are only unique within a message, not globally. diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 1e943d4d..c932a4c1 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -241,6 +241,9 @@ type MaxFileSizeingNetwork interface { type MatrixMessageResponse struct { DB *database.Message + + Pending networkid.TransactionID + HandleEcho func(RemoteMessage, *database.Message) (bool, error) } type FileRestriction struct { @@ -629,6 +632,8 @@ func (ret RemoteEventType) String() string { return "RemoteEventUnknown" case RemoteEventMessage: return "RemoteEventMessage" + case RemoteEventMessageUpsert: + return "RemoteEventMessageUpsert" case RemoteEventEdit: return "RemoteEventEdit" case RemoteEventReaction: @@ -663,6 +668,7 @@ func (ret RemoteEventType) String() string { const ( RemoteEventUnknown RemoteEventType = iota RemoteEventMessage + RemoteEventMessageUpsert RemoteEventEdit RemoteEventReaction RemoteEventReactionRemove @@ -744,6 +750,21 @@ type RemoteMessage interface { ConvertMessage(ctx context.Context, portal *Portal, intent MatrixAPI) (*ConvertedMessage, error) } +type UpsertResult struct { + SubEvents []RemoteEvent + ContinueMessageHandling bool +} + +type RemoteMessageUpsert interface { + RemoteMessage + HandleExisting(ctx context.Context, portal *Portal, intent MatrixAPI, existing []*database.Message) (UpsertResult, error) +} + +type RemoteMessageWithTransactionID interface { + RemoteMessage + GetTransactionID() networkid.TransactionID +} + type RemoteEdit interface { RemoteEventWithTargetMessage ConvertEdit(ctx context.Context, portal *Portal, intent MatrixAPI, existing []*database.Message) (*ConvertedEdit, error) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index c02c6bf9..0d24b9c4 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -54,6 +54,12 @@ type portalEvent interface { isPortalEvent() } +type outgoingMessage struct { + db *database.Message + evt *event.Event + handle func(RemoteMessage, *database.Message) (bool, error) +} + type Portal struct { *database.Portal Bridge *Bridge @@ -65,6 +71,9 @@ type Portal struct { currentlyTypingLogins map[id.UserID]*UserLogin currentlyTypingLock sync.Mutex + outgoingMessages map[networkid.TransactionID]outgoingMessage + outgoingMessagesLock sync.Mutex + roomCreateLock sync.Mutex events chan portalEvent @@ -93,9 +102,9 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que Portal: dbPortal, Bridge: br, - events: make(chan portalEvent, PortalEventBuffer), - + events: make(chan portalEvent, PortalEventBuffer), currentlyTypingLogins: make(map[id.UserID]*UserLogin), + outgoingMessages: make(map[networkid.TransactionID]outgoingMessage), } br.portalsByKey[portal.PortalKey] = portal if portal.MXID != "" { @@ -767,12 +776,25 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin if message.SenderMXID == "" { message.SenderMXID = evt.Sender } - // Hack to ensure the ghost row exists - // TODO move to better place (like login) - portal.Bridge.GetGhostByID(ctx, message.SenderID) - err = portal.Bridge.DB.Message.Insert(ctx, message) - if err != nil { - log.Err(err).Msg("Failed to save message to database") + if resp.Pending != "" { + // TODO if the event queue is ever removed, this will have to be done by the network connector before sending the request + // (for now this is fine because incoming messages will wait in the queue for this function to return) + portal.outgoingMessagesLock.Lock() + portal.outgoingMessages[resp.Pending] = outgoingMessage{ + db: message, + evt: evt, + handle: resp.HandleEcho, + } + portal.outgoingMessagesLock.Unlock() + } else { + // Hack to ensure the ghost row exists + // TODO move to better place (like login) + portal.Bridge.GetGhostByID(ctx, message.SenderID) + err = portal.Bridge.DB.Message.Insert(ctx, message) + if err != nil { + log.Err(err).Msg("Failed to save message to database") + } + portal.sendSuccessStatus(ctx, evt) } if portal.Disappear.Type != database.DisappearingTypeNone { go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ @@ -785,7 +807,6 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin }, }) } - portal.sendSuccessStatus(ctx, evt) } func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *NetworkRoomCapabilities) { @@ -1227,7 +1248,7 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { switch evtType { case RemoteEventUnknown: log.Debug().Msg("Ignoring remote event with type unknown") - case RemoteEventMessage: + case RemoteEventMessage, RemoteEventMessageUpsert: portal.handleRemoteMessage(ctx, source, evt.(RemoteMessage)) case RemoteEventEdit: portal.handleRemoteEdit(ctx, source, evt.(RemoteEdit)) @@ -1366,7 +1387,7 @@ func (portal *Portal) applyRelationMeta(content *event.MessageEventContent, repl } } -func (portal *Portal) sendConvertedMessage(ctx context.Context, id networkid.MessageID, intent MatrixAPI, sender EventSender, converted *ConvertedMessage, ts time.Time, logContext func(*zerolog.Event) *zerolog.Event) []*database.Message { +func (portal *Portal) sendConvertedMessage(ctx context.Context, id networkid.MessageID, intent MatrixAPI, senderID networkid.UserID, converted *ConvertedMessage, ts time.Time, logContext func(*zerolog.Event) *zerolog.Event) []*database.Message { if logContext == nil { logContext = func(e *zerolog.Event) *zerolog.Event { return e @@ -1381,7 +1402,7 @@ func (portal *Portal) sendConvertedMessage(ctx context.Context, id networkid.Mes ID: id, PartID: part.ID, Room: portal.PortalKey, - SenderID: sender.Sender, + SenderID: senderID, SenderMXID: intent.GetMXID(), Timestamp: ts, ThreadRoot: ptr.Val(converted.ThreadRoot), @@ -1430,14 +1451,94 @@ func (portal *Portal) sendConvertedMessage(ctx context.Context, id networkid.Mes return output } +func (portal *Portal) checkPendingMessage(ctx context.Context, evt RemoteMessage) (bool, *database.Message) { + evtWithTxn, ok := evt.(RemoteMessageWithTransactionID) + if !ok { + return false, nil + } + txnID := evtWithTxn.GetTransactionID() + if txnID == "" { + return false, nil + } + portal.outgoingMessagesLock.Lock() + defer portal.outgoingMessagesLock.Unlock() + pending, ok := portal.outgoingMessages[txnID] + if !ok { + return false, nil + } + delete(portal.outgoingMessages, txnID) + pending.db.ID = evt.GetID() + if pending.db.SenderID == "" { + pending.db.SenderID = evt.GetSender().Sender + } + evtWithTimestamp, ok := evt.(RemoteEventWithTimestamp) + if ok { + pending.db.Timestamp = evtWithTimestamp.GetTimestamp() + } + var statusErr error + saveMessage := true + if pending.handle != nil { + saveMessage, statusErr = pending.handle(evt, pending.db) + } + if saveMessage { + // Hack to ensure the ghost row exists + // TODO move to better place (like login) + portal.Bridge.GetGhostByID(ctx, pending.db.SenderID) + err := portal.Bridge.DB.Message.Insert(ctx, pending.db) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save message to database after receiving remote echo") + } + } + if statusErr != nil { + portal.sendErrorStatus(ctx, pending.evt, statusErr) + } else { + portal.sendSuccessStatus(ctx, pending.evt) + } + zerolog.Ctx(ctx).Debug().Stringer("event_id", pending.evt.ID).Msg("Received remote echo for message") + return true, pending.db +} + +func (portal *Portal) handleRemoteUpsert(ctx context.Context, source *UserLogin, evt RemoteMessageUpsert, existing []*database.Message) bool { + log := zerolog.Ctx(ctx) + intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageUpsert) + if intent == nil { + return false + } + res, err := evt.HandleExisting(ctx, portal, intent, existing) + if err != nil { + log.Err(err).Msg("Failed to handle existing message in upsert event after receiving remote echo") + } else if len(res.SubEvents) > 0 { + for _, subEvt := range res.SubEvents { + portal.handleRemoteEvent(source, subEvt) + } + } + return res.ContinueMessageHandling +} + func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin, evt RemoteMessage) { log := zerolog.Ctx(ctx) - existing, err := portal.Bridge.DB.Message.GetFirstPartByID(ctx, portal.Receiver, evt.GetID()) + upsertEvt, isUpsert := evt.(RemoteMessageUpsert) + isUpsert = isUpsert && evt.GetType() == RemoteEventMessageUpsert + if wasPending, dbMessage := portal.checkPendingMessage(ctx, evt); wasPending { + if isUpsert { + portal.handleRemoteUpsert(ctx, source, upsertEvt, []*database.Message{dbMessage}) + } + return + } + existing, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, evt.GetID()) if err != nil { log.Err(err).Msg("Failed to check if message is a duplicate") - } else if existing != nil { - log.Debug().Stringer("existing_mxid", existing.MXID).Msg("Ignoring duplicate message") - return + } else if len(existing) > 0 { + if isUpsert { + if portal.handleRemoteUpsert(ctx, source, upsertEvt, existing) { + log.Debug().Msg("Upsert handler said to continue message handling normally") + } else { + return + } + } else { + log.Debug().Stringer("existing_mxid", existing[0].MXID).Msg("Ignoring duplicate message") + return + } } intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessage) if intent == nil { @@ -1446,11 +1547,15 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin ts := getEventTS(evt) converted, err := evt.ConvertMessage(ctx, portal, intent) if err != nil { - log.Err(err).Msg("Failed to convert remote message") - portal.sendRemoteErrorNotice(ctx, intent, err, ts, "message") + if errors.Is(err, ErrIgnoringRemoteEvent) { + log.Debug().Err(err).Msg("Remote event handling was cancelled by convert function") + } else { + log.Err(err).Msg("Failed to convert remote message") + portal.sendRemoteErrorNotice(ctx, intent, err, ts, "message") + } return } - portal.sendConvertedMessage(ctx, evt.GetID(), intent, evt.GetSender(), converted, ts, nil) + portal.sendConvertedMessage(ctx, evt.GetID(), intent, evt.GetSender().Sender, converted, ts, nil) } func (portal *Portal) sendRemoteErrorNotice(ctx context.Context, intent MatrixAPI, err error, ts time.Time, evtTypeName string) { diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 8b1bdeb1..b05b16e5 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -369,7 +369,7 @@ func (portal *Portal) sendLegacyBackfill(ctx context.Context, source *UserLogin, var lastPart id.EventID for _, msg := range messages { intent := portal.GetIntentFor(ctx, msg.Sender, source, RemoteEventMessage) - dbMessages := portal.sendConvertedMessage(ctx, msg.ID, intent, msg.Sender, msg.ConvertedMessage, msg.Timestamp, func(z *zerolog.Event) *zerolog.Event { + dbMessages := portal.sendConvertedMessage(ctx, msg.ID, intent, msg.Sender.Sender, msg.ConvertedMessage, msg.Timestamp, func(z *zerolog.Event) *zerolog.Event { return z. Str("message_id", string(msg.ID)). Any("sender_id", msg.Sender).