From 24f70e3c9de759032e10fc6f72c56428edc94924 Mon Sep 17 00:00:00 2001 From: topi314 Date: Sun, 10 Nov 2024 03:29:19 +0100 Subject: [PATCH] refactor: change Node pointers to values & return (Node, bool) --- node.go | 118 ++++++++++++------------ node_test.go | 244 ++++++++++++++++++++++++++----------------------- parser_test.go | 44 ++++----- query.go | 6 +- query_test.go | 2 +- tree.go | 8 +- tree_cursor.go | 4 +- tree_test.go | 60 ++++++------ 8 files changed, 253 insertions(+), 233 deletions(-) diff --git a/node.go b/node.go index 998b7a6..33b0e0d 100644 --- a/node.go +++ b/node.go @@ -13,11 +13,11 @@ type Node struct { _inner C.TSNode } -func newNode(node C.TSNode) *Node { +func newNode(node C.TSNode) (Node, bool) { if node.id == nil { - return nil + return Node{}, false } - return &Node{_inner: node} + return Node{_inner: node}, true } // Get a numeric id for this node that is unique. @@ -26,34 +26,34 @@ func newNode(node C.TSNode) *Node { // a new tree is created based on an older tree, and a node from the old // tree is reused in the process, then that node will have the same id in // both trees. -func (n *Node) Id() uintptr { +func (n Node) Id() uintptr { return uintptr(n._inner.id) } // Get this node's type as a numerical id. -func (n *Node) KindId() uint16 { +func (n Node) KindId() uint16 { return uint16(C.ts_node_symbol(n._inner)) } // Get the node's type as a numerical id as it appears in the grammar // ignoring aliases. -func (n *Node) GrammarId() uint16 { +func (n Node) GrammarId() uint16 { return uint16(C.ts_node_grammar_symbol(n._inner)) } // Get this node's type as a string. -func (n *Node) Kind() string { +func (n Node) Kind() string { return C.GoString(C.ts_node_type(n._inner)) } // Get this node's symbol name as it appears in the grammar ignoring // aliases as a string. -func (n *Node) GrammarName() string { +func (n Node) GrammarName() string { return C.GoString(C.ts_node_grammar_type(n._inner)) } // Get the [Language] that was used to parse this node's syntax tree. -func (n *Node) Language() *Language { +func (n Node) Language() *Language { return &Language{Inner: C.ts_node_language(n._inner)} } @@ -61,7 +61,7 @@ func (n *Node) Language() *Language { // // Named nodes correspond to named rules in the grammar, whereas // *anonymous* nodes correspond to string literals in the grammar. -func (n *Node) IsNamed() bool { +func (n Node) IsNamed() bool { return bool(C.ts_node_is_named(n._inner)) } @@ -69,18 +69,18 @@ func (n *Node) IsNamed() bool { // // Extra nodes represent things like comments, which are not required in the // grammar, but can appear anywhere. -func (n *Node) IsExtra() bool { +func (n Node) IsExtra() bool { return bool(C.ts_node_is_extra(n._inner)) } // Check if this node has been edited. -func (n *Node) HasChanges() bool { +func (n Node) HasChanges() bool { return bool(C.ts_node_has_changes(n._inner)) } // Check if this node represents a syntax error or contains any syntax // errors anywhere within it. -func (n *Node) HasError() bool { +func (n Node) HasError() bool { return bool(C.ts_node_has_error(n._inner)) } @@ -88,17 +88,17 @@ func (n *Node) HasError() bool { // // Syntax errors represent parts of the code that could not be incorporated // into a valid syntax tree. -func (n *Node) IsError() bool { +func (n Node) IsError() bool { return bool(C.ts_node_is_error(n._inner)) } // Get this node's parse state. -func (n *Node) ParseState() uint16 { +func (n Node) ParseState() uint16 { return uint16(C.ts_node_parse_state(n._inner)) } // Get the parse state after this node. -func (n *Node) NextParseState() uint16 { +func (n Node) NextParseState() uint16 { return uint16(C.ts_node_next_parse_state(n._inner)) } @@ -106,28 +106,28 @@ func (n *Node) NextParseState() uint16 { // // Missing nodes are inserted by the parser in order to recover from // certain kinds of syntax errors. -func (n *Node) IsMissing() bool { +func (n Node) IsMissing() bool { return bool(C.ts_node_is_missing(n._inner)) } // Get the byte offsets where this node starts. -func (n *Node) StartByte() uint { +func (n Node) StartByte() uint { return uint(C.ts_node_start_byte(n._inner)) } // Get the byte offsets where this node end. -func (n *Node) EndByte() uint { +func (n Node) EndByte() uint { return uint(C.ts_node_end_byte(n._inner)) } // Get the byte range of source code that this node represents. -func (n *Node) ByteRange() (uint, uint) { +func (n Node) ByteRange() (uint, uint) { return n.StartByte(), n.EndByte() } // Get the range of source code that this node represents, both in terms of // raw bytes and of row/column coordinates. -func (n *Node) Range() Range { +func (n Node) Range() Range { return Range{ StartByte: n.StartByte(), EndByte: n.EndByte(), @@ -137,14 +137,14 @@ func (n *Node) Range() Range { } // Get this node's start position in terms of rows and columns. -func (n *Node) StartPosition() Point { +func (n Node) StartPosition() Point { p := Point{} p.fromTSPoint(C.ts_node_start_point(n._inner)) return p } // Get this node's end position in terms of rows and columns. -func (n *Node) EndPosition() Point { +func (n Node) EndPosition() Point { p := Point{} p.fromTSPoint(C.ts_node_end_point(n._inner)) return p @@ -156,12 +156,12 @@ func (n *Node) EndPosition() Point { // This method is fairly fast, but its cost is technically log(i), so if // you might be iterating over a long list of children, you should use // [Node.Children] instead. -func (n *Node) Child(i uint) *Node { +func (n Node) Child(i uint) (Node, bool) { return newNode(C.ts_node_child(n._inner, C.uint(i))) } // Get this node's number of children. -func (n *Node) ChildCount() uint { +func (n Node) ChildCount() uint { return uint(C.ts_node_child_count(n._inner)) } @@ -171,14 +171,14 @@ func (n *Node) ChildCount() uint { // This method is fairly fast, but its cost is technically log(i), so if // you might be iterating over a long list of children, you should use // [Node.NamedChildren] instead. -func (n *Node) NamedChild(i uint) *Node { +func (n Node) NamedChild(i uint) (Node, bool) { return newNode(C.ts_node_named_child(n._inner, C.uint(i))) } // Get this node's number of *named* children. // // See also [Node.IsNamed]. -func (n *Node) NamedChildCount() uint { +func (n Node) NamedChildCount() uint { return uint(C.ts_node_named_child_count(n._inner)) } @@ -186,7 +186,7 @@ func (n *Node) NamedChildCount() uint { // // If multiple children may have the same field name, access them using // [Node.ChildrenByFieldName] -func (n *Node) ChildByFieldName(fieldName string) *Node { +func (n Node) ChildByFieldName(fieldName string) (Node, bool) { cFieldName := C.CString(fieldName) defer go_free(unsafe.Pointer(cFieldName)) return newNode(C.ts_node_child_by_field_name(n._inner, cFieldName, C.uint32_t(len(fieldName)))) @@ -196,12 +196,12 @@ func (n *Node) ChildByFieldName(fieldName string) *Node { // // See also [Node.ChildByFieldName]. You can // convert a field name to an id using [Language.FieldIdForName]. -func (n *Node) ChildByFieldId(fieldId uint16) *Node { +func (n Node) ChildByFieldId(fieldId uint16) (Node, bool) { return newNode(C.ts_node_child_by_field_id(n._inner, C.uint16_t(fieldId))) } // Get the field name of this node's child at the given index. -func (n *Node) FieldNameForChild(childIndex uint32) string { +func (n Node) FieldNameForChild(childIndex uint32) string { ptr := C.ts_node_field_name_for_child(n._inner, C.uint32_t(childIndex)) if ptr == nil { return "" @@ -210,7 +210,7 @@ func (n *Node) FieldNameForChild(childIndex uint32) string { } // Get the field name of this node's named child at the given index. -func (n *Node) FieldNameForNamedChild(namedChildIndex uint32) string { +func (n Node) FieldNameForNamedChild(namedChildIndex uint32) string { ptr := C.ts_node_field_name_for_named_child(n._inner, C.uint32_t(namedChildIndex)) if ptr == nil { return "" @@ -227,13 +227,13 @@ func (n *Node) FieldNameForNamedChild(namedChildIndex uint32) string { // // If you're walking the tree recursively, you may want to use the // [TreeCursor] APIs directly instead. -func (n *Node) Children(cursor *TreeCursor) []Node { - cursor.Reset(*n) +func (n Node) Children(cursor *TreeCursor) []Node { + cursor.Reset(n) cursor.GotoFirstChild() childCount := n.ChildCount() result := make([]Node, 0, childCount) for i := 0; i < int(childCount); i++ { - result = append(result, *cursor.Node()) + result = append(result, cursor.Node()) cursor.GotoNextSibling() } return result @@ -242,8 +242,8 @@ func (n *Node) Children(cursor *TreeCursor) []Node { // Iterate over this node's named children. // // See also [Node.Children]. -func (n *Node) NamedChildren(cursor *TreeCursor) []Node { - cursor.Reset(*n) +func (n Node) NamedChildren(cursor *TreeCursor) []Node { + cursor.Reset(n) cursor.GotoFirstChild() namedChildCount := n.NamedChildCount() result := make([]Node, 0, namedChildCount) @@ -253,7 +253,7 @@ func (n *Node) NamedChildren(cursor *TreeCursor) []Node { break } } - result = append(result, *cursor.Node()) + result = append(result, cursor.Node()) cursor.GotoNextSibling() } return result @@ -262,11 +262,11 @@ func (n *Node) NamedChildren(cursor *TreeCursor) []Node { // Iterate over this node's children with a given field name. // // See also [Node.Children]. -func (n *Node) ChildrenByFieldName(fieldName string, cursor *TreeCursor) []Node { +func (n Node) ChildrenByFieldName(fieldName string, cursor *TreeCursor) []Node { fieldId := n.Language().FieldIdForName(fieldName) done := fieldId == 0 if !done { - cursor.Reset(*n) + cursor.Reset(n) cursor.GotoFirstChild() } result := make([]Node, 0) @@ -276,7 +276,7 @@ func (n *Node) ChildrenByFieldName(fieldName string, cursor *TreeCursor) []Node return result } } - result = append(result, *cursor.Node()) + result = append(result, cursor.Node()) if !cursor.GotoNextSibling() { done = true } @@ -287,83 +287,83 @@ func (n *Node) ChildrenByFieldName(fieldName string, cursor *TreeCursor) []Node // Get this node's immediate parent. // Prefer [Node.ChildWithDescendant] // for iterating over this node's ancestors. -func (n *Node) Parent() *Node { +func (n Node) Parent() (Node, bool) { return newNode(C.ts_node_parent(n._inner)) } // Get the node that contains `descendant`. // Note that this can return `descendant` itself. -func (n *Node) ChildWithDescendant(descendant *Node) *Node { +func (n Node) ChildWithDescendant(descendant Node) (Node, bool) { return newNode(C.ts_node_child_with_descendant(n._inner, descendant._inner)) } // Get this node's next sibling. -func (n *Node) NextSibling() *Node { +func (n Node) NextSibling() (Node, bool) { return newNode(C.ts_node_next_sibling(n._inner)) } // Get this node's previous sibling. -func (n *Node) PrevSibling() *Node { +func (n Node) PrevSibling() (Node, bool) { return newNode(C.ts_node_prev_sibling(n._inner)) } // Get this node's next named sibling. -func (n *Node) NextNamedSibling() *Node { +func (n Node) NextNamedSibling() (Node, bool) { return newNode(C.ts_node_next_named_sibling(n._inner)) } // Get this node's previous named sibling. -func (n *Node) PrevNamedSibling() *Node { +func (n Node) PrevNamedSibling() (Node, bool) { return newNode(C.ts_node_prev_named_sibling(n._inner)) } // Get the node's first child that contains or starts after the given byte offset. -func (n *Node) FirstChildForByte(byteOffset uint) *Node { +func (n Node) FirstChildForByte(byteOffset uint) (Node, bool) { return newNode(C.ts_node_first_child_for_byte(n._inner, C.uint(byteOffset))) } // Get the node's first named child that contains or starts after the given byte offset. -func (n *Node) FirstNamedChildForByte(byteOffset uint) *Node { +func (n Node) FirstNamedChildForByte(byteOffset uint) (Node, bool) { return newNode(C.ts_node_first_named_child_for_byte(n._inner, C.uint(byteOffset))) } // Get the node's number of descendants, including one for the node itself. -func (n *Node) DescendantCount() uint { +func (n Node) DescendantCount() uint { return uint(C.ts_node_descendant_count(n._inner)) } // Get the smallest node within this node that spans the given range. -func (n *Node) DescendantForByteRange(start, end uint) *Node { +func (n Node) DescendantForByteRange(start, end uint) (Node, bool) { return newNode(C.ts_node_descendant_for_byte_range(n._inner, C.uint(start), C.uint(end))) } // Get the smallest named node within this node that spans the given range. -func (n *Node) NamedDescendantForByteRange(start, end uint) *Node { +func (n Node) NamedDescendantForByteRange(start, end uint) (Node, bool) { return newNode(C.ts_node_named_descendant_for_byte_range(n._inner, C.uint(start), C.uint(end))) } // Get the smallest node within this node that spans the given range. -func (n *Node) DescendantForPointRange(start, end Point) *Node { +func (n Node) DescendantForPointRange(start, end Point) (Node, bool) { return newNode(C.ts_node_descendant_for_point_range(n._inner, start.toTSPoint(), end.toTSPoint())) } // Get the smallest named node within this node that spans the given range. -func (n *Node) NamedDescendantForPointRange(start, end Point) *Node { +func (n Node) NamedDescendantForPointRange(start, end Point) (Node, bool) { return newNode(C.ts_node_named_descendant_for_point_range(n._inner, start.toTSPoint(), end.toTSPoint())) } -func (n *Node) ToSexp() string { +func (n Node) ToSexp() string { cString := C.ts_node_string(n._inner) result := C.GoString(cString) go_free(unsafe.Pointer(cString)) return result } -func (n *Node) Utf8Text(source []byte) string { +func (n Node) Utf8Text(source []byte) string { return string(source[n.StartByte():n.EndByte()]) } -func (n *Node) Utf16Text(source []uint16) []uint16 { +func (n Node) Utf16Text(source []uint16) []uint16 { return source[n.StartByte():n.EndByte()] } @@ -371,8 +371,8 @@ func (n *Node) Utf16Text(source []uint16) []uint16 { // // Note that the given node is considered the root of the cursor, // and the cursor cannot walk outside this node. -func (n *Node) Walk() *TreeCursor { - return newTreeCursor(*n) +func (n Node) Walk() *TreeCursor { + return newTreeCursor(n) } // Edit this node to keep it in-sync with source code that has been edited. @@ -387,6 +387,6 @@ func (n *Node) Edit(edit *InputEdit) { } // Check if two nodes are identical. -func (n *Node) Equals(other Node) bool { +func (n Node) Equals(other Node) bool { return bool(C.ts_node_eq(n._inner, other._inner)) } diff --git a/node_test.go b/node_test.go index 382865c..d4ac72f 100644 --- a/node_test.go +++ b/node_test.go @@ -48,11 +48,12 @@ func ExampleNode() { fmt.Println(rootNode.StartPosition()) fmt.Println(rootNode.EndPosition()) - functionNode := rootNode.Child(1) + functionNode, _ := rootNode.Child(1) fmt.Println(functionNode.Kind()) - fmt.Println(functionNode.ChildByFieldName("name").Kind()) + nameFieldNode, _ := functionNode.ChildByFieldName("name") + fmt.Println(nameFieldNode.Kind()) - functionNameNode := functionNode.Child(1) + functionNameNode, _ := functionNode.Child(1) fmt.Println(functionNameNode.StartPosition()) fmt.Println(functionNameNode.EndPosition()) @@ -68,7 +69,7 @@ func ExampleNode() { func TestNodeChild(t *testing.T) { tree := parseJsonExample() - arrayNode := tree.RootNode().Child(0) + arrayNode := nodeMust(tree.RootNode().Child(0)) assert.Equal(t, "array", arrayNode.Kind()) assert.EqualValues(t, 3, arrayNode.NamedChildCount()) @@ -78,13 +79,13 @@ func TestNodeChild(t *testing.T) { assert.Equal(t, Point{8, 1}, arrayNode.EndPosition()) assert.EqualValues(t, 7, arrayNode.ChildCount()) - leftBracketNode := arrayNode.Child(0) - numberNode := arrayNode.Child(1) - commaNode1 := arrayNode.Child(2) - falseNode := arrayNode.Child(3) - commaNode2 := arrayNode.Child(4) - objectNode := arrayNode.Child(5) - rightBracketNode := arrayNode.Child(6) + leftBracketNode := nodeMust(arrayNode.Child(0)) + numberNode := nodeMust(arrayNode.Child(1)) + commaNode1 := nodeMust(arrayNode.Child(2)) + falseNode := nodeMust(arrayNode.Child(3)) + commaNode2 := nodeMust(arrayNode.Child(4)) + objectNode := nodeMust(arrayNode.Child(5)) + rightBracketNode := nodeMust(arrayNode.Child(6)) assert.Equal(t, "[", leftBracketNode.Kind()) assert.Equal(t, "number", numberNode.Kind()) @@ -117,9 +118,9 @@ func TestNodeChild(t *testing.T) { assert.Equal(t, Point{7, 3}, objectNode.EndPosition()) assert.EqualValues(t, 3, objectNode.ChildCount()) - leftBraceNode := objectNode.Child(0) - pairNode := objectNode.Child(1) - rightBraceNode := objectNode.Child(2) + leftBraceNode := nodeMust(objectNode.Child(0)) + pairNode := nodeMust(objectNode.Child(1)) + rightBraceNode := nodeMust(objectNode.Child(2)) assert.Equal(t, "{", leftBraceNode.Kind()) assert.Equal(t, "pair", pairNode.Kind()) @@ -135,9 +136,9 @@ func TestNodeChild(t *testing.T) { assert.Equal(t, Point{6, 13}, pairNode.EndPosition()) assert.EqualValues(t, 3, pairNode.ChildCount()) - stringNode := pairNode.Child(0) - colonNode := pairNode.Child(1) - nullNode := pairNode.Child(2) + stringNode := nodeMust(pairNode.Child(0)) + colonNode := nodeMust(pairNode.Child(1)) + nullNode := nodeMust(pairNode.Child(2)) assert.Equal(t, "string", stringNode.Kind()) assert.Equal(t, ":", colonNode.Kind()) @@ -157,26 +158,28 @@ func TestNodeChild(t *testing.T) { assert.Equal(t, Point{6, 9}, nullNode.StartPosition()) assert.Equal(t, Point{6, 13}, nullNode.EndPosition()) - assert.Equal(t, pairNode, stringNode.Parent()) - assert.Equal(t, pairNode, nullNode.Parent()) - assert.Equal(t, objectNode, pairNode.Parent()) - assert.Equal(t, arrayNode, numberNode.Parent()) - assert.Equal(t, arrayNode, falseNode.Parent()) - assert.Equal(t, arrayNode, objectNode.Parent()) - assert.Equal(t, tree.RootNode(), arrayNode.Parent()) - assert.Nil(t, tree.RootNode().Parent()) - - assert.Equal(t, arrayNode, tree.RootNode().ChildWithDescendant(nullNode)) - assert.Equal(t, objectNode, arrayNode.ChildWithDescendant(nullNode)) - assert.Equal(t, pairNode, objectNode.ChildWithDescendant(nullNode)) - assert.Equal(t, nullNode, pairNode.ChildWithDescendant(nullNode)) - assert.Nil(t, nullNode.ChildWithDescendant(nullNode)) + rootNode := tree.RootNode() + + assert.Equal(t, pairNode, nodeMust(stringNode.Parent())) + assert.Equal(t, pairNode, nodeMust(nullNode.Parent())) + assert.Equal(t, objectNode, nodeMust(pairNode.Parent())) + assert.Equal(t, arrayNode, nodeMust(numberNode.Parent())) + assert.Equal(t, arrayNode, nodeMust(falseNode.Parent())) + assert.Equal(t, arrayNode, nodeMust(objectNode.Parent())) + assert.Equal(t, rootNode, nodeMust(arrayNode.Parent())) + assert.Nil(t, nodeMustNot(tree.RootNode().Parent())) + + assert.Equal(t, arrayNode, nodeMust(tree.RootNode().ChildWithDescendant(nullNode))) + assert.Equal(t, objectNode, nodeMust(arrayNode.ChildWithDescendant(nullNode))) + assert.Equal(t, pairNode, nodeMust(objectNode.ChildWithDescendant(nullNode))) + assert.Equal(t, nullNode, nodeMust(pairNode.ChildWithDescendant(nullNode))) + assert.Nil(t, nodeMustNot(nullNode.ChildWithDescendant(nullNode))) } func TestNodeChildren(t *testing.T) { tree := parseJsonExample() cursor := tree.Walk() - arrayNode := tree.RootNode().Child(0) + arrayNode := nodeMust(tree.RootNode().Child(0)) children := arrayNode.Children(cursor) var kinds []string @@ -227,14 +230,14 @@ func TestNodeChildrenByFieldName(t *testing.T) { tree := parser.Parse([]byte(source), nil) defer tree.Close() - node := tree.RootNode().Child(0) + node := nodeMust(tree.RootNode().Child(0)) assert.Equal(t, "if_statement", node.Kind()) cursor := tree.Walk() alternatives := node.ChildrenByFieldName("alternative", cursor) var alternativeTexts []string for _, alternative := range alternatives { - condition := alternative.ChildByFieldName("condition") + condition := nodeMust(alternative.ChildByFieldName("condition")) alternativeTexts = append(alternativeTexts, string(source[condition.StartByte():condition.EndByte()])) } assert.Equal(t, []string{"two", "three", "four"}, alternativeTexts) @@ -247,12 +250,12 @@ func TestNodeParentOfChildByFieldName(t *testing.T) { tree := parser.Parse([]byte("foo(a().b[0].c.d.e())"), nil) defer tree.Close() - callNode := tree.RootNode().NamedChild(0).NamedChild(0) + callNode := nodeMust(nodeMust(tree.RootNode().NamedChild(0)).NamedChild(0)) assert.Equal(t, "call_expression", callNode.Kind()) // Regression test - when a field points to a hidden node (in this case, `_expression`) // the hidden node should not be added to the node parent cache. - assert.Equal(t, callNode, callNode.ChildByFieldName("function").Parent()) + assert.Equal(t, callNode, nodeMust(nodeMust(callNode.ChildByFieldName("function")).Parent())) } func TestParentOfZeroWithNode(t *testing.T) { @@ -265,17 +268,17 @@ func TestParentOfZeroWithNode(t *testing.T) { tree := parser.Parse([]byte(code), nil) defer tree.Close() root := tree.RootNode() - functionDefinition := root.Child(0) - block := functionDefinition.Child(4) - blockParent := block.Parent() + functionDefinition := nodeMust(root.Child(0)) + block := nodeMust(functionDefinition.Child(4)) + blockParent := nodeMust(block.Parent()) assert.Equal(t, "(block)", block.ToSexp()) assert.Equal(t, "function_definition", blockParent.Kind()) assert.Equal(t, "(function_definition name: (identifier) parameters: (parameters (identifier)) body: (block))", blockParent.ToSexp()) - assert.Equal(t, functionDefinition, root.ChildWithDescendant(block)) - assert.Equal(t, block, functionDefinition.ChildWithDescendant(block)) - assert.Nil(t, block.ChildWithDescendant(block)) + assert.Equal(t, functionDefinition, nodeMust(root.ChildWithDescendant(block))) + assert.Equal(t, block, nodeMust(functionDefinition.ChildWithDescendant(block))) + assert.Nil(t, nodeMustNot(block.ChildWithDescendant(block))) } func TestFirstChildForOffset(t *testing.T) { @@ -285,12 +288,12 @@ func TestFirstChildForOffset(t *testing.T) { tree := parser.Parse([]byte("x10 + 100"), nil) defer tree.Close() - sumNode := tree.RootNode().Child(0).Child(0) + sumNode := nodeMust(nodeMust(tree.RootNode().Child(0)).Child(0)) - assert.Equal(t, "identifier", sumNode.FirstChildForByte(0).Kind()) - assert.Equal(t, "identifier", sumNode.FirstChildForByte(1).Kind()) - assert.Equal(t, "+", sumNode.FirstChildForByte(3).Kind()) - assert.Equal(t, "number", sumNode.FirstChildForByte(5).Kind()) + assert.Equal(t, "identifier", nodeMust(sumNode.FirstChildForByte(0)).Kind()) + assert.Equal(t, "identifier", nodeMust(sumNode.FirstChildForByte(1)).Kind()) + assert.Equal(t, "+", nodeMust(sumNode.FirstChildForByte(3)).Kind()) + assert.Equal(t, "number", nodeMust(sumNode.FirstChildForByte(5)).Kind()) } func TestFirstNamedChildForOffset(t *testing.T) { @@ -300,11 +303,11 @@ func TestFirstNamedChildForOffset(t *testing.T) { tree := parser.Parse([]byte("x10 + 100"), nil) defer tree.Close() - sumNode := tree.RootNode().Child(0).Child(0) + sumNode := nodeMust(nodeMust(tree.RootNode().Child(0)).Child(0)) - assert.Equal(t, "identifier", sumNode.FirstNamedChildForByte(0).Kind()) - assert.Equal(t, "identifier", sumNode.FirstNamedChildForByte(1).Kind()) - assert.Equal(t, "number", sumNode.FirstNamedChildForByte(3).Kind()) + assert.Equal(t, "identifier", nodeMust(sumNode.FirstNamedChildForByte(0)).Kind()) + assert.Equal(t, "identifier", nodeMust(sumNode.FirstNamedChildForByte(1)).Kind()) + assert.Equal(t, "number", nodeMust(sumNode.FirstNamedChildForByte(3)).Kind()) } func TestNodeFieldNameForChild(t *testing.T) { @@ -315,9 +318,9 @@ func TestNodeFieldNameForChild(t *testing.T) { tree := parser.Parse([]byte("int w = x + /* y is special! */ y;"), nil) defer tree.Close() translationUnitNode := tree.RootNode() - declarationNode := translationUnitNode.NamedChild(0) + declarationNode := nodeMust(translationUnitNode.NamedChild(0)) - binaryExpressionNode := declarationNode.ChildByFieldName("declarator").ChildByFieldName("value") + binaryExpressionNode := nodeMust(nodeMust(declarationNode.ChildByFieldName("declarator")).ChildByFieldName("value")) // ------------------- // left: (identifier) 0 @@ -343,9 +346,9 @@ func TestNodeFieldNameForNamedChild(t *testing.T) { tree := parser.Parse([]byte("int w = x + /* y is special! */ y;"), nil) defer tree.Close() translationUnitNode := tree.RootNode() - declarationNode := translationUnitNode.NamedChild(0) + declarationNode := nodeMust(translationUnitNode.NamedChild(0)) - binaryExpressionNode := declarationNode.ChildByFieldName("declarator").ChildByFieldName("value") + binaryExpressionNode := nodeMust(nodeMust(declarationNode.ChildByFieldName("declarator")).ChildByFieldName("value")) // ------------------- // left: (identifier) 0 @@ -373,18 +376,18 @@ func TestNodeChildByFieldNameWithExtraHiddenChildren(t *testing.T) { // Check that when searching for a child with a field name, we don't tree := parser.Parse([]byte("while a:\n pass"), nil) defer tree.Close() - whileNode := tree.RootNode().Child(0) + whileNode := nodeMust(tree.RootNode().Child(0)) assert.Equal(t, "while_statement", whileNode.Kind()) - assert.Equal(t, whileNode.Child(3), whileNode.ChildByFieldName("body")) + assert.Equal(t, nodeMust(whileNode.Child(3)), nodeMust(whileNode.ChildByFieldName("body"))) } func TestNodeNamedChild(t *testing.T) { tree := parseJsonExample() - arrayNode := tree.RootNode().Child(0) + arrayNode := nodeMust(tree.RootNode().Child(0)) - numberNode := arrayNode.NamedChild(0) - falseNode := arrayNode.NamedChild(1) - objectNode := arrayNode.NamedChild(2) + numberNode := nodeMust(arrayNode.NamedChild(0)) + falseNode := nodeMust(arrayNode.NamedChild(1)) + objectNode := nodeMust(arrayNode.NamedChild(2)) assert.Equal(t, "number", numberNode.Kind()) assert.EqualValues(t, strings.Index(JSON_EXAMPLE, "123"), numberNode.StartByte()) @@ -404,15 +407,15 @@ func TestNodeNamedChild(t *testing.T) { assert.Equal(t, Point{7, 3}, objectNode.EndPosition()) assert.EqualValues(t, 1, objectNode.NamedChildCount()) - pairNode := objectNode.NamedChild(0) + pairNode := nodeMust(objectNode.NamedChild(0)) assert.Equal(t, "pair", pairNode.Kind()) assert.EqualValues(t, strings.Index(JSON_EXAMPLE, "\"x\""), pairNode.StartByte()) assert.EqualValues(t, strings.Index(JSON_EXAMPLE, "null")+4, pairNode.EndByte()) assert.Equal(t, Point{6, 4}, pairNode.StartPosition()) assert.Equal(t, Point{6, 13}, pairNode.EndPosition()) - stringNode := pairNode.NamedChild(0) - nullNode := pairNode.NamedChild(1) + stringNode := nodeMust(pairNode.NamedChild(0)) + nullNode := nodeMust(pairNode.NamedChild(1)) assert.Equal(t, "string", stringNode.Kind()) assert.Equal(t, "null", nullNode.Kind()) @@ -427,20 +430,22 @@ func TestNodeNamedChild(t *testing.T) { assert.Equal(t, Point{6, 9}, nullNode.StartPosition()) assert.Equal(t, Point{6, 13}, nullNode.EndPosition()) - assert.Equal(t, pairNode, stringNode.Parent()) - assert.Equal(t, pairNode, nullNode.Parent()) - assert.Equal(t, objectNode, pairNode.Parent()) - assert.Equal(t, arrayNode, numberNode.Parent()) - assert.Equal(t, arrayNode, falseNode.Parent()) - assert.Equal(t, arrayNode, objectNode.Parent()) - assert.Equal(t, tree.RootNode(), arrayNode.Parent()) - assert.Nil(t, tree.RootNode().Parent()) - - assert.Equal(t, arrayNode, tree.RootNode().ChildWithDescendant(nullNode)) - assert.Equal(t, objectNode, arrayNode.ChildWithDescendant(nullNode)) - assert.Equal(t, pairNode, objectNode.ChildWithDescendant(nullNode)) - assert.Equal(t, nullNode, pairNode.ChildWithDescendant(nullNode)) - assert.Nil(t, nullNode.ChildWithDescendant(nullNode)) + rootNode := tree.RootNode() + + assert.Equal(t, pairNode, nodeMust(stringNode.Parent())) + assert.Equal(t, pairNode, nodeMust(nullNode.Parent())) + assert.Equal(t, objectNode, nodeMust(pairNode.Parent())) + assert.Equal(t, arrayNode, nodeMust(numberNode.Parent())) + assert.Equal(t, arrayNode, nodeMust(falseNode.Parent())) + assert.Equal(t, arrayNode, nodeMust(objectNode.Parent())) + assert.Equal(t, rootNode, nodeMust(arrayNode.Parent())) + assert.Nil(t, nodeMustNot(tree.RootNode().Parent())) + + assert.Equal(t, arrayNode, nodeMust(tree.RootNode().ChildWithDescendant(nullNode))) + assert.Equal(t, objectNode, nodeMust(arrayNode.ChildWithDescendant(nullNode))) + assert.Equal(t, pairNode, nodeMust(objectNode.ChildWithDescendant(nullNode))) + assert.Equal(t, nullNode, nodeMust(pairNode.ChildWithDescendant(nullNode))) + assert.Nil(t, nodeMustNot(nullNode.ChildWithDescendant(nullNode))) } func TestNodeDescendantCount(t *testing.T) { @@ -453,12 +458,12 @@ func TestNodeDescendantCount(t *testing.T) { cursor := valueNode.Walk() for i, node := range allNodes { cursor.GotoDescendant(uint32(i)) - assert.Equal(t, node, cursor.Node()) + assert.Equal(t, *node, cursor.Node()) } for i := len(allNodes) - 1; i >= 0; i-- { cursor.GotoDescendant(uint32(i)) - assert.Equal(t, allNodes[i], cursor.Node()) + assert.Equal(t, *allNodes[i], cursor.Node()) } } @@ -477,10 +482,10 @@ func TestDescendantCountSingleNodeTree(t *testing.T) { cursor.GotoDescendant(0) assert.EqualValues(t, 0, cursor.Depth()) - assert.Equal(t, allNodes[0], cursor.Node()) + assert.Equal(t, *allNodes[0], cursor.Node()) cursor.GotoDescendant(1) assert.EqualValues(t, 1, cursor.Depth()) - assert.Equal(t, allNodes[1], cursor.Node()) + assert.Equal(t, *allNodes[1], cursor.Node()) } func TestNodeDescendantForRange(t *testing.T) { @@ -489,7 +494,7 @@ func TestNodeDescendantForRange(t *testing.T) { // Leaf node exactly matches the given bounds - byte query colonIndex := strings.Index(JSON_EXAMPLE, ":") - colonNode := arrayNode.DescendantForByteRange(uint(colonIndex), uint(colonIndex+1)) + colonNode := nodeMust(arrayNode.DescendantForByteRange(uint(colonIndex), uint(colonIndex+1))) assert.NotNil(t, colonNode) assert.Equal(t, ":", colonNode.Kind()) assert.EqualValues(t, colonIndex, colonNode.StartByte()) @@ -498,7 +503,7 @@ func TestNodeDescendantForRange(t *testing.T) { assert.Equal(t, Point{6, 8}, colonNode.EndPosition()) // Leaf node exactly matches the given bounds - point query - colonNode = arrayNode.DescendantForPointRange(Point{6, 7}, Point{6, 8}) + colonNode = nodeMust(arrayNode.DescendantForPointRange(Point{6, 7}, Point{6, 8})) assert.NotNil(t, colonNode) assert.Equal(t, ":", colonNode.Kind()) assert.EqualValues(t, colonIndex, colonNode.StartByte()) @@ -507,7 +512,7 @@ func TestNodeDescendantForRange(t *testing.T) { assert.Equal(t, Point{6, 8}, colonNode.EndPosition()) // The given point is between two adjacent leaf nodes - byte query - colonNode = arrayNode.DescendantForByteRange(uint(colonIndex), uint(colonIndex)) + colonNode = nodeMust(arrayNode.DescendantForByteRange(uint(colonIndex), uint(colonIndex))) assert.NotNil(t, colonNode) assert.Equal(t, ":", colonNode.Kind()) assert.EqualValues(t, colonIndex, colonNode.StartByte()) @@ -516,7 +521,7 @@ func TestNodeDescendantForRange(t *testing.T) { assert.Equal(t, Point{6, 8}, colonNode.EndPosition()) // The given point is between two adjacent leaf nodes - point query - colonNode = arrayNode.DescendantForPointRange(Point{6, 7}, Point{6, 7}) + colonNode = nodeMust(arrayNode.DescendantForPointRange(Point{6, 7}, Point{6, 7})) assert.NotNil(t, colonNode) assert.Equal(t, ":", colonNode.Kind()) assert.EqualValues(t, colonIndex, colonNode.StartByte()) @@ -526,7 +531,7 @@ func TestNodeDescendantForRange(t *testing.T) { // Leaf node starts at the lower bound, ends after the upper bound - byte query stringIndex := strings.Index(JSON_EXAMPLE, "\"x\"") - stringNode := arrayNode.DescendantForByteRange(uint(stringIndex), uint(stringIndex+2)) + stringNode := nodeMust(arrayNode.DescendantForByteRange(uint(stringIndex), uint(stringIndex+2))) assert.NotNil(t, stringNode) assert.Equal(t, "string", stringNode.Kind()) assert.EqualValues(t, stringIndex, stringNode.StartByte()) @@ -535,7 +540,7 @@ func TestNodeDescendantForRange(t *testing.T) { assert.Equal(t, Point{6, 7}, stringNode.EndPosition()) // Leaf node starts at the lower bound, ends after the upper bound - point query - stringNode = arrayNode.DescendantForPointRange(Point{6, 4}, Point{6, 6}) + stringNode = nodeMust(arrayNode.DescendantForPointRange(Point{6, 4}, Point{6, 6})) assert.NotNil(t, stringNode) assert.Equal(t, "string", stringNode.Kind()) assert.EqualValues(t, stringIndex, stringNode.StartByte()) @@ -545,7 +550,7 @@ func TestNodeDescendantForRange(t *testing.T) { // Leaf node starts before the lower bound, ends at the upper bound - byte query nullIndex := strings.Index(JSON_EXAMPLE, "null") - nullNode := arrayNode.DescendantForByteRange(uint(nullIndex+1), uint(nullIndex+4)) + nullNode := nodeMust(arrayNode.DescendantForByteRange(uint(nullIndex+1), uint(nullIndex+4))) assert.NotNil(t, nullNode) assert.Equal(t, "null", nullNode.Kind()) assert.EqualValues(t, nullIndex, nullNode.StartByte()) @@ -554,7 +559,7 @@ func TestNodeDescendantForRange(t *testing.T) { assert.Equal(t, Point{6, 13}, nullNode.EndPosition()) // Leaf node starts before the lower bound, ends at the upper bound - point query - nullNode = arrayNode.DescendantForPointRange(Point{6, 11}, Point{6, 13}) + nullNode = nodeMust(arrayNode.DescendantForPointRange(Point{6, 11}, Point{6, 13})) assert.NotNil(t, nullNode) assert.Equal(t, "null", nullNode.Kind()) assert.EqualValues(t, nullIndex, nullNode.StartByte()) @@ -563,7 +568,7 @@ func TestNodeDescendantForRange(t *testing.T) { assert.Equal(t, Point{6, 13}, nullNode.EndPosition()) // The bounds span multiple leaf nodes - return the smallest node that does span it. - pairNode := arrayNode.DescendantForByteRange(uint(stringIndex+2), uint(stringIndex+4)) + pairNode := nodeMust(arrayNode.DescendantForByteRange(uint(stringIndex+2), uint(stringIndex+4))) assert.NotNil(t, pairNode) assert.Equal(t, "pair", pairNode.Kind()) assert.EqualValues(t, stringIndex, pairNode.StartByte()) @@ -571,10 +576,10 @@ func TestNodeDescendantForRange(t *testing.T) { assert.Equal(t, Point{6, 4}, pairNode.StartPosition()) assert.Equal(t, Point{6, 13}, pairNode.EndPosition()) - assert.Equal(t, colonNode.Parent(), pairNode) + assert.Equal(t, nodeMust(colonNode.Parent()), pairNode) // no leaf spans the given range - return the smallest node that does span it. - pairNode = arrayNode.NamedDescendantForPointRange(Point{6, 6}, Point{6, 8}) + pairNode = nodeMust(arrayNode.NamedDescendantForPointRange(Point{6, 6}, Point{6, 8})) assert.NotNil(t, pairNode) assert.Equal(t, "pair", pairNode.Kind()) assert.EqualValues(t, stringIndex, pairNode.StartByte()) @@ -583,8 +588,8 @@ func TestNodeDescendantForRange(t *testing.T) { assert.Equal(t, Point{6, 13}, pairNode.EndPosition()) // Negative test, start > end - assert.Nil(t, arrayNode.DescendantForByteRange(1, 0)) - assert.Nil(t, arrayNode.DescendantForPointRange(Point{6, 8}, Point{6, 7})) + assert.Nil(t, nodeMustNot(arrayNode.DescendantForByteRange(1, 0))) + assert.Nil(t, nodeMustNot(arrayNode.DescendantForPointRange(Point{6, 8}, Point{6, 7}))) } func TestNodeEdit(t *testing.T) { @@ -624,7 +629,7 @@ func TestRootNodeWithOffset(t *testing.T) { assert.Equal(t, Point{2, 4}, node.StartPosition()) assert.Equal(t, Point{2, 12}, node.EndPosition()) - child := node.Child(0).Child(2) + child := nodeMust(nodeMust(node.Child(0)).Child(2)) assert.Equal(t, "expression_statement", child.Kind()) assert.EqualValues(t, 15, child.StartByte()) assert.EqualValues(t, 16, child.EndByte()) @@ -651,7 +656,7 @@ func TestNodeIsExtra(t *testing.T) { defer tree.Close() rootNode := tree.RootNode() - commentNode := rootNode.DescendantForByteRange(7, 7) + commentNode := nodeMust(rootNode.DescendantForByteRange(7, 7)) assert.Equal(t, "program", rootNode.Kind()) assert.Equal(t, "comment", commentNode.Kind()) @@ -670,7 +675,7 @@ func TestNodeIsError(t *testing.T) { assert.Equal(t, "program", rootNode.Kind()) assert.True(t, rootNode.HasError()) - child := rootNode.Child(0) + child := nodeMust(rootNode.Child(0)) assert.Equal(t, "ERROR", child.Kind()) assert.True(t, child.IsError()) } @@ -683,9 +688,9 @@ func TestNodeSexp(t *testing.T) { defer tree.Close() rootNode := tree.RootNode() - ifNode := rootNode.DescendantForByteRange(0, 0) - parenNode := rootNode.DescendantForByteRange(3, 3) - identifierNode := rootNode.DescendantForByteRange(4, 4) + ifNode := nodeMust(rootNode.DescendantForByteRange(0, 0)) + parenNode := nodeMust(rootNode.DescendantForByteRange(3, 3)) + identifierNode := nodeMust(rootNode.DescendantForByteRange(4, 4)) assert.Equal(t, "if", ifNode.Kind()) assert.Equal(t, "(\"if\")", ifNode.ToSexp()) @@ -715,10 +720,10 @@ func TestNodeNumericSymbolsRespectSimpleAliases(t *testing.T) { root.ToSexp(), ) - outExprNode := root.Child(0).Child(0) + outExprNode := nodeMust(nodeMust(root.Child(0)).Child(0)) assert.Equal(t, "parenthesized_expression", outExprNode.Kind()) - innerExprNode := outExprNode.NamedChild(0).ChildByFieldName("arguments").NamedChild(0) + innerExprNode := nodeMust(nodeMust(nodeMust(outExprNode.NamedChild(0)).ChildByFieldName("arguments")).NamedChild(0)) assert.Equal(t, "parenthesized_expression", innerExprNode.Kind()) assert.Equal(t, outExprNode.KindId(), innerExprNode.KindId()) @@ -735,13 +740,13 @@ func TestNodeNumericSymbolsRespectSimpleAliases(t *testing.T) { root.ToSexp(), ) - binaryNode := root.Child(0) + binaryNode := nodeMust(root.Child(0)) assert.Equal(t, "binary", binaryNode.Kind()) - unaryMinusNode := binaryNode.ChildByFieldName("left").Child(0) + unaryMinusNode := nodeMust(nodeMust(binaryNode.ChildByFieldName("left")).Child(0)) assert.Equal(t, "-", unaryMinusNode.Kind()) - binaryMinusNode := binaryNode.ChildByFieldName("operator") + binaryMinusNode := nodeMust(binaryNode.ChildByFieldName("operator")) assert.Equal(t, "-", binaryMinusNode.Kind()) assert.Equal(t, unaryMinusNode.KindId(), binaryMinusNode.KindId()) } @@ -762,11 +767,11 @@ private: defer tree.Close() root := tree.RootNode() - classSpecifier := root.Child(0) - fieldDeclList := classSpecifier.ChildByFieldName("body") - fieldDecl := fieldDeclList.NamedChild(0) - fieldIdent := fieldDecl.ChildByFieldName("declarator") - assert.Equal(t, fieldIdent, fieldDecl.ChildWithDescendant(fieldIdent)) + classSpecifier := nodeMust(root.Child(0)) + fieldDeclList := nodeMust(classSpecifier.ChildByFieldName("body")) + fieldDecl := nodeMust(fieldDeclList.NamedChild(0)) + fieldIdent := nodeMust(fieldDecl.ChildByFieldName("declarator")) + assert.Equal(t, fieldIdent, nodeMust(fieldDecl.ChildWithDescendant(fieldIdent))) } func getAllNodes(tree *Tree) []*Node { @@ -775,7 +780,8 @@ func getAllNodes(tree *Tree) []*Node { cursor := tree.Walk() for { if !visitedChildren { - result = append(result, cursor.Node()) + node := cursor.Node() + result = append(result, &node) if !cursor.GotoFirstChild() { visitedChildren = true } @@ -794,3 +800,17 @@ func parseJsonExample() *Tree { parser.SetLanguage(getLanguage("json")) return parser.Parse([]byte(JSON_EXAMPLE), nil) } + +func nodeMust(node Node, ok bool) Node { + if !ok { + panic("node is nil") + } + return node +} + +func nodeMustNot(_ Node, ok bool) *Node { + if ok { + panic("node is not nil") + } + return nil +} diff --git a/parser_test.go b/parser_test.go index f91768b..2e7658f 100644 --- a/parser_test.go +++ b/parser_test.go @@ -133,7 +133,7 @@ func TestParsingSimpleString(t *testing.T) { assert.Equal(t, rootNode.ToSexp(), "(source_file (struct_item name: (type_identifier) body: (field_declaration_list)) (function_item name: (identifier) parameters: (parameters) body: (block)))") - structNode := rootNode.Child(0) + structNode := nodeMust(rootNode.Child(0)) assert.Equal(t, structNode.Kind(), "struct_item") } @@ -232,7 +232,7 @@ func TestParsingWithCustomUTF8Input(t *testing.T) { assert.Equal(t, root.ToSexp(), "(source_file (function_item (visibility_modifier) name: (identifier) parameters: (parameters) body: (block (integer_literal))))") assert.Equal(t, root.Kind(), "source_file") assert.False(t, root.HasError()) - assert.Equal(t, root.Child(0).Kind(), "function_item") + assert.Equal(t, nodeMust(root.Child(0)).Kind(), "function_item") } func TestParsingWithCustomUTF16LEInput(t *testing.T) { @@ -266,7 +266,7 @@ func TestParsingWithCustomUTF16LEInput(t *testing.T) { assert.Equal(t, root.ToSexp(), "(source_file (function_item (visibility_modifier) name: (identifier) parameters: (parameters) body: (block (integer_literal))))") assert.Equal(t, root.Kind(), "source_file") assert.False(t, root.HasError()) - assert.Equal(t, root.Child(0).Kind(), "function_item") + assert.Equal(t, nodeMust(root.Child(0)).Kind(), "function_item") } func TestParsingWithCustomUTF16BEInput(t *testing.T) { @@ -311,7 +311,7 @@ func TestParsingWithCustomUTF16BEInput(t *testing.T) { assert.Equal(t, root.ToSexp(), "(source_file (function_item (visibility_modifier) name: (identifier) parameters: (parameters) body: (block (integer_literal))))") assert.Equal(t, root.Kind(), "source_file") assert.False(t, root.HasError()) - assert.Equal(t, root.Child(0).Kind(), "function_item") + assert.Equal(t, nodeMust(root.Child(0)).Kind(), "function_item") } func TestParsingWithCallbackReturningOwnedStrings(t *testing.T) { @@ -717,7 +717,7 @@ func TestParsingWithTimeout(t *testing.T) { nil, nil, ) - assert.Equal(t, "array", tree.RootNode().Child(0).Kind()) + assert.Equal(t, "array", nodeMust(tree.RootNode().Child(0)).Kind()) } func TestParsingWithTimeoutAndReset(t *testing.T) { @@ -753,7 +753,7 @@ func TestParsingWithTimeoutAndReset(t *testing.T) { // it does not see the changes to the beginning of the source code. code = []byte("[null, 1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32]") tree = parser.ParseWithOptions(callback, nil, nil) - assert.Equal(t, "string", tree.RootNode().NamedChild(0).NamedChild(0).Kind()) + assert.Equal(t, "string", nodeMust(nodeMust(tree.RootNode().NamedChild(0)).NamedChild(0)).Kind()) code = []byte("[\"ok\", 1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32]") tree = parser.ParseWithOptions( @@ -770,7 +770,7 @@ func TestParsingWithTimeoutAndReset(t *testing.T) { parser.Reset() code = []byte("[null, 1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32]") tree = parser.ParseWithOptions(callback, nil, nil) - assert.Equal(t, "null", tree.RootNode().NamedChild(0).NamedChild(0).Kind()) + assert.Equal(t, "null", nodeMust(nodeMust(tree.RootNode().NamedChild(0)).NamedChild(0)).Kind()) } func TestParsingWithTimeoutAndImplicitReset(t *testing.T) { @@ -805,7 +805,7 @@ func TestParsingWithTimeoutAndImplicitReset(t *testing.T) { parser.SetLanguage(getLanguage("json")) code = []byte("[null, 1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32]") tree = parser.ParseWithOptions(callback, nil, nil) - assert.Equal(t, "null", tree.RootNode().NamedChild(0).NamedChild(0).Kind()) + assert.Equal(t, "null", nodeMust(nodeMust(tree.RootNode().NamedChild(0)).NamedChild(0)).Kind()) } func TestParsingWithTimeoutAndNoCompletion(t *testing.T) { @@ -951,7 +951,7 @@ func TestParsingWithOneIncludedRange(t *testing.T) { defer parser.Close() parser.SetLanguage(getLanguage("html")) htmlTree := parser.Parse([]byte(sourceCode), nil) - scriptContentNode := htmlTree.RootNode().Child(1).Child(1) + scriptContentNode := nodeMust(nodeMust(htmlTree.RootNode().Child(1)).Child(1)) assert.Equal(t, "raw_text", scriptContentNode.Kind()) assert.Equal(t, []Range{ @@ -985,16 +985,16 @@ func TestParsingWithMultipleIncludedRanges(t *testing.T) { defer parser.Close() parser.SetLanguage(getLanguage("javascript")) jsTree := parser.Parse([]byte(sourceCode), nil) - templateStringNode := jsTree.RootNode().DescendantForByteRange( + templateStringNode := nodeMust(jsTree.RootNode().DescendantForByteRange( uint(strings.Index(sourceCode, "`<")), uint(strings.Index(sourceCode, ">`")), - ) + )) assert.Equal(t, "template_string", templateStringNode.Kind()) - openQuoteNode := templateStringNode.Child(0) - interpolationNode1 := templateStringNode.Child(2) - interpolationNode2 := templateStringNode.Child(4) - closeQuoteNode := templateStringNode.Child(6) + openQuoteNode := nodeMust(templateStringNode.Child(0)) + interpolationNode1 := nodeMust(templateStringNode.Child(2)) + interpolationNode2 := nodeMust(templateStringNode.Child(4)) + closeQuoteNode := nodeMust(templateStringNode.Child(6)) parser.SetLanguage(getLanguage("html")) htmlRanges := []Range{ @@ -1027,11 +1027,11 @@ func TestParsingWithMultipleIncludedRanges(t *testing.T) { ) assert.Equal(t, htmlRanges, htmlTree.IncludedRanges()) - divElementNode := htmlTree.RootNode().Child(0) - helloTextNode := divElementNode.Child(1) - bElementNode := divElementNode.Child(2) - bStartTagNode := bElementNode.Child(0) - bEndTagNode := bElementNode.Child(1) + divElementNode := nodeMust(htmlTree.RootNode().Child(0)) + helloTextNode := nodeMust(divElementNode.Child(1)) + bElementNode := nodeMust(divElementNode.Child(2)) + bStartTagNode := nodeMust(bElementNode.Child(0)) + bEndTagNode := nodeMust(bElementNode.Child(1)) assert.Equal(t, "text", helloTextNode.Kind()) assert.Equal(t, uint(strings.Index(sourceCode, "Hello")), helloTextNode.StartByte()) @@ -1164,8 +1164,8 @@ func TestParsingWithExternalScannerThatUsesIncludedRangeBoundaries(t *testing.T) tree := parser.Parse([]byte(sourceCode), nil) defer tree.Close() root := tree.RootNode() - statement1 := root.Child(0) - statement2 := root.Child(1) + statement1 := nodeMust(root.Child(0)) + statement2 := nodeMust(root.Child(1)) assert.Equal( t, diff --git a/query.go b/query.go index 8e41ccf..adffef7 100644 --- a/query.go +++ b/query.go @@ -748,7 +748,7 @@ func (qc *QueryCursor) DidExceedMatchLimit() bool { // captures. Because multiple patterns can match the same set of nodes, // one match may contain captures that appear *before* some of the // captures from a previous match. -func (qc *QueryCursor) Matches(query *Query, node *Node, text []byte) QueryMatches { +func (qc *QueryCursor) Matches(query *Query, node Node, text []byte) QueryMatches { C.ts_query_cursor_exec(qc._inner, query._inner, node._inner) qm := QueryMatches{ _inner: qc._inner, @@ -783,7 +783,7 @@ func queryProgressCallback(state *C.TSQueryCursorState) C.bool { // captures. Because multiple patterns can match the same set of nodes, // one match may contain captures that appear *before* some of the // captures from a previous match. -func (qc *QueryCursor) MatchesWithOptions(query *Query, node *Node, text []byte, options QueryCursorOptions) QueryMatches { +func (qc *QueryCursor) MatchesWithOptions(query *Query, node Node, text []byte, options QueryCursorOptions) QueryMatches { cOptions := &C.TSQueryCursorOptions{ payload: pointer.Save(&options), progress_callback: (*[0]byte)(C.queryProgressCallback), @@ -813,7 +813,7 @@ func (qc *QueryCursor) MatchesWithOptions(query *Query, node *Node, text []byte, // // This is useful if you don't care about which pattern matched, and just // want a single, ordered sequence of captures. -func (qc *QueryCursor) Captures(query *Query, node *Node, text []byte) QueryCaptures { +func (qc *QueryCursor) Captures(query *Query, node Node, text []byte) QueryCaptures { C.ts_query_cursor_exec(qc._inner, query._inner, node._inner) return QueryCaptures{ _inner: qc._inner, diff --git a/query_test.go b/query_test.go index 602f8b2..5a78ac5 100644 --- a/query_test.go +++ b/query_test.go @@ -4887,7 +4887,7 @@ func TestQueryMaxStartDepthMore(t *testing.T) { for _, row := range rows { cursor.SetMaxStartDepth(&row.depth) - matches := cursor.Matches(query, &node, []byte(source)) + matches := cursor.Matches(query, node, []byte(source)) assert.Equal(t, row.matches, collectMatches(matches, query, source)) } } diff --git a/tree.go b/tree.go index 60db4e5..8c805a1 100644 --- a/tree.go +++ b/tree.go @@ -22,14 +22,14 @@ func newTree(inner *C.TSTree) *Tree { } // Get the root node of the syntax tree. -func (t *Tree) RootNode() *Node { - return &Node{_inner: C.ts_tree_root_node(t._inner)} +func (t *Tree) RootNode() Node { + return Node{_inner: C.ts_tree_root_node(t._inner)} } // Get the root node of the syntax tree, but with its position shifted // forward by the given offset. -func (t *Tree) RootNodeWithOffset(offsetBytes int, offsetExtent Point) *Node { - return &Node{_inner: C.ts_tree_root_node_with_offset(t._inner, C.uint(offsetBytes), offsetExtent.toTSPoint())} +func (t *Tree) RootNodeWithOffset(offsetBytes int, offsetExtent Point) Node { + return Node{_inner: C.ts_tree_root_node_with_offset(t._inner, C.uint(offsetBytes), offsetExtent.toTSPoint())} } // Get the language that was used to parse the syntax tree. diff --git a/tree_cursor.go b/tree_cursor.go index 3363fca..b253923 100644 --- a/tree_cursor.go +++ b/tree_cursor.go @@ -24,8 +24,8 @@ func (tc *TreeCursor) Copy() *TreeCursor { } // Get the tree cursor's current [Node]. -func (tc *TreeCursor) Node() *Node { - return newNode(C.ts_tree_cursor_current_node(&tc._inner)) +func (tc *TreeCursor) Node() Node { + return Node{_inner: C.ts_tree_cursor_current_node(&tc._inner)} } // Get the numerical field id of this tree cursor's current node. diff --git a/tree_test.go b/tree_test.go index 6f5c693..6624e2d 100644 --- a/tree_test.go +++ b/tree_test.go @@ -88,9 +88,9 @@ func TestTreeEdit(t *testing.T) { NewEndPosition: Point{Row: 0, Column: 2}, }) - expr := tree.RootNode().Child(0).Child(0) - child1 := expr.Child(0) - child2 := expr.Child(1) + expr := nodeMust(nodeMust(tree.RootNode().Child(0)).Child(0)) + child1 := nodeMust(expr.Child(0)) + child2 := nodeMust(expr.Child(1)) assert.True(t, expr.HasChanges()) assert.EqualValues(t, expr.StartByte(), 3) @@ -116,9 +116,9 @@ func TestTreeEdit(t *testing.T) { NewEndPosition: Point{Row: 0, Column: 5}, }) - expr := tree.RootNode().Child(0).Child(0) - child1 := expr.Child(0) - child2 := expr.Child(1) + expr := nodeMust(nodeMust(tree.RootNode().Child(0)).Child(0)) + child1 := nodeMust(expr.Child(0)) + child2 := nodeMust(expr.Child(1)) assert.True(t, expr.HasChanges()) assert.EqualValues(t, expr.StartByte(), 5) @@ -155,9 +155,9 @@ func TestTreeEdit(t *testing.T) { // assert!(!child2.has_changes()); // assert_eq!(child2.byte_range(), 9..12); - expr := tree.RootNode().Child(0).Child(0) - child1 := expr.Child(0) - child2 := expr.Child(1) + expr := nodeMust(nodeMust(tree.RootNode().Child(0)).Child(0)) + child1 := nodeMust(expr.Child(0)) + child2 := nodeMust(expr.Child(1)) assert.True(t, expr.HasChanges()) assert.EqualValues(t, expr.StartByte(), 4) @@ -183,9 +183,9 @@ func TestTreeEdit(t *testing.T) { NewEndPosition: Point{Row: 0, Column: 4}, }) - expr := tree.RootNode().Child(0).Child(0) - child1 := expr.Child(0) - child2 := expr.Child(1) + expr := nodeMust(nodeMust(tree.RootNode().Child(0)).Child(0)) + child1 := nodeMust(expr.Child(0)) + child2 := nodeMust(expr.Child(1)) assert.True(t, expr.HasChanges()) assert.EqualValues(t, expr.StartByte(), 4) @@ -211,10 +211,10 @@ func TestTreeEdit(t *testing.T) { NewEndPosition: Point{Row: 0, Column: 4}, }) - expr := tree.RootNode().Child(0).Child(0) - child1 := expr.Child(0) - child2 := expr.Child(1) - child3 := expr.Child(2) + expr := nodeMust(nodeMust(tree.RootNode().Child(0)).Child(0)) + child1 := nodeMust(expr.Child(0)) + child2 := nodeMust(expr.Child(1)) + child3 := nodeMust(expr.Child(2)) assert.True(t, expr.HasChanges()) assert.EqualValues(t, expr.StartByte(), 4) @@ -243,10 +243,10 @@ func TestTreeEdit(t *testing.T) { NewEndPosition: Point{Row: 0, Column: 16}, }) - expr := tree.RootNode().Child(0).Child(0) - child1 := expr.Child(0) - child2 := expr.Child(1) - child3 := expr.Child(2) + expr := nodeMust(nodeMust(tree.RootNode().Child(0)).Child(0)) + child1 := nodeMust(expr.Child(0)) + child2 := nodeMust(expr.Child(1)) + child3 := nodeMust(expr.Child(2)) assert.True(t, expr.HasChanges()) assert.EqualValues(t, expr.StartByte(), 2) @@ -275,10 +275,10 @@ func TestTreeEdit(t *testing.T) { NewEndPosition: Point{Row: 0, Column: 4}, }) - expr := tree.RootNode().Child(0).Child(0) - child1 := expr.Child(0) - child2 := expr.Child(1) - child3 := expr.Child(2) + expr := nodeMust(nodeMust(tree.RootNode().Child(0)).Child(0)) + child1 := nodeMust(expr.Child(0)) + child2 := nodeMust(expr.Child(1)) + child3 := nodeMust(expr.Child(2)) assert.True(t, expr.HasChanges()) assert.EqualValues(t, expr.StartByte(), 2) @@ -307,10 +307,10 @@ func TestTreeEdit(t *testing.T) { NewEndPosition: Point{Row: 0, Column: 8}, }) - expr := tree.RootNode().Child(0).Child(0) - child1 := expr.Child(0) - child2 := expr.Child(1) - child3 := expr.Child(2) + expr := nodeMust(nodeMust(tree.RootNode().Child(0)).Child(0)) + child1 := nodeMust(expr.Child(0)) + child2 := nodeMust(expr.Child(1)) + child3 := nodeMust(expr.Child(2)) assert.True(t, expr.HasChanges()) assert.EqualValues(t, expr.StartByte(), 2) @@ -602,8 +602,8 @@ func TestTreeNodeEquality(t *testing.T) { node2 := tree.RootNode() assert.Equal(t, node1, node2) - assert.Equal(t, node1.Child(0), node2.Child(0)) - assert.NotEqual(t, node1.Child(0), node2) + assert.Equal(t, nodeMust(node1.Child(0)), nodeMust(node2.Child(0))) + assert.NotEqual(t, nodeMust(node1.Child(0)), node2) } func TestGetChangedRanges(t *testing.T) {