mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
bridgev2: improve handling of user logins in bad credentials
This commit is contained in:
parent
c8b03b087e
commit
e25578d435
4 changed files with 39 additions and 29 deletions
|
|
@ -107,18 +107,29 @@ func (br *Bridge) Start() error {
|
|||
go br.DisappearLoop.Start()
|
||||
}
|
||||
|
||||
logins, err := br.GetAllUserLogins(ctx)
|
||||
userIDs, err := br.DB.UserLogin.GetAllUserIDsWithLogins(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user logins: %w", err)
|
||||
return fmt.Errorf("failed to get users with logins: %w", err)
|
||||
}
|
||||
for _, login := range logins {
|
||||
br.Log.Info().Str("id", string(login.ID)).Msg("Starting user login")
|
||||
err = login.Client.Connect(login.Log.WithContext(ctx))
|
||||
startedAny := false
|
||||
for _, userID := range userIDs {
|
||||
br.Log.Info().Stringer("user_id", userID).Msg("Loading user")
|
||||
var user *User
|
||||
user, err = br.GetUserByMXID(ctx, userID)
|
||||
if err != nil {
|
||||
br.Log.Err(err).Msg("Failed to connect existing client")
|
||||
br.Log.Err(err).Stringer("user_id", userID).Msg("Failed to load user")
|
||||
} else {
|
||||
for _, login := range user.GetCachedUserLogins() {
|
||||
startedAny = true
|
||||
br.Log.Info().Str("id", string(login.ID)).Msg("Starting user login")
|
||||
err = login.Client.Connect(login.Log.WithContext(ctx))
|
||||
if err != nil {
|
||||
br.Log.Err(err).Msg("Failed to connect existing client")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(logins) == 0 {
|
||||
if !startedAny {
|
||||
br.Log.Info().Msg("No user logins found")
|
||||
br.SendGlobalBridgeState(status.BridgeState{StateEvent: status.StateUnconfigured})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -54,10 +54,10 @@ const (
|
|||
getUserLoginBaseQuery = `
|
||||
SELECT bridge_id, user_mxid, id, space_room, metadata FROM user_login
|
||||
`
|
||||
getLoginByIDQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND id=$2`
|
||||
getAllLoginsQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1`
|
||||
getAllLoginsForUserQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND user_mxid=$2`
|
||||
getAllLoginsInPortalQuery = `
|
||||
getLoginByIDQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND id=$2`
|
||||
getAllUsersWithLoginsQuery = `SELECT DISTINCT user_mxid FROM user_login WHERE bridge_id=$1`
|
||||
getAllLoginsForUserQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND user_mxid=$2`
|
||||
getAllLoginsInPortalQuery = `
|
||||
SELECT ul.bridge_id, ul.user_mxid, ul.id, ul.space_room, ul.metadata FROM user_portal
|
||||
LEFT JOIN user_login ul ON user_portal.bridge_id=ul.bridge_id AND user_portal.user_mxid=ul.user_mxid AND user_portal.login_id=ul.id
|
||||
WHERE user_portal.bridge_id=$1 AND user_portal.portal_id=$2 AND user_portal.portal_receiver=$3
|
||||
|
|
@ -79,8 +79,9 @@ func (uq *UserLoginQuery) GetByID(ctx context.Context, id networkid.UserLoginID)
|
|||
return uq.QueryOne(ctx, getLoginByIDQuery, uq.BridgeID, id)
|
||||
}
|
||||
|
||||
func (uq *UserLoginQuery) GetAll(ctx context.Context) ([]*UserLogin, error) {
|
||||
return uq.QueryMany(ctx, getAllLoginsQuery, uq.BridgeID)
|
||||
func (uq *UserLoginQuery) GetAllUserIDsWithLogins(ctx context.Context) ([]id.UserID, error) {
|
||||
rows, err := uq.GetDB().Query(ctx, getAllUsersWithLoginsQuery, uq.BridgeID)
|
||||
return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList()
|
||||
}
|
||||
|
||||
func (uq *UserLoginQuery) GetAllInPortal(ctx context.Context, portal networkid.PortalKey) ([]*UserLogin, error) {
|
||||
|
|
|
|||
|
|
@ -158,6 +158,12 @@ func (user *User) GetUserLoginIDs() []networkid.UserLoginID {
|
|||
return maps.Keys(user.logins)
|
||||
}
|
||||
|
||||
func (user *User) GetCachedUserLogins() []*UserLogin {
|
||||
user.Bridge.cacheLock.Lock()
|
||||
defer user.Bridge.cacheLock.Unlock()
|
||||
return maps.Values(user.logins)
|
||||
}
|
||||
|
||||
func (user *User) GetFormattedUserLogins() string {
|
||||
user.Bridge.cacheLock.Lock()
|
||||
logins := make([]string, len(user.logins))
|
||||
|
|
|
|||
|
|
@ -56,7 +56,8 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da
|
|||
}
|
||||
err := br.Network.LoadUserLogin(ctx, userLogin)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to prepare: %w", err)
|
||||
userLogin.Log.Err(err).Msg("Failed to load user login")
|
||||
return nil, nil
|
||||
}
|
||||
user.logins[userLogin.ID] = userLogin
|
||||
br.userLoginsByID[userLogin.ID] = userLogin
|
||||
|
|
@ -65,16 +66,17 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da
|
|||
}
|
||||
|
||||
func (br *Bridge) loadManyUserLogins(ctx context.Context, user *User, logins []*database.UserLogin) ([]*UserLogin, error) {
|
||||
output := make([]*UserLogin, len(logins))
|
||||
for i, dbLogin := range logins {
|
||||
output := make([]*UserLogin, 0, len(logins))
|
||||
for _, dbLogin := range logins {
|
||||
if cached, ok := br.userLoginsByID[dbLogin.ID]; ok {
|
||||
output[i] = cached
|
||||
output = append(output, cached)
|
||||
} else {
|
||||
loaded, err := br.loadUserLogin(ctx, user, dbLogin)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load user login: %w", err)
|
||||
return nil, err
|
||||
} else if loaded != nil {
|
||||
output = append(output, loaded)
|
||||
}
|
||||
output[i] = loaded
|
||||
}
|
||||
}
|
||||
return output, nil
|
||||
|
|
@ -89,16 +91,6 @@ func (br *Bridge) unlockedLoadUserLoginsByMXID(ctx context.Context, user *User)
|
|||
return err
|
||||
}
|
||||
|
||||
func (br *Bridge) GetAllUserLogins(ctx context.Context) ([]*UserLogin, error) {
|
||||
logins, err := br.DB.UserLogin.GetAll(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
br.cacheLock.Lock()
|
||||
defer br.cacheLock.Unlock()
|
||||
return br.loadManyUserLogins(ctx, nil, logins)
|
||||
}
|
||||
|
||||
func (br *Bridge) GetUserLoginsInPortal(ctx context.Context, portal networkid.PortalKey) ([]*UserLogin, error) {
|
||||
logins, err := br.DB.UserLogin.GetAllInPortal(ctx, portal)
|
||||
if err != nil {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue