Skip to content

Commit 93d515f

Browse files
committed
amend impl for float64 conv
1 parent e2ee440 commit 93d515f

File tree

5 files changed

+72
-23
lines changed

5 files changed

+72
-23
lines changed

enginetest/queries/script_queries.go

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package queries
1616

1717
import (
18+
"github.com/dolthub/vitess/go/mysql"
1819
"math"
1920
"time"
2021

@@ -125,16 +126,53 @@ var ScriptTests = []ScriptTest{
125126
Name: "String-to-number comparison operators should behave consistently",
126127
Assertions: []ScriptTestAssertion{
127128
{
128-
Query: "SELECT ('A') = (0)",
129-
Expected: []sql.Row{{true}},
130-
ExpectedWarningsCount: 1,
131-
ExpectedWarning: 1292,
129+
Query: "SELECT ('A') != (0)",
130+
Expected: []sql.Row{{false}},
131+
ExpectedWarningsCount: 1,
132+
ExpectedWarning: mysql.ERTruncatedWrongValue,
133+
ExpectedWarningMessageSubstring: "Truncated incorrect double value: A",
132134
},
133135
{
134-
Query: "SELECT ('A') IN (0)",
135-
Expected: []sql.Row{{true}},
136-
ExpectedWarningsCount: 1,
137-
ExpectedWarning: 1292,
136+
Query: "SELECT ('A') <> (0)",
137+
Expected: []sql.Row{{false}},
138+
ExpectedWarningsCount: 1,
139+
ExpectedWarning: mysql.ERTruncatedWrongValue,
140+
ExpectedWarningMessageSubstring: "Truncated incorrect double value: A",
141+
},
142+
{
143+
Query: "SELECT ('A') < (0)",
144+
Expected: []sql.Row{{false}},
145+
ExpectedWarningsCount: 1,
146+
ExpectedWarning: mysql.ERTruncatedWrongValue,
147+
ExpectedWarningMessageSubstring: "Truncated incorrect double value: A",
148+
},
149+
{
150+
Query: "SELECT ('A') <= (0)",
151+
Expected: []sql.Row{{true}},
152+
ExpectedWarningsCount: 1,
153+
ExpectedWarning: mysql.ERTruncatedWrongValue,
154+
ExpectedWarningMessageSubstring: "Truncated incorrect double value: A",
155+
},
156+
{
157+
Query: "SELECT ('A') > (0)",
158+
Expected: []sql.Row{{false}},
159+
ExpectedWarningsCount: 1,
160+
ExpectedWarning: mysql.ERTruncatedWrongValue,
161+
ExpectedWarningMessageSubstring: "Truncated incorrect double value: A",
162+
},
163+
{
164+
Query: "SELECT ('A') >= (0)",
165+
Expected: []sql.Row{{true}},
166+
ExpectedWarningsCount: 1,
167+
ExpectedWarning: mysql.ERTruncatedWrongValue,
168+
ExpectedWarningMessageSubstring: "Truncated incorrect double value: A",
169+
},
170+
{
171+
Query: "SELECT ('A') NOT IN (0)",
172+
Expected: []sql.Row{{false}},
173+
ExpectedWarningsCount: 1,
174+
ExpectedWarning: mysql.ERTruncatedWrongValue,
175+
ExpectedWarningMessageSubstring: "Truncated incorrect double value: A",
138176
},
139177
},
140178
},

sql/expression/comparison.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ func (c *comparison) Compare(ctx *sql.Context, row sql.Row) (int, error) {
141141
return c.Left().Type().Compare(ctx, left, right)
142142
}
143143

144-
l, r, compareType, err := c.CastLeftAndRight(ctx, left, right)
144+
l, r, compareType, err := c.castLeftAndRight(ctx, left, right)
145145
if err != nil {
146146
return 0, err
147147
}
@@ -171,7 +171,7 @@ func (c *comparison) evalLeftAndRight(ctx *sql.Context, row sql.Row) (interface{
171171
return left, right, nil
172172
}
173173

174-
func (c *comparison) CastLeftAndRight(ctx *sql.Context, left, right interface{}) (interface{}, interface{}, sql.Type, error) {
174+
func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) (interface{}, interface{}, sql.Type, error) {
175175
leftType := c.Left().Type()
176176
rightType := c.Right().Type()
177177

@@ -452,7 +452,7 @@ func (e *NullSafeEquals) Compare(ctx *sql.Context, row sql.Row) (int, error) {
452452
}
453453

454454
var compareType sql.Type
455-
left, right, compareType, err = e.CastLeftAndRight(ctx, left, right)
455+
left, right, compareType, err = e.castLeftAndRight(ctx, left, right)
456456
if err != nil {
457457
return 0, err
458458
}

sql/expression/convert.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s
360360
if err != nil {
361361
return nil, err
362362
}
363-
d, _, err := types.Float32.Convert(ctx, value)
363+
d, err := types.ConvertOrTruncate(ctx, value, types.Float32)
364364
if err != nil {
365365
return types.Float32.Zero(), nil
366366
}
@@ -370,7 +370,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s
370370
if err != nil {
371371
return nil, err
372372
}
373-
d, _, err := types.Float64.Convert(ctx, value)
373+
d, err := types.ConvertOrTruncate(ctx, value, types.Float64)
374374
if err != nil {
375375
return types.Float64.Zero(), nil
376376
}
@@ -386,7 +386,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s
386386
if err != nil {
387387
return nil, err
388388
}
389-
num, _, err := types.Int64.Convert(ctx, value)
389+
num, err := types.ConvertOrTruncate(ctx, value, types.Int64)
390390
if err != nil {
391391
return types.Int64.Zero(), nil
392392
}
@@ -403,7 +403,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s
403403
if err != nil {
404404
return nil, err
405405
}
406-
num, _, err := types.Uint64.Convert(ctx, value)
406+
num, err := types.ConvertOrTruncate(ctx, value, types.Uint64)
407407
if err != nil {
408408
num, _, err = types.Int64.Convert(ctx, value)
409409
if err != nil {
@@ -484,7 +484,6 @@ func prepareForNumericContext(val interface{}, originType sql.Type, isInt bool)
484484
return convertHexBlobToDecimalForNumericContext(val, originType)
485485
}
486486

487-
488487
// trimStringToNumberPrefix trims a string to the appropriate number prefix
489488
func trimStringToNumberPrefix(s string, isInt bool) string {
490489
if isInt {

sql/expression/in.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ func (in *InTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
9797
}
9898

9999
comp := newComparison(NewLiteral(originalLeft, in.Left().Type()), NewLiteral(originalRight, el.Type()))
100-
l, r, compareType, err := comp.CastLeftAndRight(ctx, originalLeft, originalRight)
100+
l, r, compareType, err := comp.castLeftAndRight(ctx, originalLeft, originalRight)
101101
if err != nil {
102102
return nil, err
103103
}
@@ -152,7 +152,6 @@ func NewNotInTuple(left sql.Expression, right sql.Expression) sql.Expression {
152152
return NewNot(NewInTuple(left, right))
153153
}
154154

155-
156155
// HashInTuple is an expression that checks an expression is inside a list of expressions using a hashmap.
157156
type HashInTuple struct {
158157
in *InTuple

sql/types/number.go

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,9 @@ func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRan
11341134
return uint64(v), sql.InRange, nil
11351135
case int64:
11361136
if v < 0 {
1137+
// For negative integers, use two's complement wrapping:
1138+
// -1 -> MaxUint64, -2 -> MaxUint64-1, -3 -> MaxUint64-2, etc.
1139+
// Formula: MaxUint64 - uint(-v-1)
11371140
return uint64(math.MaxUint64 - uint(-v-1)), sql.OutOfRange, nil
11381141
}
11391142
return uint64(v), sql.InRange, nil
@@ -1151,16 +1154,24 @@ func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRan
11511154
if v > float32(math.MaxInt64) {
11521155
return math.MaxUint64, sql.OutOfRange, nil
11531156
} else if v < 0 {
1154-
return uint64(math.MaxUint64 - v), sql.OutOfRange, nil
1157+
// For negative floats, use the same two's complement wrapping as integers:
1158+
// -1 -> MaxUint64, -2 -> MaxUint64-1, -3 -> MaxUint64-2, etc.
1159+
// Formula: MaxUint64 - uint(-v-1)
1160+
return uint64(math.MaxUint64 - uint(-v-1)), sql.OutOfRange, nil
11551161
}
11561162
return uint64(math.Round(float64(v))), sql.InRange, nil
11571163
case float64:
11581164
if v >= float64(math.MaxUint64) {
11591165
return math.MaxUint64, sql.OutOfRange, nil
11601166
} else if v <= 0 {
1161-
return uint64(math.MaxUint64 - v), sql.OutOfRange, nil
1162-
}
1163-
return uint64(math.Round(v)), sql.InRange, nil
1167+
// For negative floats, use the same two's complement wrapping as integers:
1168+
// -1 -> MaxUint64, -2 -> MaxUint64-1, -3 -> MaxUint64-2, etc.
1169+
// Formula: MaxUint64 - uint(-v-1)
1170+
result := uint64(math.MaxUint64 - uint(-v-1))
1171+
return result, sql.OutOfRange, nil
1172+
}
1173+
result := uint64(math.Round(v))
1174+
return result, sql.InRange, nil
11641175
case decimal.Decimal:
11651176
if v.GreaterThan(dec_uint64_max) {
11661177
return math.MaxUint64, sql.OutOfRange, nil
@@ -1186,7 +1197,9 @@ func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRan
11861197
return math.MaxUint64, sql.OutOfRange, nil
11871198
}
11881199
if f, err := strconv.ParseFloat(v, 64); err == nil {
1189-
if val, inRange, err := convertToUint64(t, f); err == nil && inRange {
1200+
// Note: We only check err == nil, not inRangqe, because negative numbers
1201+
// correctly return OutOfRange but should still be processed
1202+
if val, inRange, err := convertToUint64(t, f); err == nil {
11901203
return val, inRange, err
11911204
}
11921205
}

0 commit comments

Comments
 (0)