From 0de68ae32d962650f3b19cb16ce952ac4c18f371 Mon Sep 17 00:00:00 2001 From: Simon Vieille Date: Tue, 19 Mar 2024 13:19:44 +0100 Subject: [PATCH] add truncate operation --- app/app.go | 189 +++++++++++++++++++++++----------------- config/schema_config.go | 9 +- 2 files changed, 116 insertions(+), 82 deletions(-) diff --git a/app/app.go b/app/app.go index f0e350e..150838b 100644 --- a/app/app.go +++ b/app/app.go @@ -19,17 +19,25 @@ type App struct { Db *sql.DB } -func (a *App) ApplyRule(c config.SchemaConfigData, globalColumns map[string]string, generators map[string][]string) error { +func (a *App) DoAction(c config.SchemaConfigAction, globalColumns map[string]string, generators map[string][]string) error { var query string if c.Table == "" { return errors.New("Table must be defined") } - if c.Query != "" { - query = c.Query + if c.Truncate { + if c.Query != "" { + query = c.Query + } else { + return a.TruncateTable(c.Table) + } } else { - query = fmt.Sprintf("SELECT * FROM %s", c.Table) + if c.Query != "" { + query = c.Query + } else { + query = fmt.Sprintf("SELECT * FROM %s", c.Table) + } } if len(c.PrimaryKey) == 0 { @@ -37,40 +45,56 @@ func (a *App) ApplyRule(c config.SchemaConfigData, globalColumns map[string]stri } rows := database.GetRows(a.Db, query) + var scan any - for key, row := range rows { - if len(c.VirtualColumns) > 0 { - for col, faker := range c.VirtualColumns { - rows[key][col] = data.Data{ - Value: "", - Faker: faker, - IsVirtual: true, - } + if c.Truncate { + for _, row := range rows { + pkeys := []string{} + pCounter := 1 + + for _, col := range c.PrimaryKey { + pkeys = append(pkeys, fmt.Sprintf("%s=:p%s", col, strconv.Itoa(pCounter))) + pCounter = pCounter + 1 } - } - if len(c.Columns) > 0 { - for col, faker := range c.Columns { - r := row[col] - r.Faker = faker - rows[key][col] = r + sql := fmt.Sprintf( + "DELETE FROM %s WHERE %s", + c.Table, + strings.Join(pkeys, " AND "), + ) + + stmt := nq.NewNamedParameterQuery(sql) + pCounter = 1 + + for _, col := range c.PrimaryKey { + stmt.SetValue(fmt.Sprintf("p%s", strconv.Itoa(pCounter)), row[col].Value) + pCounter = pCounter + 1 } - } - if len(globalColumns) > 0 { - for col, faker := range globalColumns { - if value, exists := row[col]; exists { - if value.Faker == "" { - value.Faker = faker - rows[key][col] = value + a.Db.QueryRow(stmt.GetParsedQuery(), (stmt.GetParsedParameters())...).Scan(&scan) + } + } else { + for key, row := range rows { + if len(c.VirtualColumns) > 0 { + for col, faker := range c.VirtualColumns { + rows[key][col] = data.Data{ + Value: "", + Faker: faker, + IsVirtual: true, } } } - } - if len(generators) > 0 { - for faker, columns := range generators { - for _, col := range columns { + if len(c.Columns) > 0 { + for col, faker := range c.Columns { + r := row[col] + r.Faker = faker + rows[key][col] = r + } + } + + if len(globalColumns) > 0 { + for col, faker := range globalColumns { if value, exists := row[col]; exists { if value.Faker == "" { value.Faker = faker @@ -79,68 +103,71 @@ func (a *App) ApplyRule(c config.SchemaConfigData, globalColumns map[string]stri } } } - } - for _, col := range c.PrimaryKey { - value := row[col] - value.IsPrimaryKey = true - rows[key][col] = value - } - - rows[key] = a.UpdateRow(rows[key]) - } - - var scan any - - x := 1 - t := len(rows) - - for _, row := range rows { - fmt.Printf("%+v/%+v\n", x, t) - x = x + 1 - - updates := []string{} - pkeys := []string{} - pCounter := 1 - - for col, value := range row { - if value.IsUpdated { - updates = append(updates, fmt.Sprintf("%s=:p%s", col, strconv.Itoa(pCounter))) - pCounter = pCounter + 1 + if len(generators) > 0 { + for faker, columns := range generators { + for _, col := range columns { + if value, exists := row[col]; exists { + if value.Faker == "" { + value.Faker = faker + rows[key][col] = value + } + } + } + } } + + for _, col := range c.PrimaryKey { + value := row[col] + value.IsPrimaryKey = true + rows[key][col] = value + } + + rows[key] = a.UpdateRow(rows[key]) } - for _, col := range c.PrimaryKey { - pkeys = append(pkeys, fmt.Sprintf("%s=:p%s", col, strconv.Itoa(pCounter))) - pCounter = pCounter + 1 - } + for _, row := range rows { + updates := []string{} + pkeys := []string{} + pCounter := 1 - if len(updates) > 0 { - sql := fmt.Sprintf( - "UPDATE %s SET %s WHERE %s", - c.Table, - strings.Join(updates, ", "), - strings.Join(pkeys, " AND "), - ) - - stmt := nq.NewNamedParameterQuery(sql) - pCounter = 1 - - for _, value := range row { + for col, value := range row { if value.IsUpdated { - stmt.SetValue(fmt.Sprintf("p%s", strconv.Itoa(pCounter)), value.Value) + updates = append(updates, fmt.Sprintf("%s=:p%s", col, strconv.Itoa(pCounter))) pCounter = pCounter + 1 } } for _, col := range c.PrimaryKey { - stmt.SetValue(fmt.Sprintf("p%s", strconv.Itoa(pCounter)), row[col].Value) + pkeys = append(pkeys, fmt.Sprintf("%s=:p%s", col, strconv.Itoa(pCounter))) pCounter = pCounter + 1 } - a.Db.QueryRow(stmt.GetParsedQuery(), (stmt.GetParsedParameters())...).Scan(&scan) + if len(updates) > 0 { + sql := fmt.Sprintf( + "UPDATE %s SET %s WHERE %s", + c.Table, + strings.Join(updates, ", "), + strings.Join(pkeys, " AND "), + ) - // fmt.Printf("%+v\n", r) + stmt := nq.NewNamedParameterQuery(sql) + pCounter = 1 + + for _, value := range row { + if value.IsUpdated { + stmt.SetValue(fmt.Sprintf("p%s", strconv.Itoa(pCounter)), value.Value) + pCounter = pCounter + 1 + } + } + + for _, col := range c.PrimaryKey { + stmt.SetValue(fmt.Sprintf("p%s", strconv.Itoa(pCounter)), row[col].Value) + pCounter = pCounter + 1 + } + + a.Db.QueryRow(stmt.GetParsedQuery(), (stmt.GetParsedParameters())...).Scan(&scan) + } } } @@ -179,11 +206,17 @@ func (a *App) UpdateRow(row map[string]data.Data) map[string]data.Data { return row } +func (a *App) TruncateTable(table string) error { + _, err := a.Db.Exec(fmt.Sprintf("TRUNCATE %s", table)) + + return err +} + func (a *App) Run(db *sql.DB, c config.SchemaConfig) error { a.Db = db - for _, data := range c.Rules.Datas { - err := a.ApplyRule(data, c.Rules.Columns, c.Rules.Generators) + for _, data := range c.Rules.Actions { + err := a.DoAction(data, c.Rules.Columns, c.Rules.Generators) logger.LogFatalExitIf(err) } diff --git a/config/schema_config.go b/config/schema_config.go index 39251c2..c2b207f 100644 --- a/config/schema_config.go +++ b/config/schema_config.go @@ -6,19 +6,20 @@ import ( "os" ) -type SchemaConfigData struct { +type SchemaConfigAction struct { Table string `yaml:"table"` Query string `yaml:"query"` VirtualColumns map[string]string `yaml:"virtual_columns"` Generators map[string][]string `yaml:"generators"` Columns map[string]string `yaml:"columns"` PrimaryKey []string `yaml:"primary_key"` + Truncate bool `yaml:"truncate"` } type SchemaConfigRules struct { - Columns map[string]string `yaml:"columns"` - Generators map[string][]string `yaml:"generators"` - Datas []SchemaConfigData `yaml:"datas"` + Columns map[string]string `yaml:"columns"` + Generators map[string][]string `yaml:"generators"` + Actions []SchemaConfigAction `yaml:"actions"` } type SchemaConfig struct {