diff --git a/handler.go b/handler.go index c748e74..886766a 100644 --- a/handler.go +++ b/handler.go @@ -19,6 +19,9 @@ type Handler interface { // jsonrpc.Handler that calls f. type HandlerFunc func(c context.Context, params *json.RawMessage) (result interface{}, err *Error) +// MiddlewareFunc defines a function to process middleware. +type MiddlewareFunc func(HandlerFunc) HandlerFunc + // ServeJSONRPC calls f(w, r). func (f HandlerFunc) ServeJSONRPC(c context.Context, params *json.RawMessage) (result interface{}, err *Error) { return f(c, params) @@ -64,7 +67,8 @@ func (mr *MethodRepository) InvokeMethod(c context.Context, r *Request) *Respons wrappedContext := WithRequestID(c, r.ID) wrappedContext = WithMethodName(wrappedContext, r.Method) wrappedContext = WithMetadata(wrappedContext, md) - res.Result, res.Error = md.Handler.ServeJSONRPC(wrappedContext, r.Params) + handler := applyMiddleware(md.Handler, md.Middlewares...) + res.Result, res.Error = handler.ServeJSONRPC(wrappedContext, r.Params) if res.Error != nil { res.Result = nil } diff --git a/handler_test.go b/handler_test.go index c098ad7..920c572 100644 --- a/handler_test.go +++ b/handler_test.go @@ -69,3 +69,39 @@ func TestHandler(t *testing.T) { require.NoError(t, err) assert.NotNil(t, res.Error) } + +func TestInvokeMethodMiddlewares(t *testing.T) { + ctx := context.Background() + id := json.RawMessage("test") + r := &Request{ + Version: "2.0", + Method: "test", + ID: &id, + } + + mr := NewMethodRepository() + err := mr.RegisterMethod("test", HandlerFunc(func(c context.Context, params *json.RawMessage) (result interface{}, err *Error) { + v := c.Value("key1") + require.NotNil(t, v) + v = c.Value("key2") + require.NotNil(t, v) + return "value3", nil + }), nil, nil, func(next HandlerFunc) HandlerFunc { + return func(c context.Context, params *json.RawMessage) (result interface{}, err *Error) { + c = context.WithValue(c, "key1", "value1") + return next(c, params) + } + }, func(next HandlerFunc) HandlerFunc { + return func(c context.Context, params *json.RawMessage) (result interface{}, err *Error) { + v := c.Value("key1") + require.NotNil(t, v) + c = context.WithValue(c, "key2", "value2") + return next(c, params) + } + }) + require.NoError(t, err) + + resp := mr.InvokeMethod(ctx, r) + require.Nil(t, resp.Error) + require.NotNil(t, resp.Result) +} diff --git a/method.go b/method.go index 241d824..8c9f640 100644 --- a/method.go +++ b/method.go @@ -8,14 +8,16 @@ import ( type ( // A MethodRepository has JSON-RPC method functions. MethodRepository struct { - m sync.RWMutex - r map[string]Metadata + middlewares []MiddlewareFunc + m sync.RWMutex + r map[string]Metadata } // Metadata has method meta data. Metadata struct { - Handler Handler - Params interface{} - Result interface{} + Middlewares []MiddlewareFunc + Handler Handler + Params interface{} + Result interface{} } ) @@ -53,15 +55,21 @@ func (mr *MethodRepository) TakeMethod(r *Request) (Handler, *Error) { } // RegisterMethod registers jsonrpc.Func to MethodRepository. -func (mr *MethodRepository) RegisterMethod(method string, h Handler, params, result interface{}) error { +func (mr *MethodRepository) RegisterMethod(method string, h Handler, params, result interface{}, middlewares ...MiddlewareFunc) error { if method == "" || h == nil { return errors.New("jsonrpc: method name and function should not be empty") } + + m := make([]MiddlewareFunc, 0, len(mr.middlewares)+len(middlewares)) + m = append(m, mr.middlewares...) + m = append(m, middlewares...) + mr.m.Lock() mr.r[method] = Metadata{ - Handler: h, - Params: params, - Result: result, + Handler: h, + Params: params, + Result: result, + Middlewares: m, } mr.m.Unlock() return nil diff --git a/method_test.go b/method_test.go index fb41c9d..d14a294 100644 --- a/method_test.go +++ b/method_test.go @@ -45,6 +45,14 @@ func TestRegisterMethod(t *testing.T) { err = mr.RegisterMethod("test", SampleHandler(), nil, nil) require.NoError(t, err) + + err = mr.RegisterMethod("test", SampleHandler(), nil, nil, nil) + require.NoError(t, err) + assert.Len(t, mr.r["test"].Middlewares, 1) + + err = mr.RegisterMethod("test", SampleHandler(), nil, nil, nil, nil) + require.NoError(t, err) + assert.Len(t, mr.r["test"].Middlewares, 2) } func TestMethods(t *testing.T) { diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000..41bc582 --- /dev/null +++ b/middleware.go @@ -0,0 +1,14 @@ +package jsonrpc + +// UseMiddleware adds middlewares +func (mr *MethodRepository) UseMiddleware(middlewares ...MiddlewareFunc) { + mr.middlewares = append(mr.middlewares, middlewares...) +} + +// applyMiddleware applies middlewares to Handler +func applyMiddleware(h Handler, middleware ...MiddlewareFunc) Handler { + for i := len(middleware) - 1; i >= 0; i-- { + h = middleware[i](h.ServeJSONRPC) + } + return h +} diff --git a/middleware_test.go b/middleware_test.go new file mode 100644 index 0000000..23cb06d --- /dev/null +++ b/middleware_test.go @@ -0,0 +1,59 @@ +package jsonrpc + +import ( + "bytes" + "context" + "testing" + + "github.com/goccy/go-json" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUseMiddleware(t *testing.T) { + mr := NewMethodRepository() + assert.Len(t, mr.middlewares, 0) + mr.UseMiddleware(nil) + assert.Len(t, mr.middlewares, 1) + mr.UseMiddleware(nil, nil) + assert.Len(t, mr.middlewares, 3) +} + +func TestMiddlewareOrder(t *testing.T) { + buf := bytes.Buffer{} + mw := func(s string) MiddlewareFunc { + return func(next HandlerFunc) HandlerFunc { + return func(c context.Context, params *json.RawMessage) (result interface{}, err *Error) { + buf.WriteString(s) + return next(c, params) + } + } + } + + mr := NewMethodRepository() + mr.UseMiddleware(mw("-1"), mw("1")) + mr.UseMiddleware(mw("2")) + + ctx := context.Background() + id := json.RawMessage("test") + r := &Request{ + Version: "2.0", + Method: "test", + ID: &id, + } + + err := mr.RegisterMethod("test", HandlerFunc(func(c context.Context, params *json.RawMessage) (result interface{}, err *Error) { + return nil, nil + }), + nil, + nil, + mw("3"), + mw("4"), + mw("5"), + ) + require.NoError(t, err) + + resp := mr.InvokeMethod(ctx, r) + require.Nil(t, resp.Error) + assert.Equal(t, "-112345", buf.String()) +}