From 2c6ca02eeb56eeacfafc8dbfa442eeac052f472e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 18 Jun 2024 17:31:48 +0300 Subject: [PATCH] bridgev2: add support for disappearing messages --- bridgev2/bridge.go | 5 + bridgev2/database/database.go | 36 +++--- bridgev2/database/disappear.go | 106 +++++++++++++++++ bridgev2/database/portal.go | 3 + bridgev2/database/upgrades/00-latest.sql | 13 ++- .../upgrades/02-disappearing-messages.sql | 11 ++ bridgev2/disappear.go | 110 ++++++++++++++++++ bridgev2/networkinterface.go | 1 + bridgev2/portal.go | 33 +++++- 9 files changed, 297 insertions(+), 21 deletions(-) create mode 100644 bridgev2/database/disappear.go create mode 100644 bridgev2/database/upgrades/02-disappearing-messages.sql create mode 100644 bridgev2/disappear.go diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 8eaf6a1e..9a73bc2a 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -35,6 +35,8 @@ type Bridge struct { Commands *CommandProcessor Config *bridgeconfig.BridgeConfig + DisappearLoop *DisappearLoop + usersByMXID map[id.UserID]*User userLoginsByID map[networkid.UserLoginID]*UserLogin portalsByKey map[networkid.PortalKey]*Portal @@ -66,6 +68,7 @@ func NewBridge(bridgeID networkid.BridgeID, db *dbutil.Database, log zerolog.Log br.Matrix.Init(br) br.Bot = br.Matrix.BotIntent() br.Network.Init(br) + br.DisappearLoop = &DisappearLoop{br: br} return br } @@ -100,6 +103,8 @@ func (br *Bridge) Start() error { if err != nil { return fmt.Errorf("failed to start network connector: %w", err) } + // TODO only start if the network supports disappearing messages? + go br.DisappearLoop.Start() logins, err := br.GetAllUserLogins(ctx) if err != nil { diff --git a/bridgev2/database/database.go b/bridgev2/database/database.go index c910498a..c6d1e4eb 100644 --- a/bridgev2/database/database.go +++ b/bridgev2/database/database.go @@ -23,28 +23,30 @@ import ( type Database struct { *dbutil.Database - BridgeID networkid.BridgeID - Portal *PortalQuery - Ghost *GhostQuery - Message *MessageQuery - Reaction *ReactionQuery - User *UserQuery - UserLogin *UserLoginQuery - UserPortal *UserPortalQuery + BridgeID networkid.BridgeID + Portal *PortalQuery + Ghost *GhostQuery + Message *MessageQuery + DisappearingMessage *DisappearingMessageQuery + Reaction *ReactionQuery + User *UserQuery + UserLogin *UserLoginQuery + UserPortal *UserPortalQuery } func New(bridgeID networkid.BridgeID, db *dbutil.Database) *Database { db.UpgradeTable = upgrades.Table return &Database{ - Database: db, - BridgeID: bridgeID, - Portal: &PortalQuery{bridgeID, dbutil.MakeQueryHelper(db, newPortal)}, - Ghost: &GhostQuery{bridgeID, dbutil.MakeQueryHelper(db, newGhost)}, - Message: &MessageQuery{bridgeID, dbutil.MakeQueryHelper(db, newMessage)}, - Reaction: &ReactionQuery{bridgeID, dbutil.MakeQueryHelper(db, newReaction)}, - User: &UserQuery{bridgeID, dbutil.MakeQueryHelper(db, newUser)}, - UserLogin: &UserLoginQuery{bridgeID, dbutil.MakeQueryHelper(db, newUserLogin)}, - UserPortal: &UserPortalQuery{bridgeID, dbutil.MakeQueryHelper(db, newUserPortal)}, + Database: db, + BridgeID: bridgeID, + Portal: &PortalQuery{bridgeID, dbutil.MakeQueryHelper(db, newPortal)}, + Ghost: &GhostQuery{bridgeID, dbutil.MakeQueryHelper(db, newGhost)}, + Message: &MessageQuery{bridgeID, dbutil.MakeQueryHelper(db, newMessage)}, + DisappearingMessage: &DisappearingMessageQuery{bridgeID, dbutil.MakeQueryHelper(db, newDisappearingMessage)}, + Reaction: &ReactionQuery{bridgeID, dbutil.MakeQueryHelper(db, newReaction)}, + User: &UserQuery{bridgeID, dbutil.MakeQueryHelper(db, newUser)}, + UserLogin: &UserLoginQuery{bridgeID, dbutil.MakeQueryHelper(db, newUserLogin)}, + UserPortal: &UserPortalQuery{bridgeID, dbutil.MakeQueryHelper(db, newUserPortal)}, } } diff --git a/bridgev2/database/disappear.go b/bridgev2/database/disappear.go new file mode 100644 index 00000000..22a5be5c --- /dev/null +++ b/bridgev2/database/disappear.go @@ -0,0 +1,106 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "time" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/id" +) + +// DisappearingType represents the type of a disappearing message timer. +type DisappearingType string + +const ( + DisappearingTypeNone DisappearingType = "" + DisappearingTypeAfterRead DisappearingType = "after_read" + DisappearingTypeAfterSend DisappearingType = "after_send" +) + +// DisappearingSetting represents a disappearing message timer setting +// by combining a type with a timer and an optional start timestamp. +type DisappearingSetting struct { + Type DisappearingType + Timer time.Duration + DisappearAt time.Time +} + +type DisappearingMessageQuery struct { + BridgeID networkid.BridgeID + *dbutil.QueryHelper[*DisappearingMessage] +} + +type DisappearingMessage struct { + BridgeID networkid.BridgeID + RoomID id.RoomID + EventID id.EventID + DisappearingSetting +} + +func newDisappearingMessage(_ *dbutil.QueryHelper[*DisappearingMessage]) *DisappearingMessage { + return &DisappearingMessage{} +} + +const ( + upsertDisappearingMessageQuery = ` + INSERT INTO disappearing_message (bridge_id, mx_room, mxid, type, timer, disappear_at) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (bridge_id, mxid) DO UPDATE SET timer=excluded.timer, disappear_at=excluded.disappear_at + ` + startDisappearingMessagesQuery = ` + UPDATE disappearing_message + SET disappear_at=$1 + timer + WHERE bridge_id=$2 AND mx_room=$3 AND disappear_at IS NULL AND type='after_read' + RETURNING bridge_id, mx_room, mxid, type, timer, disappear_at + ` + getUpcomingDisappearingMessagesQuery = ` + SELECT bridge_id, mx_room, mxid, type, timer, disappear_at + FROM disappearing_message WHERE bridge_id = $1 AND disappear_at IS NOT NULL AND disappear_at < $2 + ORDER BY disappear_at + ` + deleteDisappearingMessageQuery = ` + DELETE FROM disappearing_message WHERE bridge_id=$1 AND mxid=$2 + ` +) + +func (dmq *DisappearingMessageQuery) Put(ctx context.Context, dm *DisappearingMessage) error { + ensureBridgeIDMatches(&dm.BridgeID, dmq.BridgeID) + return dmq.Exec(ctx, upsertDisappearingMessageQuery, dm.sqlVariables()...) +} + +func (dmq *DisappearingMessageQuery) StartAll(ctx context.Context, roomID id.RoomID) ([]*DisappearingMessage, error) { + return dmq.QueryMany(ctx, startDisappearingMessagesQuery, time.Now().UnixNano(), dmq.BridgeID, roomID) +} + +func (dmq *DisappearingMessageQuery) GetUpcoming(ctx context.Context, duration time.Duration) ([]*DisappearingMessage, error) { + return dmq.QueryMany(ctx, getUpcomingDisappearingMessagesQuery, dmq.BridgeID, time.Now().Add(duration).UnixNano()) +} + +func (dmq *DisappearingMessageQuery) Delete(ctx context.Context, eventID id.EventID) error { + return dmq.Exec(ctx, deleteDisappearingMessageQuery, dmq.BridgeID, eventID) +} + +func (d *DisappearingMessage) Scan(row dbutil.Scannable) (*DisappearingMessage, error) { + var disappearAt sql.NullInt64 + err := row.Scan(&d.BridgeID, &d.RoomID, &d.EventID, &d.Type, &d.Timer, &disappearAt) + if err != nil { + return nil, err + } + if disappearAt.Valid { + d.DisappearAt = time.Unix(0, disappearAt.Int64) + } + return d, nil +} + +func (d *DisappearingMessage) sqlVariables() []any { + return []any{d.BridgeID, d.RoomID, d.EventID, d.Type, d.Timer, dbutil.ConvertedPtr(d.DisappearAt, time.Time.UnixNano)} +} diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index db11a78f..ded53177 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -10,6 +10,7 @@ import ( "context" "database/sql" "encoding/hex" + "time" "go.mau.fi/util/dbutil" @@ -23,6 +24,8 @@ type PortalQuery struct { } type StandardPortalMetadata struct { + DisappearType DisappearingType `json:"disappear_type,omitempty"` + DisappearTimer time.Duration `json:"disappear_timer,omitempty"` } type PortalMetadata struct { diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index 8d9d150e..adceab3c 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v1: Latest revision +-- v0 -> v2 (compatible with v1+): Latest revision CREATE TABLE portal ( bridge_id TEXT NOT NULL, id TEXT NOT NULL, @@ -76,6 +76,17 @@ CREATE TABLE message ( CONSTRAINT message_real_pkey UNIQUE (bridge_id, id, part_id) ); +CREATE TABLE disappearing_message ( + bridge_id TEXT NOT NULL, + mx_room TEXT NOT NULL, + mxid TEXT NOT NULL, + type TEXT NOT NULL, + timer BIGINT NOT NULL, + disappear_at BIGINT, + + PRIMARY KEY (bridge_id, mxid) +); + CREATE TABLE reaction ( bridge_id TEXT NOT NULL, message_id TEXT NOT NULL, diff --git a/bridgev2/database/upgrades/02-disappearing-messages.sql b/bridgev2/database/upgrades/02-disappearing-messages.sql new file mode 100644 index 00000000..e1425e75 --- /dev/null +++ b/bridgev2/database/upgrades/02-disappearing-messages.sql @@ -0,0 +1,11 @@ +-- v2 (compatible with v1+): Add disappearing messages table +CREATE TABLE disappearing_message ( + bridge_id TEXT NOT NULL, + mx_room TEXT NOT NULL, + mxid TEXT NOT NULL, + type TEXT NOT NULL, + timer BIGINT NOT NULL, + disappear_at BIGINT, + + PRIMARY KEY (bridge_id, mxid) +); diff --git a/bridgev2/disappear.go b/bridgev2/disappear.go new file mode 100644 index 00000000..971a6c39 --- /dev/null +++ b/bridgev2/disappear.go @@ -0,0 +1,110 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "time" + + "github.com/rs/zerolog" + "golang.org/x/exp/slices" + + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type DisappearLoop struct { + br *Bridge + NextCheck time.Time + stop context.CancelFunc +} + +const DisappearCheckInterval = 1 * time.Hour + +func (dl *DisappearLoop) Start() { + log := dl.br.Log.With().Str("component", "disappear loop").Logger() + ctx := log.WithContext(context.Background()) + ctx, dl.stop = context.WithCancel(ctx) + log.Debug().Msg("Disappearing message loop starting") + for { + dl.NextCheck = time.Now().Add(DisappearCheckInterval) + messages, err := dl.br.DB.DisappearingMessage.GetUpcoming(ctx, DisappearCheckInterval) + if err != nil { + log.Err(err).Msg("Failed to get upcoming disappearing messages") + } else if len(messages) > 0 { + go dl.sleepAndDisappear(ctx, messages...) + } + select { + case <-time.After(time.Until(dl.NextCheck)): + case <-ctx.Done(): + log.Debug().Msg("Disappearing message loop stopping") + return + } + } +} + +func (dl *DisappearLoop) Stop() { + if dl.stop != nil { + dl.stop() + } +} + +func (dl *DisappearLoop) StartAll(ctx context.Context, roomID id.RoomID) { + startedMessages, err := dl.br.DB.DisappearingMessage.StartAll(ctx, roomID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to start disappearing messages") + return + } + slices.SortFunc(startedMessages, func(a, b *database.DisappearingMessage) int { + return a.DisappearAt.Compare(b.DisappearAt) + }) + slices.DeleteFunc(startedMessages, func(dm *database.DisappearingMessage) bool { + return dm.DisappearAt.After(dl.NextCheck) + }) + if len(startedMessages) > 0 { + go dl.sleepAndDisappear(ctx, startedMessages...) + } +} + +func (dl *DisappearLoop) Add(ctx context.Context, dm *database.DisappearingMessage) { + err := dl.br.DB.DisappearingMessage.Put(ctx, dm) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("event_id", dm.EventID). + Msg("Failed to save disappearing message") + } + if !dm.DisappearAt.IsZero() && dm.DisappearAt.Before(dl.NextCheck) { + go dl.sleepAndDisappear(context.WithoutCancel(ctx), dm) + } +} + +func (dl *DisappearLoop) sleepAndDisappear(ctx context.Context, dms ...*database.DisappearingMessage) { + for _, msg := range dms { + time.Sleep(time.Until(msg.DisappearAt)) + resp, err := dl.br.Bot.SendMessage(ctx, msg.RoomID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: msg.EventID, + Reason: "Message disappeared", + }, + }, time.Now()) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("target_event_id", msg.EventID).Msg("Failed to disappear message") + } else { + zerolog.Ctx(ctx).Debug(). + Stringer("target_event_id", msg.EventID). + Stringer("redaction_event_id", resp.EventID). + Msg("Disappeared message") + } + err = dl.br.DB.DisappearingMessage.Delete(ctx, msg.EventID) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("event_id", msg.EventID). + Msg("Failed to delete disappearing message entry from database") + } + } +} diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 66328004..65406cbe 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -39,6 +39,7 @@ type ConvertedMessage struct { ReplyTo *networkid.MessageOptionalPartID ThreadRoot *networkid.MessageOptionalPartID Parts []*ConvertedMessagePart + Disappear database.DisappearingSetting } type ConvertedEditPart struct { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 86667d9b..213993f2 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -355,6 +355,7 @@ func (portal *Portal) handleMatrixReadReceipt(user *User, eventID id.EventID, re if err != nil { log.Err(err).Msg("Failed to save user portal metadata") } + portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID) } func (portal *Portal) handleMatrixTyping(evt *event.Event) { @@ -466,6 +467,17 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin if err != nil { log.Err(err).Msg("Failed to save message to database") } + if portal.Metadata.DisappearType != database.DisappearingTypeNone { + go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ + RoomID: portal.MXID, + EventID: message.MXID, + DisappearingSetting: database.DisappearingSetting{ + Type: portal.Metadata.DisappearType, + Timer: portal.Metadata.DisappearTimer, + DisappearAt: message.Timestamp.Add(portal.Metadata.DisappearTimer), + }, + }) + } portal.sendSuccessStatus(ctx, evt) } @@ -841,6 +853,13 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin if err != nil { log.Err(err).Str("part_id", string(part.ID)).Msg("Failed to save message part to database") } + if converted.Disappear.Type != database.DisappearingTypeNone { + go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ + RoomID: portal.MXID, + EventID: dbMessage.MXID, + DisappearingSetting: converted.Disappear, + }) + } if prevThreadEvent != nil { prevThreadEvent = dbMessage } @@ -1111,13 +1130,17 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL log.Warn().Msg("No target message found for read receipt") return } - intent := portal.getIntentFor(ctx, evt.GetSender(), source) + sender := evt.GetSender() + intent := portal.getIntentFor(ctx, sender, source) err = intent.MarkRead(ctx, portal.MXID, lastTarget.MXID, getEventTS(evt)) if err != nil { log.Err(err).Stringer("target_mxid", lastTarget.MXID).Msg("Failed to bridge read receipt") } else { log.Debug().Stringer("target_mxid", lastTarget.MXID).Msg("Bridged read receipt") } + if sender.IsFromMe { + portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID) + } } func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteReceipt) { @@ -1367,7 +1390,7 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *PortalInfo, sender * //} if changed { portal.UpdateBridgeInfo(ctx) - err := portal.Bridge.DB.Portal.Update(ctx, portal.Portal) + err := portal.Save(ctx) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal to database after updating info") } @@ -1473,7 +1496,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin) e portal.Bridge.portalsByMXID[roomID] = portal portal.Bridge.cacheLock.Unlock() portal.updateLogger() - err = portal.Bridge.DB.Portal.Update(ctx, portal.Portal) + err = portal.Save(ctx) if err != nil { log.Err(err).Msg("Failed to save portal to database after creating Matrix room") return err @@ -1489,3 +1512,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin) e } return nil } + +func (portal *Portal) Save(ctx context.Context) error { + return portal.Bridge.DB.Portal.Update(ctx, portal.Portal) +}