bridgev2: improve handling of user logins in bad credentials

This commit is contained in:
Tulir Asokan 2024-06-27 11:32:50 +03:00
commit e25578d435
4 changed files with 39 additions and 29 deletions

View file

@ -107,18 +107,29 @@ func (br *Bridge) Start() error {
go br.DisappearLoop.Start()
}
logins, err := br.GetAllUserLogins(ctx)
userIDs, err := br.DB.UserLogin.GetAllUserIDsWithLogins(ctx)
if err != nil {
return fmt.Errorf("failed to get user logins: %w", err)
return fmt.Errorf("failed to get users with logins: %w", err)
}
for _, login := range logins {
br.Log.Info().Str("id", string(login.ID)).Msg("Starting user login")
err = login.Client.Connect(login.Log.WithContext(ctx))
startedAny := false
for _, userID := range userIDs {
br.Log.Info().Stringer("user_id", userID).Msg("Loading user")
var user *User
user, err = br.GetUserByMXID(ctx, userID)
if err != nil {
br.Log.Err(err).Msg("Failed to connect existing client")
br.Log.Err(err).Stringer("user_id", userID).Msg("Failed to load user")
} else {
for _, login := range user.GetCachedUserLogins() {
startedAny = true
br.Log.Info().Str("id", string(login.ID)).Msg("Starting user login")
err = login.Client.Connect(login.Log.WithContext(ctx))
if err != nil {
br.Log.Err(err).Msg("Failed to connect existing client")
}
}
}
}
if len(logins) == 0 {
if !startedAny {
br.Log.Info().Msg("No user logins found")
br.SendGlobalBridgeState(status.BridgeState{StateEvent: status.StateUnconfigured})
}

View file

@ -54,10 +54,10 @@ const (
getUserLoginBaseQuery = `
SELECT bridge_id, user_mxid, id, space_room, metadata FROM user_login
`
getLoginByIDQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND id=$2`
getAllLoginsQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1`
getAllLoginsForUserQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND user_mxid=$2`
getAllLoginsInPortalQuery = `
getLoginByIDQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND id=$2`
getAllUsersWithLoginsQuery = `SELECT DISTINCT user_mxid FROM user_login WHERE bridge_id=$1`
getAllLoginsForUserQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND user_mxid=$2`
getAllLoginsInPortalQuery = `
SELECT ul.bridge_id, ul.user_mxid, ul.id, ul.space_room, ul.metadata FROM user_portal
LEFT JOIN user_login ul ON user_portal.bridge_id=ul.bridge_id AND user_portal.user_mxid=ul.user_mxid AND user_portal.login_id=ul.id
WHERE user_portal.bridge_id=$1 AND user_portal.portal_id=$2 AND user_portal.portal_receiver=$3
@ -79,8 +79,9 @@ func (uq *UserLoginQuery) GetByID(ctx context.Context, id networkid.UserLoginID)
return uq.QueryOne(ctx, getLoginByIDQuery, uq.BridgeID, id)
}
func (uq *UserLoginQuery) GetAll(ctx context.Context) ([]*UserLogin, error) {
return uq.QueryMany(ctx, getAllLoginsQuery, uq.BridgeID)
func (uq *UserLoginQuery) GetAllUserIDsWithLogins(ctx context.Context) ([]id.UserID, error) {
rows, err := uq.GetDB().Query(ctx, getAllUsersWithLoginsQuery, uq.BridgeID)
return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList()
}
func (uq *UserLoginQuery) GetAllInPortal(ctx context.Context, portal networkid.PortalKey) ([]*UserLogin, error) {

View file

@ -158,6 +158,12 @@ func (user *User) GetUserLoginIDs() []networkid.UserLoginID {
return maps.Keys(user.logins)
}
func (user *User) GetCachedUserLogins() []*UserLogin {
user.Bridge.cacheLock.Lock()
defer user.Bridge.cacheLock.Unlock()
return maps.Values(user.logins)
}
func (user *User) GetFormattedUserLogins() string {
user.Bridge.cacheLock.Lock()
logins := make([]string, len(user.logins))

View file

@ -56,7 +56,8 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da
}
err := br.Network.LoadUserLogin(ctx, userLogin)
if err != nil {
return nil, fmt.Errorf("failed to prepare: %w", err)
userLogin.Log.Err(err).Msg("Failed to load user login")
return nil, nil
}
user.logins[userLogin.ID] = userLogin
br.userLoginsByID[userLogin.ID] = userLogin
@ -65,16 +66,17 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da
}
func (br *Bridge) loadManyUserLogins(ctx context.Context, user *User, logins []*database.UserLogin) ([]*UserLogin, error) {
output := make([]*UserLogin, len(logins))
for i, dbLogin := range logins {
output := make([]*UserLogin, 0, len(logins))
for _, dbLogin := range logins {
if cached, ok := br.userLoginsByID[dbLogin.ID]; ok {
output[i] = cached
output = append(output, cached)
} else {
loaded, err := br.loadUserLogin(ctx, user, dbLogin)
if err != nil {
return nil, fmt.Errorf("failed to load user login: %w", err)
return nil, err
} else if loaded != nil {
output = append(output, loaded)
}
output[i] = loaded
}
}
return output, nil
@ -89,16 +91,6 @@ func (br *Bridge) unlockedLoadUserLoginsByMXID(ctx context.Context, user *User)
return err
}
func (br *Bridge) GetAllUserLogins(ctx context.Context) ([]*UserLogin, error) {
logins, err := br.DB.UserLogin.GetAll(ctx)
if err != nil {
return nil, err
}
br.cacheLock.Lock()
defer br.cacheLock.Unlock()
return br.loadManyUserLogins(ctx, nil, logins)
}
func (br *Bridge) GetUserLoginsInPortal(ctx context.Context, portal networkid.PortalKey) ([]*UserLogin, error) {
logins, err := br.DB.UserLogin.GetAllInPortal(ctx, portal)
if err != nil {