diff --git a/bridgev2/backfillqueue.go b/bridgev2/backfillqueue.go index 345c3391..63a01f68 100644 --- a/bridgev2/backfillqueue.go +++ b/bridgev2/backfillqueue.go @@ -180,13 +180,39 @@ func (br *Bridge) actuallyDoBackfillTask(ctx context.Context, task *database.Bac return false, nil } } - maxBatches := br.Config.Backfill.Queue.MaxBatches - // TODO apply max batch overrides - err = portal.DoBackwardsBackfill(ctx, login, task) - if err != nil { - return false, fmt.Errorf("failed to backfill: %w", err) + if task.BatchCount < 0 { + var msgCount int + msgCount, err = br.DB.Message.CountMessagesInPortal(ctx, task.PortalKey) + if err != nil { + return false, fmt.Errorf("failed to count messages in portal: %w", err) + } + task.BatchCount = msgCount / br.Config.Backfill.Queue.BatchSize + log.Debug(). + Int("message_count", msgCount). + Int("batch_count", task.BatchCount). + Msg("Calculated existing batch count") + } + maxBatches := br.Config.Backfill.Queue.MaxBatches + api, ok := login.Client.(BackfillingNetworkAPI) + if !ok { + return false, fmt.Errorf("network API does not support backfilling") + } + limiterAPI, ok := api.(BackfillingNetworkAPIWithLimits) + if ok { + maxBatches = limiterAPI.GetBackfillMaxBatchCount(ctx, portal, task) + } + if maxBatches < 0 || maxBatches > task.BatchCount { + err = portal.DoBackwardsBackfill(ctx, login, task) + if err != nil { + return false, fmt.Errorf("failed to backfill: %w", err) + } + task.BatchCount++ + } else { + log.Debug(). + Int("max_batches", maxBatches). + Int("batch_count", task.BatchCount). + Msg("Not actually backfilling: max batches reached") } - task.BatchCount++ task.IsDone = task.IsDone || (maxBatches > 0 && task.BatchCount >= maxBatches) batchDelay := time.Duration(br.Config.Backfill.Queue.BatchDelay) * time.Second task.CompletedAt = time.Now() diff --git a/bridgev2/bridgeconfig/backfill.go b/bridgev2/bridgeconfig/backfill.go index fe464569..44d2d588 100644 --- a/bridgev2/bridgeconfig/backfill.go +++ b/bridgev2/bridgeconfig/backfill.go @@ -28,3 +28,11 @@ type BackfillQueueConfig struct { MaxBatchesOverride map[string]int `yaml:"max_batches_override"` } + +func (bqc *BackfillQueueConfig) GetOverride(name string) int { + override, ok := bqc.MaxBatchesOverride[name] + if !ok { + return bqc.MaxBatches + } + return override +} diff --git a/bridgev2/database/backfillqueue.go b/bridgev2/database/backfillqueue.go index 5d7cf854..fed7452d 100644 --- a/bridgev2/database/backfillqueue.go +++ b/bridgev2/database/backfillqueue.go @@ -40,7 +40,7 @@ var BackfillNextDispatchNever = time.Unix(0, (1<<63)-1) const ( ensureBackfillExistsQuery = ` INSERT INTO backfill_task (bridge_id, portal_id, portal_receiver, user_login_id, batch_count, is_done, next_dispatch_min_ts) - VALUES ($1, $2, $3, $4, 0, false, $5) + VALUES ($1, $2, $3, $4, -1, false, $5) ON CONFLICT (bridge_id, portal_id, portal_receiver) DO UPDATE SET user_login_id=CASE WHEN backfill_task.user_login_id='' diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index b2e023d0..8173ad05 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -64,6 +64,10 @@ const ( getLastMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 ORDER BY timestamp DESC, part_id DESC LIMIT 1` + countMessagesInPortalQuery = ` + SELECT COUNT(*) FROM message WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 + ` + insertMessageQuery = ` INSERT INTO message ( bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, sender_mxid, @@ -155,6 +159,11 @@ func (mq *MessageQuery) Delete(ctx context.Context, rowID int64) error { return mq.Exec(ctx, deleteMessagePartByRowIDQuery, mq.BridgeID, rowID) } +func (mq *MessageQuery) CountMessagesInPortal(ctx context.Context, key networkid.PortalKey) (count int, err error) { + err = mq.GetDB().QueryRow(ctx, countMessagesInPortalQuery, mq.BridgeID, key.ID, key.Receiver).Scan(&count) + return +} + func (m *Message) Scan(row dbutil.Scannable) (*Message, error) { var timestamp int64 var threadRootID, replyToID, replyToPartID sql.NullString diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index a28b13d2..54ec81a8 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -114,8 +114,8 @@ func (cm *ConvertedMessage) MergeCaption() bool { if len(cm.Parts) != 2 { return false } - textPart, mediaPart := cm.Parts[0], cm.Parts[1] - if textPart.Content.MsgType.IsMedia() { + textPart, mediaPart := cm.Parts[1], cm.Parts[0] + if textPart.Content.MsgType != event.MsgText { textPart, mediaPart = mediaPart, textPart } if (!mediaPart.Content.MsgType.IsMedia() && mediaPart.Content.MsgType != event.MsgNotice) || textPart.Content.MsgType != event.MsgText { @@ -369,6 +369,9 @@ type FetchMessagesParams struct { // The preferred number of messages to return. The returned batch can be bigger or smaller // without any side effects, but the network connector should aim for this number. Count int + + // When the messages are being fetched for a queued backfill, this is the task object. + Task *database.BackfillTask } // BackfillReaction is an individual reaction to a message in a history pagination request. @@ -436,6 +439,11 @@ type BackfillingNetworkAPI interface { FetchMessages(ctx context.Context, fetchParams FetchMessagesParams) (*FetchMessagesResponse, error) } +type BackfillingNetworkAPIWithLimits interface { + BackfillingNetworkAPI + GetBackfillMaxBatchCount(ctx context.Context, portal *Portal, task *database.BackfillTask) int +} + // EditHandlingNetworkAPI is an optional interface that network connectors can implement to handle message edits. type EditHandlingNetworkAPI interface { NetworkAPI diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 77fe23b1..6d4124e8 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -96,6 +96,7 @@ func (portal *Portal) DoBackwardsBackfill(ctx context.Context, source *UserLogin Cursor: task.Cursor, AnchorMessage: firstMessage, Count: portal.Bridge.Config.Backfill.Queue.BatchSize, + Task: task, }) if err != nil { return fmt.Errorf("failed to fetch messages for backward backfill: %w", err)