forked from deblan/database-anonymizer
add postgres column type check
add specific method to handle named parameters
This commit is contained in:
parent
2ba8561574
commit
a7cd0634ef
4 changed files with 99 additions and 45 deletions
|
|
@ -15,7 +15,15 @@ func EscapeTable(dbType, table string) string {
|
|||
return fmt.Sprintf("\"%s\"", table)
|
||||
}
|
||||
|
||||
func GetRows(db *sql.DB, query string) map[int]map[string]data.Data {
|
||||
func GetNamedParameter(dbType, col string, number int) string {
|
||||
if dbType == "mysql" {
|
||||
return fmt.Sprintf("%s=?", col)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s=$%d", col, number)
|
||||
}
|
||||
|
||||
func GetRows(db *sql.DB, query, table, dbType string) map[int]map[string]data.Data {
|
||||
rows, err := db.Query(query)
|
||||
defer rows.Close()
|
||||
logger.LogFatalExitIf(err)
|
||||
|
|
@ -29,6 +37,8 @@ func GetRows(db *sql.DB, query string) map[int]map[string]data.Data {
|
|||
|
||||
key := 0
|
||||
|
||||
columnsTypes := make(map[string]string)
|
||||
|
||||
for rows.Next() {
|
||||
row := make(map[string]data.Data)
|
||||
|
||||
|
|
@ -40,11 +50,32 @@ func GetRows(db *sql.DB, query string) map[int]map[string]data.Data {
|
|||
logger.LogFatalExitIf(err)
|
||||
}
|
||||
|
||||
var typeValue string
|
||||
|
||||
for i, col := range columns {
|
||||
value := values[i]
|
||||
d := data.Data{IsVirtual: false}
|
||||
d := data.Data{
|
||||
IsVirtual: false,
|
||||
IsNull: value == nil,
|
||||
}
|
||||
|
||||
if value != nil {
|
||||
if dbType == "postgres" {
|
||||
if len(columnsTypes[col]) == 0 {
|
||||
typeQuery := fmt.Sprintf("SELECT pg_typeof(%s) as value FROM %s", col, EscapeTable(dbType, table))
|
||||
db.QueryRow(typeQuery).Scan(&typeValue)
|
||||
columnsTypes[col] = typeValue
|
||||
}
|
||||
|
||||
dataType := columnsTypes[col]
|
||||
|
||||
d.IsInteger = dataType == "integer"
|
||||
d.IsBoolean = dataType == "boolean"
|
||||
d.IsString = !d.IsBoolean && !d.IsInteger
|
||||
} else {
|
||||
d.IsString = true
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case []byte:
|
||||
d.FromByte(v)
|
||||
|
|
|
|||
|
|
@ -13,3 +13,13 @@ func TestEscapeTable(t *testing.T) {
|
|||
t.Fatalf("TestEscapeTable: postgres check failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetNamedParameter(t *testing.T) {
|
||||
if GetNamedParameter("mysql", "foo", 1) != "foo=?" {
|
||||
t.Fatalf("TestGetNamedParameter: mysql check failed")
|
||||
}
|
||||
|
||||
if GetNamedParameter("postgres", "foo", 1) != "foo=$1" {
|
||||
t.Fatalf("TestGetNamedParameter: postgres check failed")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue