From 9e1a8cd56e31a99a005f9327ade5d2688c30422f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 20 Aug 2024 16:19:49 +0300 Subject: [PATCH] bridgev2/matrix: use cached member list if available --- bridgev2/matrix/connector.go | 7 ++++- bridgev2/matrix/intent.go | 4 +++ client.go | 12 +++++++ hicli/database/statestore.go | 13 ++++++++ sqlstatestore/statestore.go | 41 ++++++++++++++++++++++++ sqlstatestore/v00-latest-revision.sql | 19 ++++++------ sqlstatestore/v07-full-member-flag.sql | 2 ++ statestore.go | 43 +++++++++++++++++++++----- 8 files changed, 123 insertions(+), 18 deletions(-) create mode 100644 sqlstatestore/v07-full-member-flag.sql diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index dee28b8d..115250f2 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -540,7 +540,12 @@ func (br *Connector) GetPowerLevels(ctx context.Context, roomID id.RoomID) (*eve } func (br *Connector) GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { - // TODO use cache? + fetched, err := br.Bot.StateStore.HasFetchedMembers(ctx, roomID) + if err != nil { + return nil, err + } else if fetched { + return br.Bot.StateStore.GetAllMembers(ctx, roomID) + } members, err := br.Bot.Members(ctx, roomID) if err != nil { return nil, err diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 43d23a4c..0f668f8c 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -535,6 +535,10 @@ func (as *ASIntent) DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnl if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to leave room while cleaning up portal") } + err = as.Matrix.StateStore.ClearCachedMembers(ctx, roomID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to clear cached members while cleaning up portal") + } return nil } diff --git a/client.go b/client.go index a0e86bdb..b3bd9158 100644 --- a/client.go +++ b/client.go @@ -1444,6 +1444,11 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt UpdateStateStore(ctx, cli.StateStore, evt) } } + clearErr = cli.StateStore.MarkMembersFetched(ctx, roomID) + if clearErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(clearErr). + Msg("Failed to mark members as fetched after fetching full room state") + } } return } @@ -1840,6 +1845,13 @@ func (cli *Client) Members(ctx context.Context, roomID id.RoomID, req ...ReqMemb for _, evt := range resp.Chunk { UpdateStateStore(ctx, cli.StateStore, evt) } + if extra.NotMembership == "" && extra.Membership == "" { + markErr := cli.StateStore.MarkMembersFetched(ctx, roomID) + if markErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(markErr). + Msg("Failed to mark members as fetched after fetching full member list") + } + } } return } diff --git a/hicli/database/statestore.go b/hicli/database/statestore.go index e8050e93..cefe76d3 100644 --- a/hicli/database/statestore.go +++ b/hicli/database/statestore.go @@ -10,6 +10,7 @@ import ( "context" "database/sql" "errors" + "fmt" "go.mau.fi/util/dbutil" "golang.org/x/exp/slices" @@ -115,6 +116,18 @@ func (c *ClientStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, ro return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() } +func (c *ClientStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error) { + return false, fmt.Errorf("not implemented") +} + +func (c *ClientStateStore) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error { + return fmt.Errorf("not implemented") +} + +func (c *ClientStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { + return nil, fmt.Errorf("not implemented") +} + func (c *ClientStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (isEncrypted bool, err error) { err = c.QueryRow(ctx, isRoomEncryptedQuery, roomID).Scan(&isEncrypted) if errors.Is(err, sql.ErrNoRows) { diff --git a/sqlstatestore/statestore.go b/sqlstatestore/statestore.go index 0e5c4184..2cfd1b97 100644 --- a/sqlstatestore/statestore.go +++ b/sqlstatestore/statestore.go @@ -234,9 +234,50 @@ func (store *SQLStateStore) ClearCachedMembers(ctx context.Context, roomID id.Ro query += fmt.Sprintf(" AND membership IN (%s)", strings.Join(placeholders, ",")) } _, err := store.Exec(ctx, query, params...) + if err != nil { + return err + } + _, err = store.Exec(ctx, "UPDATE mx_room_state SET members_fetched=false WHERE room_id=$1", roomID) return err } +func (store *SQLStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (fetched bool, err error) { + err = store.QueryRow(ctx, "SELECT members_fetched FROM mx_room_state WHERE room_id=$1", roomID).Scan(&fetched) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + +func (store *SQLStateStore) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error { + _, err := store.Exec(ctx, ` + INSERT INTO mx_room_state (room_id, members_fetched) VALUES ($1, true) + ON CONFLICT (room_id) DO UPDATE SET members_fetched=true + `, roomID) + return err +} + +type userAndMembership struct { + UserID id.UserID + event.MemberEventContent +} + +func (store *SQLStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { + rows, err := store.Query(ctx, "SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1", roomID) + if err != nil { + return nil, err + } + output := make(map[id.UserID]*event.MemberEventContent) + err = dbutil.NewRowIterWithError(rows, func(row dbutil.Scannable) (res userAndMembership, err error) { + err = row.Scan(&res.UserID, &res.Membership, &res.Displayname, &res.AvatarURL) + return + }, err).Iter(func(member userAndMembership) (bool, error) { + output[member.UserID] = &member.MemberEventContent + return true, nil + }) + return output, err +} + func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error { contentBytes, err := json.Marshal(content) if err != nil { diff --git a/sqlstatestore/v00-latest-revision.sql b/sqlstatestore/v00-latest-revision.sql index b2bb2ae6..a58cc56a 100644 --- a/sqlstatestore/v00-latest-revision.sql +++ b/sqlstatestore/v00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v6 (compatible with v3+): Latest revision +-- v0 -> v7 (compatible with v3+): Latest revision CREATE TABLE mx_registrations ( user_id TEXT PRIMARY KEY @@ -8,11 +8,11 @@ CREATE TABLE mx_registrations ( CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock'); CREATE TABLE mx_user_profile ( - room_id TEXT, - user_id TEXT, - membership membership NOT NULL, - displayname TEXT NOT NULL DEFAULT '', - avatar_url TEXT NOT NULL DEFAULT '', + room_id TEXT, + user_id TEXT, + membership membership NOT NULL, + displayname TEXT NOT NULL DEFAULT '', + avatar_url TEXT NOT NULL DEFAULT '', name_skeleton bytea, @@ -23,7 +23,8 @@ CREATE INDEX mx_user_profile_membership_idx ON mx_user_profile (room_id, members CREATE INDEX mx_user_profile_name_skeleton_idx ON mx_user_profile (room_id, name_skeleton); CREATE TABLE mx_room_state ( - room_id TEXT PRIMARY KEY, - power_levels jsonb, - encryption jsonb + room_id TEXT PRIMARY KEY, + power_levels jsonb, + encryption jsonb, + members_fetched BOOLEAN NOT NULL DEFAULT false ); diff --git a/sqlstatestore/v07-full-member-flag.sql b/sqlstatestore/v07-full-member-flag.sql new file mode 100644 index 00000000..32f2ef6c --- /dev/null +++ b/sqlstatestore/v07-full-member-flag.sql @@ -0,0 +1,2 @@ +-- v7 (compatible with v3+): Add flag for whether the full member list has been fetched +ALTER TABLE mx_room_state ADD COLUMN members_fetched BOOLEAN NOT NULL DEFAULT false; diff --git a/statestore.go b/statestore.go index 35bfc6ab..5f210e4f 100644 --- a/statestore.go +++ b/statestore.go @@ -8,6 +8,7 @@ package mautrix import ( "context" + "maps" "sync" "github.com/rs/zerolog" @@ -32,6 +33,10 @@ type StateStore interface { SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error) + HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error) + MarkMembersFetched(ctx context.Context, roomID id.RoomID) error + GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) + SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) @@ -90,10 +95,11 @@ func (cli *Client) StateStoreSyncHandler(ctx context.Context, evt *event.Event) } type MemoryStateStore struct { - Registrations map[id.UserID]bool `json:"registrations"` - Members map[id.RoomID]map[id.UserID]*event.MemberEventContent `json:"memberships"` - PowerLevels map[id.RoomID]*event.PowerLevelsEventContent `json:"power_levels"` - Encryption map[id.RoomID]*event.EncryptionEventContent `json:"encryption"` + Registrations map[id.UserID]bool `json:"registrations"` + Members map[id.RoomID]map[id.UserID]*event.MemberEventContent `json:"memberships"` + MembersFetched map[id.RoomID]bool `json:"members_fetched"` + PowerLevels map[id.RoomID]*event.PowerLevelsEventContent `json:"power_levels"` + Encryption map[id.RoomID]*event.EncryptionEventContent `json:"encryption"` registrationsLock sync.RWMutex membersLock sync.RWMutex @@ -103,10 +109,11 @@ type MemoryStateStore struct { func NewMemoryStateStore() StateStore { return &MemoryStateStore{ - Registrations: make(map[id.UserID]bool), - Members: make(map[id.RoomID]map[id.UserID]*event.MemberEventContent), - PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent), - Encryption: make(map[id.RoomID]*event.EncryptionEventContent), + Registrations: make(map[id.UserID]bool), + Members: make(map[id.RoomID]map[id.UserID]*event.MemberEventContent), + MembersFetched: make(map[id.RoomID]bool), + PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent), + Encryption: make(map[id.RoomID]*event.EncryptionEventContent), } } @@ -246,9 +253,29 @@ func (store *MemoryStateStore) ClearCachedMembers(_ context.Context, roomID id.R } } } + store.MembersFetched[roomID] = false return nil } +func (store *MemoryStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error) { + store.membersLock.RLock() + defer store.membersLock.RUnlock() + return store.MembersFetched[roomID], nil +} + +func (store *MemoryStateStore) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error { + store.membersLock.Lock() + defer store.membersLock.Unlock() + store.MembersFetched[roomID] = true + return nil +} + +func (store *MemoryStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { + store.membersLock.Lock() + defer store.membersLock.Unlock() + return maps.Clone(store.Members[roomID]), nil +} + func (store *MemoryStateStore) SetPowerLevels(_ context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error { store.powerLevelsLock.Lock() store.PowerLevels[roomID] = levels