Compare commits

...
Sign in to create a new pull request.

3 commits

Author SHA1 Message Date
23f432ea02
fix tests
All checks were successful
ci/woodpecker/push/test Pipeline was successful
ci/woodpecker/push/build Pipeline was successful
2024-07-18 17:07:33 +02:00
5eb3701f9f
add column escaper function
Some checks failed
ci/woodpecker/push/test Pipeline failed
ci/woodpecker/push/build unknown status
2024-07-18 17:05:30 +02:00
6754905f24
add column escaper function
Some checks failed
ci/woodpecker/push/test Pipeline failed
ci/woodpecker/push/build unknown status
2024-07-18 17:03:32 +02:00
3 changed files with 21 additions and 7 deletions

View file

@ -170,7 +170,7 @@ func (a *App) UpdateRows(c config.SchemaConfigAction, globalColumns map[string]s
updates = append(updates, database.GetNamedParameter(a.DbConfig.Type, col, len(values)+1))
values[len(values)+1] = value.FinalValue()
} else {
updates = append(updates, fmt.Sprintf("%s=%s", col, value.FinalValue()))
updates = append(updates, fmt.Sprintf("%s=%s", database.EscapeColumn(a.DbConfig.Type, col), value.FinalValue()))
}
}
}
@ -179,7 +179,7 @@ func (a *App) UpdateRows(c config.SchemaConfigAction, globalColumns map[string]s
value := row[col]
if !value.IsString || value.IsNull {
pkeys = append(pkeys, fmt.Sprintf("%s=%s", col, value.FinalValue()))
pkeys = append(pkeys, fmt.Sprintf("%s=%s", database.EscapeColumn(a.DbConfig.Type, col), value.FinalValue()))
} else {
pkeys = append(pkeys, database.GetNamedParameter(a.DbConfig.Type, col, len(values)+1))
values[len(values)+1] = value.FinalValue()

View file

@ -15,12 +15,16 @@ func EscapeTable(dbType, table string) string {
return fmt.Sprintf("\"%s\"", table)
}
func GetNamedParameter(dbType, col string, number int) string {
if dbType == "mysql" {
return fmt.Sprintf("%s=?", col)
func EscapeColumn(dbType, col string) string {
return EscapeTable(dbType, col)
}
return fmt.Sprintf("%s=$%d", col, number)
func GetNamedParameter(dbType, col string, number int) string {
if dbType == "mysql" {
return fmt.Sprintf("%s=?", EscapeColumn(dbType, col))
}
return fmt.Sprintf("%s=$%d", EscapeColumn(dbType, col), number)
}
func IsPgNumberType(value string) bool {
@ -80,7 +84,7 @@ func GetRows(db *sql.DB, query, table, dbType string) map[int]map[string]data.Da
if value != nil {
if dbType == "postgres" {
if len(columnsTypes[col]) == 0 {
typeQuery := fmt.Sprintf("SELECT pg_typeof(%s) as value FROM %s", col, EscapeTable(dbType, table))
typeQuery := fmt.Sprintf("SELECT pg_typeof(%s) as value FROM %s", EscapeColumn(dbType, col), EscapeTable(dbType, table))
db.QueryRow(typeQuery).Scan(&typeValue)
columnsTypes[col] = typeValue
}

View file

@ -14,12 +14,22 @@ func TestEscapeTable(t *testing.T) {
}
}
func TestEscapeColumn(t *testing.T) {
if EscapeColumn("mysql", "foo") != "`foo`" {
t.Fatalf("TestEscapeColumn: mysql check failed")
}
if EscapeColumn("postgres", "foo") != "\"foo\"" {
t.Fatalf("TestEscapeColumn: postgres check failed")
}
}
func TestGetNamedParameter(t *testing.T) {
if GetNamedParameter("mysql", "foo", 1) != "foo=?" {
if GetNamedParameter("mysql", "foo", 1) != "`foo`=?" {
t.Fatalf("TestGetNamedParameter: mysql check failed")
}
if GetNamedParameter("postgres", "foo", 1) != "foo=$1" {
if GetNamedParameter("postgres", "foo", 1) != "\"foo\"=$1" {
t.Fatalf("TestGetNamedParameter: postgres check failed")
}
}