diff --git a/cmd/internal/planetscale_edge_database.go b/cmd/internal/planetscale_edge_database.go index 3687fd0..7006101 100644 --- a/cmd/internal/planetscale_edge_database.go +++ b/cmd/internal/planetscale_edge_database.go @@ -381,10 +381,15 @@ func (p PlanetScaleEdgeDatabase) sync( qr := sqltypes.Proto3ToResult(result) for _, row := range qr.Rows { nread += 1 - data := QueryResultToRecords(&sqltypes.Result{ + + data, err := QueryResultToRecords(&sqltypes.Result{ Fields: result.Fields, Rows: []sqltypes.Row{row}, }, &ps) + if err != nil { + return tc, nread, fmt.Errorf("query result to records: %w", err) + } + for _, record := range data { if p.Logger.QueueFull() { if err := checkpoint(p.Logger.Flush, tc, state); err != nil { diff --git a/cmd/internal/types.go b/cmd/internal/types.go index 014fe6f..1d8fdf5 100644 --- a/cmd/internal/types.go +++ b/cmd/internal/types.go @@ -136,7 +136,7 @@ func TableCursorToSerializedCursor(cursor *psdbconnect.TableCursor) (*Serialized return sc, nil } -func QueryResultToRecords(qr *sqltypes.Result, ps *PlanetScaleSource) []map[string]interface{} { +func QueryResultToRecords(qr *sqltypes.Result, ps *PlanetScaleSource) ([]map[string]interface{}, error) { data := make([]map[string]interface{}, 0, len(qr.Rows)) columns := make([]string, 0, len(qr.Fields)) for _, field := range qr.Fields { @@ -147,7 +147,12 @@ func QueryResultToRecords(qr *sqltypes.Result, ps *PlanetScaleSource) []map[stri record := make(map[string]interface{}) for idx, val := range row { if idx < len(columns) { - parsedValue := parseValue(val, qr.Fields[idx].GetColumnType(), qr.Fields[idx].GetType(), ps) + parsedValue, err := parseValue( + val, qr.Fields[idx].GetColumnType(), qr.Fields[idx].GetType(), ps, + ) + if err != nil { + return nil, fmt.Errorf("parse value: %w", err) + } if parsedValue.isBool { record[columns[idx]] = parsedValue.boolValue } else if parsedValue.isNull { @@ -160,7 +165,7 @@ func QueryResultToRecords(qr *sqltypes.Result, ps *PlanetScaleSource) []map[stri data = append(data, record) } - return data + return data, nil } type Value struct { @@ -172,11 +177,11 @@ type Value struct { // After the initial COPY phase, enum and set values may appear as an index instead of a value. // For example, a value might look like a "1" instead of "apple" in an enum('apple','banana','orange') column) -func parseValue(val sqltypes.Value, columnType string, queryColumnType query.Type, ps *PlanetScaleSource) Value { +func parseValue(val sqltypes.Value, columnType string, queryColumnType query.Type, ps *PlanetScaleSource) (Value, error) { if val.IsNull() { return Value{ isNull: true, - } + }, nil } switch queryColumnType { @@ -186,16 +191,20 @@ func parseValue(val sqltypes.Value, columnType string, queryColumnType query.Typ values := parseEnumOrSetValues(columnType) return Value{ sqlValue: mapEnumValue(val, values), - } + }, nil case query.Type_SET: values := parseEnumOrSetValues(columnType) return Value{ sqlValue: mapSetValue(val, values), - } + }, nil case query.Type_DECIMAL: - return Value{ - sqlValue: leadDecimalWithZero(val), + newVal, err := leadDecimalWithZero(val) + if err != nil { + return Value{}, fmt.Errorf("lead decimal with zero: %w", err) } + return Value{ + sqlValue: newVal, + }, nil case query.Type_BINARY, query.Type_BIT, query.Type_BITNUM, query.Type_BLOB, query.Type_CHAR, query.Type_EXPRESSION, query.Type_FLOAT32, query.Type_FLOAT64, query.Type_GEOMETRY, @@ -207,22 +216,23 @@ func parseValue(val sqltypes.Value, columnType string, queryColumnType query.Typ query.Type_VARCHAR, query.Type_YEAR: // No special handling. default: - panic(fmt.Sprintf("unexpected query.Type: %#v", queryColumnType)) + return Value{}, fmt.Errorf("unexpected query.Type: %#v", queryColumnType) } if strings.ToLower(columnType) == "tinyint(1)" && !ps.Options.DoNotTreatTinyIntAsBoolean { - return mapTinyIntToBool(val) + return mapTinyIntToBool(val), nil } return Value{ sqlValue: val, - } + }, nil } -func leadDecimalWithZero(val sqltypes.Value) sqltypes.Value { +func leadDecimalWithZero(val sqltypes.Value) (sqltypes.Value, error) { if !val.IsDecimal() { - panic("non-decimal value") + return val, errors.New("decimal required") } + valS := val.ToString() if strings.HasPrefix(valS, ".") || strings.HasPrefix(valS, "-.") { var newVal sqltypes.Value @@ -233,15 +243,18 @@ func leadDecimalWithZero(val sqltypes.Value) sqltypes.Value { newVal, err = sqltypes.NewValue(val.Type(), fmt.Appendf(nil, "-0%s", valS[1:])) } if err != nil { - panic(fmt.Sprintf("failed to reconstruct decimal with leading zero: %v", err)) + return val, fmt.Errorf("failed to reconstruct decimal with leading zero: %w", err) } - return newVal + return newVal, nil } - return val + + return val, nil } func mapTinyIntToBool(val sqltypes.Value) Value { sqlVal, err := val.ToBool() + // TODO: should we really be doing this? + // // Fallback to the original value if we can't convert to bool if err != nil { return Value{ @@ -271,7 +284,7 @@ func parseEnumOrSetValues(columnType string) []string { return values } -func formatISO8601(mysqlType query.Type, value sqltypes.Value) Value { +func formatISO8601(mysqlType query.Type, value sqltypes.Value) (Value, error) { var formatString string var layout string if mysqlType == query.Type_DATE { @@ -299,21 +312,24 @@ func formatISO8601(mysqlType query.Type, value sqltypes.Value) Value { } else { mysqlTime, err = time.Parse(formatString, parsedDatetime) if err != nil { + // TODO: should we really be doing this? // fallback to default value if datetime is not parseable return Value{ sqlValue: value, - } + }, nil } } - } iso8601Datetime := mysqlTime.Format(layout) - formattedValue, _ := sqltypes.NewValue(mysqlType, []byte(iso8601Datetime)) + formattedValue, err := sqltypes.NewValue(mysqlType, []byte(iso8601Datetime)) + if err != nil { + return Value{}, fmt.Errorf("new sql value from formatted datetime: %w", err) + } return Value{ sqlValue: formattedValue, - } + }, nil } func mapSetValue(value sqltypes.Value, values []string) sqltypes.Value { diff --git a/cmd/internal/types_test.go b/cmd/internal/types_test.go index 00fd66f..f8cc1df 100644 --- a/cmd/internal/types_test.go +++ b/cmd/internal/types_test.go @@ -99,7 +99,8 @@ func TestCanMapEnumAndSetValues(t *testing.T) { }, } - output := QueryResultToRecords(&input, &PlanetScaleSource{}) + output, err := QueryResultToRecords(&input, &PlanetScaleSource{}) + assert.NoError(t, err) assert.Equal(t, 2, len(output)) firstRow := output[0] assert.Equal(t, "active", firstRow["status"].(sqltypes.Value).ToString()) @@ -120,11 +121,12 @@ func TestCanMapTinyIntValues(t *testing.T) { }, } - output := QueryResultToRecords(&input, &PlanetScaleSource{ + output, err := QueryResultToRecords(&input, &PlanetScaleSource{ Options: CustomSourceOptions{ DoNotTreatTinyIntAsBoolean: false, }, }) + assert.NoError(t, err) assert.Equal(t, 2, len(output)) firstRow := output[0] @@ -142,11 +144,12 @@ func TestCanMapTinyIntValues(t *testing.T) { }, } - output = QueryResultToRecords(&input, &PlanetScaleSource{ + output, err = QueryResultToRecords(&input, &PlanetScaleSource{ Options: CustomSourceOptions{ DoNotTreatTinyIntAsBoolean: true, }, }) + assert.NoError(t, err) assert.Equal(t, 2, len(output)) firstRow = output[0] @@ -181,7 +184,8 @@ func TestCanFormatISO8601Values(t *testing.T) { }, } - output := QueryResultToRecords(&input, &PlanetScaleSource{}) + output, err := QueryResultToRecords(&input, &PlanetScaleSource{}) + assert.NoError(t, err) assert.Equal(t, 3, len(output)) row := output[0] assert.Equal(t, "2025-02-14T08:08:08.000000", row["datetime_created_at"].(sqltypes.Value).ToString()) @@ -228,7 +232,8 @@ func TestCanLeadDecimalWithZero(t *testing.T) { }, } - output := QueryResultToRecords(&input, &PlanetScaleSource{}) + output, err := QueryResultToRecords(&input, &PlanetScaleSource{}) + assert.NoError(t, err) assert.Equal(t, 1, len(output)) row := output[0]