@@ -41,8 +41,12 @@ func (fn RendererFunc) Render(ctx *Context, err error, result interface{}) {
4141//
4242// func(ctx context.Context) R
4343//
44+ // func(ctx context.Context) error
45+ //
4446// func(ctx context.Context, req T) R
4547//
48+ // func(ctx context.Context, req T) error
49+ //
4650// func(ctx context.Context, req T) (R, error)
4751//
4852// func(writer http.ResponseWriter, request *http.Request)
@@ -53,18 +57,20 @@ func Bind(fn interface{}, render Renderer) http.HandlerFunc {
5357
5458 switch h := fn .(type ) {
5559 case http.HandlerFunc :
56- return h
60+ return warpHandlerCtx ( h )
5761 case http.Handler :
58- return h .ServeHTTP
62+ return warpHandlerCtx ( h .ServeHTTP )
5963 case func (http.ResponseWriter , * http.Request ):
60- return h
64+ return warpHandlerCtx ( h )
6165 default :
6266 // valid func
6367 if err := validMappingFunc (fnType ); nil != err {
6468 panic (err )
6569 }
6670 }
6771
72+ firstOutIsErrorType := 1 == fnType .NumOut () && utils .IsErrorType (fnType .Out (0 ))
73+
6874 return func (writer http.ResponseWriter , request * http.Request ) {
6975
7076 // param of context
@@ -128,14 +134,15 @@ func Bind(fn interface{}, render Renderer) http.HandlerFunc {
128134 // nothing
129135 return
130136 case 1 :
131- // write response
132- result = returnValues [0 ].Interface ()
137+ if firstOutIsErrorType {
138+ err , _ = returnValues [0 ].Interface ().(error )
139+ } else {
140+ result = returnValues [0 ].Interface ()
141+ }
133142 case 2 :
134143 // check error
135144 result = returnValues [0 ].Interface ()
136- if e , ok := returnValues [1 ].Interface ().(error ); ok && nil != e {
137- err = e
138- }
145+ err , _ = returnValues [1 ].Interface ().(error )
139146 default :
140147 panic ("unreachable here" )
141148 }
@@ -149,7 +156,9 @@ func Bind(fn interface{}, render Renderer) http.HandlerFunc {
149156func validMappingFunc (fnType reflect.Type ) error {
150157 // func(ctx context.Context)
151158 // func(ctx context.Context) R
159+ // func(ctx context.Context) error
152160 // func(ctx context.Context, req T) R
161+ // func(ctx context.Context, req T) error
153162 // func(ctx context.Context, req T) (R, error)
154163 if ! utils .IsFuncType (fnType ) {
155164 return fmt .Errorf ("%s: not a func" , fnType .String ())
@@ -174,13 +183,30 @@ func validMappingFunc(fnType reflect.Type) error {
174183 }
175184 }
176185
177- if 0 < fnType .NumOut () && utils .IsErrorType (fnType .Out (0 )) {
178- return fmt .Errorf ("%s: first output param type not be error" , fnType .String ())
179- }
186+ switch fnType .NumOut () {
187+ case 0 : // nothing
188+ case 1 : // R | error
189+ case 2 : // (R, error)
190+ if utils .IsErrorType (fnType .Out (0 )) {
191+ return fmt .Errorf ("%s: first output param type not be error" , fnType .String ())
192+ }
180193
181- if 1 < fnType .NumOut () && ! utils .IsErrorType (fnType .Out (1 )) {
182- return fmt .Errorf ("%s: second output type (%s) must a error" , fnType .String (), fnType .Out (1 ).String ())
194+ if ! utils .IsErrorType (fnType .Out (1 )) {
195+ return fmt .Errorf ("%s: second output type (%s) must a error" , fnType .String (), fnType .Out (1 ).String ())
196+ }
183197 }
184198
185199 return nil
186200}
201+
202+ func warpHandlerCtx (handler http.HandlerFunc ) http.HandlerFunc {
203+ return func (writer http.ResponseWriter , request * http.Request ) {
204+ webCtx := & Context {Writer : writer , Request : request }
205+ handler .ServeHTTP (writer , requestWithCtx (request , webCtx ))
206+ }
207+ }
208+
209+ func requestWithCtx (r * http.Request , webCtx * Context ) * http.Request {
210+ ctx := WithContext (r .Context (), webCtx )
211+ return r .WithContext (ctx )
212+ }
0 commit comments