From 6754905f2451f30517a661ae423435be04ded593 Mon Sep 17 00:00:00 2001 From: Simon Vieille Date: Thu, 18 Jul 2024 17:03:32 +0200 Subject: [PATCH 1/3] add column escaper function --- app/app.go | 4 ++-- database/database.go | 10 +++++++--- database/database_test.go | 10 ++++++++++ 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/app/app.go b/app/app.go index f167eeb..acf4e44 100644 --- a/app/app.go +++ b/app/app.go @@ -170,7 +170,7 @@ func (a *App) UpdateRows(c config.SchemaConfigAction, globalColumns map[string]s updates = append(updates, database.GetNamedParameter(a.DbConfig.Type, col, len(values)+1)) values[len(values)+1] = value.FinalValue() } else { - updates = append(updates, fmt.Sprintf("%s=%s", col, value.FinalValue())) + updates = append(updates, fmt.Sprintf("%s=%s", database.EscapeColumn(a.DbConfig.Type, col), value.FinalValue())) } } } @@ -179,7 +179,7 @@ func (a *App) UpdateRows(c config.SchemaConfigAction, globalColumns map[string]s value := row[col] if !value.IsString || value.IsNull { - pkeys = append(pkeys, fmt.Sprintf("%s=%s", col, value.FinalValue())) + pkeys = append(pkeys, fmt.Sprintf("%s=%s", database.EscapeColumn(a.DbConfig.Type, col), value.FinalValue())) } else { pkeys = append(pkeys, database.GetNamedParameter(a.DbConfig.Type, col, len(values)+1)) values[len(values)+1] = value.FinalValue() diff --git a/database/database.go b/database/database.go index 6a1c699..94b052b 100644 --- a/database/database.go +++ b/database/database.go @@ -15,12 +15,16 @@ func EscapeTable(dbType, table string) string { return fmt.Sprintf("\"%s\"", table) } +func EscapeColumn(dbType, col string) string { + return EscapeTable(dbType, col) +} + func GetNamedParameter(dbType, col string, number int) string { if dbType == "mysql" { - return fmt.Sprintf("%s=?", col) + return fmt.Sprintf("%s=?", EscapeColumn(col)) } - return fmt.Sprintf("%s=$%d", col, number) + return fmt.Sprintf("%s=$%d", EscapeColumn(col), number) } func IsPgNumberType(value string) bool { @@ -80,7 +84,7 @@ func GetRows(db *sql.DB, query, table, dbType string) map[int]map[string]data.Da if value != nil { if dbType == "postgres" { if len(columnsTypes[col]) == 0 { - typeQuery := fmt.Sprintf("SELECT pg_typeof(%s) as value FROM %s", col, EscapeTable(dbType, table)) + typeQuery := fmt.Sprintf("SELECT pg_typeof(%s) as value FROM %s", EscapeColumn(dbType, col), EscapeTable(dbType, table)) db.QueryRow(typeQuery).Scan(&typeValue) columnsTypes[col] = typeValue } diff --git a/database/database_test.go b/database/database_test.go index 75cc8e8..343d675 100644 --- a/database/database_test.go +++ b/database/database_test.go @@ -14,6 +14,16 @@ func TestEscapeTable(t *testing.T) { } } +func TestEscapeColumn(t *testing.T) { + if EscapeColumn("mysql", "foo") != "`foo`" { + t.Fatalf("TestEscapeColumn: mysql check failed") + } + + if EscapeTable("postgres", "foo") != "\"foo\"" { + t.Fatalf("TestEscapeColumn: postgres check failed") + } +} + func TestGetNamedParameter(t *testing.T) { if GetNamedParameter("mysql", "foo", 1) != "foo=?" { t.Fatalf("TestGetNamedParameter: mysql check failed") From 5eb3701f9ff05f0d103e5e2b72a21b3d6c2d1311 Mon Sep 17 00:00:00 2001 From: Simon Vieille Date: Thu, 18 Jul 2024 17:05:30 +0200 Subject: [PATCH 2/3] add column escaper function --- database/database.go | 4 ++-- database/database_test.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/database/database.go b/database/database.go index 94b052b..b819e4a 100644 --- a/database/database.go +++ b/database/database.go @@ -21,10 +21,10 @@ func EscapeColumn(dbType, col string) string { func GetNamedParameter(dbType, col string, number int) string { if dbType == "mysql" { - return fmt.Sprintf("%s=?", EscapeColumn(col)) + return fmt.Sprintf("%s=?", EscapeColumn(dbType, col)) } - return fmt.Sprintf("%s=$%d", EscapeColumn(col), number) + return fmt.Sprintf("%s=$%d", EscapeColumn(dbType, col), number) } func IsPgNumberType(value string) bool { diff --git a/database/database_test.go b/database/database_test.go index 343d675..81b9f67 100644 --- a/database/database_test.go +++ b/database/database_test.go @@ -19,7 +19,7 @@ func TestEscapeColumn(t *testing.T) { t.Fatalf("TestEscapeColumn: mysql check failed") } - if EscapeTable("postgres", "foo") != "\"foo\"" { + if EscapeColumn("postgres", "foo") != "\"foo\"" { t.Fatalf("TestEscapeColumn: postgres check failed") } } From 23f432ea0222c1d7275b8bcf7b34e73833c4abcf Mon Sep 17 00:00:00 2001 From: Simon Vieille Date: Thu, 18 Jul 2024 17:07:33 +0200 Subject: [PATCH 3/3] fix tests --- database/database_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/database/database_test.go b/database/database_test.go index 81b9f67..b51c1a4 100644 --- a/database/database_test.go +++ b/database/database_test.go @@ -25,11 +25,11 @@ func TestEscapeColumn(t *testing.T) { } func TestGetNamedParameter(t *testing.T) { - if GetNamedParameter("mysql", "foo", 1) != "foo=?" { + if GetNamedParameter("mysql", "foo", 1) != "`foo`=?" { t.Fatalf("TestGetNamedParameter: mysql check failed") } - if GetNamedParameter("postgres", "foo", 1) != "foo=$1" { + if GetNamedParameter("postgres", "foo", 1) != "\"foo\"=$1" { t.Fatalf("TestGetNamedParameter: postgres check failed") } }