From d5aafd07b1bdc25a8611978d2ef73cd563770dda Mon Sep 17 00:00:00 2001 From: Jeffrey Lovitz Date: Mon, 14 Feb 2022 14:51:28 -0500 Subject: [PATCH] Allow WHERE predicates in node and edge patterns --- lib/src/ast_node_pattern.c | 36 ++++++++++++++++++++++++++++--- lib/src/ast_rel_pattern.c | 36 +++++++++++++++++++++++++++---- lib/src/cypher-parser.h.in | 44 ++++++++++++++++++++++++++++++++------ lib/src/parser.c | 22 +++++++++++-------- lib/src/parser.leg | 19 ++++++++++------ lib/test/check_errors.c | 2 +- 6 files changed, 129 insertions(+), 30 deletions(-) diff --git a/lib/src/ast_node_pattern.c b/lib/src/ast_node_pattern.c index 3b06a5d..9687747 100644 --- a/lib/src/ast_node_pattern.c +++ b/lib/src/ast_node_pattern.c @@ -25,6 +25,7 @@ struct node_pattern cypher_astnode_t _astnode; const cypher_astnode_t *identifier; const cypher_astnode_t *properties; + const cypher_astnode_t *predicate; size_t nlabels; const cypher_astnode_t *labels[]; }; @@ -44,8 +45,9 @@ const struct cypher_astnode_vt cypher_node_pattern_astnode_vt = cypher_astnode_t *cypher_ast_node_pattern(const cypher_astnode_t *identifier, cypher_astnode_t * const *labels, unsigned int nlabels, - const cypher_astnode_t *properties, cypher_astnode_t **children, - unsigned int nchildren, struct cypher_input_range range) + const cypher_astnode_t *properties, const cypher_astnode_t *predicate, + cypher_astnode_t **children, unsigned int nchildren, + struct cypher_input_range range) { REQUIRE_CHILD_OPTIONAL(children, nchildren, identifier, CYPHER_AST_IDENTIFIER, NULL); @@ -54,6 +56,8 @@ cypher_astnode_t *cypher_ast_node_pattern(const cypher_astnode_t *identifier, cypher_astnode_instanceof(properties, CYPHER_AST_MAP) || cypher_astnode_instanceof(properties, CYPHER_AST_PARAMETER), NULL); REQUIRE_CONTAINS_OPTIONAL(children, nchildren, properties, NULL); + REQUIRE_CHILD_OPTIONAL(children, nchildren, predicate, + CYPHER_AST_EXPRESSION, NULL); struct node_pattern *node = calloc(1, sizeof(struct node_pattern) + nlabels * sizeof(cypher_astnode_t *)); @@ -70,6 +74,7 @@ cypher_astnode_t *cypher_ast_node_pattern(const cypher_astnode_t *identifier, memcpy(node->labels, labels, nlabels * sizeof(cypher_astnode_t *)); node->nlabels = nlabels; node->properties = properties; + node->predicate = predicate; return &(node->_astnode); int errsv; @@ -103,8 +108,12 @@ cypher_astnode_t *clone(const cypher_astnode_t *self, cypher_astnode_t *properties = (node->properties == NULL) ? NULL : children[child_index(self, node->properties)]; + cypher_astnode_t *predicate = (node->predicate == NULL) ? NULL : + children[child_index(self, node->predicate)]; + cypher_astnode_t *clone = cypher_ast_node_pattern(identifier, labels, - node->nlabels, properties, children, self->nchildren, self->range); + node->nlabels, properties, predicate, children, self->nchildren, + self->range); int errsv = errno; free(labels); errno = errsv; @@ -155,6 +164,16 @@ const cypher_astnode_t *cypher_ast_node_pattern_get_properties( } +const cypher_astnode_t *cypher_ast_node_pattern_get_predicate( + const cypher_astnode_t *astnode) +{ + REQUIRE_TYPE(astnode, CYPHER_AST_NODE_PATTERN, NULL); + struct node_pattern *node = container_of(astnode, struct node_pattern, + _astnode); + return node->predicate; +} + + ssize_t detailstr(const cypher_astnode_t *self, char *str, size_t size) { REQUIRE_TYPE(self, CYPHER_AST_NODE_PATTERN, -1); @@ -202,6 +221,17 @@ ssize_t detailstr(const cypher_astnode_t *self, char *str, size_t size) n += r; } + if (node->predicate != NULL) + { + r = snprintf(str+n, (n < size)? size-n : 0, ", where=@%u", + node->predicate->ordinal); + if (r < 0) + { + return -1; + } + n += r; + } + if (n < size) { str[n] = ')'; diff --git a/lib/src/ast_rel_pattern.c b/lib/src/ast_rel_pattern.c index 7f83143..0b1c6b4 100644 --- a/lib/src/ast_rel_pattern.c +++ b/lib/src/ast_rel_pattern.c @@ -27,6 +27,7 @@ struct rel_pattern const cypher_astnode_t *identifier; const cypher_astnode_t *varlength; const cypher_astnode_t *properties; + const cypher_astnode_t *predicate; size_t nreltypes; const cypher_astnode_t *reltypes[]; }; @@ -47,8 +48,9 @@ const struct cypher_astnode_vt cypher_rel_pattern_astnode_vt = cypher_astnode_t *cypher_ast_rel_pattern(enum cypher_rel_direction direction, const cypher_astnode_t *identifier, cypher_astnode_t * const *reltypes, unsigned int nreltypes, const cypher_astnode_t *properties, - const cypher_astnode_t *varlength, cypher_astnode_t **children, - unsigned int nchildren, struct cypher_input_range range) + const cypher_astnode_t *predicate, const cypher_astnode_t *varlength, + cypher_astnode_t **children, unsigned int nchildren, + struct cypher_input_range range) { REQUIRE_CHILD_OPTIONAL(children, nchildren, identifier, CYPHER_AST_IDENTIFIER, NULL); @@ -60,6 +62,8 @@ cypher_astnode_t *cypher_ast_rel_pattern(enum cypher_rel_direction direction, REQUIRE_CONTAINS_OPTIONAL(children, nchildren, properties, NULL); REQUIRE_CHILD_OPTIONAL(children, nchildren, varlength, CYPHER_AST_RANGE, NULL); + REQUIRE_CHILD_OPTIONAL(children, nchildren, predicate, + CYPHER_AST_EXPRESSION, NULL); struct rel_pattern *node = calloc(1, sizeof(struct rel_pattern) + nreltypes * sizeof(cypher_astnode_t *)); @@ -78,6 +82,7 @@ cypher_astnode_t *cypher_ast_rel_pattern(enum cypher_rel_direction direction, node->nreltypes = nreltypes; node->varlength = varlength; node->properties = properties; + node->predicate = predicate; return &(node->_astnode); int errsv; @@ -109,12 +114,14 @@ cypher_astnode_t *clone(const cypher_astnode_t *self, } cypher_astnode_t *properties = (node->properties == NULL) ? NULL : children[child_index(self, node->properties)]; + cypher_astnode_t *predicate = (node->predicate == NULL) ? NULL : + children[child_index(self, node->predicate)]; cypher_astnode_t *varlength = (node->varlength == NULL) ? NULL : children[child_index(self, node->varlength)]; cypher_astnode_t *clone = cypher_ast_rel_pattern(node->direction, - identifier, reltypes, node->nreltypes, properties, varlength, - children, self->nchildren, self->range); + identifier, reltypes, node->nreltypes, properties, predicate, + varlength, children, self->nchildren, self->range); int errsv = errno; free(reltypes); errno = errsv; @@ -185,6 +192,16 @@ const cypher_astnode_t *cypher_ast_rel_pattern_get_properties( } +const cypher_astnode_t *cypher_ast_rel_pattern_get_predicate( + const cypher_astnode_t *astnode) +{ + REQUIRE_TYPE(astnode, CYPHER_AST_REL_PATTERN, NULL); + struct rel_pattern *node = container_of(astnode, struct rel_pattern, + _astnode); + return node->predicate; +} + + ssize_t detailstr(const cypher_astnode_t *self, char *str, size_t size) { REQUIRE_TYPE(self, CYPHER_AST_REL_PATTERN, -1); @@ -243,6 +260,17 @@ ssize_t detailstr(const cypher_astnode_t *self, char *str, size_t size) n += r; } + if (node->predicate != NULL) + { + r = snprintf(str+n, (n < size)? size-n : 0, ", where=@%u", + node->predicate->ordinal); + if (r < 0) + { + return -1; + } + n += r; + } + r = snprintf(str+n, (n < size)? size-n : 0, "]-%s", (node->direction == CYPHER_REL_OUTBOUND)? ">" : ""); if (r < 0) diff --git a/lib/src/cypher-parser.h.in b/lib/src/cypher-parser.h.in index 96353d5..cb50d28 100644 --- a/lib/src/cypher-parser.h.in +++ b/lib/src/cypher-parser.h.in @@ -1984,9 +1984,9 @@ const cypher_astnode_t *cypher_ast_match_get_hint( const cypher_astnode_t *node, unsigned int index); /** - * Get the predicate of a `CYPHER_AST_PREDICATE` node. + * Get the predicate of a `CYPHER_AST_MATCH` node. * - * If the node is not an instance of `CYPHER_AST_PREDICATE` then the result + * If the node is not an instance of `CYPHER_AST_MATCH` then the result * will be undefined. * * @param [node] The AST node. @@ -5141,6 +5141,7 @@ const cypher_astnode_t *cypher_ast_pattern_path_get_element( * @param [nlabels] The number of labels in the pattern. * @param [properties] A `CYPHER_AST_MAP` node, a `CYPHER_AST_PARAMETER` node, * or null. + * @param [predicate] A `CYPHER_AST_EXPRESSION` node, or null. * @param [children] The children of the node. * @param [nchildren] The number of children. * @param [range] The input range. @@ -5149,8 +5150,9 @@ const cypher_astnode_t *cypher_ast_pattern_path_get_element( __cypherlang_must_check cypher_astnode_t *cypher_ast_node_pattern(const cypher_astnode_t *identifier, cypher_astnode_t * const *labels, unsigned int nlabels, - const cypher_astnode_t *properties, cypher_astnode_t **children, - unsigned int nchildren, struct cypher_input_range range); + const cypher_astnode_t *properties, const cypher_astnode_t *predicate, + cypher_astnode_t **children, unsigned int nchildren, + struct cypher_input_range range); /** * Get the identifier of a `CYPHER_AST_NODE_PATTERN` node. @@ -5205,6 +5207,20 @@ const cypher_astnode_t *cypher_ast_node_pattern_get_properties( const cypher_astnode_t *node); +/** + * Get the predicate of a `CYPHER_AST_NODE_PATTERN` node. + * + * If the node is not an instance of `CYPHER_AST_NODE_PATTERN` then the result + * will be undefined. + * + * @param [node] The AST node. + * @return A `CYPHER_AST_PREDICATE` node, or null. + */ +__cypherlang_pure +const cypher_astnode_t *cypher_ast_node_pattern_get_predicate( + const cypher_astnode_t *node); + + /** * Construct a `CYPHER_AST_REL_PATTERN` node. * @@ -5214,6 +5230,7 @@ const cypher_astnode_t *cypher_ast_node_pattern_get_properties( * `CYPHER_AST_RELTYPE`. * @param [nreltypes] The number of relationship types in the pattern. * @param [properties] A `CYPHER_AST_MAP` node, a `CYPHER_AST_PARAMETER` node, + * @param [predicate] A `CYPHER_AST_EXPRESSION` node, or null. * or null. * @param [varlength] A `CYPHER_AST_RANGE` node, or null. * @param [children] The children of the node. @@ -5225,8 +5242,9 @@ __cypherlang_must_check cypher_astnode_t *cypher_ast_rel_pattern(enum cypher_rel_direction direction, const cypher_astnode_t *identifier, cypher_astnode_t * const *reltypes, unsigned int nreltypes, const cypher_astnode_t *properties, - const cypher_astnode_t *varlength, cypher_astnode_t **children, - unsigned int nchildren, struct cypher_input_range range); + const cypher_astnode_t *predicate, const cypher_astnode_t *varlength, + cypher_astnode_t **children, unsigned int nchildren, + struct cypher_input_range range); /** @@ -5308,6 +5326,20 @@ const cypher_astnode_t *cypher_ast_rel_pattern_get_properties( const cypher_astnode_t *node); +/** + * Get the predicate of a `CYPHER_AST_REL_PATTERN` node. + * + * If the node is not an instance of `CYPHER_AST_REL_PATTERN` then the result + * will be undefined. + * + * @param [node] The AST node. + * @return A `CYPHER_AST_PREDICATE` node, or null. + */ +__cypherlang_pure +const cypher_astnode_t *cypher_ast_rel_pattern_get_predicate( + const cypher_astnode_t *node); + + /** * Construct a `CYPHER_AST_RANGE` node. * diff --git a/lib/src/parser.c b/lib/src/parser.c index 63d6b5b..3867b7e 100644 --- a/lib/src/parser.c +++ b/lib/src/parser.c @@ -460,15 +460,17 @@ static cypher_astnode_t *_shortest_path(yycontext *yy, bool single, cypher_astnode_t *path); #define pattern_path() _pattern_path(yy) static cypher_astnode_t *_pattern_path(yycontext *yy); -#define node_pattern(i, p) _node_pattern(yy, i, p) +#define node_pattern(i, p, c) _node_pattern(yy, i, p, c) static cypher_astnode_t *_node_pattern(yycontext *yy, - cypher_astnode_t *identifier, cypher_astnode_t *properties); + cypher_astnode_t *identifier, cypher_astnode_t *properties, + cypher_astnode_t *predicate); #define simple_rel_pattern(d) \ - _rel_pattern(yy, CYPHER_REL_##d, NULL, NULL, NULL) -#define rel_pattern(d, i, r, p) _rel_pattern(yy, CYPHER_REL_##d, i, r, p) + _rel_pattern(yy, CYPHER_REL_##d, NULL, NULL, NULL, NULL) +#define rel_pattern(d, i, r, p, c) _rel_pattern(yy, CYPHER_REL_##d, i, r, p, c) static cypher_astnode_t *_rel_pattern(yycontext *yy, enum cypher_rel_direction direction, cypher_astnode_t *identifier, - cypher_astnode_t *varlength, cypher_astnode_t *properties); + cypher_astnode_t *varlength, cypher_astnode_t *properties, + cypher_astnode_t *predicate); #define range(s, e) _range(yy, s, e) static cypher_astnode_t *_range(yycontext *yy, cypher_astnode_t *start, cypher_astnode_t *end); @@ -2958,13 +2960,13 @@ cypher_astnode_t *_pattern_path(yycontext *yy) cypher_astnode_t *_node_pattern(yycontext *yy, cypher_astnode_t *identifier, - cypher_astnode_t *properties) + cypher_astnode_t *properties, cypher_astnode_t *predicate) { assert(yy->prev_block != NULL && "An AST node can only be created immediately after a `>` in the grammar"); cypher_astnode_t *node = cypher_ast_node_pattern(identifier, astnodes_elements(&(yy->prev_block->sequence)), - astnodes_size(&(yy->prev_block->sequence)), properties, + astnodes_size(&(yy->prev_block->sequence)), properties, predicate, astnodes_elements(&(yy->prev_block->children)), astnodes_size(&(yy->prev_block->children)), yy->prev_block->range); @@ -2982,13 +2984,15 @@ cypher_astnode_t *_node_pattern(yycontext *yy, cypher_astnode_t *identifier, cypher_astnode_t *_rel_pattern(yycontext *yy, enum cypher_rel_direction direction, cypher_astnode_t *identifier, - cypher_astnode_t *varlength, cypher_astnode_t *properties) + cypher_astnode_t *varlength, cypher_astnode_t *properties, + cypher_astnode_t *predicate) { assert(yy->prev_block != NULL && "An AST node can only be created immediately after a `>` in the grammar"); cypher_astnode_t *node = cypher_ast_rel_pattern(direction, identifier, astnodes_elements(&(yy->prev_block->sequence)), - astnodes_size(&(yy->prev_block->sequence)), properties, varlength, + astnodes_size(&(yy->prev_block->sequence)), + properties, predicate, varlength, astnodes_elements(&(yy->prev_block->children)), astnodes_size(&(yy->prev_block->children)), yy->prev_block->range); diff --git a/lib/src/parser.leg b/lib/src/parser.leg index e6f1902..ca6b283 100644 --- a/lib/src/parser.leg +++ b/lib/src/parser.leg @@ -802,7 +802,8 @@ node-pattern = < LEFT-PAREN - (i:identifier | i:_null_) ( n:label { sequence_add(n); } )* (p:pattern-properties | p:_null_) - RIGHT-PAREN > { $$ = node_pattern(i, p); } + ( WHERE c:expression | c:_null_ ) + RIGHT-PAREN > { $$ = node_pattern(i, p, c); } relationship-pattern = < ( LEFT-ARROW-HEAD - DASH - @@ -812,9 +813,11 @@ relationship-pattern = ) | LEFT-SQ-PAREN - (i:identifier | i:_null_) rel-types? (l:rel-varlength | l:_null_) - (p:pattern-properties | p:_null_) RIGHT-SQ-PAREN - DASH - ( - RIGHT-ARROW-HEAD > { $$ = rel_pattern(BIDIRECTIONAL, i, l, p); } - | _empty_ > { $$ = rel_pattern(INBOUND, i, l, p); } + (p:pattern-properties | p:_null_) + ( WHERE c:expression | c:_null_ ) + RIGHT-SQ-PAREN - DASH + ( - RIGHT-ARROW-HEAD > { $$ = rel_pattern(BIDIRECTIONAL, i, l, p, c); } + | _empty_ > { $$ = rel_pattern(INBOUND, i, l, p, c); } ) ) | DASH - @@ -824,9 +827,11 @@ relationship-pattern = ) | LEFT-SQ-PAREN - (i:identifier | i:_null_) rel-types? (l:rel-varlength | l:_null_) - (p:pattern-properties | p:_null_) RIGHT-SQ-PAREN - DASH - ( - RIGHT-ARROW-HEAD > { $$ = rel_pattern(OUTBOUND, i, l, p); } - | _empty_ > { $$ = rel_pattern(BIDIRECTIONAL, i, l, p); } + (p:pattern-properties | p:_null_) + ( WHERE c:expression | c:_null_ ) + RIGHT-SQ-PAREN - DASH + ( - RIGHT-ARROW-HEAD > { $$ = rel_pattern(OUTBOUND, i, l, p, c); } + | _empty_ > { $$ = rel_pattern(BIDIRECTIONAL, i, l, p, c); } ) ) ) diff --git a/lib/test/check_errors.c b/lib/test/check_errors.c index 2585d16..f87a057 100644 --- a/lib/test/check_errors.c +++ b/lib/test/check_errors.c @@ -295,7 +295,7 @@ START_TEST (track_error_position_across_statements) ck_assert_int_eq(pos.column, 11); ck_assert_int_eq(pos.offset, 71); ck_assert_str_eq(cypher_parse_error_message(err), - "Invalid input 'e': expected a label, '{', a parameter or ')'"); + "Invalid input 'e': expected a label, '{', a parameter, WHERE or ')'"); } END_TEST