bridgev2: allow adding pending message before returning from handler

This commit is contained in:
Tulir Asokan 2024-09-04 22:17:33 +03:00
commit e5ea10d64c
2 changed files with 100 additions and 45 deletions

View file

@ -273,11 +273,16 @@ type MaxFileSizeingNetwork interface {
SetMaxFileSize(maxSize int64)
}
type RemoteEchoHandler func(RemoteMessage, *database.Message) (bool, error)
type MatrixMessageResponse struct {
DB *database.Message
Pending networkid.TransactionID
HandleEcho func(RemoteMessage, *database.Message) (bool, error)
// If Pending is set, the bridge will not save the provided message to the database.
// This should only be used if AddPendingToSave has been called.
Pending bool
// If RemovePending is set, the bridge will remove the provided transaction ID from pending messages
// after saving the provided message to the database. This should be used with AddPendingToIgnore.
RemovePending networkid.TransactionID
}
type FileRestriction struct {

View file

@ -59,6 +59,7 @@ type portalEvent interface {
type outgoingMessage struct {
db *database.Message
evt *event.Event
ignore bool
handle func(RemoteMessage, *database.Message) (bool, error)
}
@ -775,7 +776,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin
}
}
resp, err := sender.Client.HandleMatrixMessage(ctx, &MatrixMessage{
wrappedEvt := &MatrixMessage{
MatrixEventBase: MatrixEventBase[*event.MessageEventContent]{
Event: evt,
Content: content,
@ -784,52 +785,30 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin
},
ThreadRoot: threadRoot,
ReplyTo: replyTo,
})
}
resp, err := sender.Client.HandleMatrixMessage(ctx, wrappedEvt)
if err != nil {
log.Err(err).Msg("Failed to handle Matrix message")
portal.sendErrorStatus(ctx, evt, err)
return
}
message := resp.DB
if message.MXID == "" {
message.MXID = evt.ID
}
if message.Room.ID == "" {
message.Room = portal.PortalKey
}
if message.Timestamp.IsZero() {
message.Timestamp = time.UnixMilli(evt.Timestamp)
}
if message.ReplyTo.MessageID == "" && replyTo != nil {
message.ReplyTo.MessageID = replyTo.ID
message.ReplyTo.PartID = &replyTo.PartID
}
if message.ThreadRoot == "" && threadRoot != nil {
message.ThreadRoot = threadRoot.ID
if threadRoot.ThreadRoot != "" {
message.ThreadRoot = threadRoot.ThreadRoot
}
}
if message.SenderMXID == "" {
message.SenderMXID = evt.Sender
}
if resp.Pending != "" {
// TODO if the event queue is ever removed, this will have to be done by the network connector before sending the request
// (for now this is fine because incoming messages will wait in the queue for this function to return)
portal.outgoingMessagesLock.Lock()
portal.outgoingMessages[resp.Pending] = outgoingMessage{
db: message,
evt: evt,
handle: resp.HandleEcho,
}
portal.outgoingMessagesLock.Unlock()
} else {
// Hack to ensure the ghost row exists
// TODO move to better place (like login)
portal.Bridge.GetGhostByID(ctx, message.SenderID)
err = portal.Bridge.DB.Message.Insert(ctx, message)
if err != nil {
log.Err(err).Msg("Failed to save message to database")
message := wrappedEvt.fillDBMessage(resp.DB)
if !resp.Pending {
if resp.DB == nil {
log.Error().Msg("Network connector didn't return a message to save")
} else {
// Hack to ensure the ghost row exists
// TODO move to better place (like login)
portal.Bridge.GetGhostByID(ctx, message.SenderID)
err = portal.Bridge.DB.Message.Insert(ctx, message)
if err != nil {
log.Err(err).Msg("Failed to save message to database")
}
if resp.RemovePending != "" {
portal.outgoingMessagesLock.Lock()
delete(portal.outgoingMessages, resp.RemovePending)
portal.outgoingMessagesLock.Unlock()
}
}
portal.sendSuccessStatus(ctx, evt)
}
@ -846,6 +825,75 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin
}
}
// AddPendingToIgnore adds a transaction ID that should be ignored if encountered as a new message.
//
// This should be used when the network connector will return the real message ID from HandleMatrixMessage.
// The [MatrixMessageResponse] should include RemovePending with the transaction ID sto remove it from the lit
// after saving to database.
//
// See also: [MatrixMessage.AddPendingToSave]
func (evt *MatrixMessage) AddPendingToIgnore(txnID networkid.TransactionID) {
evt.Portal.outgoingMessagesLock.Lock()
evt.Portal.outgoingMessages[txnID] = outgoingMessage{
ignore: true,
}
evt.Portal.outgoingMessagesLock.Unlock()
}
// AddPendingToSave adds a transaction ID that should be processed and pointed at the existing event if encountered.
//
// This should be used when the network connector returns `Pending: true` from HandleMatrixMessage,
// i.e. when the network connector does not know the message ID at the end of the handler.
// The [MatrixMessageResponse] should set Pending to true to prevent saving the returned message to the database.
//
// The provided function will be called when the message is encountered.
func (evt *MatrixMessage) AddPendingToSave(message *database.Message, txnID networkid.TransactionID, handleEcho RemoteEchoHandler) {
evt.Portal.outgoingMessagesLock.Lock()
evt.Portal.outgoingMessages[txnID] = outgoingMessage{
db: evt.fillDBMessage(message),
evt: evt.Event,
handle: handleEcho,
}
evt.Portal.outgoingMessagesLock.Unlock()
}
// RemovePending removes a transaction ID from the list of pending messages.
// This should only be called if sending the message fails.
func (evt *MatrixMessage) RemovePending(txnID networkid.TransactionID) {
evt.Portal.outgoingMessagesLock.Lock()
delete(evt.Portal.outgoingMessages, txnID)
evt.Portal.outgoingMessagesLock.Unlock()
}
func (evt *MatrixMessage) fillDBMessage(message *database.Message) *database.Message {
if message == nil {
message = &database.Message{}
}
if message.MXID == "" {
message.MXID = evt.Event.ID
}
if message.Room.ID == "" {
message.Room = evt.Portal.PortalKey
}
if message.Timestamp.IsZero() {
message.Timestamp = time.UnixMilli(evt.Event.Timestamp)
}
if message.ReplyTo.MessageID == "" && evt.ReplyTo != nil {
message.ReplyTo.MessageID = evt.ReplyTo.ID
message.ReplyTo.PartID = &evt.ReplyTo.PartID
}
if message.ThreadRoot == "" && evt.ThreadRoot != nil {
message.ThreadRoot = evt.ThreadRoot.ID
if evt.ThreadRoot.ThreadRoot != "" {
message.ThreadRoot = evt.ThreadRoot.ThreadRoot
}
}
if message.SenderMXID == "" {
message.SenderMXID = evt.Event.Sender
}
return message
}
func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *NetworkRoomCapabilities) {
log := zerolog.Ctx(ctx)
editTargetID := content.RelatesTo.GetReplaceID()
@ -1715,6 +1763,8 @@ func (portal *Portal) checkPendingMessage(ctx context.Context, evt RemoteMessage
pending, ok := portal.outgoingMessages[txnID]
if !ok {
return false, nil
} else if pending.ignore {
return true, nil
}
delete(portal.outgoingMessages, txnID)
pending.db.ID = evt.GetID()