bridgev2: add support for disappearing messages

This commit is contained in:
Tulir Asokan 2024-06-18 17:31:48 +03:00
commit 2c6ca02eeb
9 changed files with 297 additions and 21 deletions

View file

@ -35,6 +35,8 @@ type Bridge struct {
Commands *CommandProcessor
Config *bridgeconfig.BridgeConfig
DisappearLoop *DisappearLoop
usersByMXID map[id.UserID]*User
userLoginsByID map[networkid.UserLoginID]*UserLogin
portalsByKey map[networkid.PortalKey]*Portal
@ -66,6 +68,7 @@ func NewBridge(bridgeID networkid.BridgeID, db *dbutil.Database, log zerolog.Log
br.Matrix.Init(br)
br.Bot = br.Matrix.BotIntent()
br.Network.Init(br)
br.DisappearLoop = &DisappearLoop{br: br}
return br
}
@ -100,6 +103,8 @@ func (br *Bridge) Start() error {
if err != nil {
return fmt.Errorf("failed to start network connector: %w", err)
}
// TODO only start if the network supports disappearing messages?
go br.DisappearLoop.Start()
logins, err := br.GetAllUserLogins(ctx)
if err != nil {

View file

@ -23,28 +23,30 @@ import (
type Database struct {
*dbutil.Database
BridgeID networkid.BridgeID
Portal *PortalQuery
Ghost *GhostQuery
Message *MessageQuery
Reaction *ReactionQuery
User *UserQuery
UserLogin *UserLoginQuery
UserPortal *UserPortalQuery
BridgeID networkid.BridgeID
Portal *PortalQuery
Ghost *GhostQuery
Message *MessageQuery
DisappearingMessage *DisappearingMessageQuery
Reaction *ReactionQuery
User *UserQuery
UserLogin *UserLoginQuery
UserPortal *UserPortalQuery
}
func New(bridgeID networkid.BridgeID, db *dbutil.Database) *Database {
db.UpgradeTable = upgrades.Table
return &Database{
Database: db,
BridgeID: bridgeID,
Portal: &PortalQuery{bridgeID, dbutil.MakeQueryHelper(db, newPortal)},
Ghost: &GhostQuery{bridgeID, dbutil.MakeQueryHelper(db, newGhost)},
Message: &MessageQuery{bridgeID, dbutil.MakeQueryHelper(db, newMessage)},
Reaction: &ReactionQuery{bridgeID, dbutil.MakeQueryHelper(db, newReaction)},
User: &UserQuery{bridgeID, dbutil.MakeQueryHelper(db, newUser)},
UserLogin: &UserLoginQuery{bridgeID, dbutil.MakeQueryHelper(db, newUserLogin)},
UserPortal: &UserPortalQuery{bridgeID, dbutil.MakeQueryHelper(db, newUserPortal)},
Database: db,
BridgeID: bridgeID,
Portal: &PortalQuery{bridgeID, dbutil.MakeQueryHelper(db, newPortal)},
Ghost: &GhostQuery{bridgeID, dbutil.MakeQueryHelper(db, newGhost)},
Message: &MessageQuery{bridgeID, dbutil.MakeQueryHelper(db, newMessage)},
DisappearingMessage: &DisappearingMessageQuery{bridgeID, dbutil.MakeQueryHelper(db, newDisappearingMessage)},
Reaction: &ReactionQuery{bridgeID, dbutil.MakeQueryHelper(db, newReaction)},
User: &UserQuery{bridgeID, dbutil.MakeQueryHelper(db, newUser)},
UserLogin: &UserLoginQuery{bridgeID, dbutil.MakeQueryHelper(db, newUserLogin)},
UserPortal: &UserPortalQuery{bridgeID, dbutil.MakeQueryHelper(db, newUserPortal)},
}
}

View file

@ -0,0 +1,106 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package database
import (
"context"
"database/sql"
"time"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/bridgev2/networkid"
"maunium.net/go/mautrix/id"
)
// DisappearingType represents the type of a disappearing message timer.
type DisappearingType string
const (
DisappearingTypeNone DisappearingType = ""
DisappearingTypeAfterRead DisappearingType = "after_read"
DisappearingTypeAfterSend DisappearingType = "after_send"
)
// DisappearingSetting represents a disappearing message timer setting
// by combining a type with a timer and an optional start timestamp.
type DisappearingSetting struct {
Type DisappearingType
Timer time.Duration
DisappearAt time.Time
}
type DisappearingMessageQuery struct {
BridgeID networkid.BridgeID
*dbutil.QueryHelper[*DisappearingMessage]
}
type DisappearingMessage struct {
BridgeID networkid.BridgeID
RoomID id.RoomID
EventID id.EventID
DisappearingSetting
}
func newDisappearingMessage(_ *dbutil.QueryHelper[*DisappearingMessage]) *DisappearingMessage {
return &DisappearingMessage{}
}
const (
upsertDisappearingMessageQuery = `
INSERT INTO disappearing_message (bridge_id, mx_room, mxid, type, timer, disappear_at)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (bridge_id, mxid) DO UPDATE SET timer=excluded.timer, disappear_at=excluded.disappear_at
`
startDisappearingMessagesQuery = `
UPDATE disappearing_message
SET disappear_at=$1 + timer
WHERE bridge_id=$2 AND mx_room=$3 AND disappear_at IS NULL AND type='after_read'
RETURNING bridge_id, mx_room, mxid, type, timer, disappear_at
`
getUpcomingDisappearingMessagesQuery = `
SELECT bridge_id, mx_room, mxid, type, timer, disappear_at
FROM disappearing_message WHERE bridge_id = $1 AND disappear_at IS NOT NULL AND disappear_at < $2
ORDER BY disappear_at
`
deleteDisappearingMessageQuery = `
DELETE FROM disappearing_message WHERE bridge_id=$1 AND mxid=$2
`
)
func (dmq *DisappearingMessageQuery) Put(ctx context.Context, dm *DisappearingMessage) error {
ensureBridgeIDMatches(&dm.BridgeID, dmq.BridgeID)
return dmq.Exec(ctx, upsertDisappearingMessageQuery, dm.sqlVariables()...)
}
func (dmq *DisappearingMessageQuery) StartAll(ctx context.Context, roomID id.RoomID) ([]*DisappearingMessage, error) {
return dmq.QueryMany(ctx, startDisappearingMessagesQuery, time.Now().UnixNano(), dmq.BridgeID, roomID)
}
func (dmq *DisappearingMessageQuery) GetUpcoming(ctx context.Context, duration time.Duration) ([]*DisappearingMessage, error) {
return dmq.QueryMany(ctx, getUpcomingDisappearingMessagesQuery, dmq.BridgeID, time.Now().Add(duration).UnixNano())
}
func (dmq *DisappearingMessageQuery) Delete(ctx context.Context, eventID id.EventID) error {
return dmq.Exec(ctx, deleteDisappearingMessageQuery, dmq.BridgeID, eventID)
}
func (d *DisappearingMessage) Scan(row dbutil.Scannable) (*DisappearingMessage, error) {
var disappearAt sql.NullInt64
err := row.Scan(&d.BridgeID, &d.RoomID, &d.EventID, &d.Type, &d.Timer, &disappearAt)
if err != nil {
return nil, err
}
if disappearAt.Valid {
d.DisappearAt = time.Unix(0, disappearAt.Int64)
}
return d, nil
}
func (d *DisappearingMessage) sqlVariables() []any {
return []any{d.BridgeID, d.RoomID, d.EventID, d.Type, d.Timer, dbutil.ConvertedPtr(d.DisappearAt, time.Time.UnixNano)}
}

View file

@ -10,6 +10,7 @@ import (
"context"
"database/sql"
"encoding/hex"
"time"
"go.mau.fi/util/dbutil"
@ -23,6 +24,8 @@ type PortalQuery struct {
}
type StandardPortalMetadata struct {
DisappearType DisappearingType `json:"disappear_type,omitempty"`
DisappearTimer time.Duration `json:"disappear_timer,omitempty"`
}
type PortalMetadata struct {

View file

@ -1,4 +1,4 @@
-- v0 -> v1: Latest revision
-- v0 -> v2 (compatible with v1+): Latest revision
CREATE TABLE portal (
bridge_id TEXT NOT NULL,
id TEXT NOT NULL,
@ -76,6 +76,17 @@ CREATE TABLE message (
CONSTRAINT message_real_pkey UNIQUE (bridge_id, id, part_id)
);
CREATE TABLE disappearing_message (
bridge_id TEXT NOT NULL,
mx_room TEXT NOT NULL,
mxid TEXT NOT NULL,
type TEXT NOT NULL,
timer BIGINT NOT NULL,
disappear_at BIGINT,
PRIMARY KEY (bridge_id, mxid)
);
CREATE TABLE reaction (
bridge_id TEXT NOT NULL,
message_id TEXT NOT NULL,

View file

@ -0,0 +1,11 @@
-- v2 (compatible with v1+): Add disappearing messages table
CREATE TABLE disappearing_message (
bridge_id TEXT NOT NULL,
mx_room TEXT NOT NULL,
mxid TEXT NOT NULL,
type TEXT NOT NULL,
timer BIGINT NOT NULL,
disappear_at BIGINT,
PRIMARY KEY (bridge_id, mxid)
);

110
bridgev2/disappear.go Normal file
View file

@ -0,0 +1,110 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package bridgev2
import (
"context"
"time"
"github.com/rs/zerolog"
"golang.org/x/exp/slices"
"maunium.net/go/mautrix/bridgev2/database"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
type DisappearLoop struct {
br *Bridge
NextCheck time.Time
stop context.CancelFunc
}
const DisappearCheckInterval = 1 * time.Hour
func (dl *DisappearLoop) Start() {
log := dl.br.Log.With().Str("component", "disappear loop").Logger()
ctx := log.WithContext(context.Background())
ctx, dl.stop = context.WithCancel(ctx)
log.Debug().Msg("Disappearing message loop starting")
for {
dl.NextCheck = time.Now().Add(DisappearCheckInterval)
messages, err := dl.br.DB.DisappearingMessage.GetUpcoming(ctx, DisappearCheckInterval)
if err != nil {
log.Err(err).Msg("Failed to get upcoming disappearing messages")
} else if len(messages) > 0 {
go dl.sleepAndDisappear(ctx, messages...)
}
select {
case <-time.After(time.Until(dl.NextCheck)):
case <-ctx.Done():
log.Debug().Msg("Disappearing message loop stopping")
return
}
}
}
func (dl *DisappearLoop) Stop() {
if dl.stop != nil {
dl.stop()
}
}
func (dl *DisappearLoop) StartAll(ctx context.Context, roomID id.RoomID) {
startedMessages, err := dl.br.DB.DisappearingMessage.StartAll(ctx, roomID)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to start disappearing messages")
return
}
slices.SortFunc(startedMessages, func(a, b *database.DisappearingMessage) int {
return a.DisappearAt.Compare(b.DisappearAt)
})
slices.DeleteFunc(startedMessages, func(dm *database.DisappearingMessage) bool {
return dm.DisappearAt.After(dl.NextCheck)
})
if len(startedMessages) > 0 {
go dl.sleepAndDisappear(ctx, startedMessages...)
}
}
func (dl *DisappearLoop) Add(ctx context.Context, dm *database.DisappearingMessage) {
err := dl.br.DB.DisappearingMessage.Put(ctx, dm)
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("event_id", dm.EventID).
Msg("Failed to save disappearing message")
}
if !dm.DisappearAt.IsZero() && dm.DisappearAt.Before(dl.NextCheck) {
go dl.sleepAndDisappear(context.WithoutCancel(ctx), dm)
}
}
func (dl *DisappearLoop) sleepAndDisappear(ctx context.Context, dms ...*database.DisappearingMessage) {
for _, msg := range dms {
time.Sleep(time.Until(msg.DisappearAt))
resp, err := dl.br.Bot.SendMessage(ctx, msg.RoomID, event.EventRedaction, &event.Content{
Parsed: &event.RedactionEventContent{
Redacts: msg.EventID,
Reason: "Message disappeared",
},
}, time.Now())
if err != nil {
zerolog.Ctx(ctx).Err(err).Stringer("target_event_id", msg.EventID).Msg("Failed to disappear message")
} else {
zerolog.Ctx(ctx).Debug().
Stringer("target_event_id", msg.EventID).
Stringer("redaction_event_id", resp.EventID).
Msg("Disappeared message")
}
err = dl.br.DB.DisappearingMessage.Delete(ctx, msg.EventID)
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("event_id", msg.EventID).
Msg("Failed to delete disappearing message entry from database")
}
}
}

View file

@ -39,6 +39,7 @@ type ConvertedMessage struct {
ReplyTo *networkid.MessageOptionalPartID
ThreadRoot *networkid.MessageOptionalPartID
Parts []*ConvertedMessagePart
Disappear database.DisappearingSetting
}
type ConvertedEditPart struct {

View file

@ -355,6 +355,7 @@ func (portal *Portal) handleMatrixReadReceipt(user *User, eventID id.EventID, re
if err != nil {
log.Err(err).Msg("Failed to save user portal metadata")
}
portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID)
}
func (portal *Portal) handleMatrixTyping(evt *event.Event) {
@ -466,6 +467,17 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin
if err != nil {
log.Err(err).Msg("Failed to save message to database")
}
if portal.Metadata.DisappearType != database.DisappearingTypeNone {
go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{
RoomID: portal.MXID,
EventID: message.MXID,
DisappearingSetting: database.DisappearingSetting{
Type: portal.Metadata.DisappearType,
Timer: portal.Metadata.DisappearTimer,
DisappearAt: message.Timestamp.Add(portal.Metadata.DisappearTimer),
},
})
}
portal.sendSuccessStatus(ctx, evt)
}
@ -841,6 +853,13 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin
if err != nil {
log.Err(err).Str("part_id", string(part.ID)).Msg("Failed to save message part to database")
}
if converted.Disappear.Type != database.DisappearingTypeNone {
go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{
RoomID: portal.MXID,
EventID: dbMessage.MXID,
DisappearingSetting: converted.Disappear,
})
}
if prevThreadEvent != nil {
prevThreadEvent = dbMessage
}
@ -1111,13 +1130,17 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL
log.Warn().Msg("No target message found for read receipt")
return
}
intent := portal.getIntentFor(ctx, evt.GetSender(), source)
sender := evt.GetSender()
intent := portal.getIntentFor(ctx, sender, source)
err = intent.MarkRead(ctx, portal.MXID, lastTarget.MXID, getEventTS(evt))
if err != nil {
log.Err(err).Stringer("target_mxid", lastTarget.MXID).Msg("Failed to bridge read receipt")
} else {
log.Debug().Stringer("target_mxid", lastTarget.MXID).Msg("Bridged read receipt")
}
if sender.IsFromMe {
portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID)
}
}
func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteReceipt) {
@ -1367,7 +1390,7 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *PortalInfo, sender *
//}
if changed {
portal.UpdateBridgeInfo(ctx)
err := portal.Bridge.DB.Portal.Update(ctx, portal.Portal)
err := portal.Save(ctx)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal to database after updating info")
}
@ -1473,7 +1496,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin) e
portal.Bridge.portalsByMXID[roomID] = portal
portal.Bridge.cacheLock.Unlock()
portal.updateLogger()
err = portal.Bridge.DB.Portal.Update(ctx, portal.Portal)
err = portal.Save(ctx)
if err != nil {
log.Err(err).Msg("Failed to save portal to database after creating Matrix room")
return err
@ -1489,3 +1512,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin) e
}
return nil
}
func (portal *Portal) Save(ctx context.Context) error {
return portal.Bridge.DB.Portal.Update(ctx, portal.Portal)
}