diff --git a/README.md b/README.md index 2861b2c..3909113 100644 --- a/README.md +++ b/README.md @@ -216,6 +216,32 @@ func init() { } } ``` +### Database error behavior configuration + +Having every DAO function returning an error can make things very cumbersone and repetitive quickly. Sometimes you just want to return a single value without having to check for errors because you are sure that any errors would be fatal, unrecoverable and require human intervention. + +```go +type UserRepository struct { + Fetch () func (ctx context.Context, q ContextQuerier)[]*models.Users `proq:"select * from users"` +} +``` +The above struct works fine, but if the database throws an error, you won't be able to know. The following configurations will control the behavior of such scenarios. + +By passing the following values in the `context.Context` to the `ShouldBuild` function. + +```go +c := context.WithValue(context.Background(), ContextKeyErrorBehavior, PanicAlways) +ShouldBuild(c, &dao, Postgres) +``` + +**`ErrorBehavior` Values**: + + +- `DoNothing` - proteus does not do anything when the underlying data source throws an error. +- `PanicWhenAbsent` - proteus will `panic`, if the DAO function being called does not have the `error` return type. +- `PanicAlways` - proteus will always `panic` if there is an error from the data source, whether or not the DAO function being called indicates an `error` return type. + + ## Struct Tags diff --git a/proteus.go b/proteus.go index 9c5822a..3661b48 100644 --- a/proteus.go +++ b/proteus.go @@ -52,6 +52,12 @@ If the entity is a primitive, then the first value returned for a row must be of */ +type ContextKey string + +const ( + ContextKeyErrorBehavior ContextKey = ContextKey("errorBehavior") +) + type Error struct { FuncName string FieldOrder int @@ -352,7 +358,7 @@ func makeImplementation(c context.Context, funcType reflect.Type, query string, case fType.Implements(qType): return makeQuerierImplementation(c, funcType, fixedQuery, paramOrder) } - //this should impossible, since we already validated that the first parameter is either an executor or a querier + //this should be impossible, since we already validated that the first parameter is either an executor or a querier return nil, stackerr.New("first parameter must be of type Executor or Querier") } diff --git a/proteus_function.go b/proteus_function.go index 0d6ad7e..f84ed4a 100644 --- a/proteus_function.go +++ b/proteus_function.go @@ -12,9 +12,23 @@ import ( "github.com/jonbodner/stackerr" ) +type ErrorBehavior string + +const ( + // (Default) proteus will do nothing the the query executor returns an error + DoNothing ErrorBehavior = "do_nothing" + + // proteus will always panic when the query executor returns an error + PanicAlways ErrorBehavior = "panic_always" + + // proteus will panic only if the calling function does not specify error in one of its return types + PanicWhenAbsent ErrorBehavior = "panic_if_absent" +) + type Builder struct { - adapter ParamAdapter - mappers []QueryMapper + adapter ParamAdapter + mappers []QueryMapper + errorBehavior ErrorBehavior } func NewBuilder(adapter ParamAdapter, mappers ...QueryMapper) Builder { diff --git a/runner.go b/runner.go index 2d3ebe6..38b2bad 100644 --- a/runner.go +++ b/runner.go @@ -43,8 +43,16 @@ var ( sqlResultType = reflect.TypeOf((*sql.Result)(nil)).Elem() ) +func getErrorBehaviorFromContext(ctx context.Context) ErrorBehavior { + v := ctx.Value(ContextKeyErrorBehavior) + if v != nil { + return v.(ErrorBehavior) + } + return DoNothing +} + func makeContextExecutorImplementation(c context.Context, funcType reflect.Type, query queryHolder, paramOrder []paramInfo) func(args []reflect.Value) []reflect.Value { - buildRetVals := makeExecutorReturnVals(funcType) + buildRetVals := makeExecutorReturnVals(funcType, getErrorBehaviorFromContext(c)) return func(args []reflect.Value) []reflect.Value { executor := args[1].Interface().(ContextExecutor) @@ -83,7 +91,7 @@ func makeContextExecutorImplementation(c context.Context, funcType reflect.Type, } func makeExecutorImplementation(c context.Context, funcType reflect.Type, query queryHolder, paramOrder []paramInfo) func(args []reflect.Value) []reflect.Value { - buildRetVals := makeExecutorReturnVals(funcType) + buildRetVals := makeExecutorReturnVals(funcType, getErrorBehaviorFromContext(c)) return func(args []reflect.Value) []reflect.Value { executor := args[0].Interface().(Executor) @@ -108,12 +116,29 @@ func makeExecutorImplementation(c context.Context, funcType reflect.Type, query } } -func makeExecutorReturnVals(funcType reflect.Type) func(sql.Result, error) []reflect.Value { +// Checks if query executor returns an error and handles the case appropriately +// based on configured behavior +func checkError(err error, behavior ErrorBehavior, isAbsent bool) { + if err != nil { + switch behavior { + case DoNothing: + case PanicAlways: + panic(stackerr.Wrap(err)) + case PanicWhenAbsent: + if isAbsent { + panic(stackerr.Wrap(err)) + } + } + } +} + +func makeExecutorReturnVals(funcType reflect.Type, errorBehavior ErrorBehavior) func(sql.Result, error) []reflect.Value { numOut := funcType.NumOut() //handle the 0,1,2 out parameter cases if numOut == 0 { - return func(sql.Result, error) []reflect.Value { + return func(_ sql.Result, err error) []reflect.Value { + checkError(err, errorBehavior, true) return []reflect.Value{} } } @@ -121,6 +146,8 @@ func makeExecutorReturnVals(funcType reflect.Type) func(sql.Result, error) []ref sType := funcType.Out(0) if numOut == 1 { return func(result sql.Result, err error) []reflect.Value { + // (sType == errType) in case error is the only single return value + checkError(err, errorBehavior, sType != errType) if err != nil { if sType == sqlResultType { return []reflect.Value{zeroSQLResult} @@ -131,6 +158,7 @@ func makeExecutorReturnVals(funcType reflect.Type) func(sql.Result, error) []ref return []reflect.Value{reflect.ValueOf(result)} } val, err := result.RowsAffected() + checkError(err, errorBehavior, sType != errType) if err != nil { return []reflect.Value{zeroInt64} } @@ -140,6 +168,7 @@ func makeExecutorReturnVals(funcType reflect.Type) func(sql.Result, error) []ref if numOut == 2 { return func(result sql.Result, err error) []reflect.Value { eType := funcType.Out(1) + checkError(err, errorBehavior, false) if sType == sqlResultType { if err != nil { return []reflect.Value{zeroSQLResult, reflect.ValueOf(err).Convert(eType)} @@ -150,6 +179,7 @@ func makeExecutorReturnVals(funcType reflect.Type) func(sql.Result, error) []ref return []reflect.Value{zeroInt64, reflect.ValueOf(err).Convert(eType)} } val, err := result.RowsAffected() + checkError(err, errorBehavior, false) if err != nil { return []reflect.Value{zeroInt64, reflect.ValueOf(err).Convert(eType)} } @@ -279,9 +309,12 @@ func makeQuerierImplementation(c context.Context, funcType reflect.Type, query q func makeQuerierReturnVals(c context.Context, funcType reflect.Type, builder mapper.Builder) func(*sql.Rows, error) []reflect.Value { numOut := funcType.NumOut() + errorBehavior := getErrorBehaviorFromContext(c) + //handle the 0,1,2 out parameter cases if numOut == 0 { - return func(*sql.Rows, error) []reflect.Value { + return func(_ *sql.Rows, err error) []reflect.Value { + checkError(err, errorBehavior, true) return []reflect.Value{} } } @@ -290,6 +323,7 @@ func makeQuerierReturnVals(c context.Context, funcType reflect.Type, builder map qZero := reflect.Zero(sType) if numOut == 1 { return func(rows *sql.Rows, err error) []reflect.Value { + checkError(err, errorBehavior, sType != errType) if err != nil { return []reflect.Value{qZero} } @@ -307,6 +341,7 @@ func makeQuerierReturnVals(c context.Context, funcType reflect.Type, builder map if numOut == 2 { return func(rows *sql.Rows, err error) []reflect.Value { eType := funcType.Out(1) + checkError(err, errorBehavior, false) if err != nil { return []reflect.Value{qZero, reflect.ValueOf(err).Convert(eType)} } diff --git a/runner_test.go b/runner_test.go index 9def1fb..75f4f6b 100644 --- a/runner_test.go +++ b/runner_test.go @@ -3,6 +3,7 @@ package proteus import ( "context" "database/sql" + "fmt" "reflect" "testing" @@ -10,6 +11,17 @@ import ( "github.com/jonbodner/proteus/mapper" ) +type faultyDb struct { +} + +func (db *faultyDb) crash(query string) error { + return fmt.Errorf("error: %s", query) +} + +func (db faultyDb) QueryContext(c context.Context, query string, args ...any) (*sql.Rows, error) { + return nil, db.crash(query) +} + func Test_getQArgs(t *testing.T) { type args struct { args []reflect.Value @@ -112,3 +124,64 @@ func Test_handleMapping(t *testing.T) { } } } + +func Test_errorBehaviorDoNothing(t *testing.T) { + dao := struct { + Foo func(c context.Context, q ContextQuerier) (*sql.Rows, error) `proq:"drop table users"` + }{} + + db := faultyDb{} + + c := context.WithValue(context.Background(), ContextKeyErrorBehavior, DoNothing) + if err := ShouldBuild(c, &dao, Postgres); err != nil { + t.Fatal(err) + return + } + _, _ = dao.Foo(c, db) +} + +func Test_errorBehaviorPanicAlways(t *testing.T) { + dao := struct { + Foo func(c context.Context, q ContextQuerier) (*sql.Rows, error) `proq:"drop table users"` + }{} + + db := faultyDb{} + + c := context.WithValue(context.Background(), ContextKeyErrorBehavior, PanicAlways) + if err := ShouldBuild(c, &dao, Postgres); err != nil { + t.Fatal(err) + return + } + recovery := func() { + err := recover() + t.Logf("This is a pass %s", err) + } + defer recovery() + + _, _ = dao.Foo(c, db) + + t.Fatalf("failed") +} + +func Test_errorBehaviorPanicWhenAbsent(t *testing.T) { + dao := struct { + Foo func(c context.Context, q ContextQuerier) *sql.Rows `proq:"hot potato!"` + }{} + + db := faultyDb{} + + c := context.WithValue(context.Background(), ContextKeyErrorBehavior, PanicWhenAbsent) + if err := ShouldBuild(c, &dao, Postgres); err != nil { + t.Fatal(err) + return + } + recovery := func() { + err := recover() + t.Logf("This is a pass %s", err) + } + defer recovery() + + _ = dao.Foo(c, db) + + t.Fatalf("failed") +}