From e7df474b56ae85148c4eefdc5c61d4693224a97c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 31 Mar 2023 13:22:46 +0300 Subject: [PATCH] Fix FindSharedRooms for non-bridge sqlstatestores --- bridge/bridge.go | 2 +- crypto/cryptohelper/cryptohelper.go | 2 +- sqlstatestore/statestore.go | 16 +++++++++++++--- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/bridge/bridge.go b/bridge/bridge.go index c327c933..4c65c035 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -492,7 +492,7 @@ func (br *Bridge) init() { br.DB.IgnoreForeignTables = *ignoreForeignTables br.ZLog.Debug().Msg("Initializing state store") - br.StateStore = sqlstatestore.NewSQLStateStore(br.DB, dbutil.ZeroLogger(br.ZLog.With().Str("db_section", "matrix_state").Logger())) + br.StateStore = sqlstatestore.NewSQLStateStore(br.DB, dbutil.ZeroLogger(br.ZLog.With().Str("db_section", "matrix_state").Logger()), true) br.AS.StateStore = br.StateStore br.ZLog.Debug().Msg("Initializing Matrix event processor") diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index c8498d56..f5f860bb 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -86,7 +86,7 @@ func NewCryptoHelper(cli *mautrix.Client, pickleKey []byte, store any) (*CryptoH } log := cli.Log.With().Str("component", "crypto").Logger() if cli.StateStore == nil && dbForManagedStores != nil { - managedStateStore = sqlstatestore.NewSQLStateStore(dbForManagedStores, dbutil.ZeroLogger(log.With().Str("db_section", "matrix_state").Logger())) + managedStateStore = sqlstatestore.NewSQLStateStore(dbForManagedStores, dbutil.ZeroLogger(log.With().Str("db_section", "matrix_state").Logger()), false) cli.StateStore = managedStateStore } else if _, isCryptoCompatible := cli.StateStore.(crypto.StateStore); !isCryptoCompatible { return nil, fmt.Errorf("the client state store must implement crypto.StateStore") diff --git a/sqlstatestore/statestore.go b/sqlstatestore/statestore.go index 429e308a..ab9b22d3 100644 --- a/sqlstatestore/statestore.go +++ b/sqlstatestore/statestore.go @@ -32,11 +32,13 @@ const VersionTableName = "mx_version" type SQLStateStore struct { *dbutil.Database + IsBridge bool } -func NewSQLStateStore(db *dbutil.Database, log dbutil.DatabaseLogger) *SQLStateStore { +func NewSQLStateStore(db *dbutil.Database, log dbutil.DatabaseLogger, isBridge bool) *SQLStateStore { return &SQLStateStore{ Database: db.Child(VersionTableName, UpgradeTable, log), + IsBridge: isBridge, } } @@ -130,11 +132,19 @@ func (store *SQLStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (*e } func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID) { - rows, err := store.Query(` + query := ` SELECT room_id FROM mx_user_profile LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id WHERE mx_user_profile.user_id=$1 AND portal.encrypted=true - `, userID) + ` + if !store.IsBridge { + query = ` + SELECT room_id FROM mx_user_profile + LEFT JOIN mx_room_state ON mx_room_state.room_id=mx_user_profile.room_id + WHERE mx_user_profile.user_id=$1 AND mx_room_state.encryption IS NOT NULL + ` + } + rows, err := store.Query(query, userID) if err != nil { store.Log.Warn("Failed to query shared rooms with %s: %v", userID, err) return