bridgev2: add disambiguation for relayed user displaynames

This commit is contained in:
Tulir Asokan 2024-08-02 21:33:29 +03:00
commit 83d3a0de5b
12 changed files with 174 additions and 23 deletions

View file

@ -533,6 +533,10 @@ func (br *Connector) GetMemberInfo(ctx context.Context, roomID id.RoomID, userID
return br.AS.StateStore.GetMember(ctx, roomID, userID) return br.AS.StateStore.GetMember(ctx, roomID, userID)
} }
func (br *Connector) IsConfusableName(ctx context.Context, roomID id.RoomID, userID id.UserID, name string) ([]id.UserID, error) {
return br.AS.StateStore.IsConfusableName(ctx, roomID, userID, name)
}
func (br *Connector) BatchSend(ctx context.Context, roomID id.RoomID, req *mautrix.ReqBeeperBatchSend, extras []*bridgev2.MatrixSendExtra) (*mautrix.RespBeeperBatchSend, error) { func (br *Connector) BatchSend(ctx context.Context, roomID id.RoomID, req *mautrix.ReqBeeperBatchSend, extras []*bridgev2.MatrixSendExtra) (*mautrix.RespBeeperBatchSend, error) {
if encrypted, err := br.StateStore.IsEncrypted(ctx, roomID); err != nil { if encrypted, err := br.StateStore.IsEncrypted(ctx, roomID); err != nil {
return nil, fmt.Errorf("failed to check if room is encrypted: %w", err) return nil, fmt.Errorf("failed to check if room is encrypted: %w", err)

View file

@ -45,15 +45,25 @@ bridge:
# List of user login IDs which anyone can set as a relay, as long as the relay user is in the room. # List of user login IDs which anyone can set as a relay, as long as the relay user is in the room.
default_relays: [] default_relays: []
# The formats to use when sending messages via the relaybot. # The formats to use when sending messages via the relaybot.
# Available variables:
# .Sender.UserID - The Matrix user ID of the sender.
# .Sender.Displayname - The display name of the sender (if set).
# .Sender.RequiresDisambiguation - Whether the sender's name may be confused with the name of another user in the room.
# .Sender.DisambiguatedName - The disambiguated name of the sender. This will be the displayname if set,
# plus the user ID in parentheses if the displayname is not unique.
# If the displayname is not set, this is just the user ID.
# .Message - The `formatted_body` field of the message.
# .Caption - The `formatted_body` field of the message, if it's a caption. Otherwise an empty string.
# .FileName - The name of the file being sent.
message_formats: message_formats:
m.text: "<b>{{ .Sender.Displayname }}</b>: {{ .Message }}" m.text: "<b>{{ .Sender.DisambiguatedName }}</b>: {{ .Message }}"
m.notice: "<b>{{ .Sender.Displayname }}</b>: {{ .Message }}" m.notice: "<b>{{ .Sender.DisambiguatedName }}</b>: {{ .Message }}"
m.emote: "* <b>{{ .Sender.Displayname }}</b> {{ .Message }}" m.emote: "* <b>{{ .Sender.DisambiguatedName }}</b> {{ .Message }}"
m.file: "<b>{{ .Sender.Displayname }}</b> sent a file{{ if .Caption }}: {{ .Caption }}{{ end }}" m.file: "<b>{{ .Sender.DisambiguatedName }}</b> sent a file{{ if .Caption }}: {{ .Caption }}{{ end }}"
m.image: "<b>{{ .Sender.Displayname }}</b> sent an image{{ if .Caption }}: {{ .Caption }}{{ end }}" m.image: "<b>{{ .Sender.DisambiguatedName }}</b> sent an image{{ if .Caption }}: {{ .Caption }}{{ end }}"
m.audio: "<b>{{ .Sender.Displayname }}</b> sent an audio file{{ if .Caption }}: {{ .Caption }}{{ end }}" m.audio: "<b>{{ .Sender.DisambiguatedName }}</b> sent an audio file{{ if .Caption }}: {{ .Caption }}{{ end }}"
m.video: "<b>{{ .Sender.Displayname }}</b> sent a video{{ if .Caption }}: {{ .Caption }}{{ end }}" m.video: "<b>{{ .Sender.DisambiguatedName }}</b> sent a video{{ if .Caption }}: {{ .Caption }}{{ end }}"
m.location: "<b>{{ .Sender.Displayname }}</b> sent a location{{ if .Caption }}: {{ .Caption }}{{ end }}" m.location: "<b>{{ .Sender.DisambiguatedName }}</b> sent a location{{ if .Caption }}: {{ .Caption }}{{ end }}"
# Permissions for using the bridge. # Permissions for using the bridge.
# Permitted values: # Permitted values:

View file

@ -58,6 +58,10 @@ type MatrixConnectorWithServer interface {
GetRouter() *mux.Router GetRouter() *mux.Router
} }
type MatrixConnectorWithNameDisambiguation interface {
IsConfusableName(ctx context.Context, roomID id.RoomID, userID id.UserID, name string) ([]id.UserID, error)
}
type MatrixSendExtra struct { type MatrixSendExtra struct {
Timestamp time.Time Timestamp time.Time
MessageMeta *database.Message MessageMeta *database.Message

View file

@ -897,7 +897,12 @@ type RemoteTypingWithType interface {
} }
type OrigSender struct { type OrigSender struct {
User *User User *User
UserID id.UserID
RequiresDisambiguation bool
DisambiguatedName string
event.MemberEventContent event.MemberEventContent
} }

View file

@ -373,6 +373,26 @@ func (portal *Portal) sendErrorStatus(ctx context.Context, evt *event.Event, err
portal.Bridge.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) portal.Bridge.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt))
} }
func (portal *Portal) checkConfusableName(ctx context.Context, userID id.UserID, name string) bool {
conn, ok := portal.Bridge.Matrix.(MatrixConnectorWithNameDisambiguation)
if !ok {
return false
}
confusableWith, err := conn.IsConfusableName(ctx, portal.MXID, userID, name)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to check if name is confusable")
return true
}
for _, confusable := range confusableWith {
// Don't disambiguate names that only conflict with ghosts of this bridge
_, isGhost := portal.Bridge.Matrix.ParseGhostMXID(confusable)
if !isGhost {
return true
}
}
return false
}
func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) { func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) {
log := portal.Log.With(). log := portal.Log.With().
Str("action", "handle matrix event"). Str("action", "handle matrix event").
@ -423,7 +443,8 @@ func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) {
if login == nil { if login == nil {
login = portal.Relay login = portal.Relay
origSender = &OrigSender{ origSender = &OrigSender{
User: sender, User: sender,
UserID: sender.MXID,
} }
memberInfo, err := portal.Bridge.Matrix.GetMemberInfo(ctx, portal.MXID, sender.MXID) memberInfo, err := portal.Bridge.Matrix.GetMemberInfo(ctx, portal.MXID, sender.MXID)
@ -431,6 +452,15 @@ func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) {
log.Warn().Err(err).Msg("Failed to get member info for user being relayed") log.Warn().Err(err).Msg("Failed to get member info for user being relayed")
} else if memberInfo != nil { } else if memberInfo != nil {
origSender.MemberEventContent = *memberInfo origSender.MemberEventContent = *memberInfo
if memberInfo.Displayname == "" {
origSender.DisambiguatedName = sender.MXID.String()
} else if origSender.RequiresDisambiguation = portal.checkConfusableName(ctx, sender.MXID, memberInfo.Displayname); origSender.RequiresDisambiguation {
origSender.DisambiguatedName = fmt.Sprintf("%s (%s)", memberInfo.Displayname, sender.MXID)
} else {
origSender.DisambiguatedName = memberInfo.Displayname
}
} else {
origSender.DisambiguatedName = sender.MXID.String()
} }
} }
log.UpdateContext(func(c zerolog.Context) zerolog.Context { log.UpdateContext(func(c zerolog.Context) zerolog.Context {

7
go.mod
View file

@ -12,13 +12,13 @@ require (
github.com/rs/zerolog v1.33.0 github.com/rs/zerolog v1.33.0
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.17.1 github.com/tidwall/gjson v1.17.3
github.com/tidwall/sjson v1.2.5 github.com/tidwall/sjson v1.2.5
github.com/yuin/goldmark v1.7.4 github.com/yuin/goldmark v1.7.4
go.mau.fi/util v0.6.1-0.20240719175439-20a6073e1dd4 go.mau.fi/util v0.6.1-0.20240802175451-b430ebbffc98
go.mau.fi/zeroconfig v0.1.3 go.mau.fi/zeroconfig v0.1.3
golang.org/x/crypto v0.25.0 golang.org/x/crypto v0.25.0
golang.org/x/exp v0.0.0-20240707233637-46b078467d37 golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56
golang.org/x/net v0.27.0 golang.org/x/net v0.27.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
maunium.net/go/mauflag v1.0.0 maunium.net/go/mauflag v1.0.0
@ -33,5 +33,6 @@ require (
github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect github.com/tidwall/pretty v1.2.0 // indirect
golang.org/x/sys v0.22.0 // indirect golang.org/x/sys v0.22.0 // indirect
golang.org/x/text v0.16.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
) )

14
go.sum
View file

@ -36,8 +36,8 @@ github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDq
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U= github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94=
github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
@ -46,14 +46,14 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg=
github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
go.mau.fi/util v0.6.1-0.20240719175439-20a6073e1dd4 h1:CYKYs5jwJ0bFJqh6pRoWtC9NIJ0lz0/6i2SC4qEBFaU= go.mau.fi/util v0.6.1-0.20240802175451-b430ebbffc98 h1:gJ0peWecBm6TtlxKFVIc1KbooXSCHtPfsfb2Eha5A0A=
go.mau.fi/util v0.6.1-0.20240719175439-20a6073e1dd4/go.mod h1:ljYdq3sPfpICc3zMU+/mHV/sa4z0nKxc67hSBwnrk8U= go.mau.fi/util v0.6.1-0.20240802175451-b430ebbffc98/go.mod h1:S1juuPWGau2GctPY3FR/4ec/MDLhAG2QPhdnUwpzWIo=
go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM=
go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70=
golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30=
golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M=
golang.org/x/exp v0.0.0-20240707233637-46b078467d37 h1:uLDX+AfeFCct3a2C7uIWBKMJIR3CJMhcgfrUAqjRK6w= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8=
golang.org/x/exp v0.0.0-20240707233637-46b078467d37/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY=
golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys=
golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE=
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@ -62,6 +62,8 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=

View file

@ -92,6 +92,11 @@ func (c *ClientStateStore) TryGetMember(ctx context.Context, roomID id.RoomID, u
return return
} }
func (c *ClientStateStore) IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error) {
//TODO implement me
panic("implement me")
}
func (c *ClientStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (content *event.PowerLevelsEventContent, err error) { func (c *ClientStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (content *event.PowerLevelsEventContent, err error) {
err = c.QueryRow(ctx, getStateEventContentQuery, roomID, event.StatePowerLevels.Type, "").Scan(&dbutil.JSON{Data: &content}) err = c.QueryRow(ctx, getStateEventContentQuery, roomID, event.StatePowerLevels.Type, "").Scan(&dbutil.JSON{Data: &content})
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {

View file

@ -17,6 +17,7 @@ import (
"strings" "strings"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"go.mau.fi/util/confusable"
"go.mau.fi/util/dbutil" "go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/event" "maunium.net/go/mautrix/event"
@ -37,6 +38,8 @@ const VersionTableName = "mx_version"
type SQLStateStore struct { type SQLStateStore struct {
*dbutil.Database *dbutil.Database
IsBridge bool IsBridge bool
DisableNameDisambiguation bool
} }
func NewSQLStateStore(db *dbutil.Database, log dbutil.DatabaseLogger, isBridge bool) *SQLStateStore { func NewSQLStateStore(db *dbutil.Database, log dbutil.DatabaseLogger, isBridge bool) *SQLStateStore {
@ -65,6 +68,7 @@ func (store *SQLStateStore) MarkRegistered(ctx context.Context, userID id.UserID
type Member struct { type Member struct {
id.UserID id.UserID
event.MemberEventContent event.MemberEventContent
NameSkeleton [32]byte
} }
func (store *SQLStateStore) GetRoomMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) (map[id.UserID]*event.MemberEventContent, error) { func (store *SQLStateStore) GetRoomMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) (map[id.UserID]*event.MemberEventContent, error) {
@ -191,13 +195,32 @@ func (store *SQLStateStore) SetMembership(ctx context.Context, roomID id.RoomID,
} }
func (store *SQLStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error { func (store *SQLStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error {
var nameSkeleton []byte
if !store.DisableNameDisambiguation && len(member.Displayname) > 0 {
nameSkeletonArr := confusable.SkeletonHash(member.Displayname)
nameSkeleton = nameSkeletonArr[:]
}
_, err := store.Exec(ctx, ` _, err := store.Exec(ctx, `
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5) INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url, name_skeleton)
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership, displayname=excluded.displayname, avatar_url=excluded.avatar_url VALUES ($1, $2, $3, $4, $5, $6)
`, roomID, userID, member.Membership, member.Displayname, member.AvatarURL) ON CONFLICT (room_id, user_id) DO UPDATE
SET membership=excluded.membership,
displayname=excluded.displayname,
avatar_url=excluded.avatar_url,
name_skeleton=excluded.name_skeleton
`, roomID, userID, member.Membership, member.Displayname, member.AvatarURL, nameSkeleton)
return err return err
} }
func (store *SQLStateStore) IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error) {
if store.DisableNameDisambiguation {
return nil, nil
}
skeleton := confusable.SkeletonHash(name)
rows, err := store.Query(ctx, "SELECT user_id FROM mx_user_profile WHERE room_id=$1 AND name_skeleton=$2 AND user_id<>$3", roomID, skeleton[:], currentUser)
return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList()
}
func (store *SQLStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error { func (store *SQLStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error {
query := "DELETE FROM mx_user_profile WHERE room_id=$1" query := "DELETE FROM mx_user_profile WHERE room_id=$1"
params := make([]any, len(memberships)+1) params := make([]any, len(memberships)+1)

View file

@ -1,4 +1,4 @@
-- v0 -> v5: Latest revision -- v0 -> v6 (compatible with v3+): Latest revision
CREATE TABLE mx_registrations ( CREATE TABLE mx_registrations (
user_id TEXT PRIMARY KEY user_id TEXT PRIMARY KEY
@ -13,9 +13,15 @@ CREATE TABLE mx_user_profile (
membership membership NOT NULL, membership membership NOT NULL,
displayname TEXT NOT NULL DEFAULT '', displayname TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '', avatar_url TEXT NOT NULL DEFAULT '',
name_skeleton bytea,
PRIMARY KEY (room_id, user_id) PRIMARY KEY (room_id, user_id)
); );
CREATE INDEX mx_user_profile_membership_idx ON mx_user_profile (room_id, membership);
CREATE INDEX mx_user_profile_name_skeleton_idx ON mx_user_profile (room_id, name_skeleton);
CREATE TABLE mx_room_state ( CREATE TABLE mx_room_state (
room_id TEXT PRIMARY KEY, room_id TEXT PRIMARY KEY,
power_levels jsonb, power_levels jsonb,

View file

@ -0,0 +1,55 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package sqlstatestore
import (
"context"
"go.mau.fi/util/confusable"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/id"
)
type roomUserName struct {
RoomID id.RoomID
UserID id.UserID
Name string
}
func init() {
UpgradeTable.Register(-1, 6, 3, "Add disambiguation column for user profiles", dbutil.TxnModeOn, func(ctx context.Context, db *dbutil.Database) error {
_, err := db.Exec(ctx, `
ALTER TABLE mx_user_profile ADD COLUMN name_skeleton bytea;
CREATE INDEX mx_user_profile_membership_idx ON mx_user_profile (room_id, membership);
CREATE INDEX mx_user_profile_name_skeleton_idx ON mx_user_profile (room_id, name_skeleton);
`)
if err != nil {
return err
}
const ChunkSize = 1000
const GetEntriesChunkQuery = "SELECT room_id, user_id, displayname FROM mx_user_profile WHERE displayname<>'' LIMIT $1 OFFSET $2"
const SetSkeletonHashQuery = `UPDATE mx_user_profile SET name_skeleton = $3 WHERE room_id = $1 AND user_id = $2`
for offset := 0; ; offset += ChunkSize {
entries, err := dbutil.NewSimpleReflectRowIter[roomUserName](db.Query(ctx, GetEntriesChunkQuery, ChunkSize, offset)).AsList()
if err != nil {
return err
}
for _, entry := range entries {
skel := confusable.SkeletonHash(entry.Name)
_, err = db.Exec(ctx, SetSkeletonHashQuery, entry.RoomID, entry.UserID, skel[:])
if err != nil {
return err
}
}
if len(entries) < ChunkSize {
break
}
}
return nil
})
}

View file

@ -26,6 +26,7 @@ type StateStore interface {
TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error)
SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error
SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error
IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error)
ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error
SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error
@ -151,6 +152,11 @@ func (store *MemoryStateStore) GetMember(ctx context.Context, roomID id.RoomID,
return member, err return member, err
} }
func (store *MemoryStateStore) IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error) {
// TODO implement?
return nil, nil
}
func (store *MemoryStateStore) TryGetMember(_ context.Context, roomID id.RoomID, userID id.UserID) (member *event.MemberEventContent, err error) { func (store *MemoryStateStore) TryGetMember(_ context.Context, roomID id.RoomID, userID id.UserID) (member *event.MemberEventContent, err error) {
store.membersLock.RLock() store.membersLock.RLock()
defer store.membersLock.RUnlock() defer store.membersLock.RUnlock()