diff --git a/bridgev2/database/disappear.go b/bridgev2/database/disappear.go index 23db1448..4e6f5e0a 100644 --- a/bridgev2/database/disappear.go +++ b/bridgev2/database/disappear.go @@ -61,7 +61,7 @@ const ( getUpcomingDisappearingMessagesQuery = ` SELECT bridge_id, mx_room, mxid, type, timer, disappear_at FROM disappearing_message WHERE bridge_id = $1 AND disappear_at IS NOT NULL AND disappear_at < $2 - ORDER BY disappear_at + ORDER BY disappear_at LIMIT $3 ` deleteDisappearingMessageQuery = ` DELETE FROM disappearing_message WHERE bridge_id=$1 AND mxid=$2 @@ -77,8 +77,8 @@ func (dmq *DisappearingMessageQuery) StartAll(ctx context.Context, roomID id.Roo return dmq.QueryMany(ctx, startDisappearingMessagesQuery, time.Now().UnixNano(), dmq.BridgeID, roomID) } -func (dmq *DisappearingMessageQuery) GetUpcoming(ctx context.Context, duration time.Duration) ([]*DisappearingMessage, error) { - return dmq.QueryMany(ctx, getUpcomingDisappearingMessagesQuery, dmq.BridgeID, time.Now().Add(duration).UnixNano()) +func (dmq *DisappearingMessageQuery) GetUpcoming(ctx context.Context, duration time.Duration, limit int) ([]*DisappearingMessage, error) { + return dmq.QueryMany(ctx, getUpcomingDisappearingMessagesQuery, dmq.BridgeID, time.Now().Add(duration).UnixNano(), limit) } func (dmq *DisappearingMessageQuery) Delete(ctx context.Context, eventID id.EventID) error { diff --git a/bridgev2/disappear.go b/bridgev2/disappear.go index 1d063088..8305f84b 100644 --- a/bridgev2/disappear.go +++ b/bridgev2/disappear.go @@ -36,10 +36,21 @@ func (dl *DisappearLoop) Start() { log.Debug().Msg("Disappearing message loop starting") for { dl.NextCheck = time.Now().Add(DisappearCheckInterval) - messages, err := dl.br.DB.DisappearingMessage.GetUpcoming(ctx, DisappearCheckInterval) + const MessageLimit = 200 + messages, err := dl.br.DB.DisappearingMessage.GetUpcoming(ctx, DisappearCheckInterval, MessageLimit) if err != nil { log.Err(err).Msg("Failed to get upcoming disappearing messages") } else if len(messages) > 0 { + if len(messages) > MessageLimit/2 && messages[len(messages)-1].DisappearAt.Before(time.Now()) { + // If there are many messages, and they're all due immediately, + // process them synchronously and then check again. + dl.sleepAndDisappear(ctx, messages...) + log.Debug(). + Int("message_count", len(messages)). + Time("last_due", messages[len(messages)-1].DisappearAt). + Msg("Checking for disappearing messages again immediately") + continue + } go dl.sleepAndDisappear(ctx, messages...) } select { @@ -91,10 +102,17 @@ func (dl *DisappearLoop) Add(ctx context.Context, dm *database.DisappearingMessa func (dl *DisappearLoop) sleepAndDisappear(ctx context.Context, dms ...*database.DisappearingMessage) { for _, msg := range dms { - select { - case <-time.After(time.Until(msg.DisappearAt)): - case <-ctx.Done(): - return + timeUntilDisappear := time.Until(msg.DisappearAt) + if timeUntilDisappear <= 0 { + if ctx.Err() != nil { + return + } + } else { + select { + case <-time.After(timeUntilDisappear): + case <-ctx.Done(): + return + } } resp, err := dl.br.Bot.SendMessage(ctx, msg.RoomID, event.EventRedaction, &event.Content{ Parsed: &event.RedactionEventContent{