add table escape

This commit is contained in:
Simon Vieille 2024-03-20 11:42:18 +01:00
commit b547166c83
Signed by untrusted user: deblan
GPG key ID: 579388D585F70417
3 changed files with 22 additions and 6 deletions

View file

@ -19,12 +19,19 @@ import (
type App struct { type App struct {
Db *sql.DB Db *sql.DB
DbConfig config.DatabaseConfig
FakeManager faker.FakeManager FakeManager faker.FakeManager
} }
func (a *App) Run(db *sql.DB, c config.SchemaConfig, fakeManager faker.FakeManager) error { func (a *App) Run(
db *sql.DB,
c config.SchemaConfig,
fakeManager faker.FakeManager,
dbc config.DatabaseConfig,
) error {
a.Db = db a.Db = db
a.FakeManager = fakeManager a.FakeManager = fakeManager
a.DbConfig = dbc
for _, data := range c.Rules.Actions { for _, data := range c.Rules.Actions {
err := a.DoAction(data, c.Rules.Columns, c.Rules.Generators) err := a.DoAction(data, c.Rules.Columns, c.Rules.Generators)
@ -39,7 +46,7 @@ func (a *App) Run(db *sql.DB, c config.SchemaConfig, fakeManager faker.FakeManag
func (a *App) TruncateTable(c config.SchemaConfigAction) error { func (a *App) TruncateTable(c config.SchemaConfigAction) error {
if c.Query == "" { if c.Query == "" {
_, err := a.Db.Exec(fmt.Sprintf("TRUNCATE %s", c.Table)) _, err := a.Db.Exec(fmt.Sprintf("TRUNCATE %s", database.EscapeTable(a.DbConfig.Type, c.Table)))
return err return err
} }
@ -59,7 +66,7 @@ func (a *App) TruncateTable(c config.SchemaConfigAction) error {
sql := fmt.Sprintf( sql := fmt.Sprintf(
"DELETE FROM %s WHERE %s", "DELETE FROM %s WHERE %s",
c.Table, database.EscapeTable(a.DbConfig.Type, c.Table),
strings.Join(pkeys, " AND "), strings.Join(pkeys, " AND "),
) )
@ -161,7 +168,7 @@ func (a *App) UpdateRows(c config.SchemaConfigAction, globalColumns map[string]s
if len(updates) > 0 { if len(updates) > 0 {
sql := fmt.Sprintf( sql := fmt.Sprintf(
"UPDATE %s SET %s WHERE %s", "UPDATE %s SET %s WHERE %s",
c.Table, database.EscapeTable(a.DbConfig.Type, c.Table),
strings.Join(updates, ", "), strings.Join(updates, ", "),
strings.Join(pkeys, " AND "), strings.Join(pkeys, " AND "),
) )
@ -185,7 +192,7 @@ func (a *App) CreateSelectQuery(c config.SchemaConfigAction) string {
return c.Query return c.Query
} }
return fmt.Sprintf("SELECT * FROM %s", c.Table) return fmt.Sprintf("SELECT * FROM %s", database.EscapeTable(a.DbConfig.Type, c.Table))
} }
func (a *App) DoAction(c config.SchemaConfigAction, globalColumns map[string]string, generators map[string][]string) error { func (a *App) DoAction(c config.SchemaConfigAction, globalColumns map[string]string, generators map[string][]string) error {

View file

@ -2,10 +2,19 @@ package database
import ( import (
"database/sql" "database/sql"
"fmt"
"gitnet.fr/deblan/database-anonymizer/data" "gitnet.fr/deblan/database-anonymizer/data"
"gitnet.fr/deblan/database-anonymizer/logger" "gitnet.fr/deblan/database-anonymizer/logger"
) )
func EscapeTable(dbType, table string) string {
if dbType == "mysql" {
return fmt.Sprintf("`%s`", table)
}
return fmt.Sprintf("\"%s\"", table)
}
func GetRows(db *sql.DB, query string) map[int]map[string]data.Data { func GetRows(db *sql.DB, query string) map[int]map[string]data.Data {
rows, err := db.Query(query) rows, err := db.Query(query)
defer rows.Close() defer rows.Close()

View file

@ -39,7 +39,7 @@ func main() {
logger.LogFatalExitIf(err) logger.LogFatalExitIf(err)
app := app.App{} app := app.App{}
return app.Run(db, schema, faker.NewFakeManager()) return app.Run(db, schema, faker.NewFakeManager(), databaseConfig)
}, },
} }