add table escape
This commit is contained in:
parent
4bdc6a04d3
commit
b547166c83
17
app/app.go
17
app/app.go
|
@ -19,12 +19,19 @@ import (
|
|||
|
||||
type App struct {
|
||||
Db *sql.DB
|
||||
DbConfig config.DatabaseConfig
|
||||
FakeManager faker.FakeManager
|
||||
}
|
||||
|
||||
func (a *App) Run(db *sql.DB, c config.SchemaConfig, fakeManager faker.FakeManager) error {
|
||||
func (a *App) Run(
|
||||
db *sql.DB,
|
||||
c config.SchemaConfig,
|
||||
fakeManager faker.FakeManager,
|
||||
dbc config.DatabaseConfig,
|
||||
) error {
|
||||
a.Db = db
|
||||
a.FakeManager = fakeManager
|
||||
a.DbConfig = dbc
|
||||
|
||||
for _, data := range c.Rules.Actions {
|
||||
err := a.DoAction(data, c.Rules.Columns, c.Rules.Generators)
|
||||
|
@ -39,7 +46,7 @@ func (a *App) Run(db *sql.DB, c config.SchemaConfig, fakeManager faker.FakeManag
|
|||
|
||||
func (a *App) TruncateTable(c config.SchemaConfigAction) error {
|
||||
if c.Query == "" {
|
||||
_, err := a.Db.Exec(fmt.Sprintf("TRUNCATE %s", c.Table))
|
||||
_, err := a.Db.Exec(fmt.Sprintf("TRUNCATE %s", database.EscapeTable(a.DbConfig.Type, c.Table)))
|
||||
|
||||
return err
|
||||
}
|
||||
|
@ -59,7 +66,7 @@ func (a *App) TruncateTable(c config.SchemaConfigAction) error {
|
|||
|
||||
sql := fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE %s",
|
||||
c.Table,
|
||||
database.EscapeTable(a.DbConfig.Type, c.Table),
|
||||
strings.Join(pkeys, " AND "),
|
||||
)
|
||||
|
||||
|
@ -161,7 +168,7 @@ func (a *App) UpdateRows(c config.SchemaConfigAction, globalColumns map[string]s
|
|||
if len(updates) > 0 {
|
||||
sql := fmt.Sprintf(
|
||||
"UPDATE %s SET %s WHERE %s",
|
||||
c.Table,
|
||||
database.EscapeTable(a.DbConfig.Type, c.Table),
|
||||
strings.Join(updates, ", "),
|
||||
strings.Join(pkeys, " AND "),
|
||||
)
|
||||
|
@ -185,7 +192,7 @@ func (a *App) CreateSelectQuery(c config.SchemaConfigAction) string {
|
|||
return c.Query
|
||||
}
|
||||
|
||||
return fmt.Sprintf("SELECT * FROM %s", c.Table)
|
||||
return fmt.Sprintf("SELECT * FROM %s", database.EscapeTable(a.DbConfig.Type, c.Table))
|
||||
}
|
||||
|
||||
func (a *App) DoAction(c config.SchemaConfigAction, globalColumns map[string]string, generators map[string][]string) error {
|
||||
|
|
|
@ -2,10 +2,19 @@ package database
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"gitnet.fr/deblan/database-anonymizer/data"
|
||||
"gitnet.fr/deblan/database-anonymizer/logger"
|
||||
)
|
||||
|
||||
func EscapeTable(dbType, table string) string {
|
||||
if dbType == "mysql" {
|
||||
return fmt.Sprintf("`%s`", table)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("\"%s\"", table)
|
||||
}
|
||||
|
||||
func GetRows(db *sql.DB, query string) map[int]map[string]data.Data {
|
||||
rows, err := db.Query(query)
|
||||
defer rows.Close()
|
||||
|
|
Loading…
Reference in a new issue