diff --git a/db_test.go b/db_test.go index 57d8f43a..5419f92e 100644 --- a/db_test.go +++ b/db_test.go @@ -7,6 +7,7 @@ package gorp_test import ( + "reflect" "testing" ) @@ -181,3 +182,92 @@ AND field12 IN (:FieldIntList) }) } } + +type comment struct { + ID int64 `db:"id,primarykey,autoincrement"` + Name string `db:"name,notnull,default:'NoName',size:200"` + Text string `db:"text,notnull,size:400"` + Number int `db:"number,notnull,default:774"` + Private bool `db:"private,notnull"` +} + +func TestDbMap_DefaultTag_oneByOne(t *testing.T) { + tests := []struct { + name string + comment *comment + wantComment comment + }{ + {"Use default", + &comment{Text: "Hey!", Private: false}, + comment{ID: 1, Name: "NoName", Text: "Hey!", Number: 774, Private: false}}, + {"Specify all property", + &comment{Name: "bob", Text: "Hello!", Number: 5, Private: true}, + comment{ID: 1, Name: "bob", Text: "Hello!", Number: 5, Private: true}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dbmap := newDbMap() + dbmap.AddTableWithName(comment{}, "comments").SetKeys(true, "id") + + err := dbmap.CreateTables() + if err != nil { + t.Errorf("failed to create tables:%v", err) + } + defer dropAndClose(dbmap) + + err = dbmap.Insert(tt.comment) + if err != nil { + t.Errorf("failed to insert:%v", err) + } + var gotComment comment + err = dbmap.SelectOne(&gotComment, "SELECT * FROM comments ORDER BY id desc LIMIT 1") + if err != nil { + t.Errorf("failed to select:%v", err) + } + if !reflect.DeepEqual(gotComment, tt.wantComment) { + t.Errorf("gotComment = %+v, want %+v", gotComment, tt.wantComment) + } + }) + } +} + +func TestDbMap_DefaultTag_allAtOnce(t *testing.T) { + comments := []comment{ + comment{Text: "Hey!", Private: false}, + comment{Name: "bob", Text: "Hello!", Number: 5, Private: true}, + } + wantComments := []comment{ + comment{ID: 1, Name: "NoName", Text: "Hey!", Number: 774, Private: false}, + comment{ID: 2, Name: "bob", Text: "Hello!", Number: 5, Private: true}, + } + + t.Run("Insert at once", func(t *testing.T) { + dbmap := newDbMap() + dbmap.AddTableWithName(comment{}, "comments").SetKeys(true, "id") + + err := dbmap.CreateTables() + if err != nil { + t.Errorf("failed to create tables:%v", err) + } + defer dropAndClose(dbmap) + + for i := range comments { + err = dbmap.Insert(&comments[i]) + if err != nil { + t.Errorf("failed to insert:%v", err) + } + } + + var gotComments []comment + _, err = dbmap.Select(&gotComments, "SELECT * FROM comments ORDER BY id") + if err != nil { + t.Errorf("failed to select:%v", err) + } + for i := range gotComments { + if !reflect.DeepEqual(gotComments[i], wantComments[i]) { + t.Errorf("gotComment = %+v, want %+v", gotComments[i], wantComments[i]) + } + } + }) +} diff --git a/dialect.go b/dialect.go index fdea2b20..f9caeb88 100644 --- a/dialect.go +++ b/dialect.go @@ -50,6 +50,14 @@ type Dialect interface { // BindVar(i int) string + // bind variable string to use when forming SQL statements + // in many dbs it is "?", but Postgres appears to use $1::int when using int value. + // + // i is a zero based index of the bind variable in this statement + // t is the type of the variable + // + BindVarWithType(i int, t reflect.Type) string + // Handles quoting of a field name to ensure that it doesn't raise any // SQL parsing exceptions by using a reserved word as a field name. QuoteField(field string) string diff --git a/dialect_mysql.go b/dialect_mysql.go index d068ebe8..19394fb2 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -140,6 +140,11 @@ func (d MySQLDialect) BindVar(i int) string { return "?" } +// BindVarWithType of MySQL returns "?" +func (d MySQLDialect) BindVarWithType(i int, t reflect.Type) string { + return "?" +} + func (d MySQLDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { return standardInsertAutoIncr(exec, insertSql, params...) } diff --git a/dialect_mysql_test.go b/dialect_mysql_test.go index 20ddeac1..fe66cf56 100644 --- a/dialect_mysql_test.go +++ b/dialect_mysql_test.go @@ -14,10 +14,10 @@ import ( "testing" "time" + "github.com/go-gorp/gorp" "github.com/poy/onpar" "github.com/poy/onpar/expect" "github.com/poy/onpar/matchers" - "github.com/go-gorp/gorp" ) func TestMySQLDialect(t *testing.T) { @@ -129,6 +129,10 @@ func TestMySQLDialect(t *testing.T) { expect(dialect.BindVar(0)).To(matchers.Equal("?")) }) + o.Spec("BindVarWithType", func(expect expect.Expectation, dialect gorp.MySQLDialect) { + expect(dialect.BindVarWithType(0, reflect.TypeOf(0))).To(matchers.Equal("?")) + }) + o.Spec("QuoteField", func(expect expect.Expectation, dialect gorp.MySQLDialect) { expect(dialect.QuoteField("foo")).To(matchers.Equal("`foo`")) }) diff --git a/dialect_oracle.go b/dialect_oracle.go index 01a99b8a..7821bf2d 100644 --- a/dialect_oracle.go +++ b/dialect_oracle.go @@ -91,6 +91,11 @@ func (d OracleDialect) BindVar(i int) string { return fmt.Sprintf(":%d", i+1) } +// BindVarWithType of Oracle returns "$(i+1)" +func (d OracleDialect) BindVarWithType(i int, t reflect.Type) string { + return fmt.Sprintf(":%d", i+1) +} + // After executing the insert uses the ColMap IdQuery to get the generated id func (d OracleDialect) InsertQueryToTarget(exec SqlExecutor, insertSql, idSql string, target interface{}, params ...interface{}) error { _, err := exec.Exec(insertSql, params...) diff --git a/dialect_postgres.go b/dialect_postgres.go index 7a9c50bf..48678919 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -102,6 +102,24 @@ func (d PostgresDialect) BindVar(i int) string { return fmt.Sprintf("$%d", i+1) } +// BindVarWithType of PostgreSQL returns "$(i+1::t)" +func (d PostgresDialect) BindVarWithType(i int, t reflect.Type) string { + var s string + switch t.Kind() { + case reflect.Bool: + s = "::bool" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + s = "::int" + case reflect.Float32, reflect.Float64: + s = "::float" + case reflect.String: + s = "::text" + default: + s = "" + } + return fmt.Sprintf("$%d%s", i+1, s) +} + func (d PostgresDialect) InsertAutoIncrToTarget(exec SqlExecutor, insertSql string, target interface{}, params ...interface{}) error { rows, err := exec.Query(insertSql, params...) if err != nil { diff --git a/dialect_postgres_test.go b/dialect_postgres_test.go index 7b116f27..26800f21 100644 --- a/dialect_postgres_test.go +++ b/dialect_postgres_test.go @@ -12,10 +12,10 @@ import ( "testing" "time" + "github.com/go-gorp/gorp" "github.com/poy/onpar" "github.com/poy/onpar/expect" "github.com/poy/onpar/matchers" - "github.com/go-gorp/gorp" ) func TestPostgresDialect(t *testing.T) { @@ -108,6 +108,13 @@ func TestPostgresDialect(t *testing.T) { expect(dialect.BindVar(4)).To(matchers.Equal("$5")) }) + o.Spec("BindVarWithType", func(expect expect.Expectation, dialect gorp.PostgresDialect) { + expect(dialect.BindVarWithType(0, reflect.TypeOf(0))).To(matchers.Equal("$1::int")) + expect(dialect.BindVarWithType(1, reflect.TypeOf(false))).To(matchers.Equal("$2::bool")) + expect(dialect.BindVarWithType(2, reflect.TypeOf(1.23))).To(matchers.Equal("$3::float")) + expect(dialect.BindVarWithType(3, reflect.TypeOf("gopher"))).To(matchers.Equal("$4::text")) + }) + o.Group("QuoteField", func() { o.Spec("By default, case is preserved", func(expect expect.Expectation, dialect gorp.PostgresDialect) { expect(dialect.QuoteField("Foo")).To(matchers.Equal(`"Foo"`)) diff --git a/dialect_sqlite.go b/dialect_sqlite.go index 2296275b..1d2725b6 100644 --- a/dialect_sqlite.go +++ b/dialect_sqlite.go @@ -86,6 +86,11 @@ func (d SqliteDialect) BindVar(i int) string { return "?" } +// BindVarWithType of SQLite returns "?" +func (d SqliteDialect) BindVarWithType(i int, t reflect.Type) string { + return "?" +} + func (d SqliteDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { return standardInsertAutoIncr(exec, insertSql, params...) } diff --git a/dialect_sqlserver.go b/dialect_sqlserver.go index ec06aed6..66476ab7 100644 --- a/dialect_sqlserver.go +++ b/dialect_sqlserver.go @@ -101,6 +101,11 @@ func (d SqlServerDialect) BindVar(i int) string { return "?" } +// BindVarWithType of SQL Server returns "?" +func (d SqlServerDialect) BindVarWithType(i int, t reflect.Type) string { + return "?" +} + func (d SqlServerDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { return standardInsertAutoIncr(exec, insertSql, params...) } diff --git a/table_bindings.go b/table_bindings.go index 43c0b3dd..bd871223 100644 --- a/table_bindings.go +++ b/table_bindings.go @@ -8,6 +8,8 @@ import ( "bytes" "fmt" "reflect" + "strconv" + "strings" "sync" ) @@ -137,7 +139,26 @@ func (t *TableMap) bindInsert(elem reflect.Value) (bindInstance, error) { } x++ } else { - s2.WriteString(col.DefaultValue) + defaultVal, err := getValueAsType(col.gotype, col.DefaultValue) + if err != nil { + fmt.Println("failed to parse col.DefaultValue:", err) + } + + s2.WriteString( + fmt.Sprintf("case when %s is null or %s = %s then %v else %s end", + t.dbmap.Dialect.BindVarWithType(x, col.gotype), + t.dbmap.Dialect.BindVarWithType(x+1, col.gotype), + getZeroValueStringForSQL(col.gotype), + defaultVal, + t.dbmap.Dialect.BindVarWithType(x+2, col.gotype))) + + if col == t.version { + plan.versField = col.fieldName + plan.argFields = append(plan.argFields, versFieldConst, versFieldConst, versFieldConst) + } else { + plan.argFields = append(plan.argFields, col.fieldName, col.fieldName, col.fieldName) + } + x += 3 } } first = false @@ -161,6 +182,43 @@ func (t *TableMap) bindInsert(elem reflect.Value) (bindInstance, error) { return plan.createBindInstance(elem, t.dbmap.TypeConverter) } +func getZeroValueStringForSQL(t reflect.Type) (s string) { + switch t.Kind() { + case reflect.Bool: + s = "false" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + s = "0" + case reflect.Float32, reflect.Float64: + s = "0.0" + default: + s = "''" + } + return +} + +func getValueAsType(t reflect.Type, value string) (s string, err error) { + value = strings.Trim(value, "'") + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + var n int + n, err = strconv.Atoi(value) + if err != nil { + return "", err + } + s = fmt.Sprintf("%v", n) + case reflect.Float32, reflect.Float64: + var f float64 + f, err = strconv.ParseFloat(value, 64) + if err != nil { + return "", err + } + s = fmt.Sprintf("%v", f) + default: + s = fmt.Sprintf("'%v'", value) + } + return +} + func (t *TableMap) bindUpdate(elem reflect.Value, colFilter ColumnFilter) (bindInstance, error) { if colFilter == nil { colFilter = acceptAllFilter diff --git a/table_bindings_test.go b/table_bindings_test.go new file mode 100644 index 00000000..a8023f23 --- /dev/null +++ b/table_bindings_test.go @@ -0,0 +1,69 @@ +// Copyright 2012 James Cooper. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package gorp + +import ( + "reflect" + "testing" +) + +func Test_getZeroValueStringForSQL(t *testing.T) { + type args struct { + t reflect.Type + } + tests := []struct { + name string + args args + wantS string + }{ + {"bool", args{reflect.TypeOf(true)}, "false"}, + {"int", args{reflect.TypeOf(-5)}, "0"}, + {"uint", args{reflect.TypeOf(100)}, "0"}, + {"float", args{reflect.TypeOf(12.3)}, "0.0"}, + {"string", args{reflect.TypeOf("gorp")}, "''"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotS := getZeroValueStringForSQL(tt.args.t); gotS != tt.wantS { + t.Errorf("getZerovalueStringForSQL() = %v, want %v", gotS, tt.wantS) + } + }) + } +} + +func Test_getValueAsType(t *testing.T) { + type args struct { + t reflect.Type + value string + } + tests := []struct { + name string + args args + wantS string + wantErr bool + }{ + {"int", args{reflect.TypeOf(1), "774"}, "774", false}, + {"int with single quotation", args{reflect.TypeOf(1), "'774'"}, "774", false}, + {"int of empty string", args{reflect.TypeOf(1), ""}, "", true}, + {"float", args{reflect.TypeOf(1.00), "1.23"}, "1.23", false}, + {"float with single quotation", args{reflect.TypeOf(1.00), "'1.23'"}, "1.23", false}, + {"float of empty string", args{reflect.TypeOf(1.00), ""}, "", true}, + {"string", args{reflect.TypeOf(""), "Gopher"}, "'Gopher'", false}, + {"string with single quotation", args{reflect.TypeOf(""), "'Gopher'"}, "'Gopher'", false}, + {"string of empty string", args{reflect.TypeOf(""), ""}, "''", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotS, err := getValueAsType(tt.args.t, tt.args.value) + if (err != nil) != tt.wantErr { + t.Errorf("getValueAsType() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotS != tt.wantS { + t.Errorf("getValueAsType() = %v, want %v", gotS, tt.wantS) + } + }) + } +}