add postgres column type check
add specific method to handle named parameters
This commit is contained in:
parent
2ba8561574
commit
a7cd0634ef
74
app/app.go
74
app/app.go
|
@ -5,11 +5,8 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
// "os"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
nq "github.com/Knetic/go-namedParameterQuery"
|
|
||||||
"gitnet.fr/deblan/database-anonymizer/config"
|
"gitnet.fr/deblan/database-anonymizer/config"
|
||||||
"gitnet.fr/deblan/database-anonymizer/data"
|
"gitnet.fr/deblan/database-anonymizer/data"
|
||||||
"gitnet.fr/deblan/database-anonymizer/database"
|
"gitnet.fr/deblan/database-anonymizer/database"
|
||||||
|
@ -52,20 +49,20 @@ func (a *App) TruncateTable(c config.SchemaConfigAction) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
query := a.CreateSelectQuery(c)
|
query := a.CreateSelectQuery(c)
|
||||||
rows := database.GetRows(a.Db, query)
|
rows := database.GetRows(a.Db, query, c.Table, a.DbConfig.Type)
|
||||||
var scan any
|
var scan any
|
||||||
|
|
||||||
for _, row := range rows {
|
for _, row := range rows {
|
||||||
pkeys := []string{}
|
pkeys := []string{}
|
||||||
pCounter := 1
|
values := make(map[int]string)
|
||||||
|
|
||||||
for _, col := range c.PrimaryKey {
|
for _, col := range c.PrimaryKey {
|
||||||
if row[col].IsInteger {
|
if !row[col].IsString {
|
||||||
value, _ := strconv.Atoi(row[col].Value)
|
value := row[col]
|
||||||
pkeys = append(pkeys, fmt.Sprintf("%s=%d", col, value))
|
pkeys = append(pkeys, fmt.Sprintf("%s=%s", col, value.FinalValue()))
|
||||||
} else {
|
} else {
|
||||||
pkeys = append(pkeys, fmt.Sprintf("%s=:p%s", col, strconv.Itoa(pCounter)))
|
pkeys = append(pkeys, database.GetNamedParameter(a.DbConfig.Type, col, len(values)+1))
|
||||||
pCounter = pCounter + 1
|
values[len(values)+1] = row[col].Value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -75,17 +72,14 @@ func (a *App) TruncateTable(c config.SchemaConfigAction) error {
|
||||||
strings.Join(pkeys, " AND "),
|
strings.Join(pkeys, " AND "),
|
||||||
)
|
)
|
||||||
|
|
||||||
stmt := nq.NewNamedParameterQuery(sql)
|
var args []any
|
||||||
pCounter = 1
|
if len(values) > 0 {
|
||||||
|
for i := 1; i <= len(values); i++ {
|
||||||
for _, col := range c.PrimaryKey {
|
args = append(args, values[i])
|
||||||
if !row[col].IsInteger {
|
|
||||||
stmt.SetValue(fmt.Sprintf("p%s", strconv.Itoa(pCounter)), row[col].Value)
|
|
||||||
pCounter = pCounter + 1
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
a.Db.QueryRow(stmt.GetParsedQuery(), (stmt.GetParsedParameters())...).Scan(&scan)
|
a.Db.QueryRow(sql, args...).Scan(&scan)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -93,7 +87,7 @@ func (a *App) TruncateTable(c config.SchemaConfigAction) error {
|
||||||
|
|
||||||
func (a *App) UpdateRows(c config.SchemaConfigAction, globalColumns map[string]string, generators map[string][]string) error {
|
func (a *App) UpdateRows(c config.SchemaConfigAction, globalColumns map[string]string, generators map[string][]string) error {
|
||||||
query := a.CreateSelectQuery(c)
|
query := a.CreateSelectQuery(c)
|
||||||
rows := database.GetRows(a.Db, query)
|
rows := database.GetRows(a.Db, query, c.Table, a.DbConfig.Type)
|
||||||
var scan any
|
var scan any
|
||||||
|
|
||||||
for key, row := range rows {
|
for key, row := range rows {
|
||||||
|
@ -155,25 +149,28 @@ func (a *App) UpdateRows(c config.SchemaConfigAction, globalColumns map[string]s
|
||||||
for _, row := range rows {
|
for _, row := range rows {
|
||||||
updates := []string{}
|
updates := []string{}
|
||||||
pkeys := []string{}
|
pkeys := []string{}
|
||||||
values := make(map[int][]string)
|
values := make(map[int]string)
|
||||||
pCounter := 1
|
|
||||||
|
|
||||||
for col, value := range row {
|
for col, value := range row {
|
||||||
if value.IsUpdated && !value.IsVirtual {
|
if value.IsUpdated && !value.IsVirtual {
|
||||||
if value.IsInteger {
|
if value.IsString {
|
||||||
values[pCounter] = []string{value.Value, "int"}
|
updates = append(updates, database.GetNamedParameter(a.DbConfig.Type, col, len(values)+1))
|
||||||
|
values[len(values)+1] = value.FinalValue()
|
||||||
} else {
|
} else {
|
||||||
values[pCounter] = []string{value.Value, "string"}
|
updates = append(updates, fmt.Sprintf("%s=%s", col, value.FinalValue()))
|
||||||
}
|
}
|
||||||
updates = append(updates, fmt.Sprintf("%s=:p%s", col, strconv.Itoa(pCounter)))
|
|
||||||
pCounter = pCounter + 1
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, col := range c.PrimaryKey {
|
for _, col := range c.PrimaryKey {
|
||||||
values[pCounter] = []string{row[col].Value, "string"}
|
value := row[col]
|
||||||
pkeys = append(pkeys, fmt.Sprintf("%s=:p%s", col, strconv.Itoa(pCounter)))
|
|
||||||
pCounter = pCounter + 1
|
if !value.IsString {
|
||||||
|
pkeys = append(pkeys, fmt.Sprintf("%s=%s", col, value.FinalValue()))
|
||||||
|
} else {
|
||||||
|
pkeys = append(pkeys, database.GetNamedParameter(a.DbConfig.Type, col, len(values)+1))
|
||||||
|
values[len(values)+1] = value.FinalValue()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(updates) > 0 {
|
if len(updates) > 0 {
|
||||||
|
@ -184,19 +181,18 @@ func (a *App) UpdateRows(c config.SchemaConfigAction, globalColumns map[string]s
|
||||||
strings.Join(pkeys, " AND "),
|
strings.Join(pkeys, " AND "),
|
||||||
)
|
)
|
||||||
|
|
||||||
stmt := nq.NewNamedParameterQuery(sql)
|
var args []any
|
||||||
pCounter = 1
|
if len(values) > 0 {
|
||||||
|
for i := 1; i <= len(values); i++ {
|
||||||
for i, value := range values {
|
args = append(args, values[i])
|
||||||
if value[1] == "string" {
|
|
||||||
stmt.SetValue(fmt.Sprintf("p%s", strconv.Itoa(i)), value[0])
|
|
||||||
} else {
|
|
||||||
newValue, _ := strconv.Atoi(value[0])
|
|
||||||
stmt.SetValue(fmt.Sprintf("p%s", strconv.Itoa(i)), newValue)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
a.Db.QueryRow(stmt.GetParsedQuery(), (stmt.GetParsedParameters())...).Scan(&scan)
|
err := a.Db.QueryRow(sql, args...).Scan(&scan)
|
||||||
|
|
||||||
|
if err.Error() != "" && err.Error() != "sql: no rows in result set" {
|
||||||
|
logger.LogFatalExitIf(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
25
data/data.go
25
data/data.go
|
@ -16,30 +16,47 @@ type Data struct {
|
||||||
IsVirtual bool
|
IsVirtual bool
|
||||||
IsPrimaryKey bool
|
IsPrimaryKey bool
|
||||||
IsUpdated bool
|
IsUpdated bool
|
||||||
IsInteger bool
|
|
||||||
|
IsInteger bool
|
||||||
|
IsBoolean bool
|
||||||
|
IsString bool
|
||||||
|
IsNull bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Data) FromByte(v []byte) *Data {
|
func (d *Data) FromByte(v []byte) *Data {
|
||||||
d.Value = string(v)
|
d.Value = string(v)
|
||||||
d.IsInteger = false
|
|
||||||
|
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Data) FromInt64(v int64) *Data {
|
func (d *Data) FromInt64(v int64) *Data {
|
||||||
d.Value = strconv.FormatInt(v, 10)
|
d.Value = strconv.FormatInt(v, 10)
|
||||||
d.IsInteger = true
|
|
||||||
|
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Data) FromString(v string) *Data {
|
func (d *Data) FromString(v string) *Data {
|
||||||
d.Value = v
|
d.Value = v
|
||||||
d.IsInteger = false
|
|
||||||
|
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Data) FinalValue() string {
|
||||||
|
if d.IsNull {
|
||||||
|
return "null"
|
||||||
|
}
|
||||||
|
|
||||||
|
if d.IsBoolean {
|
||||||
|
if d.Value == "1" {
|
||||||
|
return "true"
|
||||||
|
} else {
|
||||||
|
return "false"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.Value
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Data) IsTwigExpression() bool {
|
func (d *Data) IsTwigExpression() bool {
|
||||||
return strings.Contains(d.Faker, "{{") || strings.Contains(d.Faker, "}}")
|
return strings.Contains(d.Faker, "{{") || strings.Contains(d.Faker, "}}")
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,15 @@ func EscapeTable(dbType, table string) string {
|
||||||
return fmt.Sprintf("\"%s\"", table)
|
return fmt.Sprintf("\"%s\"", table)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetRows(db *sql.DB, query string) map[int]map[string]data.Data {
|
func GetNamedParameter(dbType, col string, number int) string {
|
||||||
|
if dbType == "mysql" {
|
||||||
|
return fmt.Sprintf("%s=?", col)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s=$%d", col, number)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetRows(db *sql.DB, query, table, dbType string) map[int]map[string]data.Data {
|
||||||
rows, err := db.Query(query)
|
rows, err := db.Query(query)
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
logger.LogFatalExitIf(err)
|
logger.LogFatalExitIf(err)
|
||||||
|
@ -29,6 +37,8 @@ func GetRows(db *sql.DB, query string) map[int]map[string]data.Data {
|
||||||
|
|
||||||
key := 0
|
key := 0
|
||||||
|
|
||||||
|
columnsTypes := make(map[string]string)
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
row := make(map[string]data.Data)
|
row := make(map[string]data.Data)
|
||||||
|
|
||||||
|
@ -40,11 +50,32 @@ func GetRows(db *sql.DB, query string) map[int]map[string]data.Data {
|
||||||
logger.LogFatalExitIf(err)
|
logger.LogFatalExitIf(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var typeValue string
|
||||||
|
|
||||||
for i, col := range columns {
|
for i, col := range columns {
|
||||||
value := values[i]
|
value := values[i]
|
||||||
d := data.Data{IsVirtual: false}
|
d := data.Data{
|
||||||
|
IsVirtual: false,
|
||||||
|
IsNull: value == nil,
|
||||||
|
}
|
||||||
|
|
||||||
if value != nil {
|
if value != nil {
|
||||||
|
if dbType == "postgres" {
|
||||||
|
if len(columnsTypes[col]) == 0 {
|
||||||
|
typeQuery := fmt.Sprintf("SELECT pg_typeof(%s) as value FROM %s", col, EscapeTable(dbType, table))
|
||||||
|
db.QueryRow(typeQuery).Scan(&typeValue)
|
||||||
|
columnsTypes[col] = typeValue
|
||||||
|
}
|
||||||
|
|
||||||
|
dataType := columnsTypes[col]
|
||||||
|
|
||||||
|
d.IsInteger = dataType == "integer"
|
||||||
|
d.IsBoolean = dataType == "boolean"
|
||||||
|
d.IsString = !d.IsBoolean && !d.IsInteger
|
||||||
|
} else {
|
||||||
|
d.IsString = true
|
||||||
|
}
|
||||||
|
|
||||||
switch v := value.(type) {
|
switch v := value.(type) {
|
||||||
case []byte:
|
case []byte:
|
||||||
d.FromByte(v)
|
d.FromByte(v)
|
||||||
|
|
|
@ -13,3 +13,13 @@ func TestEscapeTable(t *testing.T) {
|
||||||
t.Fatalf("TestEscapeTable: postgres check failed")
|
t.Fatalf("TestEscapeTable: postgres check failed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetNamedParameter(t *testing.T) {
|
||||||
|
if GetNamedParameter("mysql", "foo", 1) != "foo=?" {
|
||||||
|
t.Fatalf("TestGetNamedParameter: mysql check failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if GetNamedParameter("postgres", "foo", 1) != "foo=$1" {
|
||||||
|
t.Fatalf("TestGetNamedParameter: postgres check failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue