From b547166c83f7f298b2636f9d1393cc06c4fa289d Mon Sep 17 00:00:00 2001 From: Simon Vieille Date: Wed, 20 Mar 2024 11:42:18 +0100 Subject: [PATCH] add table escape --- app/app.go | 17 ++++++++++++----- database/database.go | 9 +++++++++ main.go | 2 +- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/app/app.go b/app/app.go index 1819a45..145f87a 100644 --- a/app/app.go +++ b/app/app.go @@ -19,12 +19,19 @@ import ( type App struct { Db *sql.DB + DbConfig config.DatabaseConfig FakeManager faker.FakeManager } -func (a *App) Run(db *sql.DB, c config.SchemaConfig, fakeManager faker.FakeManager) error { +func (a *App) Run( + db *sql.DB, + c config.SchemaConfig, + fakeManager faker.FakeManager, + dbc config.DatabaseConfig, +) error { a.Db = db a.FakeManager = fakeManager + a.DbConfig = dbc for _, data := range c.Rules.Actions { err := a.DoAction(data, c.Rules.Columns, c.Rules.Generators) @@ -39,7 +46,7 @@ func (a *App) Run(db *sql.DB, c config.SchemaConfig, fakeManager faker.FakeManag func (a *App) TruncateTable(c config.SchemaConfigAction) error { if c.Query == "" { - _, err := a.Db.Exec(fmt.Sprintf("TRUNCATE %s", c.Table)) + _, err := a.Db.Exec(fmt.Sprintf("TRUNCATE %s", database.EscapeTable(a.DbConfig.Type, c.Table))) return err } @@ -59,7 +66,7 @@ func (a *App) TruncateTable(c config.SchemaConfigAction) error { sql := fmt.Sprintf( "DELETE FROM %s WHERE %s", - c.Table, + database.EscapeTable(a.DbConfig.Type, c.Table), strings.Join(pkeys, " AND "), ) @@ -161,7 +168,7 @@ func (a *App) UpdateRows(c config.SchemaConfigAction, globalColumns map[string]s if len(updates) > 0 { sql := fmt.Sprintf( "UPDATE %s SET %s WHERE %s", - c.Table, + database.EscapeTable(a.DbConfig.Type, c.Table), strings.Join(updates, ", "), strings.Join(pkeys, " AND "), ) @@ -185,7 +192,7 @@ func (a *App) CreateSelectQuery(c config.SchemaConfigAction) string { return c.Query } - return fmt.Sprintf("SELECT * FROM %s", c.Table) + return fmt.Sprintf("SELECT * FROM %s", database.EscapeTable(a.DbConfig.Type, c.Table)) } func (a *App) DoAction(c config.SchemaConfigAction, globalColumns map[string]string, generators map[string][]string) error { diff --git a/database/database.go b/database/database.go index 6aec4b2..bb1c6af 100644 --- a/database/database.go +++ b/database/database.go @@ -2,10 +2,19 @@ package database import ( "database/sql" + "fmt" "gitnet.fr/deblan/database-anonymizer/data" "gitnet.fr/deblan/database-anonymizer/logger" ) +func EscapeTable(dbType, table string) string { + if dbType == "mysql" { + return fmt.Sprintf("`%s`", table) + } + + return fmt.Sprintf("\"%s\"", table) +} + func GetRows(db *sql.DB, query string) map[int]map[string]data.Data { rows, err := db.Query(query) defer rows.Close() diff --git a/main.go b/main.go index ebc71a2..23a512d 100644 --- a/main.go +++ b/main.go @@ -39,7 +39,7 @@ func main() { logger.LogFatalExitIf(err) app := app.App{} - return app.Run(db, schema, faker.NewFakeManager()) + return app.Run(db, schema, faker.NewFakeManager(), databaseConfig) }, }