From ec9bb4be2a6e704e822566a27fbe6b861847a32f Mon Sep 17 00:00:00 2001 From: kemokemo Date: Sun, 3 Nov 2019 22:52:54 +0900 Subject: [PATCH 1/5] fix: Fixed the condition to use the default value --- table_bindings.go | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/table_bindings.go b/table_bindings.go index 43c0b3dd..71f08b26 100644 --- a/table_bindings.go +++ b/table_bindings.go @@ -127,7 +127,14 @@ func (t *TableMap) bindInsert(elem reflect.Value) (bindInstance, error) { plan.autoIncrIdx = y plan.autoIncrFieldName = col.fieldName } else { - if col.DefaultValue == "" { + val := elem.FieldByName(col.fieldName).Interface() + var isZeroValue bool + if val != nil { + isZeroValue = reflect.DeepEqual(val, reflect.Zero(reflect.TypeOf(val)).Interface()) + } + if (val == nil || isZeroValue) && col.DefaultValue != "" { + s2.WriteString(col.DefaultValue) + } else { s2.WriteString(t.dbmap.Dialect.BindVar(x)) if col == t.version { plan.versField = col.fieldName @@ -136,8 +143,6 @@ func (t *TableMap) bindInsert(elem reflect.Value) (bindInstance, error) { plan.argFields = append(plan.argFields, col.fieldName) } x++ - } else { - s2.WriteString(col.DefaultValue) } } first = false From a0478f915ecd41d54f4db21bb534e0207d692604 Mon Sep 17 00:00:00 2001 From: kemokemo Date: Sat, 14 Dec 2019 17:06:14 +0900 Subject: [PATCH 2/5] test: add test for default tag. --- db_test.go | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/db_test.go b/db_test.go index 57d8f43a..176f46e0 100644 --- a/db_test.go +++ b/db_test.go @@ -7,6 +7,7 @@ package gorp_test import ( + "reflect" "testing" ) @@ -181,3 +182,52 @@ AND field12 IN (:FieldIntList) }) } } + +type comment struct { + ID int64 `db:"id,primarykey,autoincrement"` + Name string `db:"name,notnull,default:'NoName',size:200"` + Text string `db:"text,notnull,size:400"` + Number int `db:"number,notnull,default:'774'"` + Private bool `db:"private,notnull"` +} + +func TestDbMap_DefaultTag(t *testing.T) { + tests := []struct { + name string + comment *comment + wantComment comment + }{ + {"Use default", + &comment{Text: "Hey!", Private: false}, + comment{ID: 1, Name: "NoName", Text: "Hey!", Number: 774, Private: false}}, + {"Specify all property", + &comment{Name: "bob", Text: "Hello!", Number: 5, Private: true}, + comment{ID: 1, Name: "bob", Text: "Hello!", Number: 5, Private: true}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dbmap := newDbMap() + dbmap.AddTableWithName(comment{}, "comments").SetKeys(true, "id") + + err := dbmap.CreateTables() + if err != nil { + t.Errorf("failed to create tables:%v", err) + } + defer dropAndClose(dbmap) + + err = dbmap.Insert(tt.comment) + if err != nil { + t.Errorf("failed to insert:%v", err) + } + var gotComment comment + err = dbmap.SelectOne(&gotComment, "SELECT * FROM comments ORDER BY id desc LIMIT 1") + if err != nil { + t.Errorf("failed to select:%v", err) + } + if !reflect.DeepEqual(gotComment, tt.wantComment) { + t.Errorf("gotComment = %+v, want %+v", gotComment, tt.wantComment) + } + }) + } +} From f46a5edb584a085966fc82b11ce0197cbd4503be Mon Sep 17 00:00:00 2001 From: kemokemo Date: Mon, 13 Jan 2020 16:44:44 +0900 Subject: [PATCH 3/5] fix: fix insert query using CASE WHEN phrase. --- table_bindings.go | 40 ++++++++++++++++++++++++++-------------- table_bindings_test.go | 31 +++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 14 deletions(-) create mode 100644 table_bindings_test.go diff --git a/table_bindings.go b/table_bindings.go index 71f08b26..a41c8ebb 100644 --- a/table_bindings.go +++ b/table_bindings.go @@ -127,23 +127,21 @@ func (t *TableMap) bindInsert(elem reflect.Value) (bindInstance, error) { plan.autoIncrIdx = y plan.autoIncrFieldName = col.fieldName } else { - val := elem.FieldByName(col.fieldName).Interface() - var isZeroValue bool - if val != nil { - isZeroValue = reflect.DeepEqual(val, reflect.Zero(reflect.TypeOf(val)).Interface()) + if col.DefaultValue == "" { + s2.WriteString(t.dbmap.Dialect.BindVar(x)) + } else { + val := elem.FieldByName(col.fieldName).Interface() + s2.WriteString( + fmt.Sprintf("case when %t or %s = %s then %s else %s end", + val == nil, t.dbmap.Dialect.BindVar(x), getZeroValueStringForSQL(val), col.DefaultValue, t.dbmap.Dialect.BindVar(x))) } - if (val == nil || isZeroValue) && col.DefaultValue != "" { - s2.WriteString(col.DefaultValue) + if col == t.version { + plan.versField = col.fieldName + plan.argFields = append(plan.argFields, versFieldConst) } else { - s2.WriteString(t.dbmap.Dialect.BindVar(x)) - if col == t.version { - plan.versField = col.fieldName - plan.argFields = append(plan.argFields, versFieldConst) - } else { - plan.argFields = append(plan.argFields, col.fieldName) - } - x++ + plan.argFields = append(plan.argFields, col.fieldName) } + x++ } first = false } @@ -166,6 +164,20 @@ func (t *TableMap) bindInsert(elem reflect.Value) (bindInstance, error) { return plan.createBindInstance(elem, t.dbmap.TypeConverter) } +func getZeroValueStringForSQL(i interface{}) (s string) { + switch i.(type) { + case bool: + s = "false" + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + s = "0" + case float32, float64: + s = "0.0" + default: + s = "''" + } + return +} + func (t *TableMap) bindUpdate(elem reflect.Value, colFilter ColumnFilter) (bindInstance, error) { if colFilter == nil { colFilter = acceptAllFilter diff --git a/table_bindings_test.go b/table_bindings_test.go new file mode 100644 index 00000000..625d0fb6 --- /dev/null +++ b/table_bindings_test.go @@ -0,0 +1,31 @@ +// Copyright 2012 James Cooper. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package gorp + +import "testing" + +func Test_getZeroValueStringForSQL(t *testing.T) { + type args struct { + i interface{} + } + tests := []struct { + name string + args args + wantS string + }{ + {"bool", args{i: true}, "false"}, + {"int", args{i: -5}, "0"}, + {"uint", args{i: 100}, "0"}, + {"float", args{i: 12.3}, "0.0"}, + {"string", args{i: "gorp"}, "''"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotS := getZeroValueStringForSQL(tt.args.i); gotS != tt.wantS { + t.Errorf("getZerovalueStringForSQL() = %v, want %v", gotS, tt.wantS) + } + }) + } +} From b46f7b5a4e7cb3fd107ca4c01953921a1b151efe Mon Sep 17 00:00:00 2001 From: kemokemo Date: Mon, 13 Jan 2020 17:35:26 +0900 Subject: [PATCH 4/5] test: add multiple insert test --- db_test.go | 42 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/db_test.go b/db_test.go index 176f46e0..e2dd05dc 100644 --- a/db_test.go +++ b/db_test.go @@ -191,7 +191,7 @@ type comment struct { Private bool `db:"private,notnull"` } -func TestDbMap_DefaultTag(t *testing.T) { +func TestDbMap_DefaultTag_oneByOne(t *testing.T) { tests := []struct { name string comment *comment @@ -231,3 +231,43 @@ func TestDbMap_DefaultTag(t *testing.T) { }) } } + +func TestDbMap_DefaultTag_allAtOnce(t *testing.T) { + comments := []comment{ + comment{Text: "Hey!", Private: false}, + comment{Name: "bob", Text: "Hello!", Number: 5, Private: true}, + } + wantComments := []comment{ + comment{ID: 1, Name: "NoName", Text: "Hey!", Number: 774, Private: false}, + comment{ID: 2, Name: "bob", Text: "Hello!", Number: 5, Private: true}, + } + + t.Run("Insert at once", func(t *testing.T) { + dbmap := newDbMap() + dbmap.AddTableWithName(comment{}, "comments").SetKeys(true, "id") + + err := dbmap.CreateTables() + if err != nil { + t.Errorf("failed to create tables:%v", err) + } + defer dropAndClose(dbmap) + + for i := range comments { + err = dbmap.Insert(&comments[i]) + if err != nil { + t.Errorf("failed to insert:%v", err) + } + } + + var gotComments []comment + _, err = dbmap.Select(&gotComments, "SELECT * FROM comments ORDER BY id") + if err != nil { + t.Errorf("failed to select:%v", err) + } + for i := range gotComments { + if !reflect.DeepEqual(gotComments[i], wantComments[i]) { + t.Errorf("gotComment = %+v, want %+v", gotComments[i], wantComments[i]) + } + } + }) +} From 0994bbad7b96851ea473e60176f4d1e05a8b9940 Mon Sep 17 00:00:00 2001 From: kemokemo Date: Wed, 15 Jan 2020 23:03:24 +0900 Subject: [PATCH 5/5] fix: fix query and add BindVarWithType interface. --- db_test.go | 2 +- dialect.go | 8 +++++ dialect_mysql.go | 5 +++ dialect_mysql_test.go | 6 +++- dialect_oracle.go | 5 +++ dialect_postgres.go | 18 ++++++++++ dialect_postgres_test.go | 9 ++++- dialect_sqlite.go | 5 +++ dialect_sqlserver.go | 5 +++ table_bindings.go | 71 +++++++++++++++++++++++++++++++--------- table_bindings_test.go | 54 +++++++++++++++++++++++++----- 11 files changed, 162 insertions(+), 26 deletions(-) diff --git a/db_test.go b/db_test.go index e2dd05dc..5419f92e 100644 --- a/db_test.go +++ b/db_test.go @@ -187,7 +187,7 @@ type comment struct { ID int64 `db:"id,primarykey,autoincrement"` Name string `db:"name,notnull,default:'NoName',size:200"` Text string `db:"text,notnull,size:400"` - Number int `db:"number,notnull,default:'774'"` + Number int `db:"number,notnull,default:774"` Private bool `db:"private,notnull"` } diff --git a/dialect.go b/dialect.go index fdea2b20..f9caeb88 100644 --- a/dialect.go +++ b/dialect.go @@ -50,6 +50,14 @@ type Dialect interface { // BindVar(i int) string + // bind variable string to use when forming SQL statements + // in many dbs it is "?", but Postgres appears to use $1::int when using int value. + // + // i is a zero based index of the bind variable in this statement + // t is the type of the variable + // + BindVarWithType(i int, t reflect.Type) string + // Handles quoting of a field name to ensure that it doesn't raise any // SQL parsing exceptions by using a reserved word as a field name. QuoteField(field string) string diff --git a/dialect_mysql.go b/dialect_mysql.go index d068ebe8..19394fb2 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -140,6 +140,11 @@ func (d MySQLDialect) BindVar(i int) string { return "?" } +// BindVarWithType of MySQL returns "?" +func (d MySQLDialect) BindVarWithType(i int, t reflect.Type) string { + return "?" +} + func (d MySQLDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { return standardInsertAutoIncr(exec, insertSql, params...) } diff --git a/dialect_mysql_test.go b/dialect_mysql_test.go index 20ddeac1..fe66cf56 100644 --- a/dialect_mysql_test.go +++ b/dialect_mysql_test.go @@ -14,10 +14,10 @@ import ( "testing" "time" + "github.com/go-gorp/gorp" "github.com/poy/onpar" "github.com/poy/onpar/expect" "github.com/poy/onpar/matchers" - "github.com/go-gorp/gorp" ) func TestMySQLDialect(t *testing.T) { @@ -129,6 +129,10 @@ func TestMySQLDialect(t *testing.T) { expect(dialect.BindVar(0)).To(matchers.Equal("?")) }) + o.Spec("BindVarWithType", func(expect expect.Expectation, dialect gorp.MySQLDialect) { + expect(dialect.BindVarWithType(0, reflect.TypeOf(0))).To(matchers.Equal("?")) + }) + o.Spec("QuoteField", func(expect expect.Expectation, dialect gorp.MySQLDialect) { expect(dialect.QuoteField("foo")).To(matchers.Equal("`foo`")) }) diff --git a/dialect_oracle.go b/dialect_oracle.go index 01a99b8a..7821bf2d 100644 --- a/dialect_oracle.go +++ b/dialect_oracle.go @@ -91,6 +91,11 @@ func (d OracleDialect) BindVar(i int) string { return fmt.Sprintf(":%d", i+1) } +// BindVarWithType of Oracle returns "$(i+1)" +func (d OracleDialect) BindVarWithType(i int, t reflect.Type) string { + return fmt.Sprintf(":%d", i+1) +} + // After executing the insert uses the ColMap IdQuery to get the generated id func (d OracleDialect) InsertQueryToTarget(exec SqlExecutor, insertSql, idSql string, target interface{}, params ...interface{}) error { _, err := exec.Exec(insertSql, params...) diff --git a/dialect_postgres.go b/dialect_postgres.go index 7a9c50bf..48678919 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -102,6 +102,24 @@ func (d PostgresDialect) BindVar(i int) string { return fmt.Sprintf("$%d", i+1) } +// BindVarWithType of PostgreSQL returns "$(i+1::t)" +func (d PostgresDialect) BindVarWithType(i int, t reflect.Type) string { + var s string + switch t.Kind() { + case reflect.Bool: + s = "::bool" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + s = "::int" + case reflect.Float32, reflect.Float64: + s = "::float" + case reflect.String: + s = "::text" + default: + s = "" + } + return fmt.Sprintf("$%d%s", i+1, s) +} + func (d PostgresDialect) InsertAutoIncrToTarget(exec SqlExecutor, insertSql string, target interface{}, params ...interface{}) error { rows, err := exec.Query(insertSql, params...) if err != nil { diff --git a/dialect_postgres_test.go b/dialect_postgres_test.go index 7b116f27..26800f21 100644 --- a/dialect_postgres_test.go +++ b/dialect_postgres_test.go @@ -12,10 +12,10 @@ import ( "testing" "time" + "github.com/go-gorp/gorp" "github.com/poy/onpar" "github.com/poy/onpar/expect" "github.com/poy/onpar/matchers" - "github.com/go-gorp/gorp" ) func TestPostgresDialect(t *testing.T) { @@ -108,6 +108,13 @@ func TestPostgresDialect(t *testing.T) { expect(dialect.BindVar(4)).To(matchers.Equal("$5")) }) + o.Spec("BindVarWithType", func(expect expect.Expectation, dialect gorp.PostgresDialect) { + expect(dialect.BindVarWithType(0, reflect.TypeOf(0))).To(matchers.Equal("$1::int")) + expect(dialect.BindVarWithType(1, reflect.TypeOf(false))).To(matchers.Equal("$2::bool")) + expect(dialect.BindVarWithType(2, reflect.TypeOf(1.23))).To(matchers.Equal("$3::float")) + expect(dialect.BindVarWithType(3, reflect.TypeOf("gopher"))).To(matchers.Equal("$4::text")) + }) + o.Group("QuoteField", func() { o.Spec("By default, case is preserved", func(expect expect.Expectation, dialect gorp.PostgresDialect) { expect(dialect.QuoteField("Foo")).To(matchers.Equal(`"Foo"`)) diff --git a/dialect_sqlite.go b/dialect_sqlite.go index 2296275b..1d2725b6 100644 --- a/dialect_sqlite.go +++ b/dialect_sqlite.go @@ -86,6 +86,11 @@ func (d SqliteDialect) BindVar(i int) string { return "?" } +// BindVarWithType of SQLite returns "?" +func (d SqliteDialect) BindVarWithType(i int, t reflect.Type) string { + return "?" +} + func (d SqliteDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { return standardInsertAutoIncr(exec, insertSql, params...) } diff --git a/dialect_sqlserver.go b/dialect_sqlserver.go index ec06aed6..66476ab7 100644 --- a/dialect_sqlserver.go +++ b/dialect_sqlserver.go @@ -101,6 +101,11 @@ func (d SqlServerDialect) BindVar(i int) string { return "?" } +// BindVarWithType of SQL Server returns "?" +func (d SqlServerDialect) BindVarWithType(i int, t reflect.Type) string { + return "?" +} + func (d SqlServerDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { return standardInsertAutoIncr(exec, insertSql, params...) } diff --git a/table_bindings.go b/table_bindings.go index a41c8ebb..bd871223 100644 --- a/table_bindings.go +++ b/table_bindings.go @@ -8,6 +8,8 @@ import ( "bytes" "fmt" "reflect" + "strconv" + "strings" "sync" ) @@ -129,19 +131,35 @@ func (t *TableMap) bindInsert(elem reflect.Value) (bindInstance, error) { } else { if col.DefaultValue == "" { s2.WriteString(t.dbmap.Dialect.BindVar(x)) + if col == t.version { + plan.versField = col.fieldName + plan.argFields = append(plan.argFields, versFieldConst) + } else { + plan.argFields = append(plan.argFields, col.fieldName) + } + x++ } else { - val := elem.FieldByName(col.fieldName).Interface() + defaultVal, err := getValueAsType(col.gotype, col.DefaultValue) + if err != nil { + fmt.Println("failed to parse col.DefaultValue:", err) + } + s2.WriteString( - fmt.Sprintf("case when %t or %s = %s then %s else %s end", - val == nil, t.dbmap.Dialect.BindVar(x), getZeroValueStringForSQL(val), col.DefaultValue, t.dbmap.Dialect.BindVar(x))) - } - if col == t.version { - plan.versField = col.fieldName - plan.argFields = append(plan.argFields, versFieldConst) - } else { - plan.argFields = append(plan.argFields, col.fieldName) + fmt.Sprintf("case when %s is null or %s = %s then %v else %s end", + t.dbmap.Dialect.BindVarWithType(x, col.gotype), + t.dbmap.Dialect.BindVarWithType(x+1, col.gotype), + getZeroValueStringForSQL(col.gotype), + defaultVal, + t.dbmap.Dialect.BindVarWithType(x+2, col.gotype))) + + if col == t.version { + plan.versField = col.fieldName + plan.argFields = append(plan.argFields, versFieldConst, versFieldConst, versFieldConst) + } else { + plan.argFields = append(plan.argFields, col.fieldName, col.fieldName, col.fieldName) + } + x += 3 } - x++ } first = false } @@ -164,13 +182,13 @@ func (t *TableMap) bindInsert(elem reflect.Value) (bindInstance, error) { return plan.createBindInstance(elem, t.dbmap.TypeConverter) } -func getZeroValueStringForSQL(i interface{}) (s string) { - switch i.(type) { - case bool: +func getZeroValueStringForSQL(t reflect.Type) (s string) { + switch t.Kind() { + case reflect.Bool: s = "false" - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: s = "0" - case float32, float64: + case reflect.Float32, reflect.Float64: s = "0.0" default: s = "''" @@ -178,6 +196,29 @@ func getZeroValueStringForSQL(i interface{}) (s string) { return } +func getValueAsType(t reflect.Type, value string) (s string, err error) { + value = strings.Trim(value, "'") + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + var n int + n, err = strconv.Atoi(value) + if err != nil { + return "", err + } + s = fmt.Sprintf("%v", n) + case reflect.Float32, reflect.Float64: + var f float64 + f, err = strconv.ParseFloat(value, 64) + if err != nil { + return "", err + } + s = fmt.Sprintf("%v", f) + default: + s = fmt.Sprintf("'%v'", value) + } + return +} + func (t *TableMap) bindUpdate(elem reflect.Value, colFilter ColumnFilter) (bindInstance, error) { if colFilter == nil { colFilter = acceptAllFilter diff --git a/table_bindings_test.go b/table_bindings_test.go index 625d0fb6..a8023f23 100644 --- a/table_bindings_test.go +++ b/table_bindings_test.go @@ -4,28 +4,66 @@ package gorp -import "testing" +import ( + "reflect" + "testing" +) func Test_getZeroValueStringForSQL(t *testing.T) { type args struct { - i interface{} + t reflect.Type } tests := []struct { name string args args wantS string }{ - {"bool", args{i: true}, "false"}, - {"int", args{i: -5}, "0"}, - {"uint", args{i: 100}, "0"}, - {"float", args{i: 12.3}, "0.0"}, - {"string", args{i: "gorp"}, "''"}, + {"bool", args{reflect.TypeOf(true)}, "false"}, + {"int", args{reflect.TypeOf(-5)}, "0"}, + {"uint", args{reflect.TypeOf(100)}, "0"}, + {"float", args{reflect.TypeOf(12.3)}, "0.0"}, + {"string", args{reflect.TypeOf("gorp")}, "''"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if gotS := getZeroValueStringForSQL(tt.args.i); gotS != tt.wantS { + if gotS := getZeroValueStringForSQL(tt.args.t); gotS != tt.wantS { t.Errorf("getZerovalueStringForSQL() = %v, want %v", gotS, tt.wantS) } }) } } + +func Test_getValueAsType(t *testing.T) { + type args struct { + t reflect.Type + value string + } + tests := []struct { + name string + args args + wantS string + wantErr bool + }{ + {"int", args{reflect.TypeOf(1), "774"}, "774", false}, + {"int with single quotation", args{reflect.TypeOf(1), "'774'"}, "774", false}, + {"int of empty string", args{reflect.TypeOf(1), ""}, "", true}, + {"float", args{reflect.TypeOf(1.00), "1.23"}, "1.23", false}, + {"float with single quotation", args{reflect.TypeOf(1.00), "'1.23'"}, "1.23", false}, + {"float of empty string", args{reflect.TypeOf(1.00), ""}, "", true}, + {"string", args{reflect.TypeOf(""), "Gopher"}, "'Gopher'", false}, + {"string with single quotation", args{reflect.TypeOf(""), "'Gopher'"}, "'Gopher'", false}, + {"string of empty string", args{reflect.TypeOf(""), ""}, "''", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotS, err := getValueAsType(tt.args.t, tt.args.value) + if (err != nil) != tt.wantErr { + t.Errorf("getValueAsType() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotS != tt.wantS { + t.Errorf("getValueAsType() = %v, want %v", gotS, tt.wantS) + } + }) + } +}