From e30494d582847926eb139493f22e263a372bf403 Mon Sep 17 00:00:00 2001 From: Adam Reese Date: Mon, 27 Apr 2026 16:22:00 -0700 Subject: [PATCH] feat(pg): asyncify postgres sdk Signed-off-by: Adam Reese --- .github/workflows/build.yml | 3 + examples/pg-outbound/README.md | 30 ++ examples/pg-outbound/compose.yaml | 22 ++ examples/pg-outbound/db/pets.sql | 17 +- examples/pg-outbound/main.go | 52 ++- examples/pg-outbound/spin.toml | 2 +- go.mod | 7 + go.sum | 9 + pg/pg.go | 378 +++++++++++++----- pg/pg_test.go | 96 +++++ pg/types.go | 437 +++++++++++++++++++++ pg/types_test.go | 612 ++++++++++++++++++++++++++++++ 12 files changed, 1558 insertions(+), 107 deletions(-) create mode 100644 examples/pg-outbound/README.md create mode 100644 examples/pg-outbound/compose.yaml create mode 100644 pg/pg_test.go create mode 100644 pg/types.go create mode 100644 pg/types_test.go diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 448a4c0b..04ae3b4f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -24,5 +24,8 @@ jobs: # TODO: Switch to Spin 4.0 when it's available version: "canary" + - name: Run unit tests + run: go test -v -count=1 ./pg + - name: Run integration tests run: go test -v -count=1 . diff --git a/examples/pg-outbound/README.md b/examples/pg-outbound/README.md new file mode 100644 index 00000000..8846da8d --- /dev/null +++ b/examples/pg-outbound/README.md @@ -0,0 +1,30 @@ +# Requirements +- [**go**](https://go.dev/dl/) - v1.25+ +- [**spin**](https://github.com/spinframework/spin) - Latest version +- [**docker**](https://docs.docker.com/get-started/get-docker/) - Latest version + +# Usage +In a terminal window, use the below command to run PostgreSQL: +```sh +docker compose up -d +``` + +Then, you'll build and run your Spin app: +```sh +spin up --build +``` + +In another terminal window, you can interact with the Spin app: +```sh +curl localhost:3000 +``` + +You should see the output: +```json +[{"ID":1,"Name":"Splodge","Prey":null,"IsFinicky":false},{"ID":2,"Name":"Kiki","Prey":"Cicadas","IsFinicky":false},{"ID":3,"Name":"Slats","Prey":"Temptations","IsFinicky":true},{"ID":4,"Name":"Maya","Prey":"bananas","IsFinicky":true}] +``` + +To stop and clean up the PostgreSQL container, run the following: +```sh +docker compose down -v +``` diff --git a/examples/pg-outbound/compose.yaml b/examples/pg-outbound/compose.yaml new file mode 100644 index 00000000..a63bbccc --- /dev/null +++ b/examples/pg-outbound/compose.yaml @@ -0,0 +1,22 @@ +services: + postgres: + image: postgres:17 + container_name: postgres + restart: unless-stopped + environment: + POSTGRES_PASSWORD: postgres + POSTGRES_DB: spin_dev + ports: + - "5432:5432" + volumes: + - postgres_data:/var/lib/postgresql/data + - ./db/pets.sql:/docker-entrypoint-initdb.d/pets.sql + networks: + - postgres_network + +volumes: + postgres_data: + +networks: + postgres_network: + driver: bridge diff --git a/examples/pg-outbound/db/pets.sql b/examples/pg-outbound/db/pets.sql index bbf91d68..8ead9857 100644 --- a/examples/pg-outbound/db/pets.sql +++ b/examples/pg-outbound/db/pets.sql @@ -1,4 +1,13 @@ -CREATE TABLE pets (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, prey VARCHAR(100), is_finicky BOOL NOT NULL); -INSERT INTO pets VALUES (1, 'Splodge', NULL, false); -INSERT INTO pets VALUES (2, 'Kiki', 'Cicadas', false); -INSERT INTO pets VALUES (3, 'Slats', 'Temptations', true); +CREATE TABLE pets ( + id INT PRIMARY KEY, + name VARCHAR(100) NOT NULL, + prey VARCHAR(100), + is_finicky BOOL NOT NULL, + timestamp TIMESTAMP +); +INSERT INTO pets VALUES (1, 'Splodge', NULL, false, '2026-04-20 12:30:00'); +INSERT INTO pets VALUES (2, 'Kiki', 'Cicadas', false, '2026-04-20 12:30:00'); +INSERT INTO pets VALUES (3, 'Slats', 'Temptations', true, '2026-04-20 12:30:00'); + +-- For creating uuids using uuid_generate_v4() in the example. +CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; diff --git a/examples/pg-outbound/main.go b/examples/pg-outbound/main.go index ef2ab975..d1399db8 100644 --- a/examples/pg-outbound/main.go +++ b/examples/pg-outbound/main.go @@ -5,6 +5,8 @@ import ( "fmt" "net/http" "os" + "slices" + "time" spinhttp "github.com/spinframework/spin-go-sdk/v3/http" "github.com/spinframework/spin-go-sdk/v3/pg" @@ -15,19 +17,53 @@ type Pet struct { Name string Prey *string // nullable field must be a pointer IsFinicky bool + Timestamp time.Time } func init() { spinhttp.Handle(func(w http.ResponseWriter, r *http.Request) { // addr is the environment variable set in `spin.toml` that points to the - // address of the Mysql server. + // address of the postgres server. addr := os.Getenv("DB_URL") db := pg.Open(addr) defer db.Close() - _, err := db.Query("INSERT INTO pets VALUES ($1, 'Maya', $2, $3);", int32(4), "bananas", true) + var uuid pg.UUID + if err := db.QueryRow(`SELECT uuid_generate_v4()`).Scan(&uuid); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + fmt.Printf("Generated UUID: %#v\n", uuid) + + // Testing Array parsing + var x []int32 + if err := db.QueryRow(`SELECT ARRAY[200, 404]`).Scan(&x); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if !slices.Equal(x, []int32{200, 404}) { + http.Error(w, fmt.Sprintf("Slices aren't equal, got: %v", x), http.StatusInternalServerError) + return + } + + // Testing Range parsing + var rangeInt32 pg.Int32Range + if err := db.QueryRow(`SELECT int4range(10, 20)`).Scan(&rangeInt32); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if *rangeInt32.Lower != 10 { + http.Error(w, fmt.Sprintf("Error parsing lower range, got: %v", *rangeInt32.Lower), http.StatusInternalServerError) + return + } + if *rangeInt32.Upper != 20 { + http.Error(w, fmt.Sprintf("Error parsing upper range, got: %v", *rangeInt32.Upper), http.StatusInternalServerError) + return + } + + _, err := db.Exec("INSERT INTO pets (id, name, prey, is_finicky, timestamp) VALUES ($1, 'Maya', $2, $3, $4);", int32(4), "bananas", true, time.Now()) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -38,16 +74,24 @@ func init() { http.Error(w, err.Error(), http.StatusInternalServerError) return } + defer rows.Close() var pets []*Pet for rows.Next() { var pet Pet - if err := rows.Scan(&pet.ID, &pet.Name, &pet.Prey, &pet.IsFinicky); err != nil { + if err := rows.Scan(&pet.ID, &pet.Name, &pet.Prey, &pet.IsFinicky, &pet.Timestamp); err != nil { fmt.Println(err) } pets = append(pets, &pet) } - json.NewEncoder(w).Encode(pets) + if err := rows.Err(); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if err := json.NewEncoder(w).Encode(pets); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } }) } diff --git a/examples/pg-outbound/spin.toml b/examples/pg-outbound/spin.toml index ee0718b6..e3c1f300 100644 --- a/examples/pg-outbound/spin.toml +++ b/examples/pg-outbound/spin.toml @@ -11,7 +11,7 @@ route = "/..." component = "pg-outbound" [component.pg-outbound] -environment = { DB_URL = "host=localhost user=postgres dbname=spin_dev" } +environment = { DB_URL = "host=localhost user=postgres password=postgres dbname=spin_dev" } source = "main.wasm" allowed_outbound_hosts = ["postgres://localhost"] [component.pg-outbound.build] diff --git a/go.mod b/go.mod index edcfd21e..da20efa0 100644 --- a/go.mod +++ b/go.mod @@ -3,3 +3,10 @@ module github.com/spinframework/spin-go-sdk/v3 go 1.25.5 require go.bytecodealliance.org/pkg v0.2.1 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/testify v1.11.1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index a18d6dc5..d2612455 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,11 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= go.bytecodealliance.org/pkg v0.2.1 h1:TdRagooIcCW3UmlKqVO4cDR3GNDyfDnbiBzGI6TOvyg= go.bytecodealliance.org/pkg v0.2.1/go.mod h1:OjA+V8g3uUFixeCKFfamm6sYhTJdg8fvwEdJ2GO0GSk= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pg/pg.go b/pg/pg.go index a6fda9d6..a94b33a7 100644 --- a/pg/pg.go +++ b/pg/pg.go @@ -6,12 +6,14 @@ import ( "database/sql" "database/sql/driver" "errors" + "fmt" "io" "reflect" + "time" + pg "github.com/spinframework/spin-go-sdk/v3/imports/spin_postgres_4_2_0_postgres" spindb "github.com/spinframework/spin-go-sdk/v3/internal/db" - pg "github.com/spinframework/spin-go-sdk/v3/imports/fermyon_spin_2_0_0_postgres" - rdbmstypes "github.com/spinframework/spin-go-sdk/v3/imports/fermyon_spin_2_0_0_rdbms_types" + wittypes "go.bytecodealliance.org/pkg/wit/types" ) // Open returns a new connection to the database. @@ -40,7 +42,7 @@ func (d *connector) Driver() driver.Driver { // Open returns a new connection to the database. func (d *connector) Open(name string) (driver.Conn, error) { - results := pg.ConnectionOpen(name) + results := pg.ConnectionOpenAsync(name) if results.IsErr() { return nil, toError(results.Err()) } @@ -61,6 +63,7 @@ func (c *conn) Prepare(query string) (driver.Stmt, error) { } func (c *conn) Close() error { + c.spinConn.Drop() return nil } @@ -94,31 +97,28 @@ func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { rdbmsParams[i] = toRdbmsParameterValue(v) } - results := s.conn.spinConn.Query(s.query, rdbmsParams) + results := s.conn.spinConn.QueryAsync(s.query, rdbmsParams) if results.IsErr() { return nil, toError(results.Err()) } - rowLen := len(results.Ok().Rows) - allRows := make([][]any, rowLen) - for rowNum, row := range results.Ok().Rows { - allRows[rowNum] = toRow(row) - } - - cols := results.Ok().Columns + tuple := results.Ok() + cols := tuple.F0 colNames := make([]string, len(cols)) colTypes := make([]uint8, len(cols)) for i, c := range cols { colNames[i] = c.Name - colTypes[i] = uint8(c.DataType) + colTypes[i] = c.DataType.Tag() } rows := &rows{ columns: colNames, columnType: colTypes, - rows: allRows, - len: int(rowLen), + stream: tuple.F1, + future: tuple.F2, } + + rows.next = rows.pull() return rows, nil } @@ -130,7 +130,7 @@ func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { rdbmsParams[i] = toRdbmsParameterValue(v) } - queryResult := s.conn.spinConn.Execute(s.query, rdbmsParams) + queryResult := s.conn.spinConn.ExecuteAsync(s.query, rdbmsParams) if queryResult.IsErr() { return &result{}, toError(queryResult.Err()) } @@ -158,10 +158,10 @@ func (r result) RowsAffected() (int64, error) { type rows struct { columns []string columnType []uint8 - pos int - len int - rows [][]any - closed bool + next []any + stream *wittypes.StreamReader[[]pg.DbValue] + future *wittypes.FutureReader[wittypes.Result[wittypes.Unit, pg.Error]] + result error } var _ driver.Rows = (*rows)(nil) @@ -175,29 +175,46 @@ func (r *rows) Columns() []string { // Close closes the rows iterator. func (r *rows) Close() error { - r.rows = nil - r.pos = 0 - r.len = 0 - r.closed = true + r.stream.Drop() + r.future.Drop() + r.stream = nil + r.future = nil + r.next = nil + r.result = io.EOF + return nil +} + +func (r *rows) pull() []any { + buffer := [][]pg.DbValue{nil} + if r.stream.Read(buffer) == 1 { + return toRow(buffer[0]) + } + result := r.future.Read() + if result.IsOk() { + r.result = io.EOF + } else { + r.result = toError(result.Err()) + } return nil } // Next moves the cursor to the next row. func (r *rows) Next(dest []driver.Value) error { if !r.HasNextResultSet() { - return io.EOF + return r.result } + next := r.next + r.next = r.pull() for i := 0; i != len(r.columns); i++ { - dest[i] = driver.Value(r.rows[r.pos][i]) + dest[i] = driver.Value(next[i]) } - r.pos++ return nil } // HasNextResultSet is called at the end of the current result set and // reports whether there is another result set after the current one. func (r *rows) HasNextResultSet() bool { - return r.pos < r.len + return r.next != nil } // NextResultSet advances the driver to the next result set even @@ -206,10 +223,10 @@ func (r *rows) HasNextResultSet() bool { // NextResultSet should return io.EOF when there are no more result sets. func (r *rows) NextResultSet() error { if r.HasNextResultSet() { - r.pos++ + r.next = r.pull() return nil } - return io.EOF // Per interface spec. + return r.result } // ColumnTypeScanType returns the value type that can be used to scan types into. @@ -220,87 +237,209 @@ func (r *rows) ColumnTypeScanType(index int) reflect.Type { func toRdbmsParameterValue(x any) pg.ParameterValue { switch v := x.(type) { case bool: - return rdbmstypes.MakeParameterValueBoolean(v) + return pg.MakeParameterValueBoolean(v) case int8: - return rdbmstypes.MakeParameterValueInt8(v) + return pg.MakeParameterValueInt8(v) case int16: - return rdbmstypes.MakeParameterValueInt16(v) + return pg.MakeParameterValueInt16(v) case int32: - return rdbmstypes.MakeParameterValueInt32(v) + return pg.MakeParameterValueInt32(v) case int64: - return rdbmstypes.MakeParameterValueInt64(v) + return pg.MakeParameterValueInt64(v) case int: - return rdbmstypes.MakeParameterValueInt64(int64(v)) - case uint8: - return rdbmstypes.MakeParameterValueUint8(v) - case uint16: - return rdbmstypes.MakeParameterValueUint16(v) - case uint32: - return rdbmstypes.MakeParameterValueUint32(v) - case uint64: - return rdbmstypes.MakeParameterValueUint64(v) + return pg.MakeParameterValueInt64(int64(v)) case float32: - return rdbmstypes.MakeParameterValueFloating32(v) + return pg.MakeParameterValueFloating32(v) case float64: - return rdbmstypes.MakeParameterValueFloating64(v) + return pg.MakeParameterValueFloating64(v) case string: - return rdbmstypes.MakeParameterValueStr(v) + return pg.MakeParameterValueStr(v) case []byte: - return rdbmstypes.MakeParameterValueBinary(v) + return pg.MakeParameterValueBinary(v) + case []string: + return pg.MakeParameterValueArrayStr(toOptionSlice(v)) + case Int32Range: + witVal, _ := v.Value() + return pg.MakeParameterValueRangeInt32(witVal.(wittypes.Tuple2[wittypes.Option[wittypes.Tuple2[int32, pg.RangeBoundKind]], wittypes.Option[wittypes.Tuple2[int32, pg.RangeBoundKind]]])) + case Int64Range: + witVal, _ := v.Value() + return pg.MakeParameterValueRangeInt64(witVal.(wittypes.Tuple2[wittypes.Option[wittypes.Tuple2[int64, pg.RangeBoundKind]], wittypes.Option[wittypes.Tuple2[int64, pg.RangeBoundKind]]])) + case []int32: + return pg.MakeParameterValueArrayInt32(toOptionSlice(v)) + case []int64: + return pg.MakeParameterValueArrayInt64(toOptionSlice(v)) + case time.Time: + v = v.UTC() + return pg.MakeParameterValueDatetime(wittypes.Tuple7[int32, uint8, uint8, uint8, uint8, uint8, uint32]{ + F0: int32(v.Year()), + F1: uint8(v.Month()), + F2: uint8(v.Day()), + F3: uint8(v.Hour()), + F4: uint8(v.Minute()), + F5: uint8(v.Second()), + F6: uint32(v.Nanosecond()), + }) case nil: - return rdbmstypes.MakeParameterValueDbNull() + return pg.MakeParameterValueDbNull() + case JSONB: + return pg.MakeParameterValueJsonb([]byte(v)) + case Date: + return pg.MakeParameterValueDate(wittypes.Tuple3[int32, uint8, uint8]{ + F0: int32(v.Year), + F1: uint8(v.Month), + F2: uint8(v.Day), + }) + case Time: + return pg.MakeParameterValueTime(wittypes.Tuple4[uint8, uint8, uint8, uint32]{ + F0: uint8(v.Hour), + F1: uint8(v.Minute), + F2: uint8(v.Second), + F3: uint32(v.Nanosecond), + }) + case Interval: + return pg.MakeParameterValueInterval(pg.Interval{ + Micros: v.Micros, + Days: v.Days, + Months: v.Months, + }) + case Decimal: + return pg.MakeParameterValueDecimal(string(v)) + case DecimalRange: + witVal, _ := v.Value() + return pg.MakeParameterValueRangeDecimal(witVal.(wittypes.Tuple2[wittypes.Option[wittypes.Tuple2[string, pg.RangeBoundKind]], wittypes.Option[wittypes.Tuple2[string, pg.RangeBoundKind]]])) + case []Decimal: + opts := make([]wittypes.Option[string], len(v)) + for i, d := range v { + opts[i] = wittypes.Some(string(d)) + } + return pg.MakeParameterValueArrayDecimal(opts) + case UUID: + return pg.MakeParameterValueUuid(string(v)) default: panic("unknown value type") } } +// QueryDBError represents a structured PostgreSQL database error returned by the runtime. +// It is returned when a query fails with a structured error from Postgres (as opposed +// to a plain-text error message). +type QueryDBError struct { + // Severity is the severity level of the error (e.g. "ERROR", "FATAL", "WARNING"). + Severity string + // Code is the PostgreSQL error code (e.g. "23505" for unique_violation). + Code string + // Message is the primary human-readable error message. + Message string + // Detail is an optional secondary message providing more detail about the error. + Detail string + // Extras contains any additional key-value error information provided by Postgres. + Extras [][2]string +} + +func (e *QueryDBError) Error() string { + msg := fmt.Sprintf("%s (%s): %s", e.Severity, e.Code, e.Message) + if e.Detail != "" { + msg += ": " + e.Detail + } + return msg +} + func toError(err pg.Error) error { switch err.Tag() { - case rdbmstypes.ErrorBadParameter: + case pg.ErrorBadParameter: return errors.New(err.BadParameter()) - case rdbmstypes.ErrorConnectionFailed: + case pg.ErrorConnectionFailed: return errors.New(err.ConnectionFailed()) - case rdbmstypes.ErrorQueryFailed: - return errors.New(err.QueryFailed()) - case rdbmstypes.ErrorValueConversionFailed: + case pg.ErrorQueryFailed: + qf := err.QueryFailed() + switch qf.Tag() { + case pg.QueryErrorText: + return errors.New(qf.Text()) + case pg.QueryErrorDbError: + dbErr := qf.DbError() + pgErr := &QueryDBError{ + Severity: dbErr.Severity, + Code: dbErr.Code, + Message: dbErr.Message, + } + if dbErr.Detail.IsSome() { + pgErr.Detail = dbErr.Detail.Some() + } + if len(dbErr.Extras) > 0 { + pgErr.Extras = make([][2]string, len(dbErr.Extras)) + for i, e := range dbErr.Extras { + pgErr.Extras[i] = [2]string{e.F0, e.F1} + } + } + return pgErr + } + case pg.ErrorValueConversionFailed: return errors.New(err.ValueConversionFailed()) - default: - // TODO: not sure if using "Other" as the default is appropriate + case pg.ErrorOther: return errors.New(err.Other()) } + panic("unknown error from runtime") } -func toRow(row []rdbmstypes.DbValue) []any { +func toRow(row []pg.DbValue) []any { result := make([]any, len(row)) for i, v := range row { switch v.Tag() { - case rdbmstypes.DbValueBoolean: + case pg.DbValueBoolean: result[i] = v.Boolean() - case rdbmstypes.DbValueInt8: + case pg.DbValueInt8: result[i] = v.Int8() - case rdbmstypes.DbValueInt16: + case pg.DbValueInt16: result[i] = v.Int16() - case rdbmstypes.DbValueInt32: + case pg.DbValueInt32: result[i] = v.Int32() - case rdbmstypes.DbValueInt64: + case pg.DbValueInt64: result[i] = v.Int64() - case rdbmstypes.DbValueUint8: - result[i] = v.Uint8() - case rdbmstypes.DbValueUint16: - result[i] = v.Uint16() - case rdbmstypes.DbValueUint32: - result[i] = v.Uint32() - case rdbmstypes.DbValueUint64: - result[i] = v.Uint64() - case rdbmstypes.DbValueFloating32: + case pg.DbValueFloating32: result[i] = v.Floating32() - case rdbmstypes.DbValueFloating64: + case pg.DbValueFloating64: result[i] = v.Floating64() - case rdbmstypes.DbValueStr: + case pg.DbValueStr: result[i] = v.Str() - case rdbmstypes.DbValueBinary: + case pg.DbValueBinary: result[i] = v.Binary() - case rdbmstypes.DbValueDbNull: + case pg.DbValueDate: + d := v.Date() + result[i] = time.Date(int(d.F0), time.Month(d.F1), int(d.F2), 0, 0, 0, 0, time.UTC) + case pg.DbValueTime: + t := v.Time() + result[i] = time.Date(0, 1, 1, int(t.F0), int(t.F1), int(t.F2), int(t.F3), time.UTC) + case pg.DbValueDatetime: + dt := v.Datetime() + result[i] = time.Date(int(dt.F0), time.Month(dt.F1), int(dt.F2), int(dt.F3), int(dt.F4), int(dt.F5), int(dt.F6), time.UTC) + case pg.DbValueTimestamp: + result[i] = time.Unix(v.Timestamp(), 0).UTC() + case pg.DbValueUuid: + result[i] = v.Uuid() + case pg.DbValueJsonb: + result[i] = []byte(v.Jsonb()) + case pg.DbValueDecimal: + result[i] = v.Decimal() + case pg.DbValueRangeInt32: + result[i] = v.RangeInt32() + case pg.DbValueRangeInt64: + result[i] = v.RangeInt64() + case pg.DbValueRangeDecimal: + result[i] = v.RangeDecimal() + case pg.DbValueArrayInt32: + result[i] = fromOptionSlice(v.ArrayInt32()) + case pg.DbValueArrayInt64: + result[i] = fromOptionSlice(v.ArrayInt64()) + case pg.DbValueArrayDecimal: + result[i] = fromOptionSlice(v.ArrayDecimal()) + case pg.DbValueArrayStr: + result[i] = fromOptionSlice(v.ArrayStr()) + case pg.DbValueInterval: + iv := v.Interval() + result[i] = Interval{Months: iv.Months, Days: iv.Days, Micros: iv.Micros} + case pg.DbValueUnsupported: + result[i] = v.Unsupported() + case pg.DbValueDbNull: result[i] = nil default: panic("unknown value type") @@ -312,30 +451,73 @@ func toRow(row []rdbmstypes.DbValue) []any { func colTypeToReflectType(typ uint8) reflect.Type { switch typ { - case uint8(rdbmstypes.DbDataTypeBoolean): - return reflect.TypeOf(false) - case uint8(rdbmstypes.DbDataTypeInt8): - return reflect.TypeOf(int8(0)) - case uint8(rdbmstypes.DbDataTypeInt16): - return reflect.TypeOf(int16(0)) - case uint8(rdbmstypes.DbDataTypeInt32): - return reflect.TypeOf(int32(0)) - case uint8(rdbmstypes.DbDataTypeInt64): - return reflect.TypeOf(int64(0)) - case uint8(rdbmstypes.DbDataTypeUint8): - return reflect.TypeOf(uint8(0)) - case uint8(rdbmstypes.DbDataTypeUint16): - return reflect.TypeOf(uint16(0)) - case uint8(rdbmstypes.DbDataTypeUint32): - return reflect.TypeOf(uint32(0)) - case uint8(rdbmstypes.DbDataTypeUint64): - return reflect.TypeOf(uint64(0)) - case uint8(rdbmstypes.DbDataTypeStr): - return reflect.TypeOf("") - case uint8(rdbmstypes.DbDataTypeBinary): - return reflect.TypeOf(new([]byte)) - case uint8(rdbmstypes.DbDataTypeOther): - return reflect.TypeOf(new(any)).Elem() + case pg.DbDataTypeBoolean: + return reflect.TypeFor[bool]() + case pg.DbDataTypeInt8: + return reflect.TypeFor[int8]() + case pg.DbDataTypeInt16: + return reflect.TypeFor[int16]() + case pg.DbDataTypeInt32: + return reflect.TypeFor[int32]() + case pg.DbDataTypeInt64: + return reflect.TypeFor[int64]() + case pg.DbDataTypeFloating32: + return reflect.TypeFor[float32]() + case pg.DbDataTypeFloating64: + return reflect.TypeFor[float64]() + case pg.DbDataTypeStr: + return reflect.TypeFor[string]() + case pg.DbDataTypeUuid: + return reflect.TypeFor[string]() + case pg.DbDataTypeDecimal: + return reflect.TypeFor[string]() + case pg.DbDataTypeBinary: + return reflect.TypeFor[[]byte]() + case pg.DbDataTypeJsonb: + return reflect.TypeFor[[]byte]() + case pg.DbDataTypeDate: + return reflect.TypeFor[Date]() + case pg.DbDataTypeTime: + return reflect.TypeFor[Time]() + case pg.DbDataTypeDatetime, + pg.DbDataTypeTimestamp: + return reflect.TypeFor[time.Time]() + case pg.DbDataTypeInterval: + return reflect.TypeFor[Interval]() + case pg.DbDataTypeRangeInt32: + return reflect.TypeFor[Int32Range]() + case pg.DbDataTypeRangeInt64: + return reflect.TypeFor[Int64Range]() + case pg.DbDataTypeRangeDecimal: + return reflect.TypeFor[DecimalRange]() + case pg.DbDataTypeArrayInt32: + return reflect.TypeFor[[]int32]() + case pg.DbDataTypeArrayInt64: + return reflect.TypeFor[[]int64]() + case pg.DbDataTypeArrayDecimal: + return reflect.TypeFor[[]string]() + case pg.DbDataTypeArrayStr: + return reflect.TypeFor[[]string]() + case pg.DbDataTypeOther: + return reflect.TypeFor[any]().Elem() } panic("invalid db column type of " + string(typ)) } + +func toOptionSlice[T any](v []T) []wittypes.Option[T] { + values := make([]wittypes.Option[T], len(v)) + for i, x := range v { + values[i] = wittypes.Some(x) + } + return values +} + +func fromOptionSlice[T any](v []wittypes.Option[T]) []T { + values := make([]T, len(v)) + for i, x := range v { + if x.IsSome() { + values[i] = x.Some() + } + } + return values +} diff --git a/pg/pg_test.go b/pg/pg_test.go new file mode 100644 index 00000000..431741cd --- /dev/null +++ b/pg/pg_test.go @@ -0,0 +1,96 @@ +package pg + +import ( + "errors" + "testing" + + pg "github.com/spinframework/spin-go-sdk/v3/imports/spin_postgres_4_2_0_postgres" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + wittypes "go.bytecodealliance.org/pkg/wit/types" +) + +func TestToOptionSlice(t *testing.T) { + got := toOptionSlice([]string{"a", "b", "c"}) + want := []wittypes.Option[string]{ + wittypes.Some("a"), + wittypes.Some("b"), + wittypes.Some("c"), + } + assert.Equal(t, want, got) +} + +func TestFromOptionSlice(t *testing.T) { + got := fromOptionSlice([]wittypes.Option[int]{ + wittypes.Some(1), + wittypes.Some(2), + wittypes.Some(3), + }) + want := []int{1, 2, 3} + assert.Equal(t, want, got) +} + +func TestToError(t *testing.T) { + t.Run("ConnectionFailed", func(t *testing.T) { + err := toError(pg.MakeErrorConnectionFailed("connection refused")) + assert.EqualError(t, err, "connection refused") + }) + + t.Run("BadParameter", func(t *testing.T) { + err := toError(pg.MakeErrorBadParameter("invalid param")) + assert.EqualError(t, err, "invalid param") + }) + + t.Run("QueryFailed/Text", func(t *testing.T) { + err := toError(pg.MakeErrorQueryFailed(pg.MakeQueryErrorText("syntax error"))) + assert.EqualError(t, err, "syntax error") + }) + + t.Run("QueryFailed/DbError", func(t *testing.T) { + dbErr := pg.DbError{ + AsText: "ERROR 23505 (unique_violation): duplicate key", + Severity: "ERROR", + Code: "23505", + Message: "duplicate key value violates unique constraint", + Detail: wittypes.Some("Key (id)=(1) already exists."), + Extras: []wittypes.Tuple2[string, string]{{F0: "constraint", F1: "users_pkey"}}, + } + err := toError(pg.MakeErrorQueryFailed(pg.MakeQueryErrorDbError(dbErr))) + + var pgErr *QueryDBError + require.True(t, errors.As(err, &pgErr)) + assert.Equal(t, "ERROR", pgErr.Severity) + assert.Equal(t, "23505", pgErr.Code) + assert.Equal(t, "duplicate key value violates unique constraint", pgErr.Message) + assert.Equal(t, "Key (id)=(1) already exists.", pgErr.Detail) + assert.Equal(t, [][2]string{{"constraint", "users_pkey"}}, pgErr.Extras) + assert.Contains(t, pgErr.Error(), "23505") + assert.Contains(t, pgErr.Error(), "duplicate key value violates unique constraint") + assert.Contains(t, pgErr.Error(), "Key (id)=(1) already exists.") + }) + + t.Run("QueryFailed/DbError/NoDetail", func(t *testing.T) { + dbErr := pg.DbError{ + Severity: "ERROR", + Code: "42703", + Message: "column does not exist", + Detail: wittypes.None[string](), + } + err := toError(pg.MakeErrorQueryFailed(pg.MakeQueryErrorDbError(dbErr))) + + var pgErr *QueryDBError + require.True(t, errors.As(err, &pgErr)) + assert.Equal(t, "", pgErr.Detail) + assert.Nil(t, pgErr.Extras) + }) + + t.Run("ValueConversionFailed", func(t *testing.T) { + err := toError(pg.MakeErrorValueConversionFailed("cannot convert")) + assert.EqualError(t, err, "cannot convert") + }) + + t.Run("Other", func(t *testing.T) { + err := toError(pg.MakeErrorOther("unknown error")) + assert.EqualError(t, err, "unknown error") + }) +} diff --git a/pg/types.go b/pg/types.go new file mode 100644 index 00000000..12f894a2 --- /dev/null +++ b/pg/types.go @@ -0,0 +1,437 @@ +package pg + +import ( + "database/sql/driver" + "fmt" + "time" + + pg "github.com/spinframework/spin-go-sdk/v3/imports/spin_postgres_4_2_0_postgres" + wittypes "go.bytecodealliance.org/pkg/wit/types" +) + +// | Go type | WIT (db-value) | Postgres type(s) | +// |-------------------------|-----------------------------------------------|----------------------------- | +// | `bool` | boolean(bool) | BOOL | +// | `int16` | int16(s16) | SMALLINT, SMALLSERIAL, INT2 | +// | `int32` | int32(s32) | INT, SERIAL, INT4 | +// | `int64` | int64(s64) | BIGINT, BIGSERIAL, INT8 | +// | `float32` | floating32(float32) | REAL, FLOAT4 | +// | `float64` | floating64(float64) | DOUBLE PRECISION, FLOAT8 | +// | `string` | str(string) | VARCHAR, CHAR(N), TEXT | +// | `[]byte` | binary(list\) | BYTEA | +// | `Date` | date(tuple) | DATE | +// | `Time` | time(tuple) | TIME | +// | `time.Time` | datetime(tuple) | TIMESTAMP | +// | `time.Time` | timestamp(s64) | BIGINT | +// | `UUID` | uuid(string) | UUID | +// | `JSONB` | jsonb(list\) | JSONB | +// | `Decimal` | decimal(string) | NUMERIC | +// | `Int32Range` | range-int32(...) | INT4RANGE | +// | `Int64Range` | range-int64(...) | INT8RANGE | +// | `DecimalRange` | range-decimal(...) | NUMERICRANGE | +// | `[]int32` | array-int32(...) | INT4[] | +// | `[]int64` | array-int64(...) | INT8[] | +// | `[]string` | array-str(...) | TEXT[] | +// | `[]Decimal` | array-decimal(...) | NUMERIC[] | +// | `Interval` | interval(interval) | INTERVAL | + +// Date represents a PostgreSQL date value. +type Date struct { + Year int + Month time.Month + Day int +} + +// Scan implements [sql.Scanner] so Date can be used as a scan destination. +func (d *Date) Scan(src any) error { + switch src := src.(type) { + case time.Time: + d.Year = src.Year() + d.Month = src.Month() + d.Day = src.Day() + case nil: + *d = Date{} + default: + return fmt.Errorf("pg: cannot scan %T into *Date", src) + } + return nil +} + +// Value implements [driver.Valuer] so Date can be used as a query parameter. +func (d Date) Value() (driver.Value, error) { + return wittypes.Tuple3[int32, uint8, uint8]{ + F0: int32(d.Year), + F1: uint8(d.Month), + F2: uint8(d.Day), + }, nil +} + +// Time represents a PostgreSQL time value (time of day without date). +type Time struct { + Hour int + Minute int + Second int + Nanosecond int +} + +// Scan implements [sql.Scanner] so Time can be used as a scan destination. +func (t *Time) Scan(src any) error { + switch src := src.(type) { + case time.Time: + t.Hour = src.Hour() + t.Minute = src.Minute() + t.Second = src.Second() + t.Nanosecond = src.Nanosecond() + case nil: + *t = Time{} + default: + return fmt.Errorf("pg: cannot scan %T into *Time", src) + } + return nil +} + +// Value implements [driver.Valuer] so Time can be used as a query parameter. +func (t Time) Value() (driver.Value, error) { + return wittypes.Tuple4[uint8, uint8, uint8, uint32]{ + F0: uint8(t.Hour), + F1: uint8(t.Minute), + F2: uint8(t.Second), + F3: uint32(t.Nanosecond), + }, nil +} + +// Interval represents a PostgreSQL interval value. +// +// PostgreSQL intervals have three components: months, days, and microseconds. +// Months and days are stored separately because the number of days in a month +// varies, and a day may have 23 or 25 hours due to daylight savings. +type Interval struct { + Months int32 + Days int32 + Micros int64 +} + +// Scan implements [sql.Scanner] so Interval can be used as a scan destination. +func (iv *Interval) Scan(src any) error { + switch src := src.(type) { + case Interval: + *iv = src + case nil: + *iv = Interval{} + default: + return fmt.Errorf("pg: cannot scan %T into *Interval", src) + } + return nil +} + +// Value implements [driver.Valuer] so Interval can be used as a query parameter. +func (iv Interval) Value() (driver.Value, error) { + return pg.Interval{ + Micros: iv.Micros, + Days: iv.Days, + Months: iv.Months, + }, nil +} + +// JSONB represents a PostgreSQL jsonb value. +type JSONB []byte + +// Scan implements [sql.Scanner] so JSONB can be used as a scan destination. +func (j *JSONB) Scan(src any) error { + switch src := src.(type) { + case []byte: + *j = make(JSONB, len(src)) + copy(*j, src) + case pg.DbValue: + if src.Tag() == pg.DbValueJsonb { + data := src.Jsonb() + *j = make(JSONB, len(data)) + copy(*j, data) + } + case nil: + *j = nil + default: + return fmt.Errorf("pg: cannot scan %T into *JSONB", src) + } + return nil +} + +// Value implements [driver.Valuer] so JSONB can be used as a query parameter. +func (j JSONB) Value() (driver.Value, error) { + return pg.MakeDbValueJsonb([]byte(j)), nil +} + +// UUID represents a PostgreSQL uuid value. +type UUID string + +// Scan implements [sql.Scanner] so UUID can be used as a scan destination. +func (u *UUID) Scan(src any) error { + switch src := src.(type) { + case string: + *u = UUID(src) + case pg.DbValue: + if src.Tag() == pg.DbValueUuid { + uuid := src.Uuid() + *u = UUID(uuid) + } + case nil: + *u = "" + default: + return fmt.Errorf("pg: cannot scan %T into *UUID", src) + } + return nil +} + +// Value implements [driver.Valuer] so UUID can be used as a query parameter. +func (u UUID) Value() (driver.Value, error) { + return pg.MakeDbValueUuid(string(u)), nil +} + +// Decimal represents a PostgreSQL numeric/decimal value. +// +// Values are stored as strings to preserve arbitrary precision. Use Decimal +// when exact numeric representation matters (e.g., monetary values). +type Decimal string + +// Scan implements [sql.Scanner] so Decimal can be used as a scan destination. +func (d *Decimal) Scan(src any) error { + switch src := src.(type) { + case string: + *d = Decimal(src) + case nil: + *d = "" + default: + return fmt.Errorf("pg: cannot scan %T into *Decimal", src) + } + return nil +} + +// Value implements [driver.Valuer] so Decimal can be used as a query parameter. +func (d Decimal) Value() (driver.Value, error) { + return pg.MakeDbValueDecimal(string(d)), nil +} + +// Int32Range represents a PostgreSQL int4range value. +// +// A nil Lower or Upper indicates an unbounded (infinite) bound. +// LowerInclusive and UpperInclusive specify whether each bound is included in the range. +type Int32Range struct { + Lower *int32 + LowerInclusive bool + Upper *int32 + UpperInclusive bool +} + +// Scan implements [sql.Scanner] so Int32Range can be used as a scan destination. +func (r *Int32Range) Scan(src any) error { + if src == nil { + *r = Int32Range{} + return nil + } + + v, ok := src.(wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[int32, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[int32, pg.RangeBoundKind]], + ]) + if !ok { + return fmt.Errorf("pg: cannot scan %T into *Int32Range", src) + } + + *r = Int32Range{} + + if v.F0.IsSome() { + lb := v.F0.Some() + val := lb.F0 + r.Lower = &val + r.LowerInclusive = lb.F1 == pg.RangeBoundKindInclusive + } + + if v.F1.IsSome() { + ub := v.F1.Some() + val := ub.F0 + r.Upper = &val + r.UpperInclusive = ub.F1 == pg.RangeBoundKindInclusive + } + + return nil +} + +// Value implements [driver.Valuer] so Int32Range can be used as a query parameter. +func (r Int32Range) Value() (driver.Value, error) { + var lower wittypes.Option[wittypes.Tuple2[int32, pg.RangeBoundKind]] + if r.Lower != nil { + kind := pg.RangeBoundKindExclusive + if r.LowerInclusive { + kind = pg.RangeBoundKindInclusive + } + lower = wittypes.Some(wittypes.Tuple2[int32, pg.RangeBoundKind]{F0: *r.Lower, F1: kind}) + } else { + lower = wittypes.None[wittypes.Tuple2[int32, pg.RangeBoundKind]]() + } + + var upper wittypes.Option[wittypes.Tuple2[int32, pg.RangeBoundKind]] + if r.Upper != nil { + kind := pg.RangeBoundKindExclusive + if r.UpperInclusive { + kind = pg.RangeBoundKindInclusive + } + upper = wittypes.Some(wittypes.Tuple2[int32, pg.RangeBoundKind]{F0: *r.Upper, F1: kind}) + } else { + upper = wittypes.None[wittypes.Tuple2[int32, pg.RangeBoundKind]]() + } + + return wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[int32, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[int32, pg.RangeBoundKind]], + ]{F0: lower, F1: upper}, nil +} + +// Int64Range represents a PostgreSQL int8range value. +// +// A nil Lower or Upper indicates an unbounded (infinite) bound. +// LowerInclusive and UpperInclusive specify whether each bound is included in the range. +type Int64Range struct { + Lower *int64 + LowerInclusive bool + Upper *int64 + UpperInclusive bool +} + +// Scan implements [sql.Scanner] so Int64Range can be used as a scan destination. +func (r *Int64Range) Scan(src any) error { + if src == nil { + *r = Int64Range{} + return nil + } + + v, ok := src.(wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[int64, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[int64, pg.RangeBoundKind]], + ]) + if !ok { + return fmt.Errorf("pg: cannot scan %T into *Int64Range", src) + } + + *r = Int64Range{} + + if v.F0.IsSome() { + lb := v.F0.Some() + val := lb.F0 + r.Lower = &val + r.LowerInclusive = lb.F1 == pg.RangeBoundKindInclusive + } + + if v.F1.IsSome() { + ub := v.F1.Some() + val := ub.F0 + r.Upper = &val + r.UpperInclusive = ub.F1 == pg.RangeBoundKindInclusive + } + + return nil +} + +// Value implements [driver.Valuer] so Int64Range can be used as a query parameter. +func (r Int64Range) Value() (driver.Value, error) { + var lower wittypes.Option[wittypes.Tuple2[int64, pg.RangeBoundKind]] + if r.Lower != nil { + kind := pg.RangeBoundKindExclusive + if r.LowerInclusive { + kind = pg.RangeBoundKindInclusive + } + lower = wittypes.Some(wittypes.Tuple2[int64, pg.RangeBoundKind]{F0: *r.Lower, F1: kind}) + } else { + lower = wittypes.None[wittypes.Tuple2[int64, pg.RangeBoundKind]]() + } + + var upper wittypes.Option[wittypes.Tuple2[int64, pg.RangeBoundKind]] + if r.Upper != nil { + kind := pg.RangeBoundKindExclusive + if r.UpperInclusive { + kind = pg.RangeBoundKindInclusive + } + upper = wittypes.Some(wittypes.Tuple2[int64, pg.RangeBoundKind]{F0: *r.Upper, F1: kind}) + } else { + upper = wittypes.None[wittypes.Tuple2[int64, pg.RangeBoundKind]]() + } + + return wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[int64, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[int64, pg.RangeBoundKind]], + ]{F0: lower, F1: upper}, nil +} + +// DecimalRange represents a PostgreSQL numrange value. +// +// A nil Lower or Upper indicates an unbounded (infinite) bound. +// LowerInclusive and UpperInclusive specify whether each bound is included in the range. +type DecimalRange struct { + Lower *Decimal + LowerInclusive bool + Upper *Decimal + UpperInclusive bool +} + +// Scan implements [sql.Scanner] so DecimalRange can be used as a scan destination. +func (r *DecimalRange) Scan(src any) error { + if src == nil { + *r = DecimalRange{} + return nil + } + + v, ok := src.(wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[string, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[string, pg.RangeBoundKind]], + ]) + if !ok { + return fmt.Errorf("pg: cannot scan %T into *DecimalRange", src) + } + + *r = DecimalRange{} + + if v.F0.IsSome() { + lb := v.F0.Some() + val := Decimal(lb.F0) + r.Lower = &val + r.LowerInclusive = lb.F1 == pg.RangeBoundKindInclusive + } + + if v.F1.IsSome() { + ub := v.F1.Some() + val := Decimal(ub.F0) + r.Upper = &val + r.UpperInclusive = ub.F1 == pg.RangeBoundKindInclusive + } + + return nil +} + +// Value implements [driver.Valuer] so DecimalRange can be used as a query parameter. +func (r DecimalRange) Value() (driver.Value, error) { + var lower wittypes.Option[wittypes.Tuple2[string, pg.RangeBoundKind]] + if r.Lower != nil { + kind := pg.RangeBoundKindExclusive + if r.LowerInclusive { + kind = pg.RangeBoundKindInclusive + } + lower = wittypes.Some(wittypes.Tuple2[string, pg.RangeBoundKind]{F0: string(*r.Lower), F1: kind}) + } else { + lower = wittypes.None[wittypes.Tuple2[string, pg.RangeBoundKind]]() + } + + var upper wittypes.Option[wittypes.Tuple2[string, pg.RangeBoundKind]] + if r.Upper != nil { + kind := pg.RangeBoundKindExclusive + if r.UpperInclusive { + kind = pg.RangeBoundKindInclusive + } + upper = wittypes.Some(wittypes.Tuple2[string, pg.RangeBoundKind]{F0: string(*r.Upper), F1: kind}) + } else { + upper = wittypes.None[wittypes.Tuple2[string, pg.RangeBoundKind]]() + } + + return wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[string, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[string, pg.RangeBoundKind]], + ]{F0: lower, F1: upper}, nil +} diff --git a/pg/types_test.go b/pg/types_test.go new file mode 100644 index 00000000..764d78d3 --- /dev/null +++ b/pg/types_test.go @@ -0,0 +1,612 @@ +package pg + +import ( + "database/sql/driver" + "testing" + "time" + + pg "github.com/spinframework/spin-go-sdk/v3/imports/spin_postgres_4_2_0_postgres" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + wittypes "go.bytecodealliance.org/pkg/wit/types" +) + +func ptr[T any](v T) *T { return &v } + +func TestJSONB_Scan(t *testing.T) { + tests := []struct { + name string + src any + want JSONB + wantErr bool + }{{ + name: "from []byte", + src: []byte(`{"key":"value"}`), + want: JSONB(`{"key":"value"}`), + }, { + name: "from DbValue", + src: pg.MakeDbValueJsonb([]byte(`[1,2,3]`)), + want: JSONB(`[1,2,3]`), + }, { + name: "nil src", + src: nil, + want: nil, + }, { + name: "invalid src type", + src: 42, + wantErr: true, + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var j JSONB + err := j.Scan(tt.src) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, j) + }) + } +} + +func TestJSONB_Value(t *testing.T) { + j := JSONB(`{"key":"value"}`) + got, _ := j.Value() + assert.Equal(t, driver.Value(pg.MakeDbValueJsonb([]byte(`{"key":"value"}`))), got) +} + +func TestJSONB_RoundTrip(t *testing.T) { + original := JSONB(`{"nested":{"a":1}}`) + val, _ := original.Value() + + var recovered JSONB + recovered.Scan(val) + assert.Equal(t, original, recovered) +} + +func TestDate_Scan(t *testing.T) { + tests := []struct { + name string + src any + want Date + wantErr bool + }{{ + name: "from time.Time", + src: time.Date(2024, time.March, 15, 0, 0, 0, 0, time.UTC), + want: Date{Year: 2024, Month: time.March, Day: 15}, + }, { + name: "nil src", + src: nil, + want: Date{}, + }, { + name: "invalid src type", + src: "2024-03-15", + wantErr: true, + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var d Date + err := d.Scan(tt.src) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, d) + }) + } +} + +func TestDate_Value(t *testing.T) { + d := Date{Year: 2024, Month: time.March, Day: 15} + got, _ := d.Value() + assert.Equal(t, driver.Value(wittypes.Tuple3[int32, uint8, uint8]{ + F0: 2024, F1: 3, F2: 15, + }), got) +} + +func TestDate_RoundTrip(t *testing.T) { + original := Date{Year: 2024, Month: time.December, Day: 25} + witVal, _ := original.Value() + + tuple := witVal.(wittypes.Tuple3[int32, uint8, uint8]) + asTime := time.Date(int(tuple.F0), time.Month(tuple.F1), int(tuple.F2), 0, 0, 0, 0, time.UTC) + + var recovered Date + recovered.Scan(asTime) + assert.Equal(t, original, recovered) +} + +func TestTime_Scan(t *testing.T) { + tests := []struct { + name string + src any + want Time + wantErr bool + }{{ + name: "from time.Time", + src: time.Date(0, 1, 1, 14, 30, 45, 123456789, time.UTC), + want: Time{Hour: 14, Minute: 30, Second: 45, Nanosecond: 123456789}, + }, { + name: "nil src", + src: nil, + want: Time{}, + }, { + name: "invalid src type", + src: "14:30:45", + wantErr: true, + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var tm Time + err := tm.Scan(tt.src) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, tm) + }) + } +} + +func TestTime_Value(t *testing.T) { + original := Time{Hour: 14, Minute: 30, Second: 45, Nanosecond: 123456789} + witVal, _ := original.Value() + assert.Equal(t, driver.Value(wittypes.Tuple4[uint8, uint8, uint8, uint32]{ + F0: 14, F1: 30, F2: 45, F3: 123456789, + }), witVal) +} + +func TestTime_RoundTrip(t *testing.T) { + original := Time{Hour: 23, Minute: 59, Second: 59, Nanosecond: 999000000} + witVal, _ := original.Value() + + // Simulate what toRow does: convert the WIT tuple to time.Time + tuple := witVal.(wittypes.Tuple4[uint8, uint8, uint8, uint32]) + asTime := time.Date(0, 1, 1, int(tuple.F0), int(tuple.F1), int(tuple.F2), int(tuple.F3), time.UTC) + + var recovered Time + require.NoError(t, recovered.Scan(asTime)) + assert.Equal(t, original, recovered) +} + +func TestInterval_Scan(t *testing.T) { + tests := []struct { + name string + src any + want Interval + wantErr bool + }{{ + name: "full interval", + src: Interval{Months: 14, Days: 3, Micros: 7200000000}, + want: Interval{Months: 14, Days: 3, Micros: 7200000000}, + }, { + name: "micros only", + src: Interval{Micros: 3600000000}, + want: Interval{Micros: 3600000000}, + }, { + name: "nil src", + src: nil, + want: Interval{}, + }, { + name: "invalid src type", + src: "1 year 2 months", + wantErr: true, + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var iv Interval + err := iv.Scan(tt.src) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, iv) + }) + } +} + +func TestInterval_Value(t *testing.T) { + iv := Interval{Months: 14, Days: 3, Micros: 7200000000} + got, _ := iv.Value() + assert.Equal(t, driver.Value(pg.Interval{ + Micros: 7200000000, Days: 3, Months: 14, + }), got) +} + +func TestInterval_RoundTrip(t *testing.T) { + original := Interval{Months: 1, Days: 15, Micros: 43200000000} + + val, _ := original.Value() + witIv := val.(pg.Interval) + fromRow := Interval{Months: witIv.Months, Days: witIv.Days, Micros: witIv.Micros} + + var recovered Interval + require.NoError(t, recovered.Scan(fromRow)) + assert.Equal(t, original, recovered) +} + +func TestInt32Range_Scan(t *testing.T) { + tests := []struct { + name string + src any + want Int32Range + wantErr bool + }{{ + name: "nil src clears range", + src: nil, + want: Int32Range{}, + }, { + name: "both bounds inclusive", + src: wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[int32, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[int32, pg.RangeBoundKind]], + ]{ + F0: wittypes.Some(wittypes.Tuple2[int32, pg.RangeBoundKind]{F0: 1, F1: pg.RangeBoundKindInclusive}), + F1: wittypes.Some(wittypes.Tuple2[int32, pg.RangeBoundKind]{F0: 10, F1: pg.RangeBoundKindInclusive}), + }, + want: Int32Range{Lower: ptr(int32(1)), LowerInclusive: true, Upper: ptr(int32(10)), UpperInclusive: true}, + }, { + name: "both bounds exclusive", + src: wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[int32, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[int32, pg.RangeBoundKind]], + ]{ + F0: wittypes.Some(wittypes.Tuple2[int32, pg.RangeBoundKind]{F0: 0, F1: pg.RangeBoundKindExclusive}), + F1: wittypes.Some(wittypes.Tuple2[int32, pg.RangeBoundKind]{F0: 5, F1: pg.RangeBoundKindExclusive}), + }, + want: Int32Range{Lower: ptr(int32(0)), LowerInclusive: false, Upper: ptr(int32(5)), UpperInclusive: false}, + }, { + name: "unbounded lower", + src: wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[int32, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[int32, pg.RangeBoundKind]], + ]{ + F0: wittypes.None[wittypes.Tuple2[int32, pg.RangeBoundKind]](), + F1: wittypes.Some(wittypes.Tuple2[int32, pg.RangeBoundKind]{F0: 100, F1: pg.RangeBoundKindExclusive}), + }, + want: Int32Range{Lower: nil, LowerInclusive: false, Upper: ptr(int32(100)), UpperInclusive: false}, + }, { + name: "unbounded upper", + src: wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[int32, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[int32, pg.RangeBoundKind]], + ]{ + F0: wittypes.Some(wittypes.Tuple2[int32, pg.RangeBoundKind]{F0: 5, F1: pg.RangeBoundKindInclusive}), + F1: wittypes.None[wittypes.Tuple2[int32, pg.RangeBoundKind]](), + }, + want: Int32Range{Lower: ptr(int32(5)), LowerInclusive: true, Upper: nil, UpperInclusive: false}, + }, { + name: "invalid src type", + src: "not a range", + wantErr: true, + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var r Int32Range + err := r.Scan(tt.src) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, r) + }) + } +} + +func TestInt32Range_Value(t *testing.T) { + tests := []struct { + name string + input Int32Range + want wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[int32, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[int32, pg.RangeBoundKind]], + ] + }{ + { + name: "both bounds inclusive", + input: Int32Range{Lower: ptr(int32(1)), LowerInclusive: true, Upper: ptr(int32(10)), UpperInclusive: true}, + want: wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[int32, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[int32, pg.RangeBoundKind]], + ]{ + F0: wittypes.Some(wittypes.Tuple2[int32, pg.RangeBoundKind]{F0: 1, F1: pg.RangeBoundKindInclusive}), + F1: wittypes.Some(wittypes.Tuple2[int32, pg.RangeBoundKind]{F0: 10, F1: pg.RangeBoundKindInclusive}), + }, + }, + { + name: "unbounded both", + input: Int32Range{}, + want: wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[int32, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[int32, pg.RangeBoundKind]], + ]{ + F0: wittypes.None[wittypes.Tuple2[int32, pg.RangeBoundKind]](), + F1: wittypes.None[wittypes.Tuple2[int32, pg.RangeBoundKind]](), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, _ := tt.input.Value() + assert.Equal(t, driver.Value(tt.want), got) + }) + } +} + +func TestInt32Range_RoundTrip(t *testing.T) { + original := Int32Range{Lower: ptr(int32(3)), LowerInclusive: true, Upper: ptr(int32(7)), UpperInclusive: false} + witVal, _ := original.Value() + + var recovered Int32Range + recovered.Scan(witVal) + assert.Equal(t, original, recovered) +} + +func TestInt64Range_Scan(t *testing.T) { + tests := []struct { + name string + src any + want Int64Range + wantErr bool + }{{ + name: "nil src clears range", + src: nil, + want: Int64Range{}, + }, { + name: "both bounds inclusive", + src: wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[int64, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[int64, pg.RangeBoundKind]], + ]{ + F0: wittypes.Some(wittypes.Tuple2[int64, pg.RangeBoundKind]{F0: 1, F1: pg.RangeBoundKindInclusive}), + F1: wittypes.Some(wittypes.Tuple2[int64, pg.RangeBoundKind]{F0: 10, F1: pg.RangeBoundKindInclusive}), + }, + want: Int64Range{Lower: ptr(int64(1)), LowerInclusive: true, Upper: ptr(int64(10)), UpperInclusive: true}, + }, { + name: "both bounds exclusive", + src: wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[int64, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[int64, pg.RangeBoundKind]], + ]{ + F0: wittypes.Some(wittypes.Tuple2[int64, pg.RangeBoundKind]{F0: 0, F1: pg.RangeBoundKindExclusive}), + F1: wittypes.Some(wittypes.Tuple2[int64, pg.RangeBoundKind]{F0: 5, F1: pg.RangeBoundKindExclusive}), + }, + want: Int64Range{Lower: ptr(int64(0)), LowerInclusive: false, Upper: ptr(int64(5)), UpperInclusive: false}, + }, { + name: "unbounded lower", + src: wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[int64, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[int64, pg.RangeBoundKind]], + ]{ + F0: wittypes.None[wittypes.Tuple2[int64, pg.RangeBoundKind]](), + F1: wittypes.Some(wittypes.Tuple2[int64, pg.RangeBoundKind]{F0: 100, F1: pg.RangeBoundKindExclusive}), + }, + want: Int64Range{Lower: nil, LowerInclusive: false, Upper: ptr(int64(100)), UpperInclusive: false}, + }, { + name: "unbounded upper", + src: wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[int64, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[int64, pg.RangeBoundKind]], + ]{ + F0: wittypes.Some(wittypes.Tuple2[int64, pg.RangeBoundKind]{F0: 5, F1: pg.RangeBoundKindInclusive}), + F1: wittypes.None[wittypes.Tuple2[int64, pg.RangeBoundKind]](), + }, + want: Int64Range{Lower: ptr(int64(5)), LowerInclusive: true, Upper: nil, UpperInclusive: false}, + }, { + name: "invalid src type", + src: "not a range", + wantErr: true, + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var r Int64Range + err := r.Scan(tt.src) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, r) + }) + } +} + +func TestInt64Range_Value(t *testing.T) { + tests := []struct { + name string + input Int64Range + want wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[int64, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[int64, pg.RangeBoundKind]], + ] + }{{ + name: "both bounds inclusive", + input: Int64Range{Lower: ptr(int64(1)), LowerInclusive: true, Upper: ptr(int64(10)), UpperInclusive: true}, + want: wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[int64, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[int64, pg.RangeBoundKind]], + ]{ + F0: wittypes.Some(wittypes.Tuple2[int64, pg.RangeBoundKind]{F0: 1, F1: pg.RangeBoundKindInclusive}), + F1: wittypes.Some(wittypes.Tuple2[int64, pg.RangeBoundKind]{F0: 10, F1: pg.RangeBoundKindInclusive}), + }, + }, { + name: "unbounded both", + input: Int64Range{}, + want: wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[int64, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[int64, pg.RangeBoundKind]], + ]{ + F0: wittypes.None[wittypes.Tuple2[int64, pg.RangeBoundKind]](), + F1: wittypes.None[wittypes.Tuple2[int64, pg.RangeBoundKind]](), + }, + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, _ := tt.input.Value() + assert.Equal(t, driver.Value(tt.want), got) + }) + } +} + +func TestInt64Range_RoundTrip(t *testing.T) { + original := Int64Range{Lower: ptr(int64(3)), LowerInclusive: true, Upper: ptr(int64(7)), UpperInclusive: false} + witVal, _ := original.Value() + + var recovered Int64Range + recovered.Scan(witVal) + assert.Equal(t, original, recovered) +} + +func TestDecimal_Scan(t *testing.T) { + tests := []struct { + name string + src any + want Decimal + wantErr bool + }{{ + name: "from string", + src: "123.456", + want: Decimal("123.456"), + }, { + name: "large precision", + src: "99999999999999999999.99999999999999999999", + want: Decimal("99999999999999999999.99999999999999999999"), + }, { + name: "nil src", + src: nil, + want: Decimal(""), + }, { + name: "invalid src type", + src: 42, + wantErr: true, + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var d Decimal + err := d.Scan(tt.src) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, d) + }) + } +} + +func TestDecimal_Value(t *testing.T) { + d := Decimal("123.456") + got, _ := d.Value() + assert.Equal(t, driver.Value(pg.MakeDbValueDecimal("123.456")), got) +} + +func TestDecimalRange_Scan(t *testing.T) { + tests := []struct { + name string + src any + want DecimalRange + wantErr bool + }{{ + name: "nil src clears range", + src: nil, + want: DecimalRange{}, + }, { + name: "both bounds inclusive", + src: wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[string, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[string, pg.RangeBoundKind]], + ]{ + F0: wittypes.Some(wittypes.Tuple2[string, pg.RangeBoundKind]{F0: "1.5", F1: pg.RangeBoundKindInclusive}), + F1: wittypes.Some(wittypes.Tuple2[string, pg.RangeBoundKind]{F0: "9.99", F1: pg.RangeBoundKindInclusive}), + }, + want: DecimalRange{Lower: ptr(Decimal("1.5")), LowerInclusive: true, Upper: ptr(Decimal("9.99")), UpperInclusive: true}, + }, { + name: "unbounded lower", + src: wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[string, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[string, pg.RangeBoundKind]], + ]{ + F0: wittypes.None[wittypes.Tuple2[string, pg.RangeBoundKind]](), + F1: wittypes.Some(wittypes.Tuple2[string, pg.RangeBoundKind]{F0: "100.00", F1: pg.RangeBoundKindExclusive}), + }, + want: DecimalRange{Lower: nil, Upper: ptr(Decimal("100.00")), UpperInclusive: false}, + }, { + name: "invalid src type", + src: "not a range", + wantErr: true, + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var r DecimalRange + err := r.Scan(tt.src) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, r) + }) + } +} + +func TestDecimalRange_Value(t *testing.T) { + tests := []struct { + name string + input DecimalRange + want wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[string, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[string, pg.RangeBoundKind]], + ] + }{{ + name: "both bounds inclusive", + input: DecimalRange{Lower: ptr(Decimal("1.5")), LowerInclusive: true, Upper: ptr(Decimal("9.99")), UpperInclusive: true}, + want: wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[string, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[string, pg.RangeBoundKind]], + ]{ + F0: wittypes.Some(wittypes.Tuple2[string, pg.RangeBoundKind]{F0: "1.5", F1: pg.RangeBoundKindInclusive}), + F1: wittypes.Some(wittypes.Tuple2[string, pg.RangeBoundKind]{F0: "9.99", F1: pg.RangeBoundKindInclusive}), + }, + }, { + name: "unbounded both", + input: DecimalRange{}, + want: wittypes.Tuple2[ + wittypes.Option[wittypes.Tuple2[string, pg.RangeBoundKind]], + wittypes.Option[wittypes.Tuple2[string, pg.RangeBoundKind]], + ]{ + F0: wittypes.None[wittypes.Tuple2[string, pg.RangeBoundKind]](), + F1: wittypes.None[wittypes.Tuple2[string, pg.RangeBoundKind]](), + }, + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, _ := tt.input.Value() + assert.Equal(t, driver.Value(tt.want), got) + }) + } +} + +func TestDecimalRange_RoundTrip(t *testing.T) { + original := DecimalRange{Lower: ptr(Decimal("3.14")), LowerInclusive: true, Upper: ptr(Decimal("99.99")), UpperInclusive: false} + witVal, _ := original.Value() + + var recovered DecimalRange + recovered.Scan(witVal) + assert.Equal(t, original, recovered) +}