diff --git a/db.go b/db.go index d78062e5..cd3bc8b0 100644 --- a/db.go +++ b/db.go @@ -887,6 +887,8 @@ type numberer interface { } func expandSliceArgs(query *string, args ...interface{}) { + valuerType := reflect.TypeOf((*driver.Valuer)(nil)).Elem() + for _, arg := range args { mapper, ok := arg.(map[string]interface{}) if !ok { @@ -905,76 +907,36 @@ func expandSliceArgs(query *string, args ...interface{}) { value = v.ToInt64Slice() } - switch v := value.(type) { - case []string: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - case []uint: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - case []uint8: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - case []uint16: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - case []uint32: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - case []uint64: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - case []int: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - case []int8: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - case []int16: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - case []int32: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - case []int64: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - case []float32: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - case []float64: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) + t := reflect.TypeOf(value) + if t.Kind() != reflect.Slice || t.Implements(valuerType) { + // Do not expand if the value is not a slice or implements driver.Valuer, + continue + } + elm := t.Elem() + // If the element of slice implements driver.Valuer or is a primitive value, + // expand the slice + isValue := elm.Implements(valuerType) + if !isValue { + switch elm.Kind() { + case reflect.String, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Float32, reflect.Float64: + isValue = true } - default: + } + if !isValue { continue } + val := reflect.ValueOf(value) + l := val.Len() + for id := 0; id < l; id++ { + v := val.Index(id) + mapper[fmt.Sprintf("%s%d", key, id)] = v.Interface() + replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) + } + if len(replacements) == 0 { continue } diff --git a/db_test.go b/db_test.go index 57d8f43a..2adbdb8d 100644 --- a/db_test.go +++ b/db_test.go @@ -7,6 +7,8 @@ package gorp_test import ( + "database/sql/driver" + "strings" "testing" ) @@ -22,6 +24,21 @@ func (c customType2) ToInt64Slice() []int64 { return []int64(c) } +type valuerSlice []string + +func (vs valuerSlice) Value() (driver.Value, error) { + return strings.Join(vs, ","), nil +} + +func (vs *valuerSlice) Scan(val interface{}) error { + *vs = strings.Split(string(val.([]byte)), ",") + return nil +} + +var _ driver.Valuer = valuerSlice([]string{}) + +type customID int64 + func TestDbMap_Select_expandSliceArgs(t *testing.T) { tests := []struct { description string @@ -83,23 +100,51 @@ AND field12 IN (:FieldIntList) }, wantLen: 3, }, + { + description: "handle customID types", + query: ` +SELECT 1 FROM crazy_table +WHERE field16 IN (:FieldCustomIDList) +`, + args: []interface{}{ + map[string]interface{}{ + "FieldCustomIDList": []customID{3, 4, 5}, + }, + }, + wantLen: 2, + }, + { + description: "handle types which are sql.Valuer", + query: ` +SELECT 1 FROM crazy_table +WHERE field15 = :FieldCustomValuer +`, + args: []interface{}{ + map[string]interface{}{ + "FieldCustomValuer": valuerSlice([]string{"aaa", "bbb"}), + }, + }, + wantLen: 1, + }, } type dataFormat struct { - Field1 int `db:"field1"` - Field2 string `db:"field2"` - Field3 uint `db:"field3"` - Field4 uint8 `db:"field4"` - Field5 uint16 `db:"field5"` - Field6 uint32 `db:"field6"` - Field7 uint64 `db:"field7"` - Field8 int `db:"field8"` - Field9 int8 `db:"field9"` - Field10 int16 `db:"field10"` - Field11 int32 `db:"field11"` - Field12 int64 `db:"field12"` - Field13 float32 `db:"field13"` - Field14 float64 `db:"field14"` + Field1 int `db:"field1"` + Field2 string `db:"field2"` + Field3 uint `db:"field3"` + Field4 uint8 `db:"field4"` + Field5 uint16 `db:"field5"` + Field6 uint32 `db:"field6"` + Field7 uint64 `db:"field7"` + Field8 int `db:"field8"` + Field9 int8 `db:"field9"` + Field10 int16 `db:"field10"` + Field11 int32 `db:"field11"` + Field12 int64 `db:"field12"` + Field13 float32 `db:"field13"` + Field14 float64 `db:"field14"` + Field15 valuerSlice `db:"field15"` + Field16 customID `db:"field16"` } dbmap := newDbMap() @@ -161,6 +206,17 @@ AND field12 IN (:FieldIntList) Field13: 3, Field14: 3, }, + &dataFormat{ + Field1: 126, + Field2: "h", + Field15: []string{"aaa", "bbb"}, + Field16: customID(4), + }, + &dataFormat{ + Field1: 127, + Field2: "o", + Field16: customID(5), + }, ) if err != nil {