mirror of
https://github.com/drakkan/sftpgo.git
synced 2026-03-14 14:25:52 +01:00
ensure migration lock and transaction use the same connection
Previously, locks were session-bound but executed on a pool, allowing race conditions. Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
parent
c15aba89ae
commit
0dc3d6c804
1 changed files with 29 additions and 9 deletions
|
|
@ -3965,16 +3965,22 @@ func sqlCommonUpdateDatabaseVersion(ctx context.Context, dbHandle sqlQuerier, ve
|
|||
}
|
||||
|
||||
func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sqlQueries []string, newVersion int, isUp bool) error {
|
||||
if err := sqlAcquireLock(dbHandle); err != nil {
|
||||
return err
|
||||
}
|
||||
defer sqlReleaseLock(dbHandle)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := dbHandle.Conn(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get connection from pool: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if err := sqlAcquireLock(conn); err != nil {
|
||||
return err
|
||||
}
|
||||
defer sqlReleaseLock(conn)
|
||||
|
||||
if newVersion > 0 {
|
||||
currentVersion, err := sqlCommonGetDatabaseVersion(dbHandle, false)
|
||||
currentVersion, err := sqlCommonGetDatabaseVersion(conn, false)
|
||||
if err == nil {
|
||||
if (isUp && currentVersion.Version >= newVersion) || (!isUp && currentVersion.Version <= newVersion) {
|
||||
providerLog(logger.LevelInfo, "current schema version: %v, requested: %v, did you execute simultaneous migrations?",
|
||||
|
|
@ -3984,7 +3990,7 @@ func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sqlQueries []string, n
|
|||
}
|
||||
}
|
||||
|
||||
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
|
||||
return sqlCommonExecuteTxOnConn(ctx, conn, func(tx *sql.Tx) error {
|
||||
for _, q := range sqlQueries {
|
||||
if strings.TrimSpace(q) == "" {
|
||||
continue
|
||||
|
|
@ -4001,7 +4007,7 @@ func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sqlQueries []string, n
|
|||
})
|
||||
}
|
||||
|
||||
func sqlAcquireLock(dbHandle *sql.DB) error {
|
||||
func sqlAcquireLock(dbHandle *sql.Conn) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
|
|
@ -4030,7 +4036,7 @@ func sqlAcquireLock(dbHandle *sql.DB) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func sqlReleaseLock(dbHandle *sql.DB) {
|
||||
func sqlReleaseLock(dbHandle *sql.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
|
|
@ -4052,6 +4058,20 @@ func sqlReleaseLock(dbHandle *sql.DB) {
|
|||
}
|
||||
}
|
||||
|
||||
func sqlCommonExecuteTxOnConn(ctx context.Context, conn *sql.Conn, txFn func(*sql.Tx) error) error {
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = txFn(tx)
|
||||
if err != nil {
|
||||
tx.Rollback() //nolint:errcheck
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func sqlCommonExecuteTx(ctx context.Context, dbHandle *sql.DB, txFn func(*sql.Tx) error) error {
|
||||
if config.Driver == CockroachDataProviderName {
|
||||
return crdb.ExecuteTx(ctx, dbHandle, nil, txFn)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue