Refactor MarkTrackedUsersOutdated to use single query

This commit is contained in:
Tulir Asokan 2024-05-26 17:30:10 +03:00
commit 5afa391317

View file

@ -759,6 +759,17 @@ func (store *SQLCryptoStore) PutDevices(ctx context.Context, userID id.UserID, d
})
}
func userIDsToParams(users []id.UserID) (placeholders string, params []any) {
queryString := make([]string, len(users))
params = make([]any, len(users))
for i, user := range users {
queryString[i] = fmt.Sprintf("$%d", i+1)
params[i] = user
}
placeholders = strings.Join(queryString, ",")
return
}
// FilterTrackedUsers finds all the user IDs out of the given ones for which the database contains identity information.
func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id.UserID) ([]id.UserID, error) {
var rows dbutil.Rows
@ -766,13 +777,8 @@ func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id.
if store.DB.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil {
rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", PostgresArrayWrapper(users))
} else {
queryString := make([]string, len(users))
params := make([]interface{}, len(users))
for i, user := range users {
queryString[i] = fmt.Sprintf("?%d", i+1)
params[i] = user
}
rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...)
placeholders, params := userIDsToParams(users)
rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+placeholders+")", params...)
}
if err != nil {
return users, err
@ -781,18 +787,14 @@ func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id.
}
// MarkTrackedUsersOutdated flags that the device list for given users are outdated.
func (store *SQLCryptoStore) MarkTrackedUsersOutdated(ctx context.Context, users []id.UserID) error {
return store.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
// TODO refactor to use a single query
for _, userID := range users {
_, err := store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id = $1", userID)
if err != nil {
return fmt.Errorf("failed to update user in the tracked users list: %w", err)
}
}
return nil
})
func (store *SQLCryptoStore) MarkTrackedUsersOutdated(ctx context.Context, users []id.UserID) (err error) {
if store.DB.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil {
_, err = store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id = ANY($1)", PostgresArrayWrapper(users))
} else {
placeholders, params := userIDsToParams(users)
_, err = store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id IN ("+placeholders+")", params...)
}
return
}
// GetOutdatedTrackerUsers gets all tracked users whose devices need to be updated.