diff --git a/file/error.go b/file/error.go index 8ff85dfa5..04253b6f6 100644 --- a/file/error.go +++ b/file/error.go @@ -3,7 +3,6 @@ package file import ( "fmt" "strings" - "unicode/utf8" ) type Error struct { @@ -19,43 +18,47 @@ func (e *Error) Error() string { return e.format() } +var tabReplacer = strings.NewReplacer("\t", " ") + func (e *Error) Bind(source Source) *Error { + src := source.String() + + var runeCount, lineStart int e.Line = 1 - for i, r := range source { - if i == e.From { + e.Column = 0 + for i, r := range src { + if runeCount == e.From { break } if r == '\n' { + lineStart = i e.Line++ e.Column = 0 - } else { - e.Column++ } + runeCount++ + e.Column++ + } + + lineEnd := lineStart + strings.IndexByte(src[lineStart:], '\n') + if lineEnd < lineStart { + lineEnd = len(src) + } + if lineStart == lineEnd { + return e } - if snippet, found := source.Snippet(e.Line); found { - snippet := strings.Replace(snippet, "\t", " ", -1) - srcLine := "\n | " + snippet - var bytes = []byte(snippet) - var indLine = "\n | " - for i := 0; i < e.Column && len(bytes) > 0; i++ { - _, sz := utf8.DecodeRune(bytes) - bytes = bytes[sz:] - if sz > 1 { - goto noind - } else { - indLine += "." - } - } - if _, sz := utf8.DecodeRune(bytes); sz > 1 { - goto noind - } else { - indLine += "^" - } - srcLine += indLine - noind: - e.Snippet = srcLine + const prefix = "\n | " + line := src[lineStart:lineEnd] + snippet := new(strings.Builder) + snippet.Grow(2*len(prefix) + len(line) + e.Column + 1) + snippet.WriteString(prefix) + tabReplacer.WriteString(snippet, line) + snippet.WriteString(prefix) + for i := 0; i < e.Column; i++ { + snippet.WriteByte('.') } + snippet.WriteByte('^') + e.Snippet = snippet.String() return e } diff --git a/file/source.go b/file/source.go index 8e2b2d154..b11bb5f9d 100644 --- a/file/source.go +++ b/file/source.go @@ -1,48 +1,36 @@ package file -import ( - "strings" - "unicode/utf8" -) +import "strings" -type Source []rune +type Source struct { + raw string +} func NewSource(contents string) Source { - return []rune(contents) + return Source{ + raw: contents, + } } func (s Source) String() string { - return string(s) + return s.raw } func (s Source) Snippet(line int) (string, bool) { - if s == nil { + if s.raw == "" { return "", false } - lines := strings.Split(string(s), "\n") - lineOffsets := make([]int, len(lines)) - var offset int - for i, line := range lines { - offset = offset + utf8.RuneCountInString(line) + 1 - lineOffsets[i] = offset - } - charStart, found := getLineOffset(lineOffsets, line) - if !found || len(s) == 0 { - return "", false + var start int + for i := 1; i < line; i++ { + pos := strings.IndexByte(s.raw[start:], '\n') + if pos < 0 { + return "", false + } + start += pos + 1 } - charEnd, found := getLineOffset(lineOffsets, line+1) - if found { - return string(s[charStart : charEnd-1]), true - } - return string(s[charStart:]), true -} - -func getLineOffset(lineOffsets []int, line int) (int, bool) { - if line == 1 { - return 0, true - } else if line > 1 && line <= len(lineOffsets) { - offset := lineOffsets[line-2] - return offset, true + end := start + strings.IndexByte(s.raw[start:], '\n') + if end < start { + end = len(s.raw) } - return -1, false + return s.raw[start:end], true } diff --git a/internal/ring/ring.go b/internal/ring/ring.go new file mode 100644 index 000000000..cc9e727b0 --- /dev/null +++ b/internal/ring/ring.go @@ -0,0 +1,85 @@ +package ring + +// Ring is a very simple ring buffer implementation that uses a slice. The +// internal slice will only grow, never shrink. When it grows, it grows in +// chunks of "chunkSize" (given as argument in the [New] function). Pointer and +// reference types can be safely used because memory is cleared. +type Ring[T any] struct { + data []T + back, len, chunkSize int +} + +func New[T any](chunkSize int) *Ring[T] { + if chunkSize < 1 { + panic("chunkSize must be greater than zero") + } + return &Ring[T]{ + chunkSize: chunkSize, + } +} + +func (r *Ring[T]) Len() int { + return r.len +} + +func (r *Ring[T]) Cap() int { + return len(r.data) +} + +func (r *Ring[T]) Reset() { + var zero T + for i := range r.data { + r.data[i] = zero // clear mem, optimized by the compiler, in Go 1.21 the "clear" builtin can be used + } + r.back = 0 + r.len = 0 +} + +// Nth returns the n-th oldest value (zero-based) in the ring without making +// any change. +func (r *Ring[T]) Nth(n int) (v T, ok bool) { + if n < 0 || n >= r.len || len(r.data) == 0 { + return v, false + } + n = (n + r.back) % len(r.data) + return r.data[n], true +} + +// Dequeue returns the oldest value. +func (r *Ring[T]) Dequeue() (v T, ok bool) { + if r.len == 0 { + return v, false + } + v, r.data[r.back] = r.data[r.back], v // retrieve and clear mem + r.len-- + r.back = (r.back + 1) % len(r.data) + return v, true +} + +// Enqueue adds an item to the ring. +func (r *Ring[T]) Enqueue(v T) { + if r.len == len(r.data) { + r.grow() + } + writePos := (r.back + r.len) % len(r.data) + r.data[writePos] = v + r.len++ +} + +func (r *Ring[T]) grow() { + s := make([]T, len(r.data)+r.chunkSize) + if r.len > 0 { + chunk1 := r.back + r.len + if chunk1 > len(r.data) { + chunk1 = len(r.data) + } + copied := copy(s, r.data[r.back:chunk1]) + + if copied < r.len { // wrapped slice + chunk2 := r.len - copied + copy(s[copied:], r.data[:chunk2]) + } + } + r.back = 0 + r.data = s +} diff --git a/internal/ring/ring_test.go b/internal/ring/ring_test.go new file mode 100644 index 000000000..b7457cd70 --- /dev/null +++ b/internal/ring/ring_test.go @@ -0,0 +1,140 @@ +package ring + +import ( + "fmt" + "testing" +) + +func TestRing(t *testing.T) { + type op = ringOp[int] + testRing(t, New[int](3), + // noops on empty ring + op{cap: 0, opType: opRst, value: 0, items: []int{}}, + op{cap: 0, opType: opDeq, value: 0, items: []int{}}, + + // basic + op{cap: 3, opType: opEnq, value: 1, items: []int{1}}, + op{cap: 3, opType: opDeq, value: 1, items: []int{}}, + + // wrapping + op{cap: 3, opType: opEnq, value: 2, items: []int{2}}, + op{cap: 3, opType: opEnq, value: 3, items: []int{2, 3}}, + op{cap: 3, opType: opEnq, value: 4, items: []int{2, 3, 4}}, + op{cap: 3, opType: opDeq, value: 2, items: []int{3, 4}}, + op{cap: 3, opType: opDeq, value: 3, items: []int{4}}, + op{cap: 3, opType: opDeq, value: 4, items: []int{}}, + + // resetting + op{cap: 3, opType: opEnq, value: 2, items: []int{2}}, + op{cap: 3, opType: opRst, value: 0, items: []int{}}, + op{cap: 3, opType: opDeq, value: 0, items: []int{}}, + + // growing without wrapping + op{cap: 3, opType: opEnq, value: 5, items: []int{5}}, + op{cap: 3, opType: opEnq, value: 6, items: []int{5, 6}}, + op{cap: 3, opType: opEnq, value: 7, items: []int{5, 6, 7}}, + op{cap: 6, opType: opEnq, value: 8, items: []int{5, 6, 7, 8}}, + op{cap: 6, opType: opRst, value: 0, items: []int{}}, + op{cap: 6, opType: opDeq, value: 0, items: []int{}}, + + // growing and wrapping + op{cap: 6, opType: opEnq, value: 9, items: []int{9}}, + op{cap: 6, opType: opEnq, value: 10, items: []int{9, 10}}, + op{cap: 6, opType: opEnq, value: 11, items: []int{9, 10, 11}}, + op{cap: 6, opType: opEnq, value: 12, items: []int{9, 10, 11, 12}}, + op{cap: 6, opType: opEnq, value: 13, items: []int{9, 10, 11, 12, 13}}, + op{cap: 6, opType: opEnq, value: 14, items: []int{9, 10, 11, 12, 13, 14}}, + op{cap: 6, opType: opDeq, value: 9, items: []int{10, 11, 12, 13, 14}}, + op{cap: 6, opType: opDeq, value: 10, items: []int{11, 12, 13, 14}}, + op{cap: 6, opType: opEnq, value: 15, items: []int{11, 12, 13, 14, 15}}, + op{cap: 6, opType: opEnq, value: 16, items: []int{11, 12, 13, 14, 15, 16}}, + op{cap: 9, opType: opEnq, value: 17, items: []int{11, 12, 13, 14, 15, 16, 17}}, // grows wrapped + op{cap: 9, opType: opDeq, value: 11, items: []int{12, 13, 14, 15, 16, 17}}, + op{cap: 9, opType: opDeq, value: 12, items: []int{13, 14, 15, 16, 17}}, + op{cap: 9, opType: opDeq, value: 13, items: []int{14, 15, 16, 17}}, + op{cap: 9, opType: opDeq, value: 14, items: []int{15, 16, 17}}, + op{cap: 9, opType: opDeq, value: 15, items: []int{16, 17}}, + op{cap: 9, opType: opDeq, value: 16, items: []int{17}}, + op{cap: 9, opType: opDeq, value: 17, items: []int{}}, + op{cap: 9, opType: opDeq, value: 0, items: []int{}}, + ) + + t.Run("should panic on invalid chunkSize", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatalf("should have panicked") + } + }() + New[int](0) + }) +} + +const ( + opEnq = iota // enqueue an item + opDeq // dequeue an item and an item was available + opRst // reset +) + +type ringOp[T comparable] struct { + cap int // expected values + opType int // opEnq or opDeq + value T // value to enqueue or value expected for dequeue; ignored for opRst + items []T // items left +} + +func testRing[T comparable](t *testing.T, r *Ring[T], ops ...ringOp[T]) { + for i, op := range ops { + testOK := t.Run(fmt.Sprintf("opIndex=%v", i), func(t *testing.T) { + testRingOp(t, r, op) + }) + if !testOK { + return + } + } +} + +func testRingOp[T comparable](t *testing.T, r *Ring[T], op ringOp[T]) { + var zero T + switch op.opType { + case opEnq: + r.Enqueue(op.value) + case opDeq: + shouldSucceed := r.Len() > 0 + v, ok := r.Dequeue() + switch { + case ok != shouldSucceed: + t.Fatalf("should have succeeded: %v", shouldSucceed) + case ok && v != op.value: + t.Fatalf("expected value: %v; got: %v", op.value, v) + case !ok && v != zero: + t.Fatalf("expected zero value; got: %v", v) + } + case opRst: + r.Reset() + } + if c := r.Cap(); c != op.cap { + t.Fatalf("expected cap: %v; got: %v", op.cap, c) + } + if l := r.Len(); l != len(op.items) { + t.Errorf("expected Len(): %v; got: %v", len(op.items), l) + } + var got []T + for i := 0; ; i++ { + v, ok := r.Nth(i) + if !ok { + break + } + got = append(got, v) + } + if l := len(got); l != len(op.items) { + t.Errorf("expected items: %v\ngot items: %v", op.items, got) + } + for i := range op.items { + if op.items[i] != got[i] { + t.Fatalf("expected items: %v\ngot items: %v", op.items, got) + } + } + if v, ok := r.Nth(len(op.items)); ok || v != zero { + t.Fatalf("expected no more items, got: v=%v; ok=%v", v, ok) + } +} diff --git a/parser/bench_test.go b/parser/bench_test.go new file mode 100644 index 000000000..0b9a69276 --- /dev/null +++ b/parser/bench_test.go @@ -0,0 +1,20 @@ +package parser + +import "testing" + +func BenchmarkParser(b *testing.B) { + const source = ` + /* + Showing worst case scenario + */ + let value = trim("contains escapes \n\"\\ \U0001F600 and non ASCII ñ"); // inline comment + len(value) == 0x2A + // let's introduce an error too + whatever + ` + b.ReportAllocs() + p := new(Parser) + for i := 0; i < b.N; i++ { + p.Parse(source, nil) + } +} diff --git a/parser/lexer/lexer.go b/parser/lexer/lexer.go index e6b06c09d..f417f02b6 100644 --- a/parser/lexer/lexer.go +++ b/parser/lexer/lexer.go @@ -2,152 +2,178 @@ package lexer import ( "fmt" + "io" "strings" + "unicode/utf8" "github.com/expr-lang/expr/file" + "github.com/expr-lang/expr/internal/ring" ) +const ringChunkSize = 10 + +// Lex will buffer and return the tokens of a disposable *[Lexer]. func Lex(source file.Source) ([]Token, error) { - l := &lexer{ - source: source, - tokens: make([]Token, 0), - start: 0, - end: 0, + tokens := make([]Token, 0, ringChunkSize) + l := New() + l.Reset(source) + for { + t, err := l.Next() + switch err { + case nil: + tokens = append(tokens, t) + case io.EOF: + return tokens, nil + default: + return nil, err + } } - l.commit() +} - for state := root; state != nil; { - state = state(l) +// New returns a reusable lexer. +func New() *Lexer { + return &Lexer{ + tokens: ring.New[Token](ringChunkSize), } +} - if l.err != nil { - return nil, l.err.Bind(source) +type Lexer struct { + state stateFn + source file.Source + tokens *ring.Ring[Token] + err *file.Error + start, end struct { + byte, rune int } + eof bool +} - return l.tokens, nil +func (l *Lexer) Reset(source file.Source) { + l.source = source + l.tokens.Reset() + l.state = root } -type lexer struct { - source file.Source - tokens []Token - start, end int - err *file.Error +func (l *Lexer) Next() (Token, error) { + for l.state != nil && l.err == nil && l.tokens.Len() == 0 { + l.state = l.state(l) + } + if l.err != nil { + return Token{}, l.err.Bind(l.source) + } + if t, ok := l.tokens.Dequeue(); ok { + return t, nil + } + return Token{}, io.EOF } const eof rune = -1 -func (l *lexer) commit() { +func (l *Lexer) commit() { l.start = l.end } -func (l *lexer) next() rune { - if l.end >= len(l.source) { - l.end++ +func (l *Lexer) next() rune { + if l.end.byte >= len(l.source.String()) { + l.eof = true return eof } - r := l.source[l.end] - l.end++ + r, sz := utf8.DecodeRuneInString(l.source.String()[l.end.byte:]) + l.end.rune++ + l.end.byte += sz return r } -func (l *lexer) peek() rune { - r := l.next() - l.backup() - return r +func (l *Lexer) peek() rune { + if l.end.byte < len(l.source.String()) { + r, _ := utf8.DecodeRuneInString(l.source.String()[l.end.byte:]) + return r + } + return eof } -func (l *lexer) backup() { - l.end-- +func (l *Lexer) peekByte() (byte, bool) { + if l.end.byte >= 0 && l.end.byte < len(l.source.String()) { + return l.source.String()[l.end.byte], true + } + return 0, false } -func (l *lexer) emit(t Kind) { +func (l *Lexer) backup() { + if l.eof { + l.eof = false + } else if l.end.rune > 0 { + _, sz := utf8.DecodeLastRuneInString(l.source.String()[:l.end.byte]) + l.end.byte -= sz + l.end.rune-- + } +} + +func (l *Lexer) emit(t Kind) { l.emitValue(t, l.word()) } -func (l *lexer) emitValue(t Kind, value string) { - l.tokens = append(l.tokens, Token{ - Location: file.Location{From: l.start, To: l.end}, +func (l *Lexer) emitValue(t Kind, value string) { + l.tokens.Enqueue(Token{ + Location: file.Location{From: l.start.rune, To: l.end.rune}, Kind: t, Value: value, }) l.commit() } -func (l *lexer) emitEOF() { - from := l.end - 2 +func (l *Lexer) emitEOF() { + from := l.end.rune - 1 if from < 0 { from = 0 } - to := l.end - 1 + to := l.end.rune - 0 if to < 0 { to = 0 } - l.tokens = append(l.tokens, Token{ + l.tokens.Enqueue(Token{ Location: file.Location{From: from, To: to}, Kind: EOF, }) l.commit() } -func (l *lexer) skip() { +func (l *Lexer) skip() { l.commit() } -func (l *lexer) word() string { - // TODO: boundary check is NOT needed here, but for some reason CI fuzz tests are failing. - if l.start > len(l.source) || l.end > len(l.source) { - return "__invalid__" - } - return string(l.source[l.start:l.end]) +func (l *Lexer) word() string { + return l.source.String()[l.start.byte:l.end.byte] } -func (l *lexer) accept(valid string) bool { - if strings.ContainsRune(valid, l.next()) { +func (l *Lexer) accept(valid string) bool { + if strings.ContainsRune(valid, l.peek()) { + l.next() return true } - l.backup() return false } -func (l *lexer) acceptRun(valid string) { - for strings.ContainsRune(valid, l.next()) { +func (l *Lexer) acceptRun(valid string) { + for l.accept(valid) { } - l.backup() } -func (l *lexer) skipSpaces() { - r := l.peek() - for ; r == ' '; r = l.peek() { - l.next() - } +func (l *Lexer) skipSpaces() { + l.acceptRun(" ") l.skip() } -func (l *lexer) acceptWord(word string) bool { - pos := l.end - - l.skipSpaces() - - for _, ch := range word { - if l.next() != ch { - l.end = pos - return false - } - } - if r := l.peek(); r != ' ' && r != eof { - l.end = pos - return false - } - - return true -} - -func (l *lexer) error(format string, args ...any) stateFn { +func (l *Lexer) error(format string, args ...any) stateFn { if l.err == nil { // show first error + end := l.end.rune + if l.eof { + end++ + } l.err = &file.Error{ Location: file.Location{ - From: l.end - 1, - To: l.end, + From: end - 1, + To: end, }, Message: fmt.Sprintf(format, args...), } @@ -167,7 +193,7 @@ func digitVal(ch rune) int { func lower(ch rune) rune { return ('a' - 'A') | ch } // returns lower-case ch iff ch is ASCII letter -func (l *lexer) scanDigits(ch rune, base, n int) rune { +func (l *Lexer) scanDigits(ch rune, base, n int) rune { for n > 0 && digitVal(ch) < base { ch = l.next() n-- @@ -178,7 +204,7 @@ func (l *lexer) scanDigits(ch rune, base, n int) rune { return ch } -func (l *lexer) scanEscape(quote rune) rune { +func (l *Lexer) scanEscape(quote rune) rune { ch := l.next() // read character after '/' switch ch { case 'a', 'b', 'f', 'n', 'r', 't', 'v', '\\', quote: @@ -198,7 +224,7 @@ func (l *lexer) scanEscape(quote rune) rune { return ch } -func (l *lexer) scanString(quote rune) (n int) { +func (l *Lexer) scanString(quote rune) (n int) { ch := l.next() // read character after quote for ch != quote { if ch == '\n' || ch == eof { @@ -215,7 +241,7 @@ func (l *lexer) scanString(quote rune) (n int) { return } -func (l *lexer) scanRawString(quote rune) (n int) { +func (l *Lexer) scanRawString(quote rune) (n int) { ch := l.next() // read character after back tick for ch != quote { if ch == eof { @@ -225,6 +251,6 @@ func (l *lexer) scanRawString(quote rune) (n int) { ch = l.next() n++ } - l.emitValue(String, string(l.source[l.start+1:l.end-1])) + l.emitValue(String, l.source.String()[l.start.byte+1:l.end.byte-1]) return } diff --git a/parser/lexer/lexer_test.go b/parser/lexer/lexer_test.go index db02d2acf..5171f4255 100644 --- a/parser/lexer/lexer_test.go +++ b/parser/lexer/lexer_test.go @@ -335,6 +335,7 @@ literal not terminated (1:10) früh ♥︎ unrecognized character: U+2665 '♥' (1:6) | früh ♥︎ + | .....^ ` func TestLex_error(t *testing.T) { diff --git a/parser/lexer/state.go b/parser/lexer/state.go index c694a2ca0..e5ad45bcd 100644 --- a/parser/lexer/state.go +++ b/parser/lexer/state.go @@ -6,9 +6,9 @@ import ( "github.com/expr-lang/expr/parser/utils" ) -type stateFn func(*lexer) stateFn +type stateFn func(*Lexer) stateFn -func root(l *lexer) stateFn { +func root(l *Lexer) stateFn { switch r := l.next(); { case r == eof: l.emitEOF() @@ -61,7 +61,7 @@ func root(l *lexer) stateFn { return root } -func number(l *lexer) stateFn { +func number(l *Lexer) stateFn { if !l.scanNumber() { return l.error("bad number syntax: %q", l.word()) } @@ -69,7 +69,7 @@ func number(l *lexer) stateFn { return root } -func (l *lexer) scanNumber() bool { +func (l *Lexer) scanNumber() bool { digits := "0123456789_" // Is it hex? if l.accept("0") { @@ -107,7 +107,7 @@ func (l *lexer) scanNumber() bool { return true } -func dot(l *lexer) stateFn { +func dot(l *Lexer) stateFn { l.next() if l.accept("0123456789") { l.backup() @@ -118,7 +118,7 @@ func dot(l *lexer) stateFn { return root } -func identifier(l *lexer) stateFn { +func identifier(l *Lexer) stateFn { loop: for { switch r := l.next(); { @@ -140,7 +140,7 @@ loop: return root } -func not(l *lexer) stateFn { +func not(l *Lexer) stateFn { l.emit(Operator) l.skipSpaces() @@ -167,13 +167,13 @@ func not(l *lexer) stateFn { return root } -func questionMark(l *lexer) stateFn { +func questionMark(l *Lexer) stateFn { l.accept(".?") l.emit(Operator) return root } -func slash(l *lexer) stateFn { +func slash(l *Lexer) stateFn { if l.accept("/") { return singleLineComment } @@ -184,7 +184,7 @@ func slash(l *lexer) stateFn { return root } -func singleLineComment(l *lexer) stateFn { +func singleLineComment(l *Lexer) stateFn { for { r := l.next() if r == eof || r == '\n' { @@ -195,7 +195,7 @@ func singleLineComment(l *lexer) stateFn { return root } -func multiLineComment(l *lexer) stateFn { +func multiLineComment(l *Lexer) stateFn { for { r := l.next() if r == eof { @@ -209,7 +209,7 @@ func multiLineComment(l *lexer) stateFn { return root } -func pointer(l *lexer) stateFn { +func pointer(l *Lexer) stateFn { l.accept("#") l.emit(Operator) for { diff --git a/parser/lexer/token.go b/parser/lexer/token.go index 459fa6905..c809c690e 100644 --- a/parser/lexer/token.go +++ b/parser/lexer/token.go @@ -31,17 +31,13 @@ func (t Token) String() string { } func (t Token) Is(kind Kind, values ...string) bool { - if len(values) == 0 { - return kind == t.Kind + if kind != t.Kind { + return false } - for _, v := range values { if v == t.Value { - goto found + return true } } - return false - -found: - return kind == t.Kind + return len(values) == 0 } diff --git a/parser/lexer/utils.go b/parser/lexer/utils.go index 5c9e6b59d..fdb8beaa1 100644 --- a/parser/lexer/utils.go +++ b/parser/lexer/utils.go @@ -36,7 +36,8 @@ func unescape(value string) (string, error) { if size >= math.MaxInt { return "", fmt.Errorf("too large string") } - buf := make([]byte, 0, size) + buf := new(strings.Builder) + buf.Grow(int(size)) for len(value) > 0 { c, multibyte, rest, err := unescapeChar(value) if err != nil { @@ -44,13 +45,13 @@ func unescape(value string) (string, error) { } value = rest if c < utf8.RuneSelf || !multibyte { - buf = append(buf, byte(c)) + buf.WriteByte(byte(c)) } else { n := utf8.EncodeRune(runeTmp[:], c) - buf = append(buf, runeTmp[:n]...) + buf.Write(runeTmp[:n]) } } - return string(buf), nil + return buf.String(), nil } // unescapeChar takes a string input and returns the following info: diff --git a/parser/parser.go b/parser/parser.go index 0a463fed5..e1dd111fc 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -1,7 +1,9 @@ package parser import ( + "errors" "fmt" + "io" "math" "strconv" "strings" @@ -10,6 +12,7 @@ import ( "github.com/expr-lang/expr/builtin" "github.com/expr-lang/expr/conf" "github.com/expr-lang/expr/file" + "github.com/expr-lang/expr/parser/lexer" . "github.com/expr-lang/expr/parser/lexer" "github.com/expr-lang/expr/parser/operator" "github.com/expr-lang/expr/parser/utils" @@ -44,17 +47,50 @@ var predicates = map[string]struct { "reduce": {[]arg{expr, predicate, expr | optional}}, } -type parser struct { - tokens []Token - current Token - pos int - err *file.Error - config *conf.Config - depth int // predicate call depth - nodeCount uint // tracks number of AST nodes created +// Parser is a reusable parser. The zero value is ready for use. +type Parser struct { + lexer *lexer.Lexer + current, stashed Token + hasStash bool + err *file.Error + config *conf.Config + depth int // predicate call depth + nodeCount uint // tracks number of AST nodes created } -func (p *parser) checkNodeLimit() error { +func (p *Parser) Parse(input string, config *conf.Config) (*Tree, error) { + if p.lexer == nil { + p.lexer = lexer.New() + } + p.config = config + source := file.NewSource(input) + p.lexer.Reset(source) + p.next() + node := p.parseSequenceExpression() + + if !p.current.Is(EOF) { + p.error("unexpected token %v", p.current) + } + + tree := &Tree{ + Node: node, + Source: source, + } + err := p.err + + // cleanup non-reusable pointer values and reset state + p.err = nil + p.config = nil + p.lexer.Reset(file.Source{}) + + if err != nil { + return tree, err.Bind(source) + } + + return tree, nil +} + +func (p *Parser) checkNodeLimit() error { p.nodeCount++ if p.config == nil { if p.nodeCount > conf.DefaultMaxNodes { @@ -70,7 +106,7 @@ func (p *parser) checkNodeLimit() error { return nil } -func (p *parser) createNode(n Node, loc file.Location) Node { +func (p *Parser) createNode(n Node, loc file.Location) Node { if err := p.checkNodeLimit(); err != nil { return nil } @@ -81,7 +117,7 @@ func (p *parser) createNode(n Node, loc file.Location) Node { return n } -func (p *parser) createMemberNode(n *MemberNode, loc file.Location) *MemberNode { +func (p *Parser) createMemberNode(n *MemberNode, loc file.Location) *MemberNode { if err := p.checkNodeLimit(); err != nil { return nil } @@ -102,42 +138,14 @@ func Parse(input string) (*Tree, error) { } func ParseWithConfig(input string, config *conf.Config) (*Tree, error) { - source := file.NewSource(input) - - tokens, err := Lex(source) - if err != nil { - return nil, err - } - - p := &parser{ - tokens: tokens, - current: tokens[0], - config: config, - } - - node := p.parseSequenceExpression() - - if !p.current.Is(EOF) { - p.error("unexpected token %v", p.current) - } - - tree := &Tree{ - Node: node, - Source: source, - } - - if p.err != nil { - return tree, p.err.Bind(source) - } - - return tree, nil + return new(Parser).Parse(input, config) } -func (p *parser) error(format string, args ...any) { +func (p *Parser) error(format string, args ...any) { p.errorAt(p.current, format, args...) } -func (p *parser) errorAt(token Token, format string, args ...any) { +func (p *Parser) errorAt(token Token, format string, args ...any) { if p.err == nil { // show first error p.err = &file.Error{ Location: token.Location, @@ -146,16 +154,32 @@ func (p *parser) errorAt(token Token, format string, args ...any) { } } -func (p *parser) next() { - p.pos++ - if p.pos >= len(p.tokens) { - p.error("unexpected end of expression") +func (p *Parser) next() { + if p.hasStash { + p.current = p.stashed + p.hasStash = false return } - p.current = p.tokens[p.pos] + + token, err := p.lexer.Next() + var e *file.Error + switch { + case err == nil: + p.current = token + case errors.Is(err, io.EOF): + p.error("unexpected end of expression") + case errors.As(err, &e): + p.err = e + default: + p.err = &file.Error{ + Location: p.current.Location, + Message: "unknown lexing error", + Prev: err, + } + } } -func (p *parser) expect(kind Kind, values ...string) { +func (p *Parser) expect(kind Kind, values ...string) { if p.current.Is(kind, values...) { p.next() return @@ -165,7 +189,7 @@ func (p *parser) expect(kind Kind, values ...string) { // parse functions -func (p *parser) parseSequenceExpression() Node { +func (p *Parser) parseSequenceExpression() Node { nodes := []Node{p.parseExpression(0)} for p.current.Is(Operator, ";") && p.err == nil { @@ -186,7 +210,7 @@ func (p *parser) parseSequenceExpression() Node { }, nodes[0].Location()) } -func (p *parser) parseExpression(precedence int) Node { +func (p *Parser) parseExpression(precedence int) Node { if p.err != nil { return nil } @@ -209,15 +233,16 @@ func (p *parser) parseExpression(precedence int) Node { // Handle "not *" operator, like "not in" or "not contains". if negate { - currentPos := p.pos + tokenBackup := p.current p.next() if operator.AllowedNegateSuffix(p.current.Value) { if op, ok := operator.Binary[p.current.Value]; ok && op.Precedence >= precedence { notToken = p.current opToken = p.current } else { - p.pos = currentPos - p.current = opToken + p.hasStash = true + p.stashed = p.current + p.current = tokenBackup break } } else { @@ -288,7 +313,7 @@ func (p *parser) parseExpression(precedence int) Node { return nodeLeft } -func (p *parser) parseVariableDeclaration() Node { +func (p *Parser) parseVariableDeclaration() Node { p.expect(Operator, "let") variableName := p.current p.expect(Identifier) @@ -303,7 +328,7 @@ func (p *parser) parseVariableDeclaration() Node { }, variableName.Location) } -func (p *parser) parseConditionalIf() Node { +func (p *Parser) parseConditionalIf() Node { p.next() nodeCondition := p.parseExpression(0) p.expect(Bracket, "{") @@ -322,7 +347,7 @@ func (p *parser) parseConditionalIf() Node { } -func (p *parser) parseConditional(node Node) Node { +func (p *Parser) parseConditional(node Node) Node { var expr1, expr2 Node for p.current.Is(Operator, "?") && p.err == nil { p.next() @@ -349,7 +374,7 @@ func (p *parser) parseConditional(node Node) Node { return node } -func (p *parser) parsePrimary() Node { +func (p *Parser) parsePrimary() Node { token := p.current if token.Is(Operator) { @@ -402,7 +427,7 @@ func (p *parser) parsePrimary() Node { return p.parseSecondary() } -func (p *parser) parseSecondary() Node { +func (p *Parser) parseSecondary() Node { var node Node token := p.current @@ -501,7 +526,7 @@ func (p *parser) parseSecondary() Node { return p.parsePostfixExpression(node) } -func (p *parser) toIntegerNode(number int64) Node { +func (p *Parser) toIntegerNode(number int64) Node { if number > math.MaxInt { p.error("integer literal is too large") return nil @@ -509,7 +534,7 @@ func (p *parser) toIntegerNode(number int64) Node { return p.createNode(&IntegerNode{Value: int(number)}, p.current.Location) } -func (p *parser) toFloatNode(number float64) Node { +func (p *Parser) toFloatNode(number float64) Node { if number > math.MaxFloat64 { p.error("float literal is too large") return nil @@ -517,7 +542,7 @@ func (p *parser) toFloatNode(number float64) Node { return p.createNode(&FloatNode{Value: number}, p.current.Location) } -func (p *parser) parseCall(token Token, arguments []Node, checkOverrides bool) Node { +func (p *Parser) parseCall(token Token, arguments []Node, checkOverrides bool) Node { var node Node isOverridden := false @@ -595,7 +620,7 @@ func (p *parser) parseCall(token Token, arguments []Node, checkOverrides bool) N return node } -func (p *parser) parseArguments(arguments []Node) []Node { +func (p *Parser) parseArguments(arguments []Node) []Node { // If pipe operator is used, the first argument is the left-hand side // of the operator, so we do not parse it as an argument inside brackets. offset := len(arguments) @@ -616,7 +641,7 @@ func (p *parser) parseArguments(arguments []Node) []Node { return arguments } -func (p *parser) parsePredicate() Node { +func (p *Parser) parsePredicate() Node { startToken := p.current withBrackets := false if p.current.Is(Bracket, "{") { @@ -648,7 +673,7 @@ func (p *parser) parsePredicate() Node { return predicateNode } -func (p *parser) parseArrayExpression(token Token) Node { +func (p *Parser) parseArrayExpression(token Token) Node { nodes := make([]Node, 0) p.expect(Bracket, "[") @@ -672,7 +697,7 @@ end: return node } -func (p *parser) parseMapExpression(token Token) Node { +func (p *Parser) parseMapExpression(token Token) Node { p.expect(Bracket, "{") nodes := make([]Node, 0) @@ -725,7 +750,7 @@ end: return node } -func (p *parser) parsePostfixExpression(node Node) Node { +func (p *Parser) parsePostfixExpression(node Node) Node { postfixToken := p.current for (postfixToken.Is(Operator) || postfixToken.Is(Bracket)) && p.err == nil { optional := postfixToken.Value == "?." @@ -855,7 +880,7 @@ func (p *parser) parsePostfixExpression(node Node) Node { } return node } -func (p *parser) parseComparison(left Node, token Token, precedence int) Node { +func (p *Parser) parseComparison(left Node, token Token, precedence int) Node { var rootNode Node for { comparator := p.parseExpression(precedence + 1) diff --git a/vm/vm_test.go b/vm/vm_test.go index 91752a419..817fc6cc2 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/expr-lang/expr/file" "github.com/expr-lang/expr/internal/testify/require" "github.com/expr-lang/expr" @@ -609,10 +610,10 @@ func TestVM_DirectCallOpcodes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { program := vm.NewProgram( - nil, // source - nil, // node - nil, // locations - 0, // variables + file.Source{}, // source + nil, // node + nil, // locations + 0, // variables tt.consts, tt.bytecode, tt.args, @@ -735,10 +736,10 @@ func TestVM_IndexAndCountOperations(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { program := vm.NewProgram( - nil, // source - nil, // node - nil, // locations - 0, // variables + file.Source{}, // source + nil, // node + nil, // locations + 0, // variables tt.consts, tt.bytecode, tt.args, @@ -1176,10 +1177,10 @@ func TestVM_DirectBasicOpcodes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { program := vm.NewProgram( - nil, // source - nil, // node - nil, // locations - 0, // variables + file.Source{}, // source + nil, // node + nil, // locations + 0, // variables tt.consts, tt.bytecode, tt.args,