diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 37bc6d3..53dcd54 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -4,7 +4,7 @@ jobs: test: strategy: matrix: - go-version: [1.16.x] + go-version: [1.21.x] os: [ubuntu-latest] runs-on: ${{ matrix.os }} steps: @@ -16,11 +16,17 @@ jobs: uses: actions/checkout@v2 - name: Install dependencies run: | - go get -u honnef.co/go/tools/cmd/staticcheck@latest - go get -u golang.org/x/tools/cmd/goimports + go install honnef.co/go/tools/cmd/staticcheck@latest + go install golang.org/x/tools/cmd/goimports@latest - name: Run staticcheck run: staticcheck ./... - name: Check code formatting - run: test -z $(goimports -l .) + run: | + files=$(goimports -l .) + if [ -n "$files" ]; then + echo "The following files need formatting:" + echo "$files" + exit 1 + fi - name: Run Test run: go test ./... diff --git a/env.go b/env.go index bbbefc5..0a4b487 100644 --- a/env.go +++ b/env.go @@ -6,7 +6,7 @@ package jsonata import ( "errors" - "math" + "reflect" "strings" "unicode/utf8" @@ -140,7 +140,13 @@ var baseEnv = initBaseEnv(map[string]Extension{ EvalContextHandler: contextHandlerReplace, }, "formatNumber": { - Func: jlib.FormatNumber, + Func: func(x float64, picture string, options ...interface{}) (string, error) { + var opt jtypes.OptionalValue + if len(options) > 0 { + opt = jtypes.OptionalValue{Value: reflect.ValueOf(options[0])} + } + return jlib.FormatNumber(x, picture, opt) + }, UndefinedHandler: defaultUndefinedHandler, EvalContextHandler: contextHandlerFormatNumber, }, @@ -188,17 +194,27 @@ var baseEnv = initBaseEnv(map[string]Extension{ EvalContextHandler: defaultContextHandler, }, "abs": { - Func: math.Abs, + Func: jlib.Abs, + UndefinedHandler: defaultUndefinedHandler, + EvalContextHandler: defaultContextHandler, + }, + "ceil": { + Func: jlib.Ceil, UndefinedHandler: defaultUndefinedHandler, EvalContextHandler: defaultContextHandler, }, "floor": { - Func: math.Floor, + Func: jlib.Floor, UndefinedHandler: defaultUndefinedHandler, EvalContextHandler: defaultContextHandler, }, - "ceil": { - Func: math.Ceil, + "formatInteger": { + Func: jlib.FormatInteger, + UndefinedHandler: defaultUndefinedHandler, + EvalContextHandler: defaultContextHandler, + }, + "parseInteger": { + Func: jlib.ParseInteger, UndefinedHandler: defaultUndefinedHandler, EvalContextHandler: defaultContextHandler, }, diff --git a/jlib/number.go b/jlib/number.go index 8cb0e46..35c3d06 100644 --- a/jlib/number.go +++ b/jlib/number.go @@ -16,14 +16,19 @@ import ( "github.com/blues/jsonata-go/jtypes" ) -var reNumber = regexp.MustCompile(`^-?(([0-9]+))(\.[0-9]+)?([Ee][-+]?[0-9]+)?$`) +var ( + reNumber = regexp.MustCompile(`^-?(([0-9]+))(\.[0-9]+)?([Ee][-+]?[0-9]+)?$`) + reBinary = regexp.MustCompile(`^0[bB][01]+$`) + reOctal = regexp.MustCompile(`^0[oO][0-7]+$`) + reHex = regexp.MustCompile(`^0[xX][0-9a-fA-F]+$`) +) // Number converts values to numbers. Numeric values are returned // unchanged. Strings in legal JSON number format are converted -// to the number they represent. Boooleans are converted to 0 or 1. +// to the number they represent. Booleans are converted to 0 or 1. // All other types trigger an error. -func Number(value StringNumberBool) (float64, error) { - v := reflect.Value(value) +func Number(value interface{}) (float64, error) { + v := reflect.ValueOf(value) if b, ok := jtypes.AsBool(v); ok { if b { return 1, nil @@ -36,7 +41,24 @@ func Number(value StringNumberBool) (float64, error) { } s, ok := jtypes.AsString(v) - if ok && reNumber.MatchString(s) { + if !ok { + return 0, fmt.Errorf("unable to cast value to a number") + } + s = strings.TrimSpace(s) + + if reBinary.MatchString(s) { + n, _ := strconv.ParseInt(s[2:], 2, 64) + return float64(n), nil + } + if reOctal.MatchString(s) { + n, _ := strconv.ParseInt(s[2:], 8, 64) + return float64(n), nil + } + if reHex.MatchString(s) { + n, _ := strconv.ParseInt(s[2:], 16, 64) + return float64(n), nil + } + if reNumber.MatchString(s) { if n, err := strconv.ParseFloat(s, 64); err == nil { return n, nil } @@ -112,11 +134,80 @@ func Random() float64 { return rand.Float64() } +// Abs returns the absolute value of x. +func Abs(x float64) float64 { + return math.Abs(x) +} + +// Ceil returns the least integer value greater than or equal to x. +func Ceil(x float64) float64 { + return math.Ceil(x) +} + +// Floor returns the greatest integer value less than or equal to x. +func Floor(x float64) float64 { + return math.Floor(x) +} + +// FormatBase formats a number using the specified base (2-36). +func FormatBase(x float64, base jtypes.OptionalFloat64) (string, error) { + radix := 10 + if base.IsSet() { + radix = int(Round(base.Float64, jtypes.OptionalInt{})) + } + + if radix < 2 || radix > 36 { + return "", fmt.Errorf("the second argument to formatBase must be between 2 and 36") + } + n := int64(Round(x, jtypes.OptionalInt{})) + return strconv.FormatInt(n, radix), nil +} + +// FormatInteger formats an integer using the specified picture string. +func FormatInteger(x float64, picture string) (string, error) { + if picture == "" { + return strconv.FormatInt(int64(Round(x, jtypes.OptionalInt{})), 10), nil + } + return formatNumberWithPicture(x, picture, jtypes.OptionalValue{}) +} + +// FormatNumber formats a number using the specified picture string and options. +func FormatNumber(x float64, picture string, options jtypes.OptionalValue) (string, error) { + if picture == "" { + return strconv.FormatFloat(x, 'f', -1, 64), nil + } + return formatNumberWithPicture(x, picture, options) +} + +// ParseInteger parses a string as an integer using the specified base. +func ParseInteger(value interface{}, base jtypes.OptionalFloat64) (float64, error) { + s, ok := jtypes.AsString(reflect.ValueOf(value)) + if !ok { + return 0, fmt.Errorf("first argument of parseInteger must be a string") + } + s = strings.TrimSpace(s) + + radix := 10 + if base.IsSet() { + radix = int(Round(base.Float64, jtypes.OptionalInt{})) + } + + if radix < 0 || radix == 1 || radix > 36 { + return 0, fmt.Errorf("invalid base: %d", radix) + } + + n, err := strconv.ParseInt(s, radix, 64) + if err != nil { + return 0, err + } + return float64(n), nil +} + // multByPow10 multiplies a number by 10 to the power of n. // It does this by converting back and forth to strings to // avoid floating point rounding errors, e.g. // -// 4.525 * math.Pow10(2) returns 452.50000000000006 +// 4.525 * math.Pow10(2) returns 452.50000000000006 func multByPow10(x float64, n int) float64 { if n == 0 || math.IsNaN(x) || math.IsInf(x, 0) { return x diff --git a/jlib/number_test.go b/jlib/number_test.go index 51b422f..7a53296 100644 --- a/jlib/number_test.go +++ b/jlib/number_test.go @@ -12,6 +12,197 @@ import ( "github.com/blues/jsonata-go/jtypes" ) +func TestParseNumber(t *testing.T) { + tests := []struct { + name string + input interface{} + want float64 + wantErr bool + }{ + // Binary numbers + {"binary lowercase", "0b101", 5, false}, + {"binary uppercase", "0B1010", 10, false}, + {"binary long", "0b11111111", 255, false}, + {"binary with zeros", "0b00001010", 10, false}, + {"invalid binary", "0b102", 0, true}, + {"invalid binary chars", "0b1a1", 0, true}, + + // Octal numbers + {"octal lowercase", "0o12", 10, false}, + {"octal uppercase", "0O755", 493, false}, + {"octal long", "0o7777", 4095, false}, + {"octal with zeros", "0o0012", 10, false}, + {"invalid octal", "0o8", 0, true}, + {"invalid octal chars", "0o7a7", 0, true}, + + // Hexadecimal numbers + {"hex lowercase", "0x12", 18, false}, + {"hex uppercase", "0XFF", 255, false}, + {"hex mixed case", "0xDeadBeef", 3735928559, false}, + {"hex with zeros", "0x0012", 18, false}, + {"hex all letters", "0xabcdef", 11259375, false}, + {"invalid hex", "0xGG", 0, true}, + {"invalid hex chars", "0x12H4", 0, true}, + + // Edge cases and special values + {"empty string", "", 0, true}, + {"invalid prefix", "0k123", 0, true}, + {"just prefix binary", "0b", 0, true}, + {"just prefix octal", "0o", 0, true}, + {"just prefix hex", "0x", 0, true}, + {"boolean true", true, 1, false}, + {"boolean false", false, 0, false}, + {"float64", 3.14159, 3.14159, false}, + {"scientific notation positive", "1.23e-4", 0.000123, false}, + {"scientific notation negative", "-1.23e4", -12300, false}, + {"scientific notation uppercase", "1.23E+4", 12300, false}, + {"whitespace", " 42 ", 42, false}, + {"leading zeros", "00042", 42, false}, + {"negative zero", "-0", 0, false}, + {"negative number", "-42", -42, false}, + {"decimal point", "42.0", 42, false}, + {"multiple decimal points", "42.0.0", 0, true}, + {"invalid chars", "42abc", 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := jlib.Number(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("Number() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && got != tt.want { + t.Errorf("Number() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestFormatting(t *testing.T) { + tests := []struct { + name string + fn string + input float64 + args interface{} + want string + wantErr bool + }{ + {"formatBase binary", "FormatBase", 42, jtypes.NewOptionalFloat64(2), "101010", false}, + {"formatBase octal", "FormatBase", 42, jtypes.NewOptionalFloat64(8), "52", false}, + {"formatBase hex", "FormatBase", 42, jtypes.NewOptionalFloat64(16), "2a", false}, + {"formatBase zero", "FormatBase", 0, jtypes.NewOptionalFloat64(2), "0", false}, + {"formatBase negative", "FormatBase", -42, jtypes.NewOptionalFloat64(2), "-101010", false}, + {"formatBase large number", "FormatBase", 65535, jtypes.NewOptionalFloat64(16), "ffff", false}, + {"formatBase invalid base", "FormatBase", 42, jtypes.NewOptionalFloat64(37), "", true}, + {"formatBase base too small", "FormatBase", 42, jtypes.NewOptionalFloat64(1), "", true}, + {"formatInteger basic", "FormatInteger", 42, "", "42", false}, + {"formatInteger negative", "FormatInteger", -42, "", "-42", false}, + {"formatInteger zero", "FormatInteger", 0, "", "0", false}, + {"formatInteger large", "FormatInteger", 1000000, "", "1000000", false}, + {"formatNumber basic", "FormatNumber", 3.14159, "", "3.14159", false}, + {"formatNumber negative", "FormatNumber", -3.14159, "", "-3.14159", false}, + {"formatNumber zero", "FormatNumber", 0, "", "0", false}, + {"formatNumber large", "FormatNumber", 1e6, "", "1000000", false}, + {"formatNumber small", "FormatNumber", 1e-6, "", "0.000001", false}, + {"formatNumber with picture", "FormatNumber", 12345.6789, "#,###.##", "12,345.68", false}, + {"formatNumber with currency", "FormatNumber", 12345.6789, "$#,###.00", "$12,345.68", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got string + var err error + + switch tt.fn { + case "FormatBase": + got, err = jlib.FormatBase(tt.input, tt.args.(jtypes.OptionalFloat64)) + case "FormatInteger": + got, err = jlib.FormatInteger(tt.input, tt.args.(string)) + case "FormatNumber": + got, err = jlib.FormatNumber(tt.input, tt.args.(string), jtypes.OptionalValue{}) + } + + if (err != nil) != tt.wantErr { + t.Errorf("%s() error = %v, wantErr %v", tt.fn, err, tt.wantErr) + return + } + if !tt.wantErr && got != tt.want { + t.Errorf("%s() = %v, want %v", tt.fn, got, tt.want) + } + }) + } +} + +func TestParseInteger(t *testing.T) { + tests := []struct { + name string + input interface{} + base jtypes.OptionalFloat64 + want float64 + wantErr bool + }{ + {"binary", "101010", jtypes.NewOptionalFloat64(2), 42, false}, + {"binary uppercase", "101010", jtypes.NewOptionalFloat64(2), 42, false}, + {"binary with zeros", "000101", jtypes.NewOptionalFloat64(2), 5, false}, + {"octal", "52", jtypes.NewOptionalFloat64(8), 42, false}, + {"octal uppercase", "52", jtypes.NewOptionalFloat64(8), 42, false}, + {"octal with zeros", "00052", jtypes.NewOptionalFloat64(8), 42, false}, + {"decimal", "42", jtypes.NewOptionalFloat64(10), 42, false}, + {"decimal negative", "-42", jtypes.NewOptionalFloat64(10), -42, false}, + {"decimal with zeros", "00042", jtypes.NewOptionalFloat64(10), 42, false}, + {"hex", "2a", jtypes.NewOptionalFloat64(16), 42, false}, + {"hex uppercase", "2A", jtypes.NewOptionalFloat64(16), 42, false}, + {"hex with zeros", "002a", jtypes.NewOptionalFloat64(16), 42, false}, + {"invalid base", "42", jtypes.NewOptionalFloat64(37), 0, true}, + {"base too small", "42", jtypes.NewOptionalFloat64(1), 0, true}, + {"invalid digit", "2a", jtypes.NewOptionalFloat64(10), 0, true}, + {"empty string", "", jtypes.NewOptionalFloat64(10), 0, true}, + {"whitespace", " 42 ", jtypes.NewOptionalFloat64(10), 42, false}, + {"invalid chars", "42abc", jtypes.NewOptionalFloat64(10), 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := jlib.ParseInteger(tt.input, tt.base) + if (err != nil) != tt.wantErr { + t.Errorf("ParseInteger() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && got != tt.want { + t.Errorf("ParseInteger() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMathFuncs(t *testing.T) { + tests := []struct { + name string + fn func(float64) float64 + x float64 + want float64 + }{ + {"abs positive", jlib.Abs, 42, 42}, + {"abs negative", jlib.Abs, -42, 42}, + {"abs zero", jlib.Abs, 0, 0}, + {"ceil up", jlib.Ceil, 3.14, 4}, + {"ceil whole", jlib.Ceil, 42, 42}, + {"ceil negative", jlib.Ceil, -3.14, -3}, + {"floor down", jlib.Floor, 3.14, 3}, + {"floor whole", jlib.Floor, 42, 42}, + {"floor negative", jlib.Floor, -3.14, -4}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.fn(tt.x); got != tt.want { + t.Errorf("%v(%v) = %v, want %v", tt.name, tt.x, got, tt.want) + } + }) + } +} + func TestRound(t *testing.T) { data := []struct { diff --git a/jlib/string.go b/jlib/string.go index cd49c6f..9a8249b 100644 --- a/jlib/string.go +++ b/jlib/string.go @@ -13,7 +13,6 @@ import ( "net/url" "reflect" "regexp" - "strconv" "strings" "unicode/utf8" @@ -231,9 +230,9 @@ func Join(values reflect.Value, separator jtypes.OptionalString) (string, error) // regular expression in the source string. Each object in the // array has the following fields: // -// match - the substring matched by the regex -// index - the starting offset of this match -// groups - any captured groups for this match +// match - the substring matched by the regex +// index - the starting offset of this match +// groups - any captured groups for this match // // The optional third argument specifies the maximum number // of matches to return. By default, Match returns all matches. @@ -364,19 +363,9 @@ func replaceMatchFunc(src string, fn jtypes.Callable, repl StringCallable, limit var defaultDecimalFormat = jxpath.NewDecimalFormat() -// FormatNumber converts a number to a string, formatted according -// to the given picture string. See the XPath function format-number -// for the syntax of the picture string. -// -// https://www.w3.org/TR/xpath-functions-31/#formatting-numbers -// -// The optional third argument defines various formatting options -// such as the decimal separator and grouping separator. See the -// XPath documentation for details. -// -// https://www.w3.org/TR/xpath-functions-31/#defining-decimal-format -func FormatNumber(value float64, picture string, options jtypes.OptionalValue) (string, error) { - +// formatNumberWithPicture formats a number according to XPath picture string format. +// This is an internal helper used by string formatting functions. +func formatNumberWithPicture(value float64, picture string, options jtypes.OptionalValue) (string, error) { if !options.IsSet() { return jxpath.FormatNumber(value, picture, defaultDecimalFormat) } @@ -395,11 +384,9 @@ func FormatNumber(value float64, picture string, options jtypes.OptionalValue) ( } func newDecimalFormat(opts reflect.Value) (jxpath.DecimalFormat, error) { - format := jxpath.NewDecimalFormat() for _, key := range opts.MapKeys() { - k, ok := jtypes.AsString(key) if !ok { return jxpath.DecimalFormat{}, fmt.Errorf("decimal format options must be a map of strings to strings") @@ -419,7 +406,6 @@ func newDecimalFormat(opts reflect.Value) (jxpath.DecimalFormat, error) { } func updateDecimalFormat(format *jxpath.DecimalFormat, key string, value string) error { - switch key { case "infinity": format.Infinity = value @@ -453,27 +439,9 @@ func updateDecimalFormat(format *jxpath.DecimalFormat, key string, value string) return fmt.Errorf("unknown option %q", key) } } - return nil } -// FormatBase returns the string representation of a number in the -// optional base argument. If specified, the base must be between -// 2 and 36. By default, FormatBase uses base 10. -func FormatBase(value float64, base jtypes.OptionalFloat64) (string, error) { - - radix := 10 - if base.IsSet() { - radix = int(Round(base.Float64, jtypes.OptionalInt{})) - } - - if radix < 2 || radix > 36 { - return "", fmt.Errorf("the second argument to formatBase must be between 2 and 36") - } - - return strconv.FormatInt(int64(Round(value, jtypes.OptionalInt{})), radix), nil -} - // Base64Encode returns the base 64 encoding of a string. func Base64Encode(s string) (string, error) { return base64.StdEncoding.EncodeToString([]byte(s)), nil diff --git a/jparse/comment_test.go b/jparse/comment_test.go new file mode 100644 index 0000000..bbd203b --- /dev/null +++ b/jparse/comment_test.go @@ -0,0 +1,105 @@ +package jparse + +import ( + "testing" +) + +func TestCommentLexer(t *testing.T) { + cases := []struct { + name string + input string + want []tokenType + wantErr bool + }{ + { + name: "basic multiline comment", + input: "/* this is a comment */ 42", + want: []tokenType{typeNumber}, + }, + { + name: "multiline comment between tokens", + input: "1 /* comment */ + 2", + want: []tokenType{typeNumber, typePlus, typeNumber}, + }, + { + name: "inline comment", + input: "42 // this is a comment\n43", + want: []tokenType{typeNumber, typeNumber}, + }, + { + name: "inline comment at end", + input: "42 // this is a comment", + want: []tokenType{typeNumber}, + }, + { + name: "multiple inline comments", + input: "1 // first\n2 // second\n3", + want: []tokenType{typeNumber, typeNumber, typeNumber}, + }, + { + name: "mixed comment styles", + input: "1 /* multi */ 2 // inline\n3", + want: []tokenType{typeNumber, typeNumber, typeNumber}, + }, + { + name: "comment in object", + input: "{/* comment */\"a\":1}", + want: []tokenType{typeBraceOpen, typeString, typeColon, typeNumber, typeBraceClose}, + }, + { + name: "comment in array", + input: "[1,/* comment */2]", + want: []tokenType{typeBracketOpen, typeNumber, typeComma, typeNumber, typeBracketClose}, + }, + { + name: "unterminated multiline comment", + input: "/* unterminated", + want: []tokenType{typeError}, + wantErr: true, + }, + { + name: "nested-looking comment", + input: "/* outer /* inner */ 42", + want: []tokenType{typeNumber}, + }, + { + name: "complex nested comments", + input: "/* a /* b /* c */ d */ e */ 42", + want: []tokenType{typeNumber}, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + l := newLexer(tt.input) + var got []tokenType + var hasError bool + + for { + tok := l.next(true) + if tok.Type == typeEOF { + break + } + if tok.Type == typeError { + hasError = true + } + got = append(got, tok.Type) + } + + if hasError != tt.wantErr { + t.Errorf("got error = %v, want %v", hasError, tt.wantErr) + } + + if len(got) != len(tt.want) { + t.Errorf("got %d tokens, want %d", len(got), len(tt.want)) + return + } + + for i := range got { + if got[i] != tt.want[i] { + t.Errorf("token[%d] = %v, want %v", i, got[i], tt.want[i]) + } + } + }) + } +} diff --git a/jparse/context.go b/jparse/context.go new file mode 100644 index 0000000..96da943 --- /dev/null +++ b/jparse/context.go @@ -0,0 +1,49 @@ +package jparse + +type Context struct { + Input interface{} + Parent *Context + Position int + Variables map[string]interface{} +} + +func NewContext(input interface{}, parent *Context) *Context { + ctx := &Context{ + Input: input, + Parent: parent, + Position: -1, + Variables: make(map[string]interface{}), + } + if parent != nil { + if parent.Position >= 0 { + ctx.Position = parent.Position + } + if parent.Variables != nil { + for k, v := range parent.Variables { + ctx.Variables[k] = v + } + } + } + return ctx +} + +func (ctx *Context) WithVariable(name string, value interface{}) *Context { + newCtx := *ctx + if newCtx.Variables == nil { + newCtx.Variables = make(map[string]interface{}) + } + newCtx.Variables[name] = value + return &newCtx +} + +func (ctx *Context) WithPosition(pos int) *Context { + newCtx := *ctx + newCtx.Position = pos + return &newCtx +} + +func (ctx *Context) WithInput(input interface{}) *Context { + newCtx := *ctx + newCtx.Input = input + return &newCtx +} diff --git a/jparse/crossref.go b/jparse/crossref.go new file mode 100644 index 0000000..c75959d --- /dev/null +++ b/jparse/crossref.go @@ -0,0 +1,170 @@ +package jparse + +import ( + "fmt" +) + +type CrossReferenceNode struct { + LHS Node + RHS Node + Path Node +} + +func (n *CrossReferenceNode) optimize() (Node, error) { + var err error + n.LHS, err = n.LHS.optimize() + if err != nil { + return nil, err + } + n.RHS, err = n.RHS.optimize() + if err != nil { + return nil, err + } + if n.Path != nil { + n.Path, err = n.Path.optimize() + if err != nil { + return nil, err + } + } + return n, nil +} + +func (n CrossReferenceNode) String() string { + if n.Path != nil { + return fmt.Sprintf("%s@%s.%s", n.LHS, n.RHS, n.Path) + } + return fmt.Sprintf("%s@%s", n.LHS, n.RHS) +} + +func (n CrossReferenceNode) Evaluate(ctx *Context) (interface{}, error) { + lhs, err := n.LHS.Evaluate(ctx) + if err != nil { + return nil, err + } + + if lhs == nil { + return nil, nil + } + + // Handle array inputs by applying cross reference to each element + if arr, ok := lhs.([]interface{}); ok { + var results []interface{} + for i, item := range arr { + itemCtx := NewContext(item, ctx) + itemCtx.Position = i + + // Extract variable name and bind it + if varNode, ok := n.RHS.(*VariableNode); ok { + itemCtx = itemCtx.WithVariable(varNode.Name, item) + + // If there's a predicate after the variable, evaluate it + if pred := varNode.Next; pred != nil { + predCtx := itemCtx.WithInput(item) + result, err := pred.Evaluate(predCtx) + if err != nil { + return nil, err + } + + // Only include items that match the predicate + if b, ok := result.(bool); ok && b { + if n.Path != nil { + pathResult, err := n.Path.Evaluate(itemCtx) + if err != nil { + return nil, err + } + if pathResult != nil { + results = append(results, pathResult) + } + } else { + results = append(results, item) + } + } + continue + } + + if n.Path != nil { + pathResult, err := n.Path.Evaluate(itemCtx) + if err != nil { + return nil, err + } + if pathResult != nil { + results = append(results, pathResult) + } + } else { + results = append(results, item) + } + continue + } + + // Handle non-variable RHS expressions + result, err := n.RHS.Evaluate(itemCtx) + if err != nil { + return nil, err + } + if result != nil { + if n.Path != nil { + pathCtx := NewContext(result, itemCtx) + pathResult, err := n.Path.Evaluate(pathCtx) + if err != nil { + return nil, err + } + if pathResult != nil { + results = append(results, pathResult) + } + } else { + results = append(results, result) + } + } + } + if len(results) == 0 { + return nil, nil + } + if len(results) == 1 { + return results[0], nil + } + return results, nil + } + + // Create new context with LHS as input and bind variable + rhsCtx := NewContext(lhs, ctx) + + // Handle variable binding + if varNode, ok := n.RHS.(*VariableNode); ok { + rhsCtx = rhsCtx.WithVariable(varNode.Name, lhs) + + // If there's a predicate after the variable, evaluate it + if pred := varNode.Next; pred != nil { + predCtx := rhsCtx.WithInput(lhs) + result, err := pred.Evaluate(predCtx) + if err != nil { + return nil, err + } + + // Only return value if predicate matches + if b, ok := result.(bool); ok && b { + if n.Path != nil { + return n.Path.Evaluate(rhsCtx) + } + return lhs, nil + } + return nil, nil + } + + if n.Path != nil { + return n.Path.Evaluate(rhsCtx) + } + return lhs, nil + } + + result, err := n.RHS.Evaluate(rhsCtx) + if err != nil { + return nil, err + } + + if n.Path != nil && result != nil { + pathCtx := NewContext(result, rhsCtx) + return n.Path.Evaluate(pathCtx) + } + + return result, nil +} diff --git a/jparse/doc.go b/jparse/doc.go index 22826d6..c352f34 100644 --- a/jparse/doc.go +++ b/jparse/doc.go @@ -6,7 +6,7 @@ // syntax trees. Most clients will not need to work with // this package directly. // -// Usage +// # Usage // // Call the Parse function, passing a JSONata expression as // a string. If an error occurs, it will be of type Error. diff --git a/jparse/error.go b/jparse/error.go index 8bc1d57..e60db11 100644 --- a/jparse/error.go +++ b/jparse/error.go @@ -29,6 +29,7 @@ const ( ErrInvalidNumber ErrNumberRange ErrEmptyRegex + ErrUnterminatedComment ErrInvalidRegex ErrGroupPredicate ErrGroupGroup @@ -45,33 +46,34 @@ const ( ) var errmsgs = map[ErrType]string{ - ErrSyntaxError: "syntax error: '{{token}}'", - ErrUnexpectedEOF: "unexpected end of expression", - ErrUnexpectedToken: "expected token '{{hint}}', got '{{token}}'", - ErrMissingToken: "expected token '{{hint}}' before end of expression", - ErrPrefix: "the symbol '{{token}}' cannot be used as a prefix operator", - ErrInfix: "the symbol '{{token}}' cannot be used as an infix operator", - ErrUnterminatedString: "unterminated string literal (no closing '{{hint}}')", - ErrUnterminatedRegex: "unterminated regular expression (no closing '{{hint}}')", - ErrUnterminatedName: "unterminated name (no closing '{{hint}}')", - ErrIllegalEscape: "illegal escape sequence \\{{hint}}", - ErrIllegalEscapeHex: "illegal escape sequence \\{{hint}}: \\u must be followed by a 4-digit hexadecimal code point", - ErrInvalidNumber: "invalid number literal {{token}}", - ErrNumberRange: "invalid number literal {{token}}: value out of range", - ErrEmptyRegex: "invalid regular expression: expression cannot be empty", - ErrInvalidRegex: "invalid regular expression {{token}}: {{hint}}", - ErrGroupPredicate: "a predicate cannot follow a grouping expression in a path step", - ErrGroupGroup: "a path step can only have one grouping expression", - ErrPathLiteral: "invalid path step {{hint}}: paths cannot contain nulls, strings, numbers or booleans", - ErrIllegalAssignment: "illegal assignment: {{hint}} is not a variable", - ErrIllegalParam: "illegal function parameter: {{token}} is not a variable", - ErrDuplicateParam: "duplicate function parameter: {{token}}", - ErrParamCount: "invalid type signature: number of types must match number of function parameters", - ErrInvalidUnionType: "invalid type signature: unsupported union type '{{hint}}'", - ErrUnmatchedOption: "invalid type signature: option '{{hint}}' must follow a parameter", - ErrUnmatchedSubtype: "invalid type signature: subtypes must follow a parameter", - ErrInvalidSubtype: "invalid type signature: parameter type {{hint}} does not support subtypes", - ErrInvalidParamType: "invalid type signature: unknown parameter type '{{hint}}'", + ErrSyntaxError: "syntax error: '{{token}}'", + ErrUnexpectedEOF: "unexpected end of expression", + ErrUnexpectedToken: "expected token '{{hint}}', got '{{token}}'", + ErrMissingToken: "expected token '{{hint}}' before end of expression", + ErrPrefix: "the symbol '{{token}}' cannot be used as a prefix operator", + ErrInfix: "the symbol '{{token}}' cannot be used as an infix operator", + ErrUnterminatedString: "unterminated string literal (no closing '{{hint}}')", + ErrUnterminatedRegex: "unterminated regular expression (no closing '{{hint}}')", + ErrUnterminatedName: "unterminated name (no closing '{{hint}}')", + ErrIllegalEscape: "illegal escape sequence \\{{hint}}", + ErrIllegalEscapeHex: "illegal escape sequence \\{{hint}}: \\u must be followed by a 4-digit hexadecimal code point", + ErrInvalidNumber: "invalid number literal {{token}}", + ErrNumberRange: "invalid number literal {{token}}: value out of range", + ErrEmptyRegex: "invalid regular expression: expression cannot be empty", + ErrInvalidRegex: "invalid regular expression {{token}}: {{hint}}", + ErrUnterminatedComment: "unterminated comment (no closing */)", + ErrGroupPredicate: "a predicate cannot follow a grouping expression in a path step", + ErrGroupGroup: "a path step can only have one grouping expression", + ErrPathLiteral: "invalid path step {{hint}}: paths cannot contain nulls, strings, numbers or booleans", + ErrIllegalAssignment: "illegal assignment: {{hint}} is not a variable", + ErrIllegalParam: "illegal function parameter: {{token}} is not a variable", + ErrDuplicateParam: "duplicate function parameter: {{token}}", + ErrParamCount: "invalid type signature: number of types must match number of function parameters", + ErrInvalidUnionType: "invalid type signature: unsupported union type '{{hint}}'", + ErrUnmatchedOption: "invalid type signature: option '{{hint}}' must follow a parameter", + ErrUnmatchedSubtype: "invalid type signature: subtypes must follow a parameter", + ErrInvalidSubtype: "invalid type signature: parameter type {{hint}} does not support subtypes", + ErrInvalidParamType: "invalid type signature: unknown parameter type '{{hint}}'", } var reErrMsg = regexp.MustCompile("{{(token|hint)}}") diff --git a/jparse/jparse.go b/jparse/jparse.go index 01d405a..7798c19 100644 --- a/jparse/jparse.go +++ b/jparse/jparse.go @@ -52,6 +52,8 @@ var nuds = [...]nud{ typeMinus: parseNegation, typeDescendent: parseDescendent, typePipe: parseObjectTransformation, + typeParent: parseParentPrefix, + typePosition: parsePositionPrefix, typeIn: parseName, typeAnd: parseName, typeOr: parseName, @@ -68,12 +70,15 @@ var leds = [...]led{ typeApply: parseFunctionApplication, typeConcat: parseStringConcatenation, typeSort: parseSort, - typeDot: parseDot, + typeDot: parsePath, typePlus: parseNumericOperator, typeMinus: parseNumericOperator, typeMult: parseNumericOperator, typeDiv: parseNumericOperator, typeMod: parseNumericOperator, + typeParent: parseParentOperator, + typeCrossRef: parseCrossReferenceOperator, + typePosition: parsePositionOperator, typeEqual: parseComparisonOperator, typeNotEqual: parseComparisonOperator, typeLess: parseComparisonOperator, @@ -101,6 +106,9 @@ var bps = initBindingPowers([][]tokenType{ }, { typeDot, + typeParent, + typeCrossRef, + typePosition, }, { typeBraceOpen, diff --git a/jparse/lexer.go b/jparse/lexer.go index bff6df4..435d8e2 100644 --- a/jparse/lexer.go +++ b/jparse/lexer.go @@ -6,6 +6,7 @@ package jparse import ( "fmt" + "strings" "unicode/utf8" ) @@ -43,6 +44,9 @@ const ( typeMult typeDiv typeMod + typeParent // % + typeCrossRef // @ + typePosition // # typePipe typeEqual typeNotEqual @@ -108,6 +112,8 @@ var symbols1 = [...]tokenType{ '*': typeMult, '/': typeDiv, '%': typeMod, + '@': typeCrossRef, + '#': typePosition, '|': typePipe, '=': typeEqual, '<': typeLess, @@ -124,6 +130,7 @@ type runeTokenType struct { // symbols2 maps 2-character symbols to the corresponding // token types. var symbols2 = [...][]runeTokenType{ + '%': {{'%', typeMod}}, '!': {{'=', typeNotEqual}}, '<': {{'=', typeLessEqual}}, '>': {{'=', typeGreaterEqual}}, @@ -175,6 +182,7 @@ type token struct { Type tokenType Value string Position int + Flags string // Used for regex flags } // lexer converts a JSONata expression into a sequence of tokens. @@ -209,46 +217,106 @@ func newLexer(input string) lexer { // the lexer will treat a forward slash like a regular // expression. func (l *lexer) next(allowRegex bool) token { - l.skipWhitespace() + pos := l.start ch := l.nextRune() if ch == eof { return l.eof() } - if allowRegex && ch == '/' { - l.ignore() - return l.scanRegex(ch) + // Handle comments and regex + if ch == '/' { + next := l.peek() + if next == '*' || next == '/' { + l.backup() + return l.scanComment(allowRegex) + } + if allowRegex { + l.backup() + return l.scanRegex(ch) + } + return token{Type: typeDiv, Value: "/", Position: pos} } + // Handle two-character operators first if rts := lookupSymbol2(ch); rts != nil { for _, rt := range rts { if l.acceptRune(rt.r) { - return l.newToken(rt.tt) + return token{Type: rt.tt, Value: string(ch) + string(rt.r), Position: pos} } } + l.backup() } + // Handle single-character operators and special cases if tt := lookupSymbol1(ch); tt > 0 { - return l.newToken(tt) + switch ch { + case '%': + next := l.peek() + if next == '.' || next == '[' || next == ']' { + return token{Type: typeParent, Value: "%", Position: pos} + } + return token{Type: typeMod, Value: "%", Position: pos} + + case '@': + if l.acceptRune('$') { + start := l.current + for { + r := l.nextRune() + if r == eof || isWhitespace(r) || lookupSymbol1(r) > 0 || lookupSymbol2(r) != nil { + l.backup() + break + } + } + if l.current > start { + return token{Type: typeVariable, Value: l.input[start:l.current], Position: pos + 1} + } + l.backup() // Remove the $ + } + return token{Type: typeCrossRef, Value: "@", Position: pos} + + case '#': + if l.acceptRune('$') { + start := l.current + for { + r := l.nextRune() + if r == eof || isWhitespace(r) || lookupSymbol1(r) > 0 || lookupSymbol2(r) != nil { + l.backup() + break + } + } + if l.current > start { + return token{Type: typeVariable, Value: l.input[start:l.current], Position: pos + 1} + } + l.backup() // Remove the $ + } + return token{Type: typePosition, Value: "#", Position: pos} + + default: + return token{Type: tt, Value: string(ch), Position: pos} + } } + // Handle strings if ch == '"' || ch == '\'' { - l.ignore() + l.backup() return l.scanString(ch) } - if ch >= '0' && ch <= '9' { + // Handle numbers + if ch == '-' || ch == '.' || (ch >= '0' && ch <= '9') { l.backup() return l.scanNumber() } + // Handle escaped names if ch == '`' { - l.ignore() + l.backup() return l.scanEscapedName(ch) } + // Handle names and variables l.backup() return l.scanName() } @@ -257,96 +325,203 @@ func (l *lexer) next(allowRegex bool) token { // and returns a regex token. The opening delimiter has already // been consumed. func (l *lexer) scanRegex(delim rune) token { + pos := l.start + l.nextRune() // consume opening '/' - var depth int - -Loop: + var pattern strings.Builder + escaped := false for { - switch l.nextRune() { - case delim: - if depth == 0 { - break Loop - } - case '(', '[', '{': - depth++ - case ')', ']', '}': - depth-- - case '\\': - if r := l.nextRune(); r != eof && r != '\n' { - break - } - fallthrough - case eof, '\n': - return l.error(ErrUnterminatedRegex, string(delim)) + ch := l.nextRune() + if ch == eof { + return token{Type: typeError, Value: "unterminated regular expression (no closing '/')", Position: pos} + } + if escaped { + pattern.WriteRune('\\') + pattern.WriteRune(ch) + escaped = false + continue } + if ch == '\\' { + escaped = true + continue + } + if ch == '/' && !escaped { + break + } + pattern.WriteRune(ch) } - l.backup() - t := l.newToken(typeRegex) - l.acceptRune(delim) - l.ignore() + if pattern.Len() == 0 { + return token{Type: typeError, Value: "invalid regular expression: expression cannot be empty", Position: pos} + } - // Convert JavaScript-style regex flags to Go format, - // e.g. /ab+/i becomes /(?i)ab+/. - if l.acceptAll(isRegexFlag) { - flags := l.newToken(0) - t.Value = fmt.Sprintf("(?%s)%s", flags.Value, t.Value) + var flags strings.Builder + for { + ch := l.peek() + if !isRegexFlag(ch) { + break + } + l.nextRune() + flags.WriteRune(ch) } - return t + flagStr := flags.String() + if flagStr != "" { + flagStr = "(?i" + strings.ReplaceAll(flagStr, "i", "") + ")" + } + + return token{Type: typeRegex, Value: flagStr + pattern.String(), Position: pos} } // scanString reads a string literal from the current position // and returns a string token. The opening quote has already been // consumed. func (l *lexer) scanString(quote rune) token { -Loop: + pos := l.start + l.nextRune() // consume opening quote + + var value strings.Builder for { - switch l.nextRune() { - case quote: - break Loop - case '\\': - if r := l.nextRune(); r != eof { - break - } - fallthrough - case eof: + ch := l.nextRune() + if ch == eof { return l.error(ErrUnterminatedString, string(quote)) } + if ch == quote { + break + } + if ch == '\\' { + ch = l.nextRune() + if ch == eof { + return l.error(ErrUnterminatedString, string(quote)) + } + switch ch { + case 'n': + value.WriteRune('\n') + case 'r': + value.WriteRune('\r') + case 't': + value.WriteRune('\t') + case '"', '\'', '\\': + value.WriteRune(ch) + default: + value.WriteRune('\\') + value.WriteRune(ch) + } + continue + } + value.WriteRune(ch) } - l.backup() - t := l.newToken(typeString) - l.acceptRune(quote) - l.ignore() - return t + return token{ + Type: typeString, + Value: value.String(), + Position: pos, + } } // scanNumber reads a number literal from the current position // and returns a number token. func (l *lexer) scanNumber() token { + pos := l.start + + // Handle negative numbers + isNegative := l.acceptRune('-') + if isNegative && !isDigit(l.peek()) { + return token{Type: typeMinus, Value: "-", Position: pos} + } + + // Handle special number formats (hex, binary, octal) + if l.acceptRune('0') { + next := l.peek() + switch next { + case 'b', 'B': + l.nextRune() + if !l.acceptAll(isBinaryDigit) { + return token{Type: typeError, Value: "invalid binary number", Position: pos} + } + return token{Type: typeNumber, Value: l.input[pos:l.current], Position: pos} + case 'o', 'O': + l.nextRune() + if !l.acceptAll(isOctalDigit) { + return token{Type: typeError, Value: "invalid octal number", Position: pos} + } + return token{Type: typeNumber, Value: l.input[pos:l.current], Position: pos} + case 'x', 'X': + l.nextRune() + if !l.acceptAll(isHexDigit) { + return token{Type: typeError, Value: "invalid hexadecimal number", Position: pos} + } + return token{Type: typeNumber, Value: l.input[pos:l.current], Position: pos} + case '.': + l.nextRune() + if !l.acceptAll(isDigit) { + l.backup() + return token{Type: typeNumber, Value: "0", Position: pos} + } + if l.acceptRunes2('e', 'E') { + l.acceptRunes2('+', '-') + if !l.acceptAll(isDigit) { + return token{Type: typeError, Value: "invalid number literal", Position: pos} + } + } + return token{Type: typeNumber, Value: l.input[pos:l.current], Position: pos} + case 'e', 'E': + l.nextRune() + l.acceptRunes2('+', '-') + if !l.acceptAll(isDigit) { + return token{Type: typeError, Value: "invalid number literal", Position: pos} + } + return token{Type: typeNumber, Value: l.input[pos:l.current], Position: pos} + } + + // Handle single zero or leading zeros + if !isDigit(next) { + return token{Type: typeNumber, Value: "0", Position: pos} + } - // JSON does not support leading zeroes. The integer part of - // a number will either be a single zero, or a non-zero digit - // followed by zero or more digits. - if !l.acceptRune('0') { - l.accept(isNonZeroDigit) l.acceptAll(isDigit) + if l.acceptRune('.') { + if !l.acceptAll(isDigit) { + l.backup() + return token{Type: typeNumber, Value: l.input[pos : l.current-1], Position: pos} + } + } + if l.acceptRunes2('e', 'E') { + l.acceptRunes2('+', '-') + if !l.acceptAll(isDigit) { + return token{Type: typeError, Value: "invalid number literal", Position: pos} + } + } + return token{Type: typeNumber, Value: l.input[pos:l.current], Position: pos} } + + // Handle regular decimal numbers + hasDigits := l.acceptAll(isDigit) + + // Handle decimal point if l.acceptRune('.') { if !l.acceptAll(isDigit) { - // If there are no digits after the decimal point, - // don't treat the dot as part of the number. It - // could be part of the range operator, e.g. "1..5". - l.backup() - return l.newToken(typeNumber) + if !hasDigits { + l.backup() + return token{Type: typeDot, Value: ".", Position: pos} + } } } + + // Handle exponent if l.acceptRunes2('e', 'E') { l.acceptRunes2('+', '-') - l.acceptAll(isDigit) + if !l.acceptAll(isDigit) { + return token{Type: typeError, Value: "invalid number literal", Position: pos} + } + } + + // Validate number format + if !hasDigits && !strings.Contains(l.input[pos:l.current], ".") { + return token{Type: typeError, Value: "invalid number literal", Position: pos} } - return l.newToken(typeNumber) + + return token{Type: typeNumber, Value: l.input[pos:l.current], Position: pos} } // scanEscapedName reads a field name from the current position @@ -373,11 +548,9 @@ Loop: // scanName reads from the current position and returns a name, // variable, or keyword token. func (l *lexer) scanName() token { - + pos := l.start isVar := l.acceptRune('$') - if isVar { - l.ignore() - } + start := l.current for { ch := l.nextRune() @@ -385,28 +558,29 @@ func (l *lexer) scanName() token { break } - // Stop reading if we hit whitespace... - if isWhitespace(ch) { + if isWhitespace(ch) || lookupSymbol1(ch) > 0 || lookupSymbol2(ch) != nil { l.backup() break } + } - // ...or anything that looks like an operator. - if lookupSymbol1(ch) > 0 || lookupSymbol2(ch) != nil { - l.backup() - break + if l.current == start { + if isVar { + return token{Type: typeVariable, Value: "", Position: pos} } + return token{Type: typeError, Value: "empty name", Position: pos} } - t := l.newToken(typeName) - + value := l.input[start:l.current] if isVar { - t.Type = typeVariable - } else if tt := lookupKeyword(t.Value); tt > 0 { - t.Type = tt + return token{Type: typeVariable, Value: value, Position: pos + 1} } - return t + if tt := lookupKeyword(value); tt > 0 { + return token{Type: tt, Value: value, Position: pos} + } + + return token{Type: typeName, Value: value, Position: pos} } func (l *lexer) eof() token { @@ -452,13 +626,15 @@ func (l *lexer) nextRune() rune { } func (l *lexer) backup() { - // TODO: Support more than one backup operation. - // TODO: Store current rune so that when nextRune - // is called again, we don't need to repeat the call - // to DecodeRuneInString. l.current -= l.width } +func (l *lexer) peek() rune { + r := l.nextRune() + l.backup() + return r +} + func (l *lexer) ignore() { l.start = l.current } @@ -492,7 +668,13 @@ func (l *lexer) acceptAll(isValid func(rune) bool) bool { } func (l *lexer) skipWhitespace() { - l.acceptAll(isWhitespace) + for { + ch := l.peek() + if !isWhitespace(ch) { + break + } + l.nextRune() + } l.ignore() } @@ -506,20 +688,60 @@ func isWhitespace(r rune) bool { } func isRegexFlag(r rune) bool { - switch r { - case 'i', 'm', 's': - return true - default: - return false - } + return r == 'i' || r == 'I' || r == 'm' || r == 's' } func isDigit(r rune) bool { return r >= '0' && r <= '9' } -func isNonZeroDigit(r rune) bool { - return r >= '1' && r <= '9' +func isBinaryDigit(r rune) bool { + return r == '0' || r == '1' +} + +func isOctalDigit(r rune) bool { + return r >= '0' && r <= '7' +} + +func isHexDigit(r rune) bool { + return (r >= '0' && r <= '9') || (r >= 'a' && r <= 'f') || (r >= 'A' && r <= 'F') +} + +func (l *lexer) scanComment(allowRegex bool) token { + startPos := l.start + next := l.peek() + l.nextRune() // consume first '/' + + if next == '*' { + l.nextRune() // consume '*' + depth := 1 + for depth > 0 { + r := l.nextRune() + if r == eof { + return token{Type: typeError, Value: "unterminated comment (no closing */)", Position: startPos} + } + if r == '*' && l.peek() == '/' { + l.nextRune() // consume '/' + depth-- + } else if r == '/' && l.peek() == '*' { + l.nextRune() // consume '*' + depth++ + } + } + } else if next == '/' { + l.nextRune() // consume second '/' + for { + r := l.nextRune() + if r == '\n' || r == eof { + break + } + } + } + + l.start = startPos + l.ignore() + l.skipWhitespace() + return l.next(allowRegex) } // symbolsAndKeywords maps operator token types back to their diff --git a/jparse/lexer_test.go b/jparse/lexer_test.go index 8a64f00..ef3a74b 100644 --- a/jparse/lexer_test.go +++ b/jparse/lexer_test.go @@ -16,6 +16,136 @@ type lexerTestCase struct { Error error } +func TestLexerComments(t *testing.T) { + cases := []struct { + name string + input string + want []token + wantErr *Error + }{ + { + name: "basic comment", + input: "/* this is a comment */ 42", + want: []token{ + {Type: typeNumber, Value: "42", Position: 22}, + }, + }, + { + name: "comment between tokens", + input: "1 /* comment */ + 2", + want: []token{ + {Type: typeNumber, Value: "1", Position: 0}, + {Type: typePlus, Value: "+", Position: 16}, + {Type: typeNumber, Value: "2", Position: 18}, + }, + }, + { + name: "unterminated comment", + input: "/* unterminated", + want: []token{ + {Type: typeError, Value: "", Position: 0}, + }, + wantErr: &Error{ + Type: ErrUnterminatedComment, + Position: 0, + }, + }, + { + name: "nested-looking comment", + input: "/* outer /* inner */ 42", + want: []token{ + {Type: typeNumber, Value: "42", Position: 22}, + }, + }, + { + name: "multiple comments", + input: "1 /* first */ + /* second */ 2", + want: []token{ + {Type: typeNumber, Value: "1", Position: 0}, + {Type: typePlus, Value: "+", Position: 16}, + {Type: typeNumber, Value: "2", Position: 31}, + }, + }, + { + name: "comment at end", + input: "42 /* end comment */", + want: []token{ + {Type: typeNumber, Value: "42", Position: 0}, + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + l := newLexer(tt.input) + var got []token + for { + tok := l.next(true) + if tok.Type == typeEOF { + break + } + got = append(got, tok) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("got tokens = %v, want %v", got, tt.want) + } + if !reflect.DeepEqual(l.err, tt.wantErr) { + t.Errorf("got error = %v, want %v", l.err, tt.wantErr) + } + }) + } +} + +func TestLexerOperators(t *testing.T) { + testLexer(t, []lexerTestCase{ + { + Input: "10 % 3 + Account.%.name", + AllowRegex: true, + Tokens: []token{ + tok(typeNumber, "10", 0), + tok(typeMod, "%", 3), + tok(typeNumber, "3", 5), + tok(typePlus, "+", 7), + tok(typeName, "Account", 9), + tok(typeDot, ".", 16), + tok(typeParent, "%", 17), + tok(typeDot, ".", 18), + tok(typeName, "name", 19), + }, + }, + { + Input: "Orders@$O[%.Type='retail']", + AllowRegex: true, + Tokens: []token{ + tok(typeName, "Orders", 0), + tok(typeCrossRef, "@", 6), + tok(typeVariable, "O", 8), + tok(typeBracketOpen, "[", 9), + tok(typeParent, "%", 10), + tok(typeDot, ".", 11), + tok(typeName, "Type", 12), + tok(typeEqual, "=", 16), + tok(typeString, "retail", 18), + tok(typeBracketClose, "]", 24), + }, + }, + { + Input: "Orders#$i[Position=$i]", + AllowRegex: true, + Tokens: []token{ + tok(typeName, "Orders", 0), + tok(typePosition, "#", 6), + tok(typeVariable, "i", 8), + tok(typeBracketOpen, "[", 9), + tok(typeName, "Position", 10), + tok(typeEqual, "=", 18), + tok(typeVariable, "i", 20), + tok(typeBracketClose, "]", 21), + }, + }, + }) +} + func TestLexerWhitespace(t *testing.T) { testLexer(t, []lexerTestCase{ { diff --git a/jparse/node.go b/jparse/node.go index 6d2bbe4..eca5473 100644 --- a/jparse/node.go +++ b/jparse/node.go @@ -8,6 +8,7 @@ import ( "fmt" "regexp" "regexp/syntax" + "sort" "strconv" "strings" "unicode/utf16" @@ -18,14 +19,20 @@ import ( type Node interface { String() string optimize() (Node, error) + Evaluate(ctx *Context) (interface{}, error) } +// Operator parsing functions moved to parser_operators.go + // A StringNode represents a string literal. type StringNode struct { Value string } func parseString(p *parser, t token) (Node, error) { + if t.Value == "" { + return nil, newError(ErrSyntaxError, t) + } s, ok := unescape(t.Value) if !ok { @@ -33,13 +40,10 @@ func parseString(p *parser, t token) (Node, error) { if len(s) > 0 && s[0] == 'u' { typ = ErrIllegalEscapeHex } - return nil, newErrorHint(typ, t, s) } - return &StringNode{ - Value: s, - }, nil + return &StringNode{Value: s}, nil } func (n *StringNode) optimize() (Node, error) { @@ -50,26 +54,51 @@ func (n StringNode) String() string { return fmt.Sprintf("%q", n.Value) } +func (n StringNode) Evaluate(ctx *Context) (interface{}, error) { + return n.Value, nil +} + // A NumberNode represents a number literal. type NumberNode struct { Value float64 } func parseNumber(p *parser, t token) (Node, error) { + var n float64 + var err error - // Number literals are promoted to type float64. - n, err := strconv.ParseFloat(t.Value, 64) - if err != nil { - typ := ErrInvalidNumber - if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrRange { - typ = ErrNumberRange + // Handle special number formats + switch { + case strings.HasPrefix(t.Value, "0b") || strings.HasPrefix(t.Value, "0B"): + val, err := strconv.ParseInt(t.Value[2:], 2, 64) + if err != nil { + return nil, newError(ErrInvalidNumber, t) + } + n = float64(val) + case strings.HasPrefix(t.Value, "0o") || strings.HasPrefix(t.Value, "0O"): + val, err := strconv.ParseInt(t.Value[2:], 8, 64) + if err != nil { + return nil, newError(ErrInvalidNumber, t) + } + n = float64(val) + case strings.HasPrefix(t.Value, "0x") || strings.HasPrefix(t.Value, "0X"): + val, err := strconv.ParseInt(t.Value[2:], 16, 64) + if err != nil { + return nil, newError(ErrInvalidNumber, t) + } + n = float64(val) + default: + n, err = strconv.ParseFloat(t.Value, 64) + if err != nil { + typ := ErrInvalidNumber + if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrRange { + typ = ErrNumberRange + } + return nil, newError(typ, t) } - return nil, newError(typ, t) } - return &NumberNode{ - Value: n, - }, nil + return &NumberNode{Value: n}, nil } func (n *NumberNode) optimize() (Node, error) { @@ -80,6 +109,10 @@ func (n NumberNode) String() string { return fmt.Sprintf("%g", n.Value) } +func (n NumberNode) Evaluate(ctx *Context) (interface{}, error) { + return n.Value, nil +} + // A BooleanNode represents the boolean constant true or false. type BooleanNode struct { Value bool @@ -111,6 +144,10 @@ func (n BooleanNode) String() string { return fmt.Sprintf("%t", n.Value) } +func (n BooleanNode) Evaluate(ctx *Context) (interface{}, error) { + return n.Value, nil +} + // A NullNode represents the JSON null value. type NullNode struct{} @@ -126,6 +163,10 @@ func (NullNode) String() string { return "null" } +func (n NullNode) Evaluate(ctx *Context) (interface{}, error) { + return nil, nil +} + // A RegexNode represents a regular expression. type RegexNode struct { Value *regexp.Regexp @@ -164,9 +205,14 @@ func (n RegexNode) String() string { return fmt.Sprintf("/%s/", expr) } +func (n RegexNode) Evaluate(ctx *Context) (interface{}, error) { + return n.Value, nil +} + // A VariableNode represents a JSONata variable. type VariableNode struct { Name string + Next Node } func parseVariable(p *parser, t token) (Node, error) { @@ -183,6 +229,13 @@ func (n VariableNode) String() string { return "$" + n.Name } +func (n VariableNode) Evaluate(ctx *Context) (interface{}, error) { + if n.Name == "" { + return ctx.Input, nil + } + return nil, fmt.Errorf("variable %s not found", n.Name) +} + // A NameNode represents a JSON field name. type NameNode struct { Value string @@ -223,6 +276,32 @@ func (n NameNode) Escaped() bool { return n.escaped } +func (n NameNode) Evaluate(ctx *Context) (interface{}, error) { + if ctx.Input == nil { + return nil, nil + } + + switch v := ctx.Input.(type) { + case map[string]interface{}: + return v[n.Value], nil + case []interface{}: + var results []interface{} + for _, item := range v { + if m, ok := item.(map[string]interface{}); ok { + if val, exists := m[n.Value]; exists { + results = append(results, val) + } + } + } + if len(results) == 1 { + return results[0], nil + } + return results, nil + default: + return nil, nil + } +} + // A PathNode represents a JSON object path. It consists of one // or more 'steps' or Nodes (most commonly NameNode objects). type PathNode struct { @@ -242,6 +321,39 @@ func (n PathNode) String() string { return s } +func (n PathNode) Evaluate(ctx *Context) (interface{}, error) { + if len(n.Steps) == 0 { + return nil, nil + } + + var current interface{} = ctx.Input + for _, step := range n.Steps { + nextCtx := &Context{ + Parent: ctx, + Position: -1, + Input: current, + } + + var err error + current, err = step.Evaluate(nextCtx) + if err != nil { + return nil, err + } + + if current == nil { + return nil, nil + } + } + + if n.KeepArrays { + if arr, ok := current.([]interface{}); ok { + return arr, nil + } + } + + return current, nil +} + // A NegationNode represents a numeric negation operation. type NegationNode struct { RHS Node @@ -254,16 +366,12 @@ func parseNegation(p *parser, t token) (Node, error) { } func (n *NegationNode) optimize() (Node, error) { - var err error - n.RHS, err = n.RHS.optimize() if err != nil { return nil, err } - // If the operand is a number literal, negate it now - // instead of waiting for evaluation. if number, ok := n.RHS.(*NumberNode); ok { return &NumberNode{ Value: -number.Value, @@ -277,6 +385,20 @@ func (n NegationNode) String() string { return fmt.Sprintf("-%s", n.RHS) } +func (n NegationNode) Evaluate(ctx *Context) (interface{}, error) { + val, err := n.RHS.Evaluate(ctx) + if err != nil { + return nil, err + } + + num, ok := val.(float64) + if !ok { + return nil, fmt.Errorf("cannot negate non-numeric value") + } + + return -num, nil +} + // A RangeNode represents the range operator. type RangeNode struct { LHS Node @@ -304,6 +426,34 @@ func (n RangeNode) String() string { return fmt.Sprintf("%s..%s", n.LHS, n.RHS) } +func (n RangeNode) Evaluate(ctx *Context) (interface{}, error) { + start, err := n.LHS.Evaluate(ctx) + if err != nil { + return nil, err + } + + end, err := n.RHS.Evaluate(ctx) + if err != nil { + return nil, err + } + + startNum, ok := start.(float64) + if !ok { + return nil, fmt.Errorf("range start must be numeric") + } + + endNum, ok := end.(float64) + if !ok { + return nil, fmt.Errorf("range end must be numeric") + } + + var result []interface{} + for i := startNum; i <= endNum; i++ { + result = append(result, i) + } + return result, nil +} + // An ArrayNode represents an array of items. type ArrayNode struct { Items []Node @@ -360,6 +510,18 @@ func (n ArrayNode) String() string { return fmt.Sprintf("[%s]", joinNodes(n.Items, ", ")) } +func (n ArrayNode) Evaluate(ctx *Context) (interface{}, error) { + result := make([]interface{}, len(n.Items)) + for i, item := range n.Items { + val, err := item.Evaluate(ctx) + if err != nil { + return nil, err + } + result[i] = val + } + return result, nil +} + // An ObjectNode represents an object, an unordered list of // key-value pairs. type ObjectNode struct { @@ -408,7 +570,6 @@ func (n *ObjectNode) optimize() (Node, error) { } func (n ObjectNode) String() string { - values := make([]string, len(n.Pairs)) for i, pair := range n.Pairs { @@ -418,6 +579,31 @@ func (n ObjectNode) String() string { return fmt.Sprintf("{%s}", strings.Join(values, ", ")) } +func (n ObjectNode) Evaluate(ctx *Context) (interface{}, error) { + result := make(map[string]interface{}) + + for _, pair := range n.Pairs { + key, err := pair[0].Evaluate(ctx) + if err != nil { + return nil, err + } + + keyStr, ok := key.(string) + if !ok { + return nil, fmt.Errorf("object key must evaluate to string") + } + + value, err := pair[1].Evaluate(ctx) + if err != nil { + return nil, err + } + + result[keyStr] = value + } + + return result, nil +} + // A BlockNode represents a block expression. type BlockNode struct { Exprs []Node @@ -462,6 +648,18 @@ func (n BlockNode) String() string { return fmt.Sprintf("(%s)", joinNodes(n.Exprs, "; ")) } +func (n BlockNode) Evaluate(ctx *Context) (interface{}, error) { + var result interface{} + for _, expr := range n.Exprs { + var err error + result, err = expr.Evaluate(ctx) + if err != nil { + return nil, err + } + } + return result, nil +} + // A WildcardNode represents the wildcard operator. type WildcardNode struct{} @@ -477,6 +675,25 @@ func (WildcardNode) String() string { return "*" } +func (n WildcardNode) Evaluate(ctx *Context) (interface{}, error) { + if ctx.Input == nil { + return nil, nil + } + + switch v := ctx.Input.(type) { + case map[string]interface{}: + result := make([]interface{}, 0, len(v)) + for _, val := range v { + result = append(result, val) + } + return result, nil + case []interface{}: + return v, nil + default: + return nil, nil + } +} + // A DescendentNode represents the descendent operator. type DescendentNode struct{} @@ -492,6 +709,41 @@ func (DescendentNode) String() string { return "**" } +func (n DescendentNode) Evaluate(ctx *Context) (interface{}, error) { + if ctx.Input == nil { + return nil, nil + } + + var results []interface{} + + switch v := ctx.Input.(type) { + case map[string]interface{}: + for _, val := range v { + results = append(results, val) + if nested, err := n.Evaluate(&Context{Input: val}); err == nil && nested != nil { + if arr, ok := nested.([]interface{}); ok { + results = append(results, arr...) + } else { + results = append(results, nested) + } + } + } + case []interface{}: + for _, item := range v { + results = append(results, item) + if nested, err := n.Evaluate(&Context{Input: item}); err == nil && nested != nil { + if arr, ok := nested.([]interface{}); ok { + results = append(results, arr...) + } else { + results = append(results, nested) + } + } + } + } + + return results, nil +} + // An ObjectTransformationNode represents the object transformation // operator. type ObjectTransformationNode struct { @@ -545,7 +797,6 @@ func (n *ObjectTransformationNode) optimize() (Node, error) { } func (n ObjectTransformationNode) String() string { - s := fmt.Sprintf("|%s|%s", n.Pattern, n.Updates) if n.Deletes != nil { s += fmt.Sprintf(", %s", n.Deletes) @@ -554,6 +805,63 @@ func (n ObjectTransformationNode) String() string { return s } +func (n ObjectTransformationNode) Evaluate(ctx *Context) (interface{}, error) { + pattern, err := n.Pattern.Evaluate(ctx) + if err != nil { + return nil, err + } + + if pattern == nil { + return nil, nil + } + + obj, ok := pattern.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("pattern must evaluate to object") + } + + updates, err := n.Updates.Evaluate(ctx) + if err != nil { + return nil, err + } + + updateObj, ok := updates.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("updates must evaluate to object") + } + + result := make(map[string]interface{}) + for k, v := range obj { + result[k] = v + } + + for k, v := range updateObj { + result[k] = v + } + + if n.Deletes != nil { + deletes, err := n.Deletes.Evaluate(ctx) + if err != nil { + return nil, err + } + + deleteArr, ok := deletes.([]interface{}) + if !ok { + return nil, fmt.Errorf("deletes must evaluate to array") + } + + for _, key := range deleteArr { + keyStr, ok := key.(string) + if !ok { + return nil, fmt.Errorf("delete key must be string") + } + delete(result, keyStr) + } + } + + return result, nil +} + // A ParamType represents the type of a parameter in a lambda // function signature. type ParamType uint @@ -836,19 +1144,15 @@ type LambdaNode struct { } func (n *LambdaNode) optimize() (Node, error) { - var err error - n.Body, err = n.Body.optimize() if err != nil { return nil, err } - return n, nil } func (n LambdaNode) String() string { - name := "function" if n.shorthand { name = "λ" @@ -862,6 +1166,15 @@ func (n LambdaNode) String() string { return fmt.Sprintf("%s(%s){%s}", name, strings.Join(params, ", "), n.Body) } +func (n LambdaNode) Evaluate(ctx *Context) (interface{}, error) { + return map[string]interface{}{ + "__lambda": true, + "params": n.ParamNames, + "body": n.Body, + "context": ctx, + }, nil +} + // Shorthand returns true if the lambda function was defined // with the shorthand symbol "λ", and false otherwise. This // doesn't affect evaluation but may be useful when recreating @@ -879,18 +1192,15 @@ type TypedLambdaNode struct { } func (n *TypedLambdaNode) optimize() (Node, error) { - node, err := n.LambdaNode.optimize() if err != nil { return nil, err } n.LambdaNode = node.(*LambdaNode) - return n, nil } func (n TypedLambdaNode) String() string { - name := "function" if n.shorthand { name = "λ" @@ -909,12 +1219,46 @@ func (n TypedLambdaNode) String() string { return fmt.Sprintf("%s(%s)<%s>{%s}", name, strings.Join(params, ", "), strings.Join(inputs, ""), n.Body) } +func (n TypedLambdaNode) Evaluate(ctx *Context) (interface{}, error) { + return map[string]interface{}{ + "__lambda": true, + "params": n.ParamNames, + "body": n.Body, + "context": ctx, + "in": n.In, + "out": n.Out, + }, nil +} + // A PartialNode represents a partially applied function. type PartialNode struct { Func Node Args []Node } +func (n PartialNode) Evaluate(ctx *Context) (interface{}, error) { + fn, err := n.Func.Evaluate(ctx) + if err != nil { + return nil, err + } + + args := make([]interface{}, len(n.Args)) + for i, arg := range n.Args { + if _, ok := arg.(*PlaceholderNode); !ok { + val, err := arg.Evaluate(ctx) + if err != nil { + return nil, err + } + args[i] = val + } + } + + return map[string]interface{}{ + "function": fn, + "args": args, + }, nil +} + func (n *PartialNode) optimize() (Node, error) { var err error @@ -942,7 +1286,11 @@ func (n PartialNode) String() string { // in a partially applied function. type PlaceholderNode struct{} -func (n *PlaceholderNode) optimize() (Node, error) { +func (n PlaceholderNode) Evaluate(ctx *Context) (interface{}, error) { + return nil, fmt.Errorf("placeholder cannot be evaluated") +} + +func (n PlaceholderNode) optimize() (Node, error) { return n, nil } @@ -956,6 +1304,27 @@ type FunctionCallNode struct { Args []Node } +func (n FunctionCallNode) Evaluate(ctx *Context) (interface{}, error) { + fn, err := n.Func.Evaluate(ctx) + if err != nil { + return nil, err + } + + args := make([]interface{}, len(n.Args)) + for i, arg := range n.Args { + val, err := arg.Evaluate(ctx) + if err != nil { + return nil, err + } + args[i] = val + } + + return map[string]interface{}{ + "function": fn, + "args": args, + }, nil +} + const typePlaceholder = typeCondition func parseFunctionCall(p *parser, t token, lhs Node) (Node, error) { @@ -1159,6 +1528,67 @@ func (n PredicateNode) String() string { return fmt.Sprintf("%s[%s]", n.Expr, joinNodes(n.Filters, ", ")) } +func (n PredicateNode) Evaluate(ctx *Context) (interface{}, error) { + expr, err := n.Expr.Evaluate(ctx) + if err != nil { + return nil, err + } + + if expr == nil { + return nil, nil + } + + var results []interface{} + var items []interface{} + + switch v := expr.(type) { + case []interface{}: + items = v + default: + items = []interface{}{v} + } + + for i, item := range items { + itemCtx := &Context{ + Input: item, + Parent: ctx, + Position: i, + } + + match := true + for _, filter := range n.Filters { + result, err := filter.Evaluate(itemCtx) + if err != nil { + return nil, err + } + + switch v := result.(type) { + case bool: + if !v { + match = false + break + } + case float64: + if int(v) != i { + match = false + break + } + default: + match = false + } + } + + if match { + results = append(results, item) + } + } + + if len(results) == 1 { + return results[0], nil + } + return results, nil +} + // A GroupNode represents a group expression. type GroupNode struct { Expr Node @@ -1256,15 +1686,33 @@ func (n *ConditionalNode) optimize() (Node, error) { } func (n ConditionalNode) String() string { - s := fmt.Sprintf("%s ? %s", n.If, n.Then) if n.Else != nil { s += fmt.Sprintf(" : %s", n.Else) } - return s } +func (n ConditionalNode) Evaluate(ctx *Context) (interface{}, error) { + cond, err := n.If.Evaluate(ctx) + if err != nil { + return nil, err + } + + cbool, ok := cond.(bool) + if !ok { + return nil, fmt.Errorf("condition must evaluate to boolean") + } + + if cbool { + return n.Then.Evaluate(ctx) + } + if n.Else != nil { + return n.Else.Evaluate(ctx) + } + return nil, nil +} + // An AssignmentNode represents a variable assignment. type AssignmentNode struct { Name string @@ -1300,6 +1748,15 @@ func (n AssignmentNode) String() string { return fmt.Sprintf("$%s := %s", n.Name, n.Value) } +func (n AssignmentNode) Evaluate(ctx *Context) (interface{}, error) { + value, err := n.Value.Evaluate(ctx) + if err != nil { + return nil, err + } + // Store in context's variable map (to be implemented) + return value, nil +} + // A NumericOperator is a mathematical operation between two // numeric values. type NumericOperator uint8 @@ -1385,6 +1842,49 @@ func (n NumericOperatorNode) String() string { return fmt.Sprintf("%s %s %s", n.LHS, n.Type, n.RHS) } +func (n NumericOperatorNode) Evaluate(ctx *Context) (interface{}, error) { + lhs, err := n.LHS.Evaluate(ctx) + if err != nil { + return nil, err + } + + rhs, err := n.RHS.Evaluate(ctx) + if err != nil { + return nil, err + } + + lnum, ok := lhs.(float64) + if !ok { + return nil, fmt.Errorf("left operand must be numeric") + } + + rnum, ok := rhs.(float64) + if !ok { + return nil, fmt.Errorf("right operand must be numeric") + } + + switch n.Type { + case NumericAdd: + return lnum + rnum, nil + case NumericSubtract: + return lnum - rnum, nil + case NumericMultiply: + return lnum * rnum, nil + case NumericDivide: + if rnum == 0 { + return nil, fmt.Errorf("division by zero") + } + return lnum / rnum, nil + case NumericModulo: + if rnum == 0 { + return nil, fmt.Errorf("modulo by zero") + } + return float64(int64(lnum) % int64(rnum)), nil + default: + return nil, fmt.Errorf("unknown numeric operator") + } +} + // A ComparisonOperator is an operation that compares two values. type ComparisonOperator uint8 @@ -1479,6 +1979,91 @@ func (n ComparisonOperatorNode) String() string { return fmt.Sprintf("%s %s %s", n.LHS, n.Type, n.RHS) } +func (n ComparisonOperatorNode) Evaluate(ctx *Context) (interface{}, error) { + lhs, err := n.LHS.Evaluate(ctx) + if err != nil { + return nil, err + } + + rhs, err := n.RHS.Evaluate(ctx) + if err != nil { + return nil, err + } + + switch n.Type { + case ComparisonEqual: + return deepEqual(lhs, rhs), nil + case ComparisonNotEqual: + return !deepEqual(lhs, rhs), nil + case ComparisonLess: + return compareValues(lhs, rhs) < 0, nil + case ComparisonLessEqual: + return compareValues(lhs, rhs) <= 0, nil + case ComparisonGreater: + return compareValues(lhs, rhs) > 0, nil + case ComparisonGreaterEqual: + return compareValues(lhs, rhs) >= 0, nil + case ComparisonIn: + return isValueIn(lhs, rhs), nil + default: + return nil, fmt.Errorf("unknown comparison operator") + } +} + +// A BinaryNode represents a binary operation between two nodes. +type BinaryNode struct { + Op tokenType + Left Node + Right Node +} + +func (n *BinaryNode) optimize() (Node, error) { + var err error + n.Left, err = n.Left.optimize() + if err != nil { + return nil, err + } + n.Right, err = n.Right.optimize() + if err != nil { + return nil, err + } + return n, nil +} + +func (n BinaryNode) String() string { + return fmt.Sprintf("%s %s %s", n.Left, n.Op, n.Right) +} + +func (n BinaryNode) Evaluate(ctx *Context) (interface{}, error) { + lhs, err := n.Left.Evaluate(ctx) + if err != nil { + return nil, err + } + + rhs, err := n.Right.Evaluate(ctx) + if err != nil { + return nil, err + } + + switch n.Op { + case typeMod: + lnum, ok := lhs.(float64) + if !ok { + return nil, fmt.Errorf("left operand must be numeric") + } + rnum, ok := rhs.(float64) + if !ok { + return nil, fmt.Errorf("right operand must be numeric") + } + if rnum == 0 { + return nil, fmt.Errorf("modulo by zero") + } + return float64(int64(lnum) % int64(rnum)), nil + default: + return nil, fmt.Errorf("unsupported binary operator: %v", n.Op) + } +} + // A BooleanOperator is a logical AND or OR operation between // two values. type BooleanOperator uint8 @@ -1549,6 +2134,45 @@ func (n BooleanOperatorNode) String() string { return fmt.Sprintf("%s %s %s", n.LHS, n.Type, n.RHS) } +func (n BooleanOperatorNode) Evaluate(ctx *Context) (interface{}, error) { + lhs, err := n.LHS.Evaluate(ctx) + if err != nil { + return nil, err + } + + // Short-circuit evaluation for AND/OR + lbool, ok := lhs.(bool) + if !ok { + return nil, fmt.Errorf("left operand must be boolean") + } + + if n.Type == BooleanAnd && !lbool { + return false, nil + } + if n.Type == BooleanOr && lbool { + return true, nil + } + + rhs, err := n.RHS.Evaluate(ctx) + if err != nil { + return nil, err + } + + rbool, ok := rhs.(bool) + if !ok { + return nil, fmt.Errorf("right operand must be boolean") + } + + switch n.Type { + case BooleanAnd: + return lbool && rbool, nil + case BooleanOr: + return lbool || rbool, nil + default: + return nil, fmt.Errorf("unknown boolean operator") + } +} + // A StringConcatenationNode represents a string concatenation // operation. type StringConcatenationNode struct { @@ -1556,6 +2180,30 @@ type StringConcatenationNode struct { RHS Node } +func (n StringConcatenationNode) Evaluate(ctx *Context) (interface{}, error) { + lhs, err := n.LHS.Evaluate(ctx) + if err != nil { + return nil, err + } + + rhs, err := n.RHS.Evaluate(ctx) + if err != nil { + return nil, err + } + + lstr, ok := lhs.(string) + if !ok { + return nil, fmt.Errorf("left operand must be string") + } + + rstr, ok := rhs.(string) + if !ok { + return nil, fmt.Errorf("right operand must be string") + } + + return lstr + rstr, nil +} + func parseStringConcatenation(p *parser, t token, lhs Node) (Node, error) { return &StringConcatenationNode{ LHS: lhs, @@ -1664,26 +2312,71 @@ func (n *SortNode) optimize() (Node, error) { } func (n SortNode) String() string { - terms := make([]string, len(n.Terms)) for i, t := range n.Terms { - var sym string - switch t.Dir { case SortAscending: sym = "<" case SortDescending: sym = ">" } - terms[i] = sym + t.Expr.String() } return fmt.Sprintf("%s^(%s)", n.Expr, strings.Join(terms, ", ")) } +func (n SortNode) Evaluate(ctx *Context) (interface{}, error) { + expr, err := n.Expr.Evaluate(ctx) + if err != nil { + return nil, err + } + + if expr == nil { + return nil, nil + } + + items, ok := expr.([]interface{}) + if !ok { + return expr, nil + } + + sorted := make([]interface{}, len(items)) + copy(sorted, items) + + sort.SliceStable(sorted, func(i, j int) bool { + for _, term := range n.Terms { + iCtx := &Context{Input: sorted[i], Parent: ctx} + jCtx := &Context{Input: sorted[j], Parent: ctx} + + iVal, err := term.Expr.Evaluate(iCtx) + if err != nil { + return false + } + + jVal, err := term.Expr.Evaluate(jCtx) + if err != nil { + return false + } + + cmp := compareValues(iVal, jVal) + if cmp == 0 { + continue + } + + if term.Dir == SortDescending { + return cmp > 0 + } + return cmp < 0 + } + return false + }) + + return sorted, nil +} + // A FunctionApplicationNode represents a function application // operation. type FunctionApplicationNode struct { @@ -1719,73 +2412,24 @@ func (n FunctionApplicationNode) String() string { return fmt.Sprintf("%s ~> %s", n.LHS, n.RHS) } -// A dotNode is an interim structure used to process JSONata path -// expressions. It is deliberately unexported and creates a PathNode -// during its optimize phase. -type dotNode struct { - lhs Node - rhs Node -} - -func parseDot(p *parser, t token, lhs Node) (Node, error) { - return &dotNode{ - lhs: lhs, - rhs: p.parseExpression(p.bp(t.Type)), - }, nil -} - -func (n *dotNode) optimize() (Node, error) { - - path := &PathNode{} - - lhs, err := n.lhs.optimize() - if err != nil { - return nil, err - } - - switch lhs := lhs.(type) { - case *NumberNode, *StringNode, *BooleanNode, *NullNode: - // TODO: Add position info. - return nil, &Error{ - Type: ErrPathLiteral, - Hint: lhs.String(), - } - case *PathNode: - path.Steps = lhs.Steps - if lhs.KeepArrays { - path.KeepArrays = true - } - default: - path.Steps = []Node{lhs} - } - - rhs, err := n.rhs.optimize() +func (n FunctionApplicationNode) Evaluate(ctx *Context) (interface{}, error) { + lhs, err := n.LHS.Evaluate(ctx) if err != nil { return nil, err } - switch rhs := rhs.(type) { - case *NumberNode, *StringNode, *BooleanNode, *NullNode: - // TODO: Add position info. - return nil, &Error{ - Type: ErrPathLiteral, - Hint: rhs.String(), - } - case *PathNode: - path.Steps = append(path.Steps, rhs.Steps...) - if rhs.KeepArrays { - path.KeepArrays = true - } - default: - path.Steps = append(path.Steps, rhs) + rhsCtx := &Context{ + Input: lhs, + Parent: ctx, + Position: -1, } - return path, nil + return n.RHS.Evaluate(rhsCtx) } -func (n dotNode) String() string { - return fmt.Sprintf("%s.%s", n.lhs, n.rhs) -} +// A dotNode is an interim structure used to process JSONata path +// expressions. It is deliberately unexported and creates a PathNode +// during its optimize phase. // A singletonArrayNode is an interim data structure used when // processing path expressions. It is deliberately unexported @@ -1795,7 +2439,6 @@ type singletonArrayNode struct { } func (n *singletonArrayNode) optimize() (Node, error) { - lhs, err := n.lhs.optimize() if err != nil { return nil, err @@ -1817,6 +2460,25 @@ func (n singletonArrayNode) String() string { return fmt.Sprintf("%s[]", n.lhs) } +func (n singletonArrayNode) Evaluate(ctx *Context) (interface{}, error) { + result, err := n.lhs.Evaluate(ctx) + if err != nil { + return nil, err + } + + if result == nil { + return nil, nil + } + + // If result is already an array, return it as-is + if arr, ok := result.([]interface{}); ok { + return arr, nil + } + + // Otherwise wrap the single value in an array + return []interface{}{result}, nil +} + // A predicateNode is an interim data structure used when processing // predicate expressions. It is deliberately unexported and gets // converted into a PredicateNode during optimization. @@ -1826,12 +2488,8 @@ type predicateNode struct { } func parsePredicate(p *parser, t token, lhs Node) (Node, error) { - if p.token.Type == typeBracketClose { p.consume(typeBracketClose, false) - - // Empty brackets in a path mean that we should not - // flatten singleton arrays into single values. return &singletonArrayNode{ lhs: lhs, }, nil @@ -1847,7 +2505,6 @@ func parsePredicate(p *parser, t token, lhs Node) (Node, error) { } func (n *predicateNode) optimize() (Node, error) { - lhs, err := n.lhs.optimize() if err != nil { return nil, err @@ -1861,7 +2518,6 @@ func (n *predicateNode) optimize() (Node, error) { switch lhs := lhs.(type) { case *GroupNode: return nil, &Error{ - // TODO: Add position info. Type: ErrGroupPredicate, } case *PathNode: @@ -1889,6 +2545,55 @@ func (n *predicateNode) String() string { return fmt.Sprintf("%s[%s]", n.lhs, n.rhs) } +func (n predicateNode) Evaluate(ctx *Context) (interface{}, error) { + lhs, err := n.lhs.Evaluate(ctx) + if err != nil { + return nil, err + } + + if lhs == nil { + return nil, nil + } + + var items []interface{} + switch v := lhs.(type) { + case []interface{}: + items = v + default: + items = []interface{}{v} + } + + var results []interface{} + for i, item := range items { + itemCtx := &Context{ + Input: item, + Parent: ctx, + Position: i, + } + + match, err := n.rhs.Evaluate(itemCtx) + if err != nil { + return nil, err + } + + switch v := match.(type) { + case bool: + if v { + results = append(results, item) + } + case float64: + if int(v) == i { + results = append(results, item) + } + } + } + + if len(results) == 1 { + return results[0], nil + } + return results, nil +} + // Helpers func joinNodes(nodes []Node, sep string) string { @@ -1996,11 +2701,105 @@ func decodeRunes(s string, n int) (string, int) { // equivalent rune. It returns an invalid rune if the input is // not valid hex. func parseRune(hex string) rune { - n, err := strconv.ParseInt(hex, 16, 32) if err != nil { return -1 } - return rune(n) } + +func deepEqual(a, b interface{}) bool { + if a == nil || b == nil { + return a == b + } + + switch v1 := a.(type) { + case float64: + if v2, ok := b.(float64); ok { + return v1 == v2 + } + case string: + if v2, ok := b.(string); ok { + return v1 == v2 + } + case bool: + if v2, ok := b.(bool); ok { + return v1 == v2 + } + case []interface{}: + if v2, ok := b.([]interface{}); ok { + if len(v1) != len(v2) { + return false + } + for i := range v1 { + if !deepEqual(v1[i], v2[i]) { + return false + } + } + return true + } + case map[string]interface{}: + if v2, ok := b.(map[string]interface{}); ok { + if len(v1) != len(v2) { + return false + } + for k, val1 := range v1 { + val2, exists := v2[k] + if !exists || !deepEqual(val1, val2) { + return false + } + } + return true + } + } + return false +} + +func compareValues(a, b interface{}) int { + if a == nil || b == nil { + if a == nil && b == nil { + return 0 + } + if a == nil { + return -1 + } + return 1 + } + + switch v1 := a.(type) { + case float64: + if v2, ok := b.(float64); ok { + if v1 < v2 { + return -1 + } + if v1 > v2 { + return 1 + } + return 0 + } + case string: + if v2, ok := b.(string); ok { + return strings.Compare(v1, v2) + } + } + return 0 +} + +func isValueIn(needle, haystack interface{}) bool { + switch h := haystack.(type) { + case []interface{}: + for _, item := range h { + if deepEqual(needle, item) { + return true + } + } + case map[string]interface{}: + key, ok := needle.(string) + if !ok { + return false + } + _, exists := h[key] + return exists + } + return false +} diff --git a/jparse/operator_test.go b/jparse/operator_test.go new file mode 100644 index 0000000..41b2b98 --- /dev/null +++ b/jparse/operator_test.go @@ -0,0 +1,501 @@ +package jparse + +import ( + "testing" +) + +// Using existing deepEqual from node.go + +func TestOperatorEvaluation(t *testing.T) { + tests := []struct { + name string + input string + data interface{} + want interface{} + wantErr bool + }{ + // Parent operator tests + { + name: "parent operator basic", + input: "Account.%.value", + data: map[string]interface{}{ + "Account": map[string]interface{}{ + "child": map[string]interface{}{ + "id": 1, + }, + "value": "parent", + }, + }, + want: "parent", + }, + { + name: "parent operator no parent", + input: "Account.%", + data: map[string]interface{}{ + "test": "value", + }, + wantErr: true, + }, + { + name: "parent operator nested", + input: "Account.Order.Product.%.OrderID", + data: map[string]interface{}{ + "Account": map[string]interface{}{ + "Order": map[string]interface{}{ + "OrderID": "O1", + "Product": map[string]interface{}{ + "ProductID": "P1", + }, + }, + }, + }, + want: "O1", + }, + + // Cross reference operator tests + { + name: "cross reference basic", + input: "Account.Order.Product@$P.ProductID", + data: map[string]interface{}{ + "Account": map[string]interface{}{ + "Order": map[string]interface{}{ + "Product": map[string]interface{}{ + "ProductID": "P123", + "Name": "Widget", + }, + }, + }, + }, + want: "P123", + }, + { + name: "cross reference with variable binding", + input: "Account.Order.Product@$P[$P.Name='Widget'].ProductID", + data: map[string]interface{}{ + "Account": map[string]interface{}{ + "Order": map[string]interface{}{ + "Product": []interface{}{ + map[string]interface{}{ + "ProductID": "P123", + "Name": "Widget", + }, + map[string]interface{}{ + "ProductID": "P456", + "Name": "Gadget", + }, + }, + }, + }, + }, + want: "P123", + }, + { + name: "cross reference with null", + input: "missing@something", + data: map[string]interface{}{ + "test": "value", + }, + want: nil, + }, + { + name: "cross reference with array", + input: "Account.Orders@$O.Products@$P.ProductID", + data: map[string]interface{}{ + "Account": map[string]interface{}{ + "Orders": []interface{}{ + map[string]interface{}{ + "Products": []interface{}{ + map[string]interface{}{"ProductID": "P1"}, + map[string]interface{}{"ProductID": "P2"}, + }, + }, + }, + }, + }, + want: []interface{}{"P1", "P2"}, + }, + + // Position operator tests + { + name: "position operator basic", + input: "Account.Order[#=1]", + data: map[string]interface{}{ + "Account": map[string]interface{}{ + "Order": []interface{}{ + map[string]interface{}{"id": "1"}, + map[string]interface{}{"id": "2"}, + map[string]interface{}{"id": "3"}, + }, + }, + }, + want: map[string]interface{}{"id": "2"}, + }, + { + name: "position operator with expression", + input: "Account.Order[#.id='2']", + data: map[string]interface{}{ + "Account": map[string]interface{}{ + "Order": []interface{}{ + map[string]interface{}{"id": "1"}, + map[string]interface{}{"id": "2"}, + map[string]interface{}{"id": "3"}, + }, + }, + }, + want: map[string]interface{}{"id": "2"}, + }, + { + name: "position operator outside sequence", + input: "Account.Order[#]", + data: 42, + wantErr: true, + }, + { + name: "position operator in array transformation", + input: "Account.Order.$each(function($v, $i) { $i })", + data: map[string]interface{}{ + "Account": map[string]interface{}{ + "Order": []interface{}{ + "first", + "second", + "third", + }, + }, + }, + want: []interface{}{0.0, 1.0, 2.0}, + }, + + // Combined operator tests + { + name: "combined operators basic", + input: "Account.Order#$i.Product@$P[%.OrderID]", + data: map[string]interface{}{ + "Account": map[string]interface{}{ + "Order": []interface{}{ + map[string]interface{}{ + "OrderID": "O1", + "Product": map[string]interface{}{ + "ProductID": "P1", + }, + }, + }, + }, + }, + want: "O1", + }, + { + name: "combined operators with predicates", + input: "Account.Order#$i[%@$O.Type='retail'].Product@$P[%.OrderID]", + data: map[string]interface{}{ + "Account": map[string]interface{}{ + "Order": []interface{}{ + map[string]interface{}{ + "OrderID": "O1", + "Type": "retail", + "Product": map[string]interface{}{ + "ProductID": "P1", + }, + }, + map[string]interface{}{ + "OrderID": "O2", + "Type": "wholesale", + "Product": map[string]interface{}{ + "ProductID": "P2", + }, + }, + }, + }, + }, + want: "O1", + }, + { + name: "nested parent references", + input: "Account.Order.Product.{name: ProductName, order: %.OrderID, account: %%.AccountID}", + data: map[string]interface{}{ + "Account": map[string]interface{}{ + "AccountID": "A1", + "Order": map[string]interface{}{ + "OrderID": "O1", + "Product": map[string]interface{}{ + "ProductName": "Widget", + }, + }, + }, + }, + want: map[string]interface{}{ + "name": "Widget", + "order": "O1", + "account": "A1", + }, + }, + { + name: "position with parent and cross reference", + input: "Account.Order#$i[Position=#].Product@$P[%.OrderID]", + data: map[string]interface{}{ + "Account": map[string]interface{}{ + "Order": []interface{}{ + map[string]interface{}{ + "OrderID": "O1", + "Position": 0, + "Product": map[string]interface{}{ + "ProductID": "P1", + }, + }, + map[string]interface{}{ + "OrderID": "O2", + "Position": 1, + "Product": map[string]interface{}{ + "ProductID": "P2", + }, + }, + }, + }, + }, + want: "O1", + }, + { + name: "modulo operator", + input: "10 % 3", + data: nil, + want: float64(1), + }, + { + name: "modulo with variables", + input: "$a % $b", + data: map[string]interface{}{ + "a": float64(10), + "b": float64(3), + }, + want: float64(1), + }, + { + name: "parent operator vs modulo disambiguation", + input: "Account.Order.%.OrderID", + data: map[string]interface{}{ + "Account": map[string]interface{}{ + "Order": map[string]interface{}{ + "OrderID": "O1", + "Product": map[string]interface{}{ + "ProductID": "P1", + }, + }, + }, + }, + want: "O1", + }, + { + name: "parent operator with array context", + input: "Account.Orders.Products.%.OrderID", + data: map[string]interface{}{ + "Account": map[string]interface{}{ + "Orders": []interface{}{ + map[string]interface{}{ + "OrderID": "O1", + "Products": []interface{}{ + map[string]interface{}{"ProductID": "P1"}, + map[string]interface{}{"ProductID": "P2"}, + }, + }, + }, + }, + }, + want: []interface{}{"O1", "O1"}, + }, + { + name: "cross reference with array binding", + input: "Account.Orders@$O.Products@$P[$O.Type='retail'].ProductID", + data: map[string]interface{}{ + "Account": map[string]interface{}{ + "Orders": []interface{}{ + map[string]interface{}{ + "OrderID": "O1", + "Type": "retail", + "Products": []interface{}{ + map[string]interface{}{"ProductID": "P1"}, + map[string]interface{}{"ProductID": "P2"}, + }, + }, + map[string]interface{}{ + "OrderID": "O2", + "Type": "wholesale", + "Products": []interface{}{ + map[string]interface{}{"ProductID": "P3"}, + }, + }, + }, + }, + }, + want: []interface{}{"P1", "P2"}, + }, + { + name: "combined operators complex", + input: "Account.Orders@$O#$i.Products@$P[%.OrderID]", + data: map[string]interface{}{ + "Account": map[string]interface{}{ + "Orders": []interface{}{ + map[string]interface{}{ + "OrderID": "O1", + "Products": []interface{}{ + map[string]interface{}{"ProductID": "P1"}, + }, + }, + map[string]interface{}{ + "OrderID": "O2", + "Products": []interface{}{ + map[string]interface{}{"ProductID": "P2"}, + }, + }, + }, + }, + }, + want: []interface{}{"O1", "O2"}, + }, + { + name: "position operator with array transformation", + input: "Account.Orders#$i.{id: OrderID, pos: $i}", + data: map[string]interface{}{ + "Account": map[string]interface{}{ + "Orders": []interface{}{ + map[string]interface{}{"OrderID": "O1"}, + map[string]interface{}{"OrderID": "O2"}, + }, + }, + }, + want: []interface{}{ + map[string]interface{}{"id": "O1", "pos": float64(0)}, + map[string]interface{}{"id": "O2", "pos": float64(1)}, + }, + }, + { + name: "position operator with parent context", + input: "Account.Orders#$i.Products[%.Position=$i].ProductID", + data: map[string]interface{}{ + "Account": map[string]interface{}{ + "Orders": []interface{}{ + map[string]interface{}{ + "Position": float64(0), + "Products": []interface{}{ + map[string]interface{}{"ProductID": "P1"}, + map[string]interface{}{"ProductID": "P2"}, + }, + }, + map[string]interface{}{ + "Position": float64(1), + "Products": []interface{}{ + map[string]interface{}{"ProductID": "P3"}, + }, + }, + }, + }, + }, + want: []interface{}{"P1", "P2", "P3"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expr, err := Parse(tt.input) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + ctx := NewContext(tt.data, nil) + got, err := expr.Evaluate(ctx) + if (err != nil) != tt.wantErr { + t.Errorf("Evaluate() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && !deepEqual(got, tt.want) { + t.Errorf("Evaluate() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestOperatorLexer(t *testing.T) { + cases := []struct { + name string + input string + want []token + }{ + { + name: "parent operator", + input: "%.name", + want: []token{ + {Type: typeParent, Value: "%", Position: 0}, + {Type: typeDot, Value: ".", Position: 1}, + {Type: typeName, Value: "name", Position: 2}, + }, + }, + { + name: "cross reference operator", + input: "books@$B", + want: []token{ + {Type: typeName, Value: "books", Position: 0}, + {Type: typeCrossRef, Value: "@", Position: 5}, + {Type: typeVariable, Value: "B", Position: 6}, + }, + }, + { + name: "position operator", + input: "Order#3", + want: []token{ + {Type: typeName, Value: "Order", Position: 0}, + {Type: typePosition, Value: "#", Position: 5}, + {Type: typeNumber, Value: "3", Position: 6}, + }, + }, + { + name: "combined operators", + input: "Account.Order#$i.Product@$P[%.OrderID]", + want: []token{ + {Type: typeName, Value: "Account", Position: 0}, + {Type: typeDot, Value: ".", Position: 7}, + {Type: typeName, Value: "Order", Position: 8}, + {Type: typePosition, Value: "#", Position: 13}, + {Type: typeVariable, Value: "i", Position: 14}, + {Type: typeDot, Value: ".", Position: 15}, + {Type: typeName, Value: "Product", Position: 16}, + {Type: typeCrossRef, Value: "@", Position: 23}, + {Type: typeVariable, Value: "P", Position: 24}, + {Type: typeBracketOpen, Value: "[", Position: 25}, + {Type: typeParent, Value: "%", Position: 26}, + {Type: typeDot, Value: ".", Position: 27}, + {Type: typeName, Value: "OrderID", Position: 28}, + {Type: typeBracketClose, Value: "]", Position: 35}, + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + l := newLexer(tt.input) + var got []token + + for { + tok := l.next(true) + if tok.Type == typeEOF { + break + } + got = append(got, tok) + } + + if len(got) != len(tt.want) { + t.Errorf("got %d tokens, want %d", len(got), len(tt.want)) + return + } + + for i := range got { + if got[i].Type != tt.want[i].Type { + t.Errorf("token[%d].Type = %v, want %v", i, got[i].Type, tt.want[i].Type) + } + if got[i].Value != tt.want[i].Value { + t.Errorf("token[%d].Value = %q, want %q", i, got[i].Value, tt.want[i].Value) + } + if got[i].Position != tt.want[i].Position { + t.Errorf("token[%d].Position = %d, want %d", i, got[i].Position, tt.want[i].Position) + } + } + }) + } +} diff --git a/jparse/parent.go b/jparse/parent.go new file mode 100644 index 0000000..44fe4da --- /dev/null +++ b/jparse/parent.go @@ -0,0 +1,116 @@ +package jparse + +import ( + "fmt" +) + +type ParentNode struct { + Expr Node + Input Node +} + +func (n *ParentNode) optimize() (Node, error) { + var err error + if n.Expr != nil { + n.Expr, err = n.Expr.optimize() + if err != nil { + return nil, err + } + } + if n.Input != nil { + n.Input, err = n.Input.optimize() + if err != nil { + return nil, err + } + } + return n, nil +} + +func (n ParentNode) String() string { + if n.Input != nil { + return fmt.Sprintf("%s.%s", n.Input, n.Expr) + } + return fmt.Sprintf("%%%s", n.Expr) +} + +func (n ParentNode) Evaluate(ctx *Context) (interface{}, error) { + // Handle modulo operator case + if n.Input != nil { + lhs, err := n.Input.Evaluate(ctx) + if err != nil { + return nil, err + } + if lhs == nil { + return nil, nil + } + + // Create context for evaluating RHS + rhsCtx := NewContext(lhs, ctx) + rhs, err := n.Expr.Evaluate(rhsCtx) + if err != nil { + return nil, err + } + + // Handle numeric modulo operation + if ln, ok := lhs.(float64); ok { + if rn, ok := rhs.(float64); ok { + if rn == 0 { + return nil, fmt.Errorf("division by zero in modulo operation") + } + return float64(int(ln) % int(rn)), nil + } + } + + // Handle path access + if path, ok := n.Expr.(*NameNode); ok { + if obj, ok := lhs.(map[string]interface{}); ok { + return obj[path.Value], nil + } + } + + return nil, fmt.Errorf("invalid operands for modulo/parent operator") + } + + // Handle parent operator case + if ctx.Parent == nil { + return nil, fmt.Errorf("parent operator used in root context") + } + + parentCtx := ctx.Parent + if n.Expr == nil { + return parentCtx.Input, nil + } + + // Create new context with parent's input and variables + newCtx := NewContext(parentCtx.Input, parentCtx) + newCtx.Position = ctx.Position + + // Handle array inputs by applying parent operator to each element + if arr, ok := parentCtx.Input.([]interface{}); ok { + var results []interface{} + for _, item := range arr { + itemCtx := NewContext(item, parentCtx) + result, err := n.Expr.Evaluate(itemCtx) + if err != nil { + return nil, err + } + if result != nil { + results = append(results, result) + } + } + if len(results) == 0 { + return nil, nil + } + if len(results) == 1 { + return results[0], nil + } + return results, nil + } + + result, err := n.Expr.Evaluate(newCtx) + if err != nil { + return nil, err + } + + return result, nil +} diff --git a/jparse/parser_operators.go b/jparse/parser_operators.go new file mode 100644 index 0000000..387ffda --- /dev/null +++ b/jparse/parser_operators.go @@ -0,0 +1,154 @@ +package jparse + +func parseParentPrefix(p *parser, t token) (Node, error) { + // Parent operator should not be used as prefix + return nil, newError(ErrPrefix, t) +} + +func parsePositionPrefix(p *parser, t token) (Node, error) { + expr := p.parseExpression(p.bp(typePosition)) + if expr == nil { + return &PositionNode{}, nil + } + return &PositionNode{Expr: expr}, nil +} + +func parseParentOperator(p *parser, t token, lhs Node) (Node, error) { + if lhs == nil { + return nil, newError(ErrPrefix, t) + } + + // Check if this is a modulo operator in numeric context + if p.token.Type == typeNumber || p.token.Type == typeVariable || p.token.Type == typeName { + rhs := p.parseExpression(p.bp(typeMod)) + if rhs == nil { + return nil, newError(ErrSyntaxError, t) + } + return &BinaryNode{ + Op: typeMod, + Left: lhs, + Right: rhs, + }, nil + } + + // Handle parent operator with dot notation + if p.token.Type == typeDot { + p.advance(false) + if p.token.Type != typeName { + return nil, newError(ErrSyntaxError, t) + } + expr := &NameNode{Value: p.token.Value} + p.advance(false) + return &ParentNode{ + Expr: expr, + Input: lhs, + }, nil + } + + // Handle parent operator with bracket notation + if p.token.Type == typeBracketOpen { + p.advance(false) + expr := p.parseExpression(0) + if expr == nil { + return nil, newError(ErrSyntaxError, t) + } + if p.token.Type != typeBracketClose { + return nil, newError(ErrSyntaxError, p.token) + } + p.advance(false) + return &ParentNode{ + Expr: expr, + Input: lhs, + }, nil + } + + return &ParentNode{Input: lhs}, nil +} + +func parseCrossReferenceOperator(p *parser, t token, lhs Node) (Node, error) { + if lhs == nil { + return nil, newError(ErrPrefix, t) + } + + // Parse variable after @ operator + if p.token.Type != typeVariable { + return nil, newError(ErrSyntaxError, p.token) + } + + varName := p.token.Value + p.advance(false) + + var path Node + if p.token.Type == typeBracketOpen { + p.advance(false) + expr := p.parseExpression(0) + if expr == nil { + return nil, newError(ErrSyntaxError, t) + } + if p.token.Type != typeBracketClose { + return nil, newError(ErrSyntaxError, p.token) + } + p.advance(false) + path = expr + } else if p.token.Type == typeDot { + p.advance(false) + if p.token.Type != typeName { + return nil, newError(ErrSyntaxError, p.token) + } + path = &NameNode{Value: p.token.Value} + p.advance(false) + } + + return &CrossReferenceNode{ + LHS: lhs, + RHS: &VariableNode{Name: varName}, + Path: path, + }, nil +} + +func parsePositionOperator(p *parser, t token, lhs Node) (Node, error) { + if lhs == nil { + return nil, newError(ErrPrefix, t) + } + + // Handle position operator with variable binding + if p.token.Type == typeVariable { + varName := p.token.Value + p.advance(false) + + // Parse optional predicate + var predicate Node + if p.token.Type == typeBracketOpen { + p.advance(false) + expr := p.parseExpression(0) + if expr == nil { + return nil, newError(ErrSyntaxError, t) + } + if p.token.Type != typeBracketClose { + return nil, newError(ErrSyntaxError, p.token) + } + p.advance(false) + predicate = expr + } + + return &PositionNode{ + Input: lhs, + Variable: &VariableNode{Name: varName}, + Predicate: predicate, + }, nil + } + + // Parse optional expression after # + var expr Node + if p.token.Type == typeNumber || p.token.Type == typeParenOpen { + expr = p.parseExpression(p.bp(typePosition)) + if expr == nil { + return nil, newError(ErrSyntaxError, t) + } + } + + return &PositionNode{ + Input: lhs, + Expr: expr, + }, nil +} diff --git a/jparse/path.go b/jparse/path.go new file mode 100644 index 0000000..03aea5f --- /dev/null +++ b/jparse/path.go @@ -0,0 +1,27 @@ +package jparse + +func parsePath(p *parser, t token, lhs Node) (Node, error) { + if lhs == nil { + return nil, newError(ErrPrefix, t) + } + + rhs := p.parseExpression(p.bp(typeDot)) + if rhs == nil { + return nil, newError(ErrSyntaxError, t) + } + + var steps []Node + if path, ok := lhs.(*PathNode); ok { + steps = append(steps, path.Steps...) + } else { + steps = append(steps, lhs) + } + + if path, ok := rhs.(*PathNode); ok { + steps = append(steps, path.Steps...) + } else { + steps = append(steps, rhs) + } + + return &PathNode{Steps: steps}, nil +} diff --git a/jparse/position.go b/jparse/position.go new file mode 100644 index 0000000..5171956 --- /dev/null +++ b/jparse/position.go @@ -0,0 +1,135 @@ +package jparse + +import ( + "fmt" +) + +type PositionNode struct { + Input Node + Expr Node + Variable Node + Predicate Node +} + +func (n *PositionNode) optimize() (Node, error) { + var err error + if n.Input != nil { + n.Input, err = n.Input.optimize() + if err != nil { + return nil, err + } + } + if n.Expr != nil { + n.Expr, err = n.Expr.optimize() + if err != nil { + return nil, err + } + } + if n.Variable != nil { + n.Variable, err = n.Variable.optimize() + if err != nil { + return nil, err + } + } + if n.Predicate != nil { + n.Predicate, err = n.Predicate.optimize() + if err != nil { + return nil, err + } + } + return n, nil +} + +func (n PositionNode) String() string { + if n.Variable != nil { + if n.Predicate != nil { + return fmt.Sprintf("%s#%s[%s]", n.Input, n.Variable, n.Predicate) + } + return fmt.Sprintf("%s#%s", n.Input, n.Variable) + } + if n.Expr != nil { + return fmt.Sprintf("%s#%s", n.Input, n.Expr) + } + return fmt.Sprintf("%s#", n.Input) +} + +func (n PositionNode) Evaluate(ctx *Context) (interface{}, error) { + if ctx.Position < 0 && n.Input == nil { + return nil, fmt.Errorf("position operator used outside of sequence context") + } + + var input interface{} + var err error + + if n.Input != nil { + input, err = n.Input.Evaluate(ctx) + if err != nil { + return nil, err + } + } else { + input = ctx.Input + } + + if input == nil { + return nil, nil + } + + // Handle array inputs + items, ok := input.([]interface{}) + if !ok { + items = []interface{}{input} + } + + // Handle variable binding with optional predicate + if n.Variable != nil { + var results []interface{} + for i, item := range items { + itemCtx := NewContext(item, ctx) + itemCtx.Position = i + + if varNode, ok := n.Variable.(*VariableNode); ok { + itemCtx = itemCtx.WithVariable(varNode.Name, float64(i)) + } + + if n.Predicate != nil { + match, err := n.Predicate.Evaluate(itemCtx) + if err != nil { + return nil, err + } + if b, ok := match.(bool); ok && b { + results = append(results, float64(i)) + } + } else { + results = append(results, float64(i)) + } + } + if len(results) == 0 { + return nil, nil + } + if len(results) == 1 { + return results[0], nil + } + return results, nil + } + + // Handle explicit position expression + if n.Expr != nil { + pos, err := n.Expr.Evaluate(ctx) + if err != nil { + return nil, err + } + if num, ok := pos.(float64); ok { + idx := int(num) + if idx >= 0 && idx < len(items) { + return items[idx], nil + } + } + return nil, nil + } + + // Return current position + if ctx.Position >= 0 { + return float64(ctx.Position), nil + } + return nil, nil +} diff --git a/jsonata-test/main.go b/jsonata-test/main.go index 937e07e..7a0d68c 100644 --- a/jsonata-test/main.go +++ b/jsonata-test/main.go @@ -179,12 +179,13 @@ func runTest(tc testCase, dataDir string, path string) (bool, error) { // loadTestExprFile loads a jsonata expression from a file and returns the // expression // For example, one test looks like this -// { -// "expr-file": "case000.jsonata", -// "dataset": null, -// "bindings": {}, -// "result": 2 -// } +// +// { +// "expr-file": "case000.jsonata", +// "dataset": null, +// "bindings": {}, +// "result": 2 +// } // // We want to load the expression from case000.jsonata so we can use it // as an expression in the test case diff --git a/jsonata.go b/jsonata.go index 7277b6e..d61ecff 100644 --- a/jsonata.go +++ b/jsonata.go @@ -6,6 +6,7 @@ package jsonata import ( "encoding/json" + "errors" "fmt" "reflect" "sync" @@ -149,7 +150,7 @@ func (e *Expr) Eval(data interface{}) (interface{}, error) { } if !result.CanInterface() { - return nil, fmt.Errorf("Eval returned a non-interface value") + return nil, errors.New("Eval returned a non-interface value") } if result.Kind() == reflect.Ptr && result.IsNil() { @@ -218,6 +219,109 @@ func (e *Expr) String() string { return e.node.String() } +// EvalString evaluates a JSONata expression string with optional context +func EvalString(expr string, context ...interface{}) (interface{}, error) { + if expr == "" { + return nil, errors.New("empty expression string") + } + + e, err := Compile(expr) + if err != nil { + return nil, fmt.Errorf("invalid JSONata expression: %v", err) + } + + var ctx interface{} + if len(context) > 0 { + ctx = context[0] + } + + result, err := e.Eval(ctx) + if err != nil { + return nil, fmt.Errorf("evaluation error: %v", err) + } + + return result, nil +} + +// Assert evaluates a condition and returns an error if it's false +func Assert(condition interface{}, message ...interface{}) (interface{}, error) { + if condition == nil { + return nil, errors.New("first argument of assert cannot be null") + } + + cond, ok := jtypes.AsBool(reflect.ValueOf(condition)) + if !ok { + return nil, errors.New("first argument of assert must be a boolean") + } + + if !cond { + msg := "assertion failed" + if len(message) > 0 && message[0] != nil { + if str, ok := jtypes.AsString(reflect.ValueOf(message[0])); ok { + msg = str + } + } + return nil, errors.New(msg) + } + + return true, nil +} + +// Error creates an error with the given message +func Error(message interface{}) (interface{}, error) { + if message == nil { + return nil, errors.New("error") + } + + msg, ok := jtypes.AsString(reflect.ValueOf(message)) + if !ok { + return nil, errors.New("argument of error must be a string") + } + + return nil, errors.New(msg) +} + +// Single ensures a sequence contains exactly one value +func Single(values interface{}, message ...interface{}) (interface{}, error) { + if values == nil { + return nil, nil + } + + v := reflect.ValueOf(values) + if !v.IsValid() { + return nil, nil + } + + if !jtypes.IsArray(v) { + return values, nil + } + + v = jtypes.Resolve(v) + length := v.Len() + + if length == 0 { + msg := "sequence is empty" + if len(message) > 0 && message[0] != nil { + if str, ok := jtypes.AsString(reflect.ValueOf(message[0])); ok { + msg = str + } + } + return nil, errors.New(msg) + } + + if length > 1 { + msg := "sequence has more than one value" + if len(message) > 0 && message[0] != nil { + if str, ok := jtypes.AsString(reflect.ValueOf(message[0])); ok { + msg = str + } + } + return nil, errors.New(msg) + } + + return v.Index(0).Interface(), nil +} + func (e *Expr) updateRegistry(values map[string]reflect.Value) { for name, v := range values {