diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 3e0617ae..30b806bc 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -495,6 +495,9 @@ type FetchMessagesResponse struct { ApproxRemainingCount int // Approximate total number of messages in the chat. ApproxTotalCount int + + // An optional function that is called after the backfill batch has been sent. + CompleteCallback func() } // BackfillingNetworkAPI is an optional interface that network connectors can implement to support backfilling message history. diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 6a912355..4350cfa2 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -74,7 +74,7 @@ func (portal *Portal) doForwardBackfill(ctx context.Context, source *UserLogin, log.Warn().Msg("No messages left to backfill after cutting off old messages") return } - portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, false) + portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, false, resp.CompleteCallback) } func (portal *Portal) DoBackwardsBackfill(ctx context.Context, source *UserLogin, task *database.BackfillTask) error { @@ -134,7 +134,7 @@ func (portal *Portal) DoBackwardsBackfill(ctx context.Context, source *UserLogin if len(resp.Messages) == 0 { return fmt.Errorf("no messages left to backfill after cutting off too new messages") } - portal.sendBackfill(ctx, source, resp.Messages, false, resp.MarkRead, false) + portal.sendBackfill(ctx, source, resp.Messages, false, resp.MarkRead, false, resp.CompleteCallback) if len(resp.Messages) > 0 { task.OldestMessageID = resp.Messages[0].ID } @@ -182,7 +182,7 @@ func (portal *Portal) doThreadBackfill(ctx context.Context, source *UserLogin, t } resp := portal.fetchThreadBackfill(ctx, source, anchorMessage) if resp != nil { - portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, true) + portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, true, resp.CompleteCallback) } } @@ -257,7 +257,15 @@ func (portal *Portal) cutoffMessages(ctx context.Context, messages []*BackfillMe return messages } -func (portal *Portal) sendBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, inThread bool) { +func (portal *Portal) sendBackfill( + ctx context.Context, + source *UserLogin, + messages []*BackfillMessage, + forceForward, + markRead, + inThread bool, + done func(), +) { canBatchSend := portal.Bridge.Matrix.GetCapabilities().BatchSending unreadThreshold := time.Duration(portal.Bridge.Config.Backfill.UnreadHoursThreshold) * time.Hour forceMarkRead := unreadThreshold > 0 && time.Since(messages[len(messages)-1].Timestamp) > unreadThreshold @@ -272,6 +280,9 @@ func (portal *Portal) sendBackfill(ctx context.Context, source *UserLogin, messa } else { portal.sendLegacyBackfill(ctx, source, messages, markRead || forceMarkRead) } + if done != nil { + done() + } zerolog.Ctx(ctx).Debug().Msg("Backfill finished") if !canBatchSend && !inThread && portal.Bridge.Config.Backfill.Threads.MaxInitialMessages > 0 { for _, msg := range messages { diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index 77bdd7fd..56726fee 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -269,8 +269,8 @@ func (portal *PortalInternals) UpdateUserLocalInfo(ctx context.Context, info *Us (*Portal)(portal).updateUserLocalInfo(ctx, info, source, didJustCreate) } -func (portal *PortalInternals) UpdateParent(ctx context.Context, newParent networkid.PortalID, source *UserLogin) bool { - return (*Portal)(portal).updateParent(ctx, newParent, source) +func (portal *PortalInternals) UpdateParent(ctx context.Context, newParentID networkid.PortalID, source *UserLogin) bool { + return (*Portal)(portal).updateParent(ctx, newParentID, source) } func (portal *PortalInternals) LockedUpdateInfoFromGhost(ctx context.Context, ghost *Ghost) { @@ -309,8 +309,8 @@ func (portal *PortalInternals) CutoffMessages(ctx context.Context, messages []*B return (*Portal)(portal).cutoffMessages(ctx, messages, aggressiveDedup, forward, lastMessage) } -func (portal *PortalInternals) SendBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, inThread bool) { - (*Portal)(portal).sendBackfill(ctx, source, messages, forceForward, markRead, inThread) +func (portal *PortalInternals) SendBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, inThread bool, done func()) { + (*Portal)(portal).sendBackfill(ctx, source, messages, forceForward, markRead, inThread, done) } func (portal *PortalInternals) CompileBatchMessage(ctx context.Context, source *UserLogin, msg *BackfillMessage, out *compileBatchOutput, inThread bool) {