diff --git a/router_serve.go b/router_serve.go index 8f07909..5f77fee 100644 --- a/router_serve.go +++ b/router_serve.go @@ -32,7 +32,7 @@ func (rootRouter *Router) ServeHTTP(rw http.ResponseWriter, r *http.Request) { closure.Routers = make([]*Router, 1, rootRouter.maxChildrenDepth) closure.Routers[0] = rootRouter closure.Contexts = make([]reflect.Value, 1, rootRouter.maxChildrenDepth) - closure.Contexts[0] = reflect.New(rootRouter.contextType) + closure.Contexts[0] = reflect.New(rootRouter.contextType.Type) closure.currentMiddlewareLen = len(rootRouter.middleware) closure.RootRouter = rootRouter closure.Request.rootContext = closure.Contexts[0] @@ -220,13 +220,17 @@ func contextsFor(contexts []reflect.Value, routers []*Router) []reflect.Value { for i := 1; i < routersLen; i++ { var ctx reflect.Value - if routers[i].contextType == routers[i-1].contextType { + if routers[i].contextType.Type == routers[i-1].contextType.Type { ctx = contexts[i-1] } else { - ctx = reflect.New(routers[i].contextType) + ctxType := routers[i].contextType.Type // set the first field to the parent - f := reflect.Indirect(ctx).Field(0) - f.Set(contexts[i-1]) + if routers[i].contextType.IsDerived { + ctx = createDrivedContext(contexts[i-1], ctxType) + } else { + ctxType = reflect.PtrTo(ctxType) + ctx = getMatchedParentContext(contexts[i-1], ctxType) + } } contexts = append(contexts, ctx) } @@ -234,6 +238,35 @@ func contextsFor(contexts []reflect.Value, routers []*Router) []reflect.Value { return contexts } +func createDrivedContext(context reflect.Value, neededType reflect.Type) reflect.Value { + ctx := reflect.New(neededType) + childCtx := ctx + for { + f := reflect.Indirect(childCtx).Field(0) + if f.Type() != context.Type() && f.Kind() == reflect.Ptr { + childCtx = reflect.New(f.Type().Elem()) + f.Set(childCtx) + continue + } else { + f.Set(context) + break + } + } + return ctx +} + +func getMatchedParentContext(context reflect.Value, neededType reflect.Type) reflect.Value { + if neededType != context.Type() { + for { + context = reflect.Indirect(context).Field(0) + if context.Type() == neededType { + break + } + } + } + return context +} + // If there's a panic in the root middleware (so that we don't have a route/target), then invoke the root handler or default. // If there's a panic in other middleware, then invoke the target action's function. // If there's a panic in the action handler, then invoke the target action's function. @@ -250,19 +283,17 @@ func (rootRouter *Router) handlePanic(rw *appResponseWriter, req *Request, err i for !targetRouter.errorHandler.IsValid() && targetRouter.parent != nil { targetRouter = targetRouter.parent - - // Need to set context to the next context, UNLESS the context is the same type. - curContextStruct := reflect.Indirect(context) - if targetRouter.contextType != curContextStruct.Type() { - context = curContextStruct.Field(0) - if reflect.Indirect(context).Type() != targetRouter.contextType { - panic("bug: shouldn't get here") - } - } } } if targetRouter.errorHandler.IsValid() { + // Need to set context to the next context, UNLESS the context is the same type. + if _, err := validateContext(reflect.Indirect(reflect.New(targetRouter.contextType.Type)).Interface(), reflect.Indirect(context).Type()); err != nil { + panic(err) + } + + ctxType := reflect.PtrTo(targetRouter.contextType.Type) + context = getMatchedParentContext(context, ctxType) invoke(targetRouter.errorHandler, context, []reflect.Value{reflect.ValueOf(rw), reflect.ValueOf(req), reflect.ValueOf(err)}) } else { http.Error(rw, DefaultPanicResponse, http.StatusInternalServerError) diff --git a/router_setup.go b/router_setup.go index c248dfb..258ad54 100644 --- a/router_setup.go +++ b/router_setup.go @@ -1,6 +1,7 @@ package web import ( + "errors" "reflect" "strings" ) @@ -19,6 +20,11 @@ const ( var httpMethods = []httpMethod{httpMethodGet, httpMethodPost, httpMethodPut, httpMethodDelete, httpMethodPatch, httpMethodHead, httpMethodOptions} +type ContextSt struct { + Type reflect.Type + IsDerived bool //true if it's drived from main route, false if main route is drived from it +} + // Router implements net/http's Handler interface and is what you attach middleware, routes/handlers, and subrouters to. type Router struct { // Hierarchy: @@ -27,7 +33,7 @@ type Router struct { maxChildrenDepth int // For each request we'll create one of these objects - contextType reflect.Type + contextType ContextSt // Eg, "/" or "/admin". Any routes added to this router will be prefixed with this. pathPrefix string @@ -89,10 +95,10 @@ var emptyInterfaceType = reflect.TypeOf((*interface{})(nil)).Elem() // whose purpose is to communicate type information. On each request, an instance of this // context type will be automatically allocated and sent to handlers. func New(ctx interface{}) *Router { - validateContext(ctx, nil) + // validateContext(ctx, nil) r := &Router{} - r.contextType = reflect.TypeOf(ctx) + r.contextType = ContextSt{Type: reflect.TypeOf(ctx)} r.pathPrefix = "/" r.maxChildrenDepth = 1 r.root = make(map[httpMethod]*pathNode) @@ -116,10 +122,14 @@ func NewWithPrefix(ctx interface{}, pathPrefix string) *Router { // embed a pointer to the previous context in the first slot. You can also pass // a pathPrefix that each route will have. If "" is passed, then no path prefix is applied. func (r *Router) Subrouter(ctx interface{}, pathPrefix string) *Router { - validateContext(ctx, r.contextType) // Create new router, link up hierarchy newRouter := &Router{parent: r} + contextType, err := validateContext(ctx, r.contextType.Type) + if err != nil { + panic(err) + } + newRouter.contextType = *contextType r.children = append(r.children, newRouter) // Increment maxChildrenDepth if this is the first child of the router @@ -131,7 +141,6 @@ func (r *Router) Subrouter(ctx interface{}, pathPrefix string) *Router { } } - newRouter.contextType = reflect.TypeOf(ctx) newRouter.pathPrefix = appendPath(r.pathPrefix, pathPrefix) newRouter.root = r.root @@ -141,7 +150,7 @@ func (r *Router) Subrouter(ctx interface{}, pathPrefix string) *Router { // Middleware adds the specified middleware tot he router and returns the router. func (r *Router) Middleware(fn interface{}) *Router { vfn := reflect.ValueOf(fn) - validateMiddleware(vfn, r.contextType) + validateMiddleware(vfn, r.contextType.Type) if vfn.Type().NumIn() == 3 { r.middleware = append(r.middleware, &middlewareHandler{Generic: true, GenericMiddleware: fn.(func(ResponseWriter, *Request, NextMiddlewareFunc))}) } else { @@ -154,7 +163,7 @@ func (r *Router) Middleware(fn interface{}) *Router { // Error sets the specified function as the error handler (when panics happen) and returns the router. func (r *Router) Error(fn interface{}) *Router { vfn := reflect.ValueOf(fn) - validateErrorHandler(vfn, r.contextType) + validateErrorHandler(vfn, r.contextType.Type) r.errorHandler = vfn return r } @@ -166,7 +175,7 @@ func (r *Router) NotFound(fn interface{}) *Router { panic("You can only set a NotFoundHandler on the root router.") } vfn := reflect.ValueOf(fn) - validateNotFoundHandler(vfn, r.contextType) + validateNotFoundHandler(vfn, r.contextType.Type) r.notFoundHandler = vfn return r } @@ -178,7 +187,7 @@ func (r *Router) OptionsHandler(fn interface{}) *Router { panic("You can only set an OptionsHandler on the root router.") } vfn := reflect.ValueOf(fn) - validateOptionsHandler(vfn, r.contextType) + validateOptionsHandler(vfn, r.contextType.Type) r.optionsHandler = vfn return r } @@ -220,7 +229,7 @@ func (r *Router) Options(path string, fn interface{}) *Router { func (r *Router) addRoute(method httpMethod, path string, fn interface{}) *Router { vfn := reflect.ValueOf(fn) - validateHandler(vfn, r.contextType) + validateHandler(vfn, r.contextType.Type) fullPath := appendPath(r.pathPrefix, path) route := &route{Method: method, Path: fullPath, Router: r} if vfn.Type().NumIn() == 2 { @@ -249,26 +258,40 @@ func (r *Router) depth() int { // Private methods: // -// Panics unless validation is correct -func validateContext(ctx interface{}, parentCtxType reflect.Type) { - ctxType := reflect.TypeOf(ctx) - - if ctxType.Kind() != reflect.Struct { - panic("web: Context needs to be a struct type") - } - - if parentCtxType != nil && parentCtxType != ctxType { - if ctxType.NumField() == 0 { - panic("web: Context needs to have first field be a pointer to parent context") +// validate contexts +func validateContext(ctx interface{}, parentCtxType reflect.Type) (*ContextSt, error) { + doCheck := func(ctxType reflect.Type, parentCtxType reflect.Type) error { + for { + if ctxType.Kind() == reflect.Ptr { + ctxType = ctxType.Elem() + } + if ctxType.Kind() != reflect.Struct { + if ctxType == reflect.TypeOf(ctx) { + return errors.New("web: Context needs to be a struct type\n " + ctxType.String()) + } + return errors.New("web: Context needs to have first field be a pointer to parent context\n" + + "Main Context: " + parentCtxType.String() + " Given Context: " + reflect.TypeOf(ctx).String()) + + } + if ctxType == parentCtxType { + break + } + if ctxType.NumField() == 0 { + return errors.New("web: Context needs to have first field be a pointer to parent context") + } + ctxType = ctxType.Field(0).Type } + return nil + } - fldType := ctxType.Field(0).Type - - // Ensure fld is a pointer to parentCtxType - if fldType != reflect.PtrTo(parentCtxType) { - panic("web: Context needs to have first field be a pointer to parent context") + ctxType := reflect.TypeOf(ctx) + if err1 := doCheck(ctxType, parentCtxType); err1 != nil { + if err2 := doCheck(parentCtxType, ctxType); err2 != nil { + return nil, err1 } + return &ContextSt{ctxType, false}, nil } + return &ContextSt{ctxType, true}, nil } // Panics unless fn is a proper handler wrt ctxType @@ -338,8 +361,10 @@ func isValidHandler(vfn reflect.Value, ctxType reflect.Type, types ...reflect.Ty } else if numIn == (typesLen + 1) { // context, types firstArgType := fnType.In(0) - if firstArgType != reflect.PtrTo(ctxType) && firstArgType != emptyInterfaceType { - return false + if firstArgType != emptyInterfaceType { + if _, err := validateContext(reflect.Indirect(reflect.New(firstArgType.Elem())).Interface(), ctxType); err != nil { + return false + } } typesStartIdx = 1 } else {