Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package jsonrpc

import (
"context"
"net/http"

"github.com/goccy/go-json"
)
Expand All @@ -10,6 +11,8 @@ type (
requestIDKey struct{}
metadataIDKey struct{}
methodNameKey struct{}
requestKey struct{}
responseKey struct{}
)

// RequestID takes request id from context.
Expand Down Expand Up @@ -41,3 +44,33 @@ func MethodName(c context.Context) string {
func WithMethodName(c context.Context, name string) context.Context {
return context.WithValue(c, methodNameKey{}, name)
}

// WithRequest adds request to context.
func WithRequest(c context.Context, r *http.Request) context.Context {
return context.WithValue(c, requestKey{}, r)
}

// GetRequest takes request from context.
func GetRequest(c context.Context) *http.Request {
v := c.Value(requestKey{})
if r, ok := v.(*http.Request); ok {
return r
}

return nil
}

// WithResponse adds response to context.
func WithResponse(c context.Context, r http.ResponseWriter) context.Context {
return context.WithValue(c, responseKey{}, r)
}

// GetResponse takes response from context.
func GetResponse(c context.Context) http.ResponseWriter {
v := c.Value(responseKey{})
if r, ok := v.(http.ResponseWriter); ok {
return r
}

return nil
}
25 changes: 25 additions & 0 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ package jsonrpc

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/goccy/go-json"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -39,3 +42,25 @@ func TestMethodName(t *testing.T) {
})
require.Equal(t, t.Name(), pick)
}

func TestRequest(t *testing.T) {
assert.NotPanics(t, func() {
r := GetRequest(context.Background())
assert.Nil(t, r)
})
c := context.Background()
r := httptest.NewRequest(http.MethodPost, "/", nil)
c = WithRequest(c, r)
assert.Equal(t, r, GetRequest(c))
}

func TestResponse(t *testing.T) {
assert.NotPanics(t, func() {
r := GetResponse(context.Background())
assert.Nil(t, r)
})
c := context.Background()
r := httptest.NewRecorder()
c = WithResponse(c, r)
assert.Equal(t, r, GetResponse(c))
}
6 changes: 4 additions & 2 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func (mr *MethodRepository) ServeHTTP(w http.ResponseWriter, r *http.Request) {

resp := make([]*Response, len(rs))
for i := range rs {
resp[i] = mr.InvokeMethod(r.Context(), rs[i])
resp[i] = mr.InvokeMethod(r.Context(), rs[i], r, w)
}

if err := SendResponse(w, resp, batch); err != nil {
Expand All @@ -53,7 +53,7 @@ func (mr *MethodRepository) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

// InvokeMethod invokes JSON-RPC method.
func (mr *MethodRepository) InvokeMethod(c context.Context, r *Request) *Response {
func (mr *MethodRepository) InvokeMethod(c context.Context, r *Request, req *http.Request, w http.ResponseWriter) *Response {
var md Metadata
res := NewResponse(r)
md, res.Error = mr.TakeMethodMetadata(r)
Expand All @@ -64,6 +64,8 @@ func (mr *MethodRepository) InvokeMethod(c context.Context, r *Request) *Respons
wrappedContext := WithRequestID(c, r.ID)
wrappedContext = WithMethodName(wrappedContext, r.Method)
wrappedContext = WithMetadata(wrappedContext, md)
wrappedContext = WithRequest(wrappedContext, req)
wrappedContext = WithResponse(wrappedContext, w)
res.Result, res.Error = md.Handler.ServeJSONRPC(wrappedContext, r.Params)
if res.Error != nil {
res.Result = nil
Expand Down