bridgev2/database: include portal receiver in reaction queries
Some checks are pending
Go / Lint (latest) (push) Waiting to run
Go / Build (old, libolm) (push) Waiting to run
Go / Build (latest, libolm) (push) Waiting to run
Go / Build (old, goolm) (push) Waiting to run
Go / Build (latest, goolm) (push) Waiting to run

This commit is contained in:
Tulir Asokan 2024-11-14 14:58:08 +02:00
commit 21aa3291f3
2 changed files with 24 additions and 24 deletions

View file

@ -41,11 +41,11 @@ const (
getReactionBaseQuery = `
SELECT bridge_id, message_id, message_part_id, sender_id, sender_mxid, emoji_id, emoji, room_id, room_receiver, mxid, timestamp, metadata FROM reaction
`
getReactionByIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND message_part_id=$3 AND sender_id=$4 AND emoji_id=$5`
getReactionByIDWithoutMessagePartQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND sender_id=$3 AND emoji_id=$4 ORDER BY message_part_id ASC LIMIT 1`
getAllReactionsToMessageBySenderQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND sender_id=$3 ORDER BY timestamp DESC`
getAllReactionsToMessageQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2`
getAllReactionsToMessagePartQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND message_part_id=$3`
getReactionByIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND message_part_id=$4 AND sender_id=$5 AND emoji_id=$6`
getReactionByIDWithoutMessagePartQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND sender_id=$4 AND emoji_id=$5 ORDER BY message_part_id ASC LIMIT 1`
getAllReactionsToMessageBySenderQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND sender_id=$4 ORDER BY timestamp DESC`
getAllReactionsToMessageQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3`
getAllReactionsToMessagePartQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND message_part_id=$4`
getReactionByMXIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND mxid=$2`
upsertReactionQuery = `
INSERT INTO reaction (bridge_id, message_id, message_part_id, sender_id, sender_mxid, emoji_id, emoji, room_id, room_receiver, mxid, timestamp, metadata)
@ -54,28 +54,28 @@ const (
DO UPDATE SET sender_mxid=excluded.sender_mxid, mxid=excluded.mxid, timestamp=excluded.timestamp, emoji=excluded.emoji, metadata=excluded.metadata
`
deleteReactionQuery = `
DELETE FROM reaction WHERE bridge_id=$1 AND message_id=$2 AND message_part_id=$3 AND sender_id=$4 AND emoji_id=$5
DELETE FROM reaction WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND message_part_id=$4 AND sender_id=$5 AND emoji_id=$6
`
)
func (rq *ReactionQuery) GetByID(ctx context.Context, messageID networkid.MessageID, messagePartID networkid.PartID, senderID networkid.UserID, emojiID networkid.EmojiID) (*Reaction, error) {
return rq.QueryOne(ctx, getReactionByIDQuery, rq.BridgeID, messageID, messagePartID, senderID, emojiID)
func (rq *ReactionQuery) GetByID(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID, messagePartID networkid.PartID, senderID networkid.UserID, emojiID networkid.EmojiID) (*Reaction, error) {
return rq.QueryOne(ctx, getReactionByIDQuery, rq.BridgeID, receiver, messageID, messagePartID, senderID, emojiID)
}
func (rq *ReactionQuery) GetByIDWithoutMessagePart(ctx context.Context, messageID networkid.MessageID, senderID networkid.UserID, emojiID networkid.EmojiID) (*Reaction, error) {
return rq.QueryOne(ctx, getReactionByIDWithoutMessagePartQuery, rq.BridgeID, messageID, senderID, emojiID)
func (rq *ReactionQuery) GetByIDWithoutMessagePart(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID, senderID networkid.UserID, emojiID networkid.EmojiID) (*Reaction, error) {
return rq.QueryOne(ctx, getReactionByIDWithoutMessagePartQuery, rq.BridgeID, receiver, messageID, senderID, emojiID)
}
func (rq *ReactionQuery) GetAllToMessageBySender(ctx context.Context, messageID networkid.MessageID, senderID networkid.UserID) ([]*Reaction, error) {
return rq.QueryMany(ctx, getAllReactionsToMessageBySenderQuery, rq.BridgeID, messageID, senderID)
func (rq *ReactionQuery) GetAllToMessageBySender(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID, senderID networkid.UserID) ([]*Reaction, error) {
return rq.QueryMany(ctx, getAllReactionsToMessageBySenderQuery, rq.BridgeID, receiver, messageID, senderID)
}
func (rq *ReactionQuery) GetAllToMessage(ctx context.Context, messageID networkid.MessageID) ([]*Reaction, error) {
return rq.QueryMany(ctx, getAllReactionsToMessageQuery, rq.BridgeID, messageID)
func (rq *ReactionQuery) GetAllToMessage(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID) ([]*Reaction, error) {
return rq.QueryMany(ctx, getAllReactionsToMessageQuery, rq.BridgeID, receiver, messageID)
}
func (rq *ReactionQuery) GetAllToMessagePart(ctx context.Context, messageID networkid.MessageID, partID networkid.PartID) ([]*Reaction, error) {
return rq.QueryMany(ctx, getAllReactionsToMessagePartQuery, rq.BridgeID, messageID, partID)
func (rq *ReactionQuery) GetAllToMessagePart(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID, partID networkid.PartID) ([]*Reaction, error) {
return rq.QueryMany(ctx, getAllReactionsToMessagePartQuery, rq.BridgeID, receiver, messageID, partID)
}
func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Reaction, error) {
@ -89,7 +89,7 @@ func (rq *ReactionQuery) Upsert(ctx context.Context, reaction *Reaction) error {
func (rq *ReactionQuery) Delete(ctx context.Context, reaction *Reaction) error {
ensureBridgeIDMatches(&reaction.BridgeID, rq.BridgeID)
return rq.Exec(ctx, deleteReactionQuery, reaction.BridgeID, reaction.MessageID, reaction.MessagePartID, reaction.SenderID, reaction.EmojiID)
return rq.Exec(ctx, deleteReactionQuery, reaction.BridgeID, reaction.Room.Receiver, reaction.MessageID, reaction.MessagePartID, reaction.SenderID, reaction.EmojiID)
}
func (r *Reaction) Scan(row dbutil.Scannable) (*Reaction, error) {

View file

@ -1147,7 +1147,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi
if portal.Bridge.Config.OutgoingMessageReID {
deterministicID = portal.Bridge.Matrix.GenerateReactionEventID(portal.MXID, reactionTarget, preResp.SenderID, preResp.EmojiID)
}
existing, err := portal.Bridge.DB.Reaction.GetByID(ctx, reactionTarget.ID, reactionTarget.PartID, preResp.SenderID, preResp.EmojiID)
existing, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, reactionTarget.ID, reactionTarget.PartID, preResp.SenderID, preResp.EmojiID)
if err != nil {
log.Err(err).Msg("Failed to check if reaction is a duplicate")
return
@ -1169,7 +1169,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi
}
react.PreHandleResp = &preResp
if preResp.MaxReactions > 0 {
allReactions, err := portal.Bridge.DB.Reaction.GetAllToMessageBySender(ctx, reactionTarget.ID, preResp.SenderID)
allReactions, err := portal.Bridge.DB.Reaction.GetAllToMessageBySender(ctx, portal.Receiver, reactionTarget.ID, preResp.SenderID)
if err != nil {
log.Err(err).Msg("Failed to get all reactions to message by sender")
portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get previous reactions: %w", ErrDatabaseError, err))
@ -2162,9 +2162,9 @@ func (portal *Portal) getTargetMessagePart(ctx context.Context, evt RemoteEventW
func (portal *Portal) getTargetReaction(ctx context.Context, evt RemoteReactionRemove) (*database.Reaction, error) {
if partTargeter, ok := evt.(RemoteEventWithTargetPart); ok {
return portal.Bridge.DB.Reaction.GetByID(ctx, evt.GetTargetMessage(), partTargeter.GetTargetMessagePart(), evt.GetSender().Sender, evt.GetRemovedEmojiID())
return portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, evt.GetTargetMessage(), partTargeter.GetTargetMessagePart(), evt.GetSender().Sender, evt.GetRemovedEmojiID())
} else {
return portal.Bridge.DB.Reaction.GetByIDWithoutMessagePart(ctx, evt.GetTargetMessage(), evt.GetSender().Sender, evt.GetRemovedEmojiID())
return portal.Bridge.DB.Reaction.GetByIDWithoutMessagePart(ctx, portal.Receiver, evt.GetTargetMessage(), evt.GetSender().Sender, evt.GetRemovedEmojiID())
}
}
@ -2196,9 +2196,9 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User
}
var existingReactions []*database.Reaction
if partTargeter, ok := evt.(RemoteEventWithTargetPart); ok {
existingReactions, err = portal.Bridge.DB.Reaction.GetAllToMessagePart(ctx, evt.GetTargetMessage(), partTargeter.GetTargetMessagePart())
existingReactions, err = portal.Bridge.DB.Reaction.GetAllToMessagePart(ctx, portal.Receiver, evt.GetTargetMessage(), partTargeter.GetTargetMessagePart())
} else {
existingReactions, err = portal.Bridge.DB.Reaction.GetAllToMessage(ctx, evt.GetTargetMessage())
existingReactions, err = portal.Bridge.DB.Reaction.GetAllToMessage(ctx, portal.Receiver, evt.GetTargetMessage())
}
existing := make(map[networkid.UserID]map[networkid.EmojiID]*database.Reaction)
for _, existingReaction := range existingReactions {
@ -2317,7 +2317,7 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi
return
}
emoji, emojiID := evt.GetReactionEmoji()
existingReaction, err := portal.Bridge.DB.Reaction.GetByID(ctx, targetMessage.ID, targetMessage.PartID, evt.GetSender().Sender, emojiID)
existingReaction, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, targetMessage.ID, targetMessage.PartID, evt.GetSender().Sender, emojiID)
if err != nil {
log.Err(err).Msg("Failed to check if reaction is a duplicate")
return