diff --git a/jsonrpc/params.go b/jsonrpc/params.go index f36f690..d3c8f9a 100644 --- a/jsonrpc/params.go +++ b/jsonrpc/params.go @@ -1,9 +1,13 @@ package jsonrpc import ( + "bytes" "encoding/json" - + "fmt" "github.com/pkg/errors" + "io" + "reflect" + "strings" ) // Params is an ARRAY of json.RawMessages. This is because *Ethereum* RPCs always use @@ -74,7 +78,8 @@ func MakeParams(params ...interface{}) (Params, error) { // UnmarshalInto will decode Params into the passed in values, which // must be pointer receivers. The type of the passed in value is used to Unmarshal the data. // UnmarshalInto will fail if the parameters cannot be converted to the passed-in types. -// +// Check each type of each param, return an error if it's not the right one and which argument. + // Example: // // var blockNum string @@ -83,6 +88,7 @@ func MakeParams(params ...interface{}) (Params, error) { // // IMPORTANT: While Go will compile with non-pointer receivers, the Unmarshal attempt will // *always* fail with an error. + func (p Params) UnmarshalInto(receivers ...interface{}) error { if p == nil { return nil @@ -92,11 +98,30 @@ func (p Params) UnmarshalInto(receivers ...interface{}) error { return errors.New("not enough params to decode") } + // Return an array of the receivers' types and check if the receiver is a ptr + receiversType, err := listTypes(receivers) + if err != nil { + return err + } + + var paramElement []string + for _, i := range p { + paramElement = append(paramElement, string(i)) + } + + // Return p Params in json.RawMessage type with [] to be parsed + rawParams := json.RawMessage("[" + strings.Join(paramElement, ",") + "]") + + receiversValues, err := parsePositionalArguments(rawParams, receiversType) + if err != nil { + return err + } + for i, r := range receivers { - err := json.Unmarshal(p[i], r) - if err != nil { - return err + if receiversValues[i].IsZero() { + continue } + reflect.ValueOf(r).Elem().Set(receiversValues[i].Elem()) } return nil @@ -117,3 +142,69 @@ func (p Params) UnmarshalSingleParam(pos int, receiver interface{}) error { err := json.Unmarshal(param, receiver) return err } + +// parsePositionalArguments tries to parse the given args to an array of values with the +// given types. It returns the parsed values or an error when the args could not be +// parsed. Missing optional arguments are returned as reflect.Zero values. +func parsePositionalArguments(rawArgs json.RawMessage, types []reflect.Type) ([]reflect.Value, error) { + dec := json.NewDecoder(bytes.NewReader(rawArgs)) + var args []reflect.Value + tok, err := dec.Token() + switch { + case err == io.EOF || tok == nil && err == nil: + // "params" is optional and may be empty. Also allow "params":null even though it's + // not in the spec because our own client used to send it. + case err != nil: + return nil, err + case tok == json.Delim('['): + // Read argument array. + if args, err = parseArgumentArray(dec, types); err != nil { + return nil, err + } + default: + return nil, errors.New("non-array args") + } + // Set any missing args to nil. + for i := len(args); i < len(types); i++ { + if types[i].Kind() != reflect.Ptr { + return nil, fmt.Errorf("missing value for required argument %d", i) + } + args = append(args, reflect.Zero(types[i])) + } + return args, nil +} + +func parseArgumentArray(dec *json.Decoder, types []reflect.Type) ([]reflect.Value, error) { + args := make([]reflect.Value, 0, len(types)) + + for i := 0; dec.More(); i++ { + if i >= len(types) { //no error when decoding a subset of param + return args, nil + } + argval := reflect.New(types[i]) + + if err := dec.Decode(argval.Interface()); err != nil { + return args, fmt.Errorf("invalid argument %d: %v", i, err) + } + + if argval.IsNil() && types[i].Kind() != reflect.Ptr { + return args, fmt.Errorf("missing value for required argument %d", i) + } + args = append(args, argval.Elem()) + } + // Read end of args array. + _, err := dec.Token() + return args, err +} + +func listTypes(a []interface{}) ([]reflect.Type, error) { + var arrayType []reflect.Type + for _, i := range a { + v := reflect.ValueOf(i).Type() + if v.Kind() != reflect.Ptr { + return nil, fmt.Errorf("the receiver %d is not a pointer", i) + } + arrayType = append(arrayType, v) + } + return arrayType, nil +} diff --git a/jsonrpc/params_test.go b/jsonrpc/params_test.go index da42ce3..58b95da 100644 --- a/jsonrpc/params_test.go +++ b/jsonrpc/params_test.go @@ -1,6 +1,12 @@ package jsonrpc import ( + "bytes" + "encoding/json" + "fmt" + "github.com/INFURA/go-ethlibs/eth" + "github.com/pkg/errors" + "reflect" "testing" "github.com/stretchr/testify/assert" @@ -107,6 +113,27 @@ func TestParams_DecodeInto(t *testing.T) { return []interface{}{str}, err }, }, + { + Description: "receiver's type is a struct", + Expected: []interface{}{eth.LogFilter{FromBlock: eth.MustBlockNumberOrTag("0x3456789"), ToBlock: eth.MustBlockNumberOrTag("0x3456"), BlockHash: (*eth.Data32)(nil), Address: []eth.Address(nil), Topics: [][]eth.Data32(nil)}}, + Input: MustParams(ð.LogFilter{FromBlock: eth.MustBlockNumberOrTag("0x3456789"), ToBlock: eth.MustBlockNumberOrTag("0x3456")}), + Test: func(tc *testCase) ([]interface{}, error) { + var rec eth.LogFilter + err := tc.Input.UnmarshalInto(&rec) + return []interface{}{rec}, err + }, + }, + { + Description: "multiple element in params", + Expected: []interface{}{eth.LogFilter{FromBlock: eth.MustBlockNumberOrTag("0x3456789"), ToBlock: eth.MustBlockNumberOrTag("0x3456")}, eth.LogFilter{FromBlock: eth.MustBlockNumberOrTag("0x5678"), ToBlock: eth.MustBlockNumberOrTag("0x1234")}}, + Input: MustParams(ð.LogFilter{FromBlock: eth.MustBlockNumberOrTag("0x3456789"), ToBlock: eth.MustBlockNumberOrTag("0x3456")}, ð.LogFilter{FromBlock: eth.MustBlockNumberOrTag("0x5678"), ToBlock: eth.MustBlockNumberOrTag("0x1234")}), + Test: func(tc *testCase) ([]interface{}, error) { + var rec1 eth.LogFilter + var rec2 eth.LogFilter + err := tc.Input.UnmarshalInto(&rec1, &rec2) + return []interface{}{rec1, rec2}, err + }, + }, } for _, testCase := range testCases { @@ -134,3 +161,175 @@ func TestParams_DecodeInto(t *testing.T) { object := Object{} assert.Error(t, multiple.UnmarshalSingleParam(3, &object), "should have failed") } + +func TestParams_DecodeInto_Fail(t *testing.T) { + + type expected struct { + output []interface{} + err error + } + type testCase struct { + Description string + Expected expected + Input Params + Test func(tc *testCase) ([]interface{}, error) + } + + testCases := []testCase{ + { + Description: "params null", + Expected: expected{output: nil, err: nil}, + Input: nil, + Test: func(tc *testCase) ([]interface{}, error) { + var str string + err := tc.Input.UnmarshalInto(str) + return nil, err + }, + }, + { + Description: "len(p)