From d9591000c3eea0978baebc9134f33dad544e96b8 Mon Sep 17 00:00:00 2001 From: jonbodner Date: Mon, 23 Feb 2026 17:38:42 -0500 Subject: [PATCH 1/3] Add structured error types with errors.Is/errors.As support Replace all inline errors.New/fmt.Errorf calls with typed error structs (ValidationError, QueryError, IdentifierError, ExtractError, AssignError) grouped by class with iota-based Kind fields. The zero value of each Kind acts as a wildcard for errors.Is matching, enabling callers to distinguish error conditions programmatically without string matching. Delete the cmp package, which did string-based error comparison. Co-Authored-By: Claude Sonnet 4.6 --- builder.go | 26 +++--- cmp/errors.go | 11 --- errors.go | 166 +++++++++++++++++++++++++++++++++++++++ errors_test.go | 140 +++++++++++++++++++++++++++++++++ mapper/errors.go | 136 ++++++++++++++++++++++++++++++++ mapper/errors_test.go | 114 +++++++++++++++++++++++++++ mapper/extract.go | 27 +++---- mapper/extract_test.go | 22 +++--- mapper/mapper.go | 19 ++--- mapper_test.go | 6 +- proteus.go | 32 ++++---- proteus_function.go | 8 +- proteus_function_test.go | 2 +- proteus_test.go | 30 +++---- runner.go | 16 ++-- 15 files changed, 640 insertions(+), 115 deletions(-) delete mode 100644 cmp/errors.go create mode 100644 errors.go create mode 100644 errors_test.go create mode 100644 mapper/errors.go create mode 100644 mapper/errors_test.go diff --git a/builder.go b/builder.go index bddc337..34b17a0 100644 --- a/builder.go +++ b/builder.go @@ -86,7 +86,7 @@ func buildFixedQueryAndParamOrder(ctx context.Context, query string, nameOrderMa if inVar { if curVar.Len() == 0 { //error! must have a something - return nil, nil, fmt.Errorf("empty variable declaration at position %d", k) + return nil, nil, QueryError{Kind: EmptyVariable, Position: k} } curVarS := curVar.String() id, err := validIdentifier(ctx, curVarS) @@ -109,13 +109,13 @@ func buildFixedQueryAndParamOrder(ctx context.Context, query string, nameOrderMa paramType := funcType.In(paramPos) if len(path) > 1 { if paramType == nil { - return nil, nil, fmt.Errorf("query Parameter %s has a path, but the incoming parameter is nil", paramName) + return nil, nil, QueryError{Kind: NilParameterPath, Name: paramName} } switch paramType.Kind() { case reflect.Map, reflect.Struct: //do nothing default: - return nil, nil, fmt.Errorf("query Parameter %s has a path, but the incoming parameter is not a map or a struct it is %s", paramName, paramType.Kind()) + return nil, nil, QueryError{Kind: InvalidParameterType, Name: paramName, TypeKind: paramType.Kind().String()} } } pathType, err := mapper.ExtractType(ctx, paramType, path) @@ -131,7 +131,7 @@ func buildFixedQueryAndParamOrder(ctx context.Context, query string, nameOrderMa } paramOrder = append(paramOrder, paramInfo{id, paramPos, isSlice}) } else { - return nil, nil, fmt.Errorf("query Parameter %s cannot be found in the incoming parameters", paramName) + return nil, nil, QueryError{Kind: ParameterNotFound, Name: paramName} } inVar = false @@ -148,7 +148,7 @@ func buildFixedQueryAndParamOrder(ctx context.Context, query string, nameOrderMa } } if inVar { - return nil, nil, fmt.Errorf("missing a closing : somewhere: %s", query) + return nil, nil, QueryError{Kind: MissingClosingColon, Query: query} } queryString := out.String() @@ -232,7 +232,7 @@ func addSlice(sliceName string) string { func validIdentifier(ctx context.Context, curVar string) (string, error) { if strings.Contains(curVar, ";") { - return "", fmt.Errorf("; is not allowed in an identifier: %s", curVar) + return "", IdentifierError{Kind: SemicolonInIdentifier, Identifier: curVar} } curVarB := []byte(curVar) @@ -253,7 +253,7 @@ loop: switch tok { case token.EOF: if first || lastPeriod { - return "", fmt.Errorf("identifiers cannot be empty or end with a .: %s", curVar) + return "", IdentifierError{Kind: EmptyOrTrailingDotIdentifier, Identifier: curVar} } break loop case token.SEMICOLON: @@ -262,7 +262,7 @@ loop: continue case token.IDENT: if !first && !lastPeriod && !lastFloat { - return "", fmt.Errorf(". missing between parts of an identifier: %s", curVar) + return "", IdentifierError{Kind: MissingDotInIdentifier, Identifier: curVar} } first = false lastPeriod = false @@ -270,7 +270,7 @@ loop: identifier += lit case token.PERIOD: if first || lastPeriod { - return "", fmt.Errorf("identifier cannot start with . or have two . in a row: %s", curVar) + return "", IdentifierError{Kind: LeadingOrDoubleDotIdentifier, Identifier: curVar} } lastPeriod = true identifier += "." @@ -282,10 +282,10 @@ loop: first = false continue } - return "", fmt.Errorf("invalid character found in identifier: %s", curVar) + return "", IdentifierError{Kind: InvalidCharacterInIdentifier, Identifier: curVar} case token.INT: if !dollar || first { - return "", fmt.Errorf("invalid character found in identifier: %s", curVar) + return "", IdentifierError{Kind: InvalidCharacterInIdentifier, Identifier: curVar} } identifier += lit if dollar { @@ -299,7 +299,7 @@ loop: // returns .0 as the lit value //Only valid for $ notation and array/slice references. if first { - return "", fmt.Errorf("invalid character found in identifier: %s", curVar) + return "", IdentifierError{Kind: InvalidCharacterInIdentifier, Identifier: curVar} } identifier += lit if dollar { @@ -310,7 +310,7 @@ loop: lastPeriod = true } default: - return "", fmt.Errorf("invalid character found in identifier: %s", curVar) + return "", IdentifierError{Kind: InvalidCharacterInIdentifier, Identifier: curVar} } } return identifier, nil diff --git a/cmp/errors.go b/cmp/errors.go deleted file mode 100644 index 24effc9..0000000 --- a/cmp/errors.go +++ /dev/null @@ -1,11 +0,0 @@ -package cmp - -func Errors(e1, e2 error) bool { - if e1 == nil || e2 == nil { - if e1 != nil || e2 != nil { - return false - } - return true - } - return e1.Error() == e2.Error() -} diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..455d8e0 --- /dev/null +++ b/errors.go @@ -0,0 +1,166 @@ +package proteus + +import "fmt" + +// ValidationErrorKind identifies the specific validation failure. +// The zero value (AnyValidation) acts as a wildcard for errors.Is matching. +type ValidationErrorKind int + +const ( + AnyValidation ValidationErrorKind = iota + NotPointer // "not a pointer" + NotPointerToStruct // "not a pointer to struct" + NotPointerToFunc // "not a pointer to func" + NeedExecutorOrQuerier // "need to supply an Executor or Querier parameter" + InvalidFirstParam // "first parameter must be of type context.Context, Executor, or Querier" + ChannelInputParam // "no input parameter can be a channel" + TooManyReturnValues // "must return 0, 1, or 2 values" + SecondReturnNotError // "2nd output parameter must be of type error" + FirstReturnIsChannel // "1st output parameter cannot be a channel" + ExecutorReturnType // "the 1st output parameter of an Executor must be int64 or sql.Result" + SQLResultWithQuerier // "output parameters of type sql.Result must be combined with Executor" + RowsMustBeNonNil // "rows must be non-nil" + NoValuesFromQuery // "no values returned from query" + ShouldNeverGetHere // "should never get here" +) + +var validationMessages = map[ValidationErrorKind]string{ + NotPointer: "not a pointer", + NotPointerToStruct: "not a pointer to struct", + NotPointerToFunc: "not a pointer to func", + NeedExecutorOrQuerier: "need to supply an Executor or Querier parameter", + InvalidFirstParam: "first parameter must be of type context.Context, Executor, or Querier", + ChannelInputParam: "no input parameter can be a channel", + TooManyReturnValues: "must return 0, 1, or 2 values", + SecondReturnNotError: "2nd output parameter must be of type error", + FirstReturnIsChannel: "1st output parameter cannot be a channel", + ExecutorReturnType: "the 1st output parameter of an Executor must be int64 or sql.Result", + SQLResultWithQuerier: "output parameters of type sql.Result must be combined with Executor", + RowsMustBeNonNil: "rows must be non-nil", + NoValuesFromQuery: "no values returned from query", + ShouldNeverGetHere: "should never get here", +} + +// ValidationError is returned when a struct, function signature, or type passed +// to Build/ShouldBuild/BuildFunction fails validation, or when a runtime +// invariant is violated. +type ValidationError struct { + Kind ValidationErrorKind +} + +func (e ValidationError) Error() string { + if msg, ok := validationMessages[e.Kind]; ok { + return msg + } + return "unknown validation error" +} + +// Is matches any ValidationError when target has AnyValidation kind, +// or matches the exact kind otherwise. +func (e ValidationError) Is(target error) bool { + t, ok := target.(ValidationError) + if !ok { + return false + } + return t.Kind == AnyValidation || e.Kind == t.Kind +} + +// QueryErrorKind identifies the specific query or parameter processing failure. +// The zero value (AnyQuery) acts as a wildcard for errors.Is matching. +type QueryErrorKind int + +const ( + AnyQuery QueryErrorKind = iota + QueryNotFound // Name: the missing query name + MissingClosingColon // Query: the full query string + EmptyVariable // Position: byte offset of the empty :: + ParameterNotFound // Name: the parameter name + NilParameterPath // Name: the parameter name + InvalidParameterType // Name: the parameter name; TypeKind: the actual kind +) + +// QueryError is returned when a query string or its parameters cannot be +// processed (missing query, bad syntax, unknown parameter). +type QueryError struct { + Kind QueryErrorKind + Name string // query or parameter name + Query string // full query string (MissingClosingColon) + Position int // byte offset (EmptyVariable) + TypeKind string // reflect.Kind string (InvalidParameterType) +} + +func (e QueryError) Error() string { + switch e.Kind { + case QueryNotFound: + return fmt.Sprintf("no query found for name %s", e.Name) + case MissingClosingColon: + return fmt.Sprintf("missing a closing : somewhere: %s", e.Query) + case EmptyVariable: + return fmt.Sprintf("empty variable declaration at position %d", e.Position) + case ParameterNotFound: + return fmt.Sprintf("query parameter %s cannot be found in the incoming parameters", e.Name) + case NilParameterPath: + return fmt.Sprintf("query parameter %s has a path, but the incoming parameter is nil", e.Name) + case InvalidParameterType: + return fmt.Sprintf("query parameter %s has a path, but the incoming parameter is not a map or a struct it is %s", e.Name, e.TypeKind) + default: + return "unknown query error" + } +} + +// Is matches any QueryError when target has AnyQuery kind, +// or matches the exact kind otherwise. +func (e QueryError) Is(target error) bool { + t, ok := target.(QueryError) + if !ok { + return false + } + return t.Kind == AnyQuery || e.Kind == t.Kind +} + +// IdentifierErrorKind identifies the specific identifier parsing failure. +// The zero value (AnyIdentifier) acts as a wildcard for errors.Is matching. +type IdentifierErrorKind int + +const ( + AnyIdentifier IdentifierErrorKind = iota + SemicolonInIdentifier // "; is not allowed in an identifier" + EmptyOrTrailingDotIdentifier // "identifiers cannot be empty or end with a ." + MissingDotInIdentifier // ". missing between parts of an identifier" + LeadingOrDoubleDotIdentifier // "identifier cannot start with . or have two . in a row" + InvalidCharacterInIdentifier // "invalid character found in identifier" +) + +// IdentifierError is returned when an identifier in a query parameter fails +// syntax validation. +type IdentifierError struct { + Kind IdentifierErrorKind + Identifier string +} + +func (e IdentifierError) Error() string { + switch e.Kind { + case SemicolonInIdentifier: + return fmt.Sprintf("; is not allowed in an identifier: %s", e.Identifier) + case EmptyOrTrailingDotIdentifier: + return fmt.Sprintf("identifiers cannot be empty or end with a .: %s", e.Identifier) + case MissingDotInIdentifier: + return fmt.Sprintf(". missing between parts of an identifier: %s", e.Identifier) + case LeadingOrDoubleDotIdentifier: + return fmt.Sprintf("identifier cannot start with . or have two . in a row: %s", e.Identifier) + case InvalidCharacterInIdentifier: + return fmt.Sprintf("invalid character found in identifier: %s", e.Identifier) + default: + return "unknown identifier error" + } +} + +// Is matches any IdentifierError when target has AnyIdentifier kind, +// or matches the exact kind otherwise. +func (e IdentifierError) Is(target error) bool { + t, ok := target.(IdentifierError) + if !ok { + return false + } + return t.Kind == AnyIdentifier || e.Kind == t.Kind +} diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..5cacf36 --- /dev/null +++ b/errors_test.go @@ -0,0 +1,140 @@ +package proteus + +import ( + "errors" + "testing" +) + +func TestValidationErrorIdentity(t *testing.T) { + err := ValidationError{Kind: NotPointer} + // exact kind match + if !errors.Is(err, ValidationError{Kind: NotPointer}) { + t.Error("ValidationError{NotPointer} should match itself") + } + // wildcard match + if !errors.Is(err, ValidationError{}) { + t.Error("ValidationError{NotPointer} should match any-ValidationError wildcard") + } + // different kind should not match + if errors.Is(err, ValidationError{Kind: NotPointerToStruct}) { + t.Error("ValidationError{NotPointer} should not match ValidationError{NotPointerToStruct}") + } +} + +func TestValidationErrorMessage(t *testing.T) { + cases := []struct { + kind ValidationErrorKind + want string + }{ + {NotPointer, "not a pointer"}, + {NotPointerToStruct, "not a pointer to struct"}, + {NeedExecutorOrQuerier, "need to supply an Executor or Querier parameter"}, + {InvalidFirstParam, "first parameter must be of type context.Context, Executor, or Querier"}, + {RowsMustBeNonNil, "rows must be non-nil"}, + } + for _, c := range cases { + e := ValidationError{Kind: c.kind} + if e.Error() != c.want { + t.Errorf("Kind %d: expected %q, got %q", c.kind, c.want, e.Error()) + } + } +} + +func TestQueryErrorKinds(t *testing.T) { + err := QueryError{Kind: QueryNotFound, Name: "foo"} + if !errors.Is(err, QueryError{Kind: QueryNotFound}) { + t.Error("QueryError{QueryNotFound} should match exact kind") + } + if !errors.Is(err, QueryError{}) { + t.Error("QueryError should match any-QueryError wildcard") + } + if errors.Is(err, QueryError{Kind: ParameterNotFound}) { + t.Error("QueryError{QueryNotFound} should not match QueryError{ParameterNotFound}") + } + if err.Error() != "no query found for name foo" { + t.Errorf("unexpected message: %s", err.Error()) + } +} + +func TestQueryErrorMessages(t *testing.T) { + cases := []struct { + err QueryError + want string + }{ + {QueryError{Kind: MissingClosingColon, Query: "select *"}, "missing a closing : somewhere: select *"}, + {QueryError{Kind: EmptyVariable, Position: 10}, "empty variable declaration at position 10"}, + {QueryError{Kind: ParameterNotFound, Name: "p"}, "query parameter p cannot be found in the incoming parameters"}, + {QueryError{Kind: NilParameterPath, Name: "p"}, "query parameter p has a path, but the incoming parameter is nil"}, + {QueryError{Kind: InvalidParameterType, Name: "p", TypeKind: "int"}, "query parameter p has a path, but the incoming parameter is not a map or a struct it is int"}, + } + for _, c := range cases { + if c.err.Error() != c.want { + t.Errorf("expected %q, got %q", c.want, c.err.Error()) + } + } +} + +func TestIdentifierErrorKinds(t *testing.T) { + err := IdentifierError{Kind: InvalidCharacterInIdentifier, Identifier: "a,b"} + if !errors.Is(err, IdentifierError{Kind: InvalidCharacterInIdentifier}) { + t.Error("IdentifierError should match exact kind") + } + if !errors.Is(err, IdentifierError{}) { + t.Error("IdentifierError should match any-IdentifierError wildcard") + } + if errors.Is(err, IdentifierError{Kind: SemicolonInIdentifier}) { + t.Error("IdentifierError{InvalidCharacter} should not match IdentifierError{Semicolon}") + } +} + +func TestIdentifierErrorMessages(t *testing.T) { + cases := []struct { + err IdentifierError + want string + }{ + {IdentifierError{Kind: SemicolonInIdentifier, Identifier: "a;b"}, "; is not allowed in an identifier: a;b"}, + {IdentifierError{Kind: EmptyOrTrailingDotIdentifier, Identifier: "a."}, "identifiers cannot be empty or end with a .: a."}, + {IdentifierError{Kind: MissingDotInIdentifier, Identifier: "ab"}, ". missing between parts of an identifier: ab"}, + {IdentifierError{Kind: LeadingOrDoubleDotIdentifier, Identifier: ".a"}, "identifier cannot start with . or have two . in a row: .a"}, + {IdentifierError{Kind: InvalidCharacterInIdentifier, Identifier: "a,b"}, "invalid character found in identifier: a,b"}, + } + for _, c := range cases { + if c.err.Error() != c.want { + t.Errorf("expected %q, got %q", c.want, c.err.Error()) + } + } +} + +func TestValidationErrorPropagation(t *testing.T) { + wrapped := Error{ + FuncName: "TestFunc", + FieldOrder: 0, + OriginalError: ValidationError{Kind: NotPointer}, + } + if !errors.Is(wrapped, ValidationError{Kind: NotPointer}) { + t.Error("ValidationError should be reachable through Error.Unwrap()") + } + if !errors.Is(wrapped, ValidationError{}) { + t.Error("any-ValidationError wildcard should be reachable through Error.Unwrap()") + } +} + +func TestErrorsAsExtraction(t *testing.T) { + err := QueryError{Kind: QueryNotFound, Name: "myquery"} + var qe QueryError + if !errors.As(err, &qe) { + t.Fatal("errors.As should succeed for QueryError") + } + if qe.Name != "myquery" { + t.Errorf("expected Name=myquery, got %s", qe.Name) + } + + err2 := IdentifierError{Kind: SemicolonInIdentifier, Identifier: "a;b"} + var ie IdentifierError + if !errors.As(err2, &ie) { + t.Fatal("errors.As should succeed for IdentifierError") + } + if ie.Identifier != "a;b" { + t.Errorf("expected Identifier=a;b, got %s", ie.Identifier) + } +} diff --git a/mapper/errors.go b/mapper/errors.go new file mode 100644 index 0000000..8a74070 --- /dev/null +++ b/mapper/errors.go @@ -0,0 +1,136 @@ +package mapper + +import ( + "fmt" + "reflect" +) + +// ExtractErrorKind identifies the specific path-extraction failure. +// The zero value (AnyExtract) acts as a wildcard for errors.Is matching. +type ExtractErrorKind int + +const ( + AnyExtract ExtractErrorKind = iota + NoPathRemaining // type extraction: no path left + SubfieldOfNil // cannot descend into nil + SubfieldUnsupportedKind // non-map/struct/slice/array subfield + ValueNoPathRemaining // value extraction: no path left + ValueMapNonStringKey // map key is not a string + ValueContainedNonMapStruct // cannot descend into non-map/struct + NoSuchFieldType // Field: struct field name not found (type extraction) + NoSuchMapKey // Key: map key not found + NoSuchField // Field: struct field name not found (value extraction) + InvalidIndex // Index: non-integer or out-of-range index; Err: strconv error if present +) + +// ExtractError is returned when navigating a dot-separated path through a +// value or type fails. +type ExtractError struct { + Kind ExtractErrorKind + Value string // field name (NoSuchFieldType, NoSuchField, NoSuchMapKey) or index (InvalidIndex) + Err error // InvalidIndex: wrapped strconv error (may be nil) +} + +func (e ExtractError) Error() string { + switch e.Kind { + case NoPathRemaining: + return "cannot extract type; no path remaining" + case SubfieldOfNil: + return "cannot find the type for the subfield of a nil" + case SubfieldUnsupportedKind: + return "cannot find the type for the subfield of anything other than a map, struct, slice, or array" + case ValueNoPathRemaining: + return "cannot extract value; no path remaining" + case ValueMapNonStringKey: + return "cannot extract value; map does not have a string key" + case ValueContainedNonMapStruct: + return "cannot extract value; only maps and structs can have contained values" + case NoSuchFieldType: + return "cannot find the type; no such field " + e.Value + case NoSuchMapKey: + return "cannot extract value; no such map key " + e.Value + case NoSuchField: + return "cannot extract value; no such field " + e.Value + case InvalidIndex: + if e.Err != nil { + return fmt.Sprintf("invalid index: %s :%v", e.Value, e.Err) + } + return fmt.Sprintf("invalid index: %s", e.Value) + default: + return "unknown extract error" + } +} + +// Is matches any ExtractError when target has AnyExtract kind, +// or matches the exact kind otherwise. +func (e ExtractError) Is(target error) bool { + t, ok := target.(ExtractError) + if !ok { + return false + } + return t.Kind == AnyExtract || e.Kind == t.Kind +} + +// Unwrap returns the underlying strconv error for InvalidIndex, nil otherwise. +func (e ExtractError) Unwrap() error { + return e.Err +} + +// AssignErrorKind identifies the specific value-assignment failure. +// The zero value (AnyAssign) acts as a wildcard for errors.Is matching. +type AssignErrorKind int + +const ( + AnyAssign AssignErrorKind = iota + InvalidOutputType // output type passed to MakeBuilder is nil + InvalidMapKeyType // map key type is not string + NilReturnForNonPointer // ToType: the non-pointer type that got a nil value + MapAssign // Value, FromType, ToType, Key + StructPointerAssign // Value, FromType, FieldName, ToType + StructNilAssign // FieldName, ToType + StructAssign // Value, FromType, FieldName, ToType + PrimitiveAssign // Value, FromType, ToType +) + +// AssignError is returned when a database value cannot be assigned to the +// target Go type or field. +type AssignError struct { + Kind AssignErrorKind + Value any + FromType reflect.Type + ToType reflect.Type + Field string // map key (MapAssign) or struct field name (Struct*Assign kinds) +} + +func (e AssignError) Error() string { + switch e.Kind { + case InvalidOutputType: + return "sType cannot be nil" + case InvalidMapKeyType: + return "only maps with string keys are supported" + case NilReturnForNonPointer: + return fmt.Sprintf("attempting to return nil for non-pointer type %v", e.ToType) + case MapAssign: + return fmt.Sprintf("unable to assign value %v of type %v to map value of type %v with key %s", e.Value, e.FromType, e.ToType, e.Field) + case StructPointerAssign: + return fmt.Sprintf("unable to assign pointer to value %v of type %v to struct field %s of type %v", e.Value, e.FromType, e.Field, e.ToType) + case StructNilAssign: + return fmt.Sprintf("unable to assign nil value to non-pointer struct field %s of type %v", e.Field, e.ToType) + case StructAssign: + return fmt.Sprintf("unable to assign value %v of type %v to struct field %s of type %v", e.Value, e.FromType, e.Field, e.ToType) + case PrimitiveAssign: + return fmt.Sprintf("unable to assign value %v of type %v to return type of type %v", e.Value, e.FromType, e.ToType) + default: + return "unknown assign error" + } +} + +// Is matches any AssignError when target has AnyAssign kind, +// or matches the exact kind otherwise. +func (e AssignError) Is(target error) bool { + t, ok := target.(AssignError) + if !ok { + return false + } + return t.Kind == AnyAssign || e.Kind == t.Kind +} diff --git a/mapper/errors_test.go b/mapper/errors_test.go new file mode 100644 index 0000000..25be223 --- /dev/null +++ b/mapper/errors_test.go @@ -0,0 +1,114 @@ +package mapper + +import ( + "errors" + "reflect" + "strconv" + "testing" +) + +func TestAssignErrorKinds(t *testing.T) { + stringType := reflect.TypeOf("") + err := AssignError{Kind: NilReturnForNonPointer, ToType: stringType} + + if !errors.Is(err, AssignError{Kind: NilReturnForNonPointer}) { + t.Error("AssignError should match exact kind") + } + if !errors.Is(err, AssignError{}) { + t.Error("AssignError should match any-AssignError wildcard") + } + if errors.Is(err, AssignError{Kind: MapAssign}) { + t.Error("AssignError{NilReturnForNonPointer} should not match AssignError{MapAssign}") + } +} + +func TestAssignErrorMessages(t *testing.T) { + stringType := reflect.TypeOf("") + intType := reflect.TypeOf(0) + cases := []struct { + err AssignError + want string + }{ + {AssignError{Kind: InvalidOutputType}, "sType cannot be nil"}, + {AssignError{Kind: InvalidMapKeyType}, "only maps with string keys are supported"}, + {AssignError{Kind: NilReturnForNonPointer, ToType: stringType}, "attempting to return nil for non-pointer type string"}, + {AssignError{Kind: StructNilAssign, Field: "Name", ToType: stringType}, "unable to assign nil value to non-pointer struct field Name of type string"}, + {AssignError{Kind: PrimitiveAssign, Value: 42, FromType: intType, ToType: stringType}, "unable to assign value 42 of type int to return type of type string"}, + } + for _, c := range cases { + if c.err.Error() != c.want { + t.Errorf("expected %q, got %q", c.want, c.err.Error()) + } + } +} + +func TestAssignErrorsAsExtraction(t *testing.T) { + stringType := reflect.TypeOf("") + err := AssignError{Kind: NilReturnForNonPointer, ToType: stringType} + var ae AssignError + if !errors.As(err, &ae) { + t.Fatal("errors.As should succeed for AssignError") + } + if ae.ToType != stringType { + t.Errorf("expected ToType=string, got %v", ae.ToType) + } +} + +func TestExtractErrorKinds(t *testing.T) { + err := ExtractError{Kind: NoSuchField, Value: "Name"} + if !errors.Is(err, ExtractError{Kind: NoSuchField}) { + t.Error("ExtractError should match exact kind") + } + if !errors.Is(err, ExtractError{}) { + t.Error("ExtractError should match any-ExtractError wildcard") + } + if errors.Is(err, ExtractError{Kind: NoSuchMapKey}) { + t.Error("ExtractError{NoSuchField} should not match ExtractError{NoSuchMapKey}") + } +} + +func TestExtractErrorMessages(t *testing.T) { + cases := []struct { + err ExtractError + want string + }{ + {ExtractError{Kind: NoPathRemaining}, "cannot extract type; no path remaining"}, + {ExtractError{Kind: ValueNoPathRemaining}, "cannot extract value; no path remaining"}, + {ExtractError{Kind: NoSuchField, Value: "Foo"}, "cannot extract value; no such field Foo"}, + {ExtractError{Kind: NoSuchMapKey, Value: "bar"}, "cannot extract value; no such map key bar"}, + {ExtractError{Kind: NoSuchFieldType, Value: "Baz"}, "cannot find the type; no such field Baz"}, + {ExtractError{Kind: InvalidIndex, Value: "xyz"}, "invalid index: xyz"}, + } + for _, c := range cases { + if c.err.Error() != c.want { + t.Errorf("expected %q, got %q", c.want, c.err.Error()) + } + } +} + +func TestExtractErrorInvalidIndexUnwrap(t *testing.T) { + _, parseErr := strconv.Atoi("abc") + err := ExtractError{Kind: InvalidIndex, Value: "abc", Err: parseErr} + + if !errors.Is(err, ExtractError{Kind: InvalidIndex}) { + t.Error("InvalidIndex ExtractError should match its kind") + } + if !errors.Is(err, strconv.ErrSyntax) { + t.Error("InvalidIndex ExtractError should unwrap to strconv.ErrSyntax") + } + noWrap := ExtractError{Kind: InvalidIndex, Value: "5"} + if noWrap.Unwrap() != nil { + t.Error("Unwrap should return nil when Err is nil") + } +} + +func TestExtractErrorsAsExtraction(t *testing.T) { + err := ExtractError{Kind: NoSuchField, Value: "MyField"} + var ee ExtractError + if !errors.As(err, &ee) { + t.Fatal("errors.As should succeed for ExtractError") + } + if ee.Value != "MyField" { + t.Errorf("expected Field=MyField, got %s", ee.Value) + } +} diff --git a/mapper/extract.go b/mapper/extract.go index d5df39b..f330ab0 100644 --- a/mapper/extract.go +++ b/mapper/extract.go @@ -3,18 +3,15 @@ package mapper import ( "context" "database/sql/driver" - "fmt" "log/slog" "reflect" "strconv" - - "errors" ) func ExtractType(ctx context.Context, curType reflect.Type, path []string) (reflect.Type, error) { // error case path length == 0 if len(path) == 0 { - return nil, errors.New("cannot extract type; no path remaining") + return nil, ExtractError{Kind: NoPathRemaining} } ss := fromPtrType(curType) // base case path length == 1 @@ -23,7 +20,7 @@ func ExtractType(ctx context.Context, curType reflect.Type, path []string) (refl } // length > 1, find a match for path[1], and recurse if ss == nil { - return nil, errors.New("cannot find the type for the subfield of a nil") + return nil, ExtractError{Kind: SubfieldOfNil} } switch ss.Kind() { case reflect.Map: @@ -34,23 +31,23 @@ func ExtractType(ctx context.Context, curType reflect.Type, path []string) (refl if f, exists := ss.FieldByName(path[1]); exists { return ExtractType(ctx, f.Type, path[1:]) } - return nil, errors.New("cannot find the type; no such field " + path[1]) + return nil, ExtractError{Kind: NoSuchFieldType, Value: path[1]} case reflect.Array, reflect.Slice: // handle slices and arrays _, err := strconv.Atoi(path[1]) if err != nil { - return nil, fmt.Errorf("invalid index: %s :%w", path[1], err) + return nil, ExtractError{Kind: InvalidIndex, Value: path[1], Err: err} } return ExtractType(ctx, ss.Elem(), path[1:]) default: - return nil, errors.New("cannot find the type for the subfield of anything other than a map, struct, slice, or array") + return nil, ExtractError{Kind: SubfieldUnsupportedKind} } } func Extract(ctx context.Context, s any, path []string) (any, error) { // error case path length == 0 if len(path) == 0 { - return nil, errors.New("cannot extract value; no path remaining") + return nil, ExtractError{Kind: ValueNoPathRemaining} } // base case path length == 1 if len(path) == 1 { @@ -66,19 +63,19 @@ func Extract(ctx context.Context, s any, path []string) (any, error) { switch sv.Kind() { case reflect.Map: if sv.Type().Key().Kind() != reflect.String { - return nil, errors.New("cannot extract value; map does not have a string key") + return nil, ExtractError{Kind: ValueMapNonStringKey} } slog.DebugContext(ctx, "map extract", "key", path[1], "availableKeys", sv.MapKeys()) v := sv.MapIndex(reflect.ValueOf(path[1])) slog.DebugContext(ctx, "map extract result", "value", v) if !v.IsValid() { - return nil, errors.New("cannot extract value; no such map key " + path[1]) + return nil, ExtractError{Kind: NoSuchMapKey, Value: path[1]} } return Extract(ctx, v.Interface(), path[1:]) case reflect.Struct: //make sure the field exists if _, exists := sv.Type().FieldByName(path[1]); !exists { - return nil, errors.New("cannot extract value; no such field " + path[1]) + return nil, ExtractError{Kind: NoSuchField, Value: path[1]} } v := sv.FieldByName(path[1]) @@ -87,15 +84,15 @@ func Extract(ctx context.Context, s any, path []string) (any, error) { // handle slices and arrays pos, err := strconv.Atoi(path[1]) if err != nil { - return nil, fmt.Errorf("invalid index: %s :%w", path[1], err) + return nil, ExtractError{Kind: InvalidIndex, Value: path[1], Err: err} } if pos < 0 || pos >= sv.Len() { - return nil, fmt.Errorf("invalid index: %s", path[1]) + return nil, ExtractError{Kind: InvalidIndex, Value: path[1]} } v := sv.Index(pos) return Extract(ctx, v.Interface(), path[1:]) default: - return nil, errors.New("cannot extract value; only maps and structs can have contained values") + return nil, ExtractError{Kind: ValueContainedNonMapStruct} } } diff --git a/mapper/extract_test.go b/mapper/extract_test.go index 466f1e3..2966d3e 100644 --- a/mapper/extract_test.go +++ b/mapper/extract_test.go @@ -2,13 +2,10 @@ package mapper import ( "context" + "errors" "fmt" "reflect" "testing" - - "errors" - - "github.com/jonbodner/proteus/cmp" ) func TestExtractPointer(t *testing.T) { @@ -119,30 +116,29 @@ func TestExtract(t *testing.T) { func TestExtractFail(t *testing.T) { ctx := context.Background() - f := func(in any, path []string, msg string) { + f := func(in any, path []string, expected error) { _, err := Extract(ctx, in, path) if err == nil { - t.Errorf("Expected an error %s, got none", msg) + t.Errorf("Expected an error, got none") } - eExp := errors.New(msg) - if !cmp.Errors(err, eExp) { - t.Errorf("Expected error %s, got %s", eExp, err) + if !errors.Is(err, expected) { + t.Errorf("Expected error %v, got %v", expected, err) } } //base case no path - f(10, []string{}, "cannot extract value; no path remaining") + f(10, []string{}, ExtractError{Kind: ValueNoPathRemaining}) //path too long - f(10, []string{"A", "B"}, "cannot extract value; only maps and structs can have contained values") + f(10, []string{"A", "B"}, ExtractError{Kind: ValueContainedNonMapStruct}) //invalid map - f(map[int]any{10: "Hello"}, []string{"m", "10"}, "cannot extract value; map does not have a string key") + f(map[int]any{10: "Hello"}, []string{"m", "10"}, ExtractError{Kind: ValueMapNonStringKey}) //no such field case type Bar struct { A int } - f(Bar{A: 20}, []string{"b", "B"}, "cannot extract value; no such field B") + f(Bar{A: 20}, []string{"b", "B"}, ExtractError{Kind: NoSuchField}) } func TestExtractType(t *testing.T) { diff --git a/mapper/mapper.go b/mapper/mapper.go index c09ba23..d1b4fc8 100644 --- a/mapper/mapper.go +++ b/mapper/mapper.go @@ -3,12 +3,9 @@ package mapper import ( "context" "database/sql" - "fmt" "log/slog" "reflect" "strings" - - "errors" ) func ptrConverter(ctx context.Context, isPtr bool, sType reflect.Type, out reflect.Value, err error) (any, error) { @@ -28,14 +25,14 @@ func ptrConverter(ctx context.Context, isPtr bool, sType reflect.Type, out refle } k := out.Type().Kind() if (k == reflect.Pointer || k == reflect.Interface) && out.IsNil() { - return nil, fmt.Errorf("attempting to return nil for non-pointer type %v", sType) + return nil, AssignError{Kind: NilReturnForNonPointer, ToType: sType} } return out.Interface(), nil } func MakeBuilder(ctx context.Context, sType reflect.Type) (Builder, error) { if sType == nil { - return nil, errors.New("sType cannot be nil") + return nil, AssignError{Kind: InvalidOutputType} } isPtr := false @@ -52,7 +49,7 @@ func MakeBuilder(ctx context.Context, sType reflect.Type) (Builder, error) { switch sType.Kind() { case reflect.Map: if sType.Key().Kind() != reflect.String { - return nil, errors.New("only maps with string keys are supported") + return nil, AssignError{Kind: InvalidMapKeyType} } return func(cols []string, vals []any) (any, error) { out, err := buildMap(ctx, sType, cols, vals) @@ -136,7 +133,7 @@ func buildMap(ctx context.Context, sType reflect.Type, cols []string, vals []any if rv.Elem().Elem().Type().ConvertibleTo(sType.Elem()) { out.SetMapIndex(reflect.ValueOf(v), rv.Elem().Elem().Convert(sType.Elem())) } else { - return out, fmt.Errorf("unable to assign value %v of type %v to map value of type %v with key %s", rv.Elem().Elem(), rv.Elem().Elem().Type(), sType.Elem(), v) + return out, AssignError{Kind: MapAssign, Value: rv.Elem().Elem().Interface(), FromType: rv.Elem().Elem().Type(), ToType: sType.Elem(), Field: v} } } return out, nil @@ -184,7 +181,7 @@ func buildStructInner(ctx context.Context, sType reflect.Type, out reflect.Value field.Elem().Set(rv.Elem().Elem().Convert(curFieldType.Elem())) } else { slog.ErrorContext(ctx, "can't find the field") - return fmt.Errorf("unable to assign pointer to value %v of type %v to struct field %s of type %v", rv.Elem().Elem(), rv.Elem().Elem().Type(), sf.name[depth], curFieldType) + return AssignError{Kind: StructPointerAssign, Value: rv.Elem().Elem().Interface(), FromType: rv.Elem().Elem().Type(), Field: sf.name[depth], ToType: curFieldType} } } else { if reflect.PointerTo(curFieldType).Implements(scannerType) { @@ -201,12 +198,12 @@ func buildStructInner(ctx context.Context, sType reflect.Type, out reflect.Value } } else if rv.Elem().IsNil() { slog.ErrorContext(ctx, "attempting to assign nil to non-pointer field") - return fmt.Errorf("unable to assign nil value to non-pointer struct field %s of type %v", sf.name[depth], curFieldType) + return AssignError{Kind: StructNilAssign, Field: sf.name[depth], ToType: curFieldType} } else if rv.Elem().Elem().Type().ConvertibleTo(curFieldType) { field.Set(rv.Elem().Elem().Convert(curFieldType)) } else { slog.ErrorContext(ctx, "can't find the field") - return fmt.Errorf("unable to assign value %v of type %v to struct field %s of type %v", rv.Elem().Elem(), rv.Elem().Elem().Type(), sf.name[depth], curFieldType) + return AssignError{Kind: StructAssign, Value: rv.Elem().Elem().Interface(), FromType: rv.Elem().Elem().Type(), Field: sf.name[depth], ToType: curFieldType} } } return nil @@ -222,7 +219,7 @@ func buildPrimitive(ctx context.Context, sType reflect.Type, cols []string, vals if rv.Elem().Elem().Type().ConvertibleTo(sType) { out.Set(rv.Elem().Elem().Convert(sType)) } else { - return out, fmt.Errorf("unable to assign value %v of type %v to return type of type %v", rv.Elem().Elem(), rv.Elem().Elem().Type(), sType) + return out, AssignError{Kind: PrimitiveAssign, Value: rv.Elem().Elem().Interface(), FromType: rv.Elem().Elem().Type(), ToType: sType} } return out, nil } diff --git a/mapper_test.go b/mapper_test.go index 2c78b7f..4d6c4e0 100644 --- a/mapper_test.go +++ b/mapper_test.go @@ -11,7 +11,6 @@ import ( "errors" - "github.com/jonbodner/proteus/cmp" "github.com/jonbodner/proteus/mapper" ) @@ -23,9 +22,8 @@ func TestMapRows(t *testing.T) { if v != nil { t.Error("Expected nil when passing in nil rows") } - eExp := errors.New("rows must be non-nil") - if !cmp.Errors(err, eExp) { - t.Errorf("Expected error %s, got %s", eExp, err) + if !errors.Is(err, ValidationError{Kind: RowsMustBeNonNil}) { + t.Errorf("Expected RowsMustBeNonNil, got %s", err) } } diff --git a/proteus.go b/proteus.go index 7782fb6..943e612 100644 --- a/proteus.go +++ b/proteus.go @@ -73,12 +73,12 @@ func ShouldBuild(ctx context.Context, dao any, paramAdapter ParamAdapter, mapper daoPointerType := reflect.TypeOf(dao) //must be a pointer to struct if daoPointerType.Kind() != reflect.Pointer { - return errors.New("not a pointer") + return ValidationError{Kind: NotPointer} } daoType := daoPointerType.Elem() //if not a struct, error out if daoType.Kind() != reflect.Struct { - return errors.New("not a pointer to struct") + return ValidationError{Kind: NotPointerToStruct} } var out error funcs := make([]reflect.Value, daoType.NumField()) @@ -161,12 +161,12 @@ func Build(dao any, paramAdapter ParamAdapter, mappers ...QueryMapper) error { daoPointerType := reflect.TypeOf(dao) //must be a pointer to struct if daoPointerType.Kind() != reflect.Pointer { - return errors.New("not a pointer") + return ValidationError{Kind: NotPointer} } daoType := daoPointerType.Elem() //if not a struct, error out if daoType.Kind() != reflect.Struct { - return errors.New("not a pointer to struct") + return ValidationError{Kind: NotPointerToStruct} } daoPointerValue := reflect.ValueOf(dao) daoValue := reflect.Indirect(daoPointerValue) @@ -249,7 +249,7 @@ var ( func validateFunction(funcType reflect.Type) (bool, error) { //the first parameter is Executor if funcType.NumIn() == 0 { - return false, errors.New("need to supply an Executor or Querier parameter") + return false, ValidationError{Kind: NeedExecutorOrQuerier} } var isExec bool var hasContext bool @@ -261,7 +261,7 @@ func validateFunction(funcType reflect.Type) (bool, error) { case fType.Implements(qType): //do nothing isExec is false default: - return false, errors.New("first parameter must be of type context.Context, Executor, or Querier") + return false, ValidationError{Kind: InvalidFirstParam} } start := 1 if hasContext { @@ -272,41 +272,41 @@ func validateFunction(funcType reflect.Type) (bool, error) { case fType.Implements(conQType): //do nothing isExec is false default: - return false, errors.New("first parameter must be of type context.Context, Executor, or Querier") + return false, ValidationError{Kind: InvalidFirstParam} } } //no in parameter can be a channel for i := start; i < funcType.NumIn(); i++ { if funcType.In(i).Kind() == reflect.Chan { - return false, errors.New("no input parameter can be a channel") + return false, ValidationError{Kind: ChannelInputParam} } } //has 0, 1, or 2 return values if funcType.NumOut() > 2 { - return false, errors.New("must return 0, 1, or 2 values") + return false, ValidationError{Kind: TooManyReturnValues} } //if 2 return values, second is error if funcType.NumOut() == 2 { if !funcType.Out(1).Implements(errType) { - return false, errors.New("2nd output parameter must be of type error") + return false, ValidationError{Kind: SecondReturnNotError} } } //if 1 or 2, the 1st param is not a channel (handle map, I guess) if funcType.NumOut() > 0 { if funcType.Out(0).Kind() == reflect.Chan { - return false, errors.New("1st output parameter cannot be a channel") + return false, ValidationError{Kind: FirstReturnIsChannel} } if isExec && funcType.Out(0).Kind() != reflect.Int64 && funcType.Out(0) != sqlResultType { - return false, errors.New("the 1st output parameter of an Executor must be int64 or sql.Result") + return false, ValidationError{Kind: ExecutorReturnType} } //sql.Result only useful with executor. if !isExec && funcType.Out(0) == sqlResultType { - return false, errors.New("output parameters of type sql.Result must be combined with Executor") + return false, ValidationError{Kind: SQLResultWithQuerier} } } return hasContext, nil @@ -331,8 +331,8 @@ func makeImplementation(ctx context.Context, funcType reflect.Type, query string case fType.Implements(qType): return makeQuerierImplementation(ctx, funcType, fixedQuery, paramOrder) } - //this should be impossible, since we already validated that the first parameter is either an executor or a querier - return nil, errors.New("first parameter must be of type Executor or Querier") + //this should be impossible, since we already validated that the first parameter is either an executor, a querier, or a context + return nil, ValidationError{Kind: InvalidFirstParam} } func lookupQuery(query string, mappers []QueryMapper) (string, error) { @@ -345,5 +345,5 @@ func lookupQuery(query string, mappers []QueryMapper) (string, error) { return q, nil } } - return "", fmt.Errorf("no query found for name %s", name) + return "", QueryError{Kind: QueryNotFound, Name: name} } diff --git a/proteus_function.go b/proteus_function.go index 135980b..36eb686 100644 --- a/proteus_function.go +++ b/proteus_function.go @@ -7,8 +7,6 @@ import ( "reflect" "strings" - "errors" - "github.com/jonbodner/proteus/mapper" ) @@ -29,12 +27,12 @@ func (fb Builder) BuildFunction(ctx context.Context, f any, query string, names funcPointerType := reflect.TypeOf(f) //must be a pointer to func if funcPointerType.Kind() != reflect.Pointer { - return errors.New("not a pointer") + return ValidationError{Kind: NotPointer} } funcType := funcPointerType.Elem() //if not a func, error out if funcType.Kind() != reflect.Func { - return errors.New("not a pointer to func") + return ValidationError{Kind: NotPointerToFunc} } //validate to make sure that the function matches what we expect @@ -108,7 +106,7 @@ func (fb Builder) Query(ctx context.Context, q ContextQuerier, query string, par // make sure that output is a pointer to something outputPointerType := reflect.TypeOf(output) if outputPointerType.Kind() != reflect.Pointer { - return errors.New("not a pointer") + return ValidationError{Kind: NotPointer} } finalQuery, queryArgs, err := fb.setupDynamicQueries(ctx, query, params) diff --git a/proteus_function_test.go b/proteus_function_test.go index 1299f14..1a3d286 100644 --- a/proteus_function_test.go +++ b/proteus_function_test.go @@ -56,7 +56,7 @@ func TestBuilder_BuildFunctionErrors(t *testing.T) { f: &f2, query: "SELECT * FROM PERSON WHERE id = :id:", params: nil, - errMsg: "query Parameter id cannot be found in the incoming parameters", + errMsg: "query parameter id cannot be found in the incoming parameters", }, } for _, v := range data { diff --git a/proteus_test.go b/proteus_test.go index 6b79c39..84a4d82 100644 --- a/proteus_test.go +++ b/proteus_test.go @@ -10,11 +10,8 @@ import ( "time" - "fmt" "github.com/google/go-cmp/cmp" - pcmp "github.com/jonbodner/proteus/cmp" - _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" ) @@ -89,7 +86,7 @@ func TestConvertToPositionalParameters(t *testing.T) { reflect.TypeOf(f3), "", nil, - fmt.Errorf("missing a closing : somewhere: %s", `select * from Product where name=:name: and cost=:cost`), + QueryError{Kind: MissingClosingColon, Query: `select * from Product where name=:name: and cost=:cost`}, }, //empty :: `select * from Product where name=:: and cost=:cost`: inner{ @@ -97,7 +94,7 @@ func TestConvertToPositionalParameters(t *testing.T) { reflect.TypeOf(f3), "", nil, - errors.New("empty variable declaration at position 34"), + QueryError{Kind: EmptyVariable, Position: 34}, }, //invalid identifier `select * from Product where name=:a,b,c: and cost=:cost`: inner{ @@ -105,7 +102,7 @@ func TestConvertToPositionalParameters(t *testing.T) { reflect.TypeOf(f3), "", nil, - errors.New("invalid character found in identifier: a,b,c"), + IdentifierError{Kind: InvalidCharacterInIdentifier, Identifier: "a,b,c"}, }, //escaped character (invalid sql, but not the problem at hand) `select * from Pr\:oduct where name=:name: and cost=:cost:`: inner{ @@ -124,7 +121,7 @@ func TestConvertToPositionalParameters(t *testing.T) { if err == nil { qSimple, _ = q.finalize(ctx, nil) } - if qSimple != v.query || !reflect.DeepEqual(qps, v.qps) || !pcmp.Errors(err, v.err) { + if qSimple != v.query || !reflect.DeepEqual(qps, v.qps) || !errors.Is(err, v.err) { t.Errorf("failed for %s -> %#v: %v", k, v, err) } } @@ -159,14 +156,13 @@ func TestBuildParamMap(t *testing.T) { // This still needs tests for context... func TestValidateFunction(t *testing.T) { - f := func(fType reflect.Type, msg string) { + f := func(fType reflect.Type, expected error) { hasCtx, err := validateFunction(fType) if err == nil { t.Fatalf("Expected err") } - eExp := errors.New(msg) - if !pcmp.Errors(err, eExp) { - t.Errorf("Wrong error expected %s, got %s", eExp, err) + if !errors.Is(err, expected) { + t.Errorf("Wrong error: expected %v, got %v", expected, err) } if hasCtx { t.Errorf("Expected no context, has one") @@ -185,15 +181,15 @@ func TestValidateFunction(t *testing.T) { //invalid -- no parameters var f1 func() - f(reflect.TypeOf(f1), "need to supply an Executor or Querier parameter") + f(reflect.TypeOf(f1), ValidationError{Kind: NeedExecutorOrQuerier}) //invalid -- wrong first parameter type var f2 func(int) - f(reflect.TypeOf(f2), "first parameter must be of type context.Context, Executor, or Querier") + f(reflect.TypeOf(f2), ValidationError{Kind: InvalidFirstParam}) //invalid -- has a channel input param var f3 func(Executor, chan int) - f(reflect.TypeOf(f3), "no input parameter can be a channel") + f(reflect.TypeOf(f3), ValidationError{Kind: ChannelInputParam}) //valid -- only an Executor var g1 func(Executor) @@ -225,7 +221,7 @@ func TestValidateFunction(t *testing.T) { D bool }, error) //invalid for Exec - f(reflect.TypeOf(g4), "the 1st output parameter of an Executor must be int64 or sql.Result") + f(reflect.TypeOf(g4), ValidationError{Kind: ExecutorReturnType}) //valid for query var g4q func(Querier, int, map[string]any, struct { @@ -243,7 +239,7 @@ func TestValidateFunction(t *testing.T) { // invalid -- a querier, returning an sql.Result var r2 func(Querier) sql.Result - f(reflect.TypeOf(r2), "output parameters of type sql.Result must be combined with Executor") + f(reflect.TypeOf(r2), ValidationError{Kind: SQLResultWithQuerier}) } func TestBuild(t *testing.T) { @@ -756,7 +752,7 @@ func TestShouldBuild(t *testing.T) { if err2.Error() != `error in field #0 (Insert): missing a closing : somewhere: insert into Product(name) values(:p.Name) error in field #1 (Insert2): first parameter must be of type context.Context, Executor, or Querier error in field #3 (Insert3): no query found for name nope -error in field #5 (InsertNoP): query Parameter p cannot be found in the incoming parameters` { +error in field #5 (InsertNoP): query parameter p cannot be found in the incoming parameters` { t.Error(err2) } } diff --git a/runner.go b/runner.go index e46cc07..92c0cb5 100644 --- a/runner.go +++ b/runner.go @@ -8,8 +8,6 @@ import ( "database/sql" - "errors" - "github.com/jonbodner/proteus/mapper" ) @@ -42,8 +40,8 @@ func buildQueryArgs(ctx context.Context, funcArgs []reflect.Value, paramOrder [] var ( errZero = reflect.Zero(errType) - zeroInt64 = reflect.ValueOf(int64(0)) - zeroSQLResult = reflect.ValueOf((sql.Result)(nil)) + zeroInt64 = reflect.Zero(reflect.TypeFor[int64]()) + zeroSQLResult = reflect.Zero(reflect.TypeFor[sql.Result]()) sqlResultType = reflect.TypeFor[sql.Result]() ) @@ -151,7 +149,7 @@ func makeExecutorReturnVals(funcType reflect.Type) func(sql.Result, error) []ref // impossible case since validation should happen first, but be safe return func(result sql.Result, err error) []reflect.Value { - impossibleErr := reflect.ValueOf(errors.New("should never get here")) + impossibleErr := reflect.ValueOf(ValidationError{Kind: ShouldNeverGetHere}) if sType == sqlResultType { return []reflect.Value{zeroSQLResult, impossibleErr} } @@ -307,13 +305,13 @@ func makeQuerierReturnVals(ctx context.Context, funcType reflect.Type, builder m // impossible case since validation should happen first, but be safe return func(*sql.Rows, error) []reflect.Value { - return []reflect.Value{qZero, reflect.ValueOf(errors.New("should never get here"))} + return []reflect.Value{qZero, reflect.ValueOf(ValidationError{Kind: ShouldNeverGetHere})} } } func handleMapping(ctx context.Context, sType reflect.Type, rows *sql.Rows, builder mapper.Builder) (any, error) { if rows == nil { - return nil, errors.New("rows must be non-nil") + return nil, ValidationError{Kind: RowsMustBeNonNil} } defer rows.Close() var val any @@ -350,7 +348,7 @@ func handleMapping(ctx context.Context, sType reflect.Type, rows *sql.Rows, buil func mapRows(ctx context.Context, rows *sql.Rows, builder mapper.Builder) (any, error) { //fmt.Println(sType) if rows == nil { - return nil, errors.New("rows must be non-nil") + return nil, ValidationError{Kind: RowsMustBeNonNil} } if !rows.Next() { if err := rows.Err(); err != nil { @@ -365,7 +363,7 @@ func mapRows(ctx context.Context, rows *sql.Rows, builder mapper.Builder) (any, } if len(cols) == 0 { - return nil, errors.New("no values returned from query") + return nil, ValidationError{Kind: NoValuesFromQuery} } vals := make([]any, len(cols)) From 27fb1901a0d4681f95664d0775f1b5ced028d158 Mon Sep 17 00:00:00 2001 From: jonbodner Date: Mon, 23 Feb 2026 17:46:45 -0500 Subject: [PATCH 2/3] Refactor tests to use `errors.AsType` instead of `errors.As` for error extraction --- errors_test.go | 22 +++++++++++----------- mapper/errors_test.go | 20 ++++++++++---------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/errors_test.go b/errors_test.go index 5cacf36..6c085d4 100644 --- a/errors_test.go +++ b/errors_test.go @@ -121,20 +121,20 @@ func TestValidationErrorPropagation(t *testing.T) { func TestErrorsAsExtraction(t *testing.T) { err := QueryError{Kind: QueryNotFound, Name: "myquery"} - var qe QueryError - if !errors.As(err, &qe) { - t.Fatal("errors.As should succeed for QueryError") - } - if qe.Name != "myquery" { - t.Errorf("expected Name=myquery, got %s", qe.Name) + if qe, ok := errors.AsType[QueryError](err); ok { + if qe.Name != "myquery" { + t.Errorf("expected Name=myquery, got %s", qe.Name) + } + } else { + t.Fatal("errors.AsType should succeed for QueryError") } err2 := IdentifierError{Kind: SemicolonInIdentifier, Identifier: "a;b"} - var ie IdentifierError - if !errors.As(err2, &ie) { + if ie, ok := errors.AsType[IdentifierError](err2); ok { + if ie.Identifier != "a;b" { + t.Errorf("expected Identifier=a;b, got %s", ie.Identifier) + } + } else { t.Fatal("errors.As should succeed for IdentifierError") } - if ie.Identifier != "a;b" { - t.Errorf("expected Identifier=a;b, got %s", ie.Identifier) - } } diff --git a/mapper/errors_test.go b/mapper/errors_test.go index 25be223..5d5282b 100644 --- a/mapper/errors_test.go +++ b/mapper/errors_test.go @@ -45,13 +45,13 @@ func TestAssignErrorMessages(t *testing.T) { func TestAssignErrorsAsExtraction(t *testing.T) { stringType := reflect.TypeOf("") err := AssignError{Kind: NilReturnForNonPointer, ToType: stringType} - var ae AssignError - if !errors.As(err, &ae) { + if ae, ok := errors.AsType[AssignError](err); ok { + if ae.ToType != stringType { + t.Errorf("expected ToType=string, got %v", ae.ToType) + } + } else { t.Fatal("errors.As should succeed for AssignError") } - if ae.ToType != stringType { - t.Errorf("expected ToType=string, got %v", ae.ToType) - } } func TestExtractErrorKinds(t *testing.T) { @@ -104,11 +104,11 @@ func TestExtractErrorInvalidIndexUnwrap(t *testing.T) { func TestExtractErrorsAsExtraction(t *testing.T) { err := ExtractError{Kind: NoSuchField, Value: "MyField"} - var ee ExtractError - if !errors.As(err, &ee) { + if ee, ok := errors.AsType[ExtractError](err); ok { + if ee.Value != "MyField" { + t.Errorf("expected Field=MyField, got %s", ee.Value) + } + } else { t.Fatal("errors.As should succeed for ExtractError") } - if ee.Value != "MyField" { - t.Errorf("expected Field=MyField, got %s", ee.Value) - } } From 81067a59dc3e2bcc83322ec0657bbb8bdfd87252 Mon Sep 17 00:00:00 2001 From: jonbodner Date: Mon, 23 Feb 2026 17:48:18 -0500 Subject: [PATCH 3/3] Mark item 9 (structured error types) as done in MODERNIZATION.md Co-Authored-By: Claude Sonnet 4.6 --- MODERNIZATION.md | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/MODERNIZATION.md b/MODERNIZATION.md index fca8140..ede2218 100644 --- a/MODERNIZATION.md +++ b/MODERNIZATION.md @@ -102,11 +102,17 @@ Replaced `reflect.NewAt(sType, unsafe.Pointer(nil))` with `reflect.Zero(reflect. --- -## 9. Delete the `cmp/errors.go` Package +## ~~9. Delete the `cmp/errors.go` Package and Add Structured Error Types~~ (DONE) -**File:** `cmp/errors.go` +Deleted `cmp/errors.go`, which compared errors by string — a fragile anti-pattern. Replaced all inline `errors.New`/`fmt.Errorf` calls with five typed error structs grouped by class: -This package contains a single function that compares errors by their `.Error()` string — a fragile anti-pattern. With proper use of sentinel errors, `errors.Is`, and `errors.As`, this package becomes unnecessary. The tests that use it should be updated to compare errors structurally. +- **`ValidationError{Kind ValidationErrorKind}`** — struct/function signature validation failures in `Build`/`ShouldBuild`/`BuildFunction` +- **`QueryError{Kind QueryErrorKind, ...}`** — query lookup and parameter processing failures (with `Name`, `Query`, `Position`, `TypeKind` fields as applicable) +- **`IdentifierError{Kind IdentifierErrorKind, Identifier string}`** — identifier syntax validation failures in query parameters +- **`ExtractError{Kind ExtractErrorKind, Value string, Err error}`** — path-extraction failures in `mapper/extract.go`; implements `Unwrap()` to surface wrapped strconv errors for `InvalidIndex` +- **`AssignError{Kind AssignErrorKind, ...}`** — value-assignment failures in `mapper/mapper.go` + +Each type's zero-value `Kind` (the `AnyXxx` constant) acts as a wildcard: `errors.Is(err, ValidationError{})` matches any `ValidationError`; `errors.Is(err, ValidationError{Kind: NotPointer})` matches exactly. All tests updated to use `errors.Is`/`errors.As` instead of string comparison. --- @@ -371,7 +377,7 @@ If `Build` returns an error, `productDao` will have nil function fields. Subsequ **Lower priority (cleanup):** - ~~#5 — `strings.ReplaceAll`~~ *(DONE)* - ~~#6 — `strings.Builder`~~ *(DONE)* -- #9 — Delete `cmp/errors.go` +- ~~#9 — Delete `cmp/errors.go` and add structured error types~~ *(DONE)* - #10 — Deprecation annotations - #12 — Testing improvements - #13 — Reduce duplication