diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 15731aca..26b4ddbe 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -759,6 +759,17 @@ func (store *SQLCryptoStore) PutDevices(ctx context.Context, userID id.UserID, d }) } +func userIDsToParams(users []id.UserID) (placeholders string, params []any) { + queryString := make([]string, len(users)) + params = make([]any, len(users)) + for i, user := range users { + queryString[i] = fmt.Sprintf("$%d", i+1) + params[i] = user + } + placeholders = strings.Join(queryString, ",") + return +} + // FilterTrackedUsers finds all the user IDs out of the given ones for which the database contains identity information. func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id.UserID) ([]id.UserID, error) { var rows dbutil.Rows @@ -766,13 +777,8 @@ func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id. if store.DB.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil { rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", PostgresArrayWrapper(users)) } else { - queryString := make([]string, len(users)) - params := make([]interface{}, len(users)) - for i, user := range users { - queryString[i] = fmt.Sprintf("?%d", i+1) - params[i] = user - } - rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...) + placeholders, params := userIDsToParams(users) + rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+placeholders+")", params...) } if err != nil { return users, err @@ -781,18 +787,14 @@ func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id. } // MarkTrackedUsersOutdated flags that the device list for given users are outdated. -func (store *SQLCryptoStore) MarkTrackedUsersOutdated(ctx context.Context, users []id.UserID) error { - return store.DB.DoTxn(ctx, nil, func(ctx context.Context) error { - // TODO refactor to use a single query - for _, userID := range users { - _, err := store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id = $1", userID) - if err != nil { - return fmt.Errorf("failed to update user in the tracked users list: %w", err) - } - } - - return nil - }) +func (store *SQLCryptoStore) MarkTrackedUsersOutdated(ctx context.Context, users []id.UserID) (err error) { + if store.DB.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil { + _, err = store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id = ANY($1)", PostgresArrayWrapper(users)) + } else { + placeholders, params := userIDsToParams(users) + _, err = store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id IN ("+placeholders+")", params...) + } + return } // GetOutdatedTrackerUsers gets all tracked users whose devices need to be updated.