From e25578d435a2a5bf972e4c59fde99ba78c2ce112 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 27 Jun 2024 11:32:50 +0300 Subject: [PATCH] bridgev2: improve handling of user logins in bad credentials --- bridgev2/bridge.go | 25 ++++++++++++++++++------- bridgev2/database/userlogin.go | 13 +++++++------ bridgev2/user.go | 6 ++++++ bridgev2/userlogin.go | 24 ++++++++---------------- 4 files changed, 39 insertions(+), 29 deletions(-) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 765041d2..36f7aa06 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -107,18 +107,29 @@ func (br *Bridge) Start() error { go br.DisappearLoop.Start() } - logins, err := br.GetAllUserLogins(ctx) + userIDs, err := br.DB.UserLogin.GetAllUserIDsWithLogins(ctx) if err != nil { - return fmt.Errorf("failed to get user logins: %w", err) + return fmt.Errorf("failed to get users with logins: %w", err) } - for _, login := range logins { - br.Log.Info().Str("id", string(login.ID)).Msg("Starting user login") - err = login.Client.Connect(login.Log.WithContext(ctx)) + startedAny := false + for _, userID := range userIDs { + br.Log.Info().Stringer("user_id", userID).Msg("Loading user") + var user *User + user, err = br.GetUserByMXID(ctx, userID) if err != nil { - br.Log.Err(err).Msg("Failed to connect existing client") + br.Log.Err(err).Stringer("user_id", userID).Msg("Failed to load user") + } else { + for _, login := range user.GetCachedUserLogins() { + startedAny = true + br.Log.Info().Str("id", string(login.ID)).Msg("Starting user login") + err = login.Client.Connect(login.Log.WithContext(ctx)) + if err != nil { + br.Log.Err(err).Msg("Failed to connect existing client") + } + } } } - if len(logins) == 0 { + if !startedAny { br.Log.Info().Msg("No user logins found") br.SendGlobalBridgeState(status.BridgeState{StateEvent: status.StateUnconfigured}) } diff --git a/bridgev2/database/userlogin.go b/bridgev2/database/userlogin.go index b371483c..cc92e7d4 100644 --- a/bridgev2/database/userlogin.go +++ b/bridgev2/database/userlogin.go @@ -54,10 +54,10 @@ const ( getUserLoginBaseQuery = ` SELECT bridge_id, user_mxid, id, space_room, metadata FROM user_login ` - getLoginByIDQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND id=$2` - getAllLoginsQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1` - getAllLoginsForUserQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND user_mxid=$2` - getAllLoginsInPortalQuery = ` + getLoginByIDQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND id=$2` + getAllUsersWithLoginsQuery = `SELECT DISTINCT user_mxid FROM user_login WHERE bridge_id=$1` + getAllLoginsForUserQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND user_mxid=$2` + getAllLoginsInPortalQuery = ` SELECT ul.bridge_id, ul.user_mxid, ul.id, ul.space_room, ul.metadata FROM user_portal LEFT JOIN user_login ul ON user_portal.bridge_id=ul.bridge_id AND user_portal.user_mxid=ul.user_mxid AND user_portal.login_id=ul.id WHERE user_portal.bridge_id=$1 AND user_portal.portal_id=$2 AND user_portal.portal_receiver=$3 @@ -79,8 +79,9 @@ func (uq *UserLoginQuery) GetByID(ctx context.Context, id networkid.UserLoginID) return uq.QueryOne(ctx, getLoginByIDQuery, uq.BridgeID, id) } -func (uq *UserLoginQuery) GetAll(ctx context.Context) ([]*UserLogin, error) { - return uq.QueryMany(ctx, getAllLoginsQuery, uq.BridgeID) +func (uq *UserLoginQuery) GetAllUserIDsWithLogins(ctx context.Context) ([]id.UserID, error) { + rows, err := uq.GetDB().Query(ctx, getAllUsersWithLoginsQuery, uq.BridgeID) + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() } func (uq *UserLoginQuery) GetAllInPortal(ctx context.Context, portal networkid.PortalKey) ([]*UserLogin, error) { diff --git a/bridgev2/user.go b/bridgev2/user.go index 500a51c0..9fca8de3 100644 --- a/bridgev2/user.go +++ b/bridgev2/user.go @@ -158,6 +158,12 @@ func (user *User) GetUserLoginIDs() []networkid.UserLoginID { return maps.Keys(user.logins) } +func (user *User) GetCachedUserLogins() []*UserLogin { + user.Bridge.cacheLock.Lock() + defer user.Bridge.cacheLock.Unlock() + return maps.Values(user.logins) +} + func (user *User) GetFormattedUserLogins() string { user.Bridge.cacheLock.Lock() logins := make([]string, len(user.logins)) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index e18c2b25..1bc81190 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -56,7 +56,8 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da } err := br.Network.LoadUserLogin(ctx, userLogin) if err != nil { - return nil, fmt.Errorf("failed to prepare: %w", err) + userLogin.Log.Err(err).Msg("Failed to load user login") + return nil, nil } user.logins[userLogin.ID] = userLogin br.userLoginsByID[userLogin.ID] = userLogin @@ -65,16 +66,17 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da } func (br *Bridge) loadManyUserLogins(ctx context.Context, user *User, logins []*database.UserLogin) ([]*UserLogin, error) { - output := make([]*UserLogin, len(logins)) - for i, dbLogin := range logins { + output := make([]*UserLogin, 0, len(logins)) + for _, dbLogin := range logins { if cached, ok := br.userLoginsByID[dbLogin.ID]; ok { - output[i] = cached + output = append(output, cached) } else { loaded, err := br.loadUserLogin(ctx, user, dbLogin) if err != nil { - return nil, fmt.Errorf("failed to load user login: %w", err) + return nil, err + } else if loaded != nil { + output = append(output, loaded) } - output[i] = loaded } } return output, nil @@ -89,16 +91,6 @@ func (br *Bridge) unlockedLoadUserLoginsByMXID(ctx context.Context, user *User) return err } -func (br *Bridge) GetAllUserLogins(ctx context.Context) ([]*UserLogin, error) { - logins, err := br.DB.UserLogin.GetAll(ctx) - if err != nil { - return nil, err - } - br.cacheLock.Lock() - defer br.cacheLock.Unlock() - return br.loadManyUserLogins(ctx, nil, logins) -} - func (br *Bridge) GetUserLoginsInPortal(ctx context.Context, portal networkid.PortalKey) ([]*UserLogin, error) { logins, err := br.DB.UserLogin.GetAllInPortal(ctx, portal) if err != nil {