diff --git a/app/app.go b/app/app.go index 18ccbf0..0b1437e 100644 --- a/app/app.go +++ b/app/app.go @@ -5,11 +5,8 @@ import ( "errors" "fmt" - // "os" - "strconv" "strings" - nq "github.com/Knetic/go-namedParameterQuery" "gitnet.fr/deblan/database-anonymizer/config" "gitnet.fr/deblan/database-anonymizer/data" "gitnet.fr/deblan/database-anonymizer/database" @@ -52,20 +49,20 @@ func (a *App) TruncateTable(c config.SchemaConfigAction) error { } query := a.CreateSelectQuery(c) - rows := database.GetRows(a.Db, query) + rows := database.GetRows(a.Db, query, c.Table, a.DbConfig.Type) var scan any for _, row := range rows { pkeys := []string{} - pCounter := 1 + values := make(map[int]string) for _, col := range c.PrimaryKey { - if row[col].IsInteger { - value, _ := strconv.Atoi(row[col].Value) - pkeys = append(pkeys, fmt.Sprintf("%s=%d", col, value)) + if !row[col].IsString { + value := row[col] + pkeys = append(pkeys, fmt.Sprintf("%s=%s", col, value.FinalValue())) } else { - pkeys = append(pkeys, fmt.Sprintf("%s=:p%s", col, strconv.Itoa(pCounter))) - pCounter = pCounter + 1 + pkeys = append(pkeys, database.GetNamedParameter(a.DbConfig.Type, col, len(values)+1)) + values[len(values)+1] = row[col].Value } } @@ -75,17 +72,14 @@ func (a *App) TruncateTable(c config.SchemaConfigAction) error { strings.Join(pkeys, " AND "), ) - stmt := nq.NewNamedParameterQuery(sql) - pCounter = 1 - - for _, col := range c.PrimaryKey { - if !row[col].IsInteger { - stmt.SetValue(fmt.Sprintf("p%s", strconv.Itoa(pCounter)), row[col].Value) - pCounter = pCounter + 1 + var args []any + if len(values) > 0 { + for i := 1; i <= len(values); i++ { + args = append(args, values[i]) } } - a.Db.QueryRow(stmt.GetParsedQuery(), (stmt.GetParsedParameters())...).Scan(&scan) + a.Db.QueryRow(sql, args...).Scan(&scan) } return nil @@ -93,7 +87,7 @@ func (a *App) TruncateTable(c config.SchemaConfigAction) error { func (a *App) UpdateRows(c config.SchemaConfigAction, globalColumns map[string]string, generators map[string][]string) error { query := a.CreateSelectQuery(c) - rows := database.GetRows(a.Db, query) + rows := database.GetRows(a.Db, query, c.Table, a.DbConfig.Type) var scan any for key, row := range rows { @@ -155,25 +149,28 @@ func (a *App) UpdateRows(c config.SchemaConfigAction, globalColumns map[string]s for _, row := range rows { updates := []string{} pkeys := []string{} - values := make(map[int][]string) - pCounter := 1 + values := make(map[int]string) for col, value := range row { if value.IsUpdated && !value.IsVirtual { - if value.IsInteger { - values[pCounter] = []string{value.Value, "int"} + if value.IsString { + updates = append(updates, database.GetNamedParameter(a.DbConfig.Type, col, len(values)+1)) + values[len(values)+1] = value.FinalValue() } else { - values[pCounter] = []string{value.Value, "string"} + updates = append(updates, fmt.Sprintf("%s=%s", col, value.FinalValue())) } - updates = append(updates, fmt.Sprintf("%s=:p%s", col, strconv.Itoa(pCounter))) - pCounter = pCounter + 1 } } for _, col := range c.PrimaryKey { - values[pCounter] = []string{row[col].Value, "string"} - pkeys = append(pkeys, fmt.Sprintf("%s=:p%s", col, strconv.Itoa(pCounter))) - pCounter = pCounter + 1 + value := row[col] + + if !value.IsString { + pkeys = append(pkeys, fmt.Sprintf("%s=%s", col, value.FinalValue())) + } else { + pkeys = append(pkeys, database.GetNamedParameter(a.DbConfig.Type, col, len(values)+1)) + values[len(values)+1] = value.FinalValue() + } } if len(updates) > 0 { @@ -184,19 +181,18 @@ func (a *App) UpdateRows(c config.SchemaConfigAction, globalColumns map[string]s strings.Join(pkeys, " AND "), ) - stmt := nq.NewNamedParameterQuery(sql) - pCounter = 1 - - for i, value := range values { - if value[1] == "string" { - stmt.SetValue(fmt.Sprintf("p%s", strconv.Itoa(i)), value[0]) - } else { - newValue, _ := strconv.Atoi(value[0]) - stmt.SetValue(fmt.Sprintf("p%s", strconv.Itoa(i)), newValue) + var args []any + if len(values) > 0 { + for i := 1; i <= len(values); i++ { + args = append(args, values[i]) } } - a.Db.QueryRow(stmt.GetParsedQuery(), (stmt.GetParsedParameters())...).Scan(&scan) + err := a.Db.QueryRow(sql, args...).Scan(&scan) + + if err.Error() != "" && err.Error() != "sql: no rows in result set" { + logger.LogFatalExitIf(err) + } } } diff --git a/data/data.go b/data/data.go index 9f24cdb..d48891b 100644 --- a/data/data.go +++ b/data/data.go @@ -16,30 +16,47 @@ type Data struct { IsVirtual bool IsPrimaryKey bool IsUpdated bool - IsInteger bool + + IsInteger bool + IsBoolean bool + IsString bool + IsNull bool } func (d *Data) FromByte(v []byte) *Data { d.Value = string(v) - d.IsInteger = false return d } func (d *Data) FromInt64(v int64) *Data { d.Value = strconv.FormatInt(v, 10) - d.IsInteger = true return d } func (d *Data) FromString(v string) *Data { d.Value = v - d.IsInteger = false return d } +func (d *Data) FinalValue() string { + if d.IsNull { + return "null" + } + + if d.IsBoolean { + if d.Value == "1" { + return "true" + } else { + return "false" + } + } + + return d.Value +} + func (d *Data) IsTwigExpression() bool { return strings.Contains(d.Faker, "{{") || strings.Contains(d.Faker, "}}") } diff --git a/database/database.go b/database/database.go index bb1c6af..d97eb1c 100644 --- a/database/database.go +++ b/database/database.go @@ -15,7 +15,15 @@ func EscapeTable(dbType, table string) string { return fmt.Sprintf("\"%s\"", table) } -func GetRows(db *sql.DB, query string) map[int]map[string]data.Data { +func GetNamedParameter(dbType, col string, number int) string { + if dbType == "mysql" { + return fmt.Sprintf("%s=?", col) + } + + return fmt.Sprintf("%s=$%d", col, number) +} + +func GetRows(db *sql.DB, query, table, dbType string) map[int]map[string]data.Data { rows, err := db.Query(query) defer rows.Close() logger.LogFatalExitIf(err) @@ -29,6 +37,8 @@ func GetRows(db *sql.DB, query string) map[int]map[string]data.Data { key := 0 + columnsTypes := make(map[string]string) + for rows.Next() { row := make(map[string]data.Data) @@ -40,11 +50,32 @@ func GetRows(db *sql.DB, query string) map[int]map[string]data.Data { logger.LogFatalExitIf(err) } + var typeValue string + for i, col := range columns { value := values[i] - d := data.Data{IsVirtual: false} + d := data.Data{ + IsVirtual: false, + IsNull: value == nil, + } if value != nil { + if dbType == "postgres" { + if len(columnsTypes[col]) == 0 { + typeQuery := fmt.Sprintf("SELECT pg_typeof(%s) as value FROM %s", col, EscapeTable(dbType, table)) + db.QueryRow(typeQuery).Scan(&typeValue) + columnsTypes[col] = typeValue + } + + dataType := columnsTypes[col] + + d.IsInteger = dataType == "integer" + d.IsBoolean = dataType == "boolean" + d.IsString = !d.IsBoolean && !d.IsInteger + } else { + d.IsString = true + } + switch v := value.(type) { case []byte: d.FromByte(v) diff --git a/database/database_test.go b/database/database_test.go index c5666b8..75cc8e8 100644 --- a/database/database_test.go +++ b/database/database_test.go @@ -13,3 +13,13 @@ func TestEscapeTable(t *testing.T) { t.Fatalf("TestEscapeTable: postgres check failed") } } + +func TestGetNamedParameter(t *testing.T) { + if GetNamedParameter("mysql", "foo", 1) != "foo=?" { + t.Fatalf("TestGetNamedParameter: mysql check failed") + } + + if GetNamedParameter("postgres", "foo", 1) != "foo=$1" { + t.Fatalf("TestGetNamedParameter: postgres check failed") + } +}