diff --git a/core/ast/node.py b/core/ast/node.py index 5e6231e..66babfc 100644 --- a/core/ast/node.py +++ b/core/ast/node.py @@ -1,299 +1,307 @@ -from datetime import datetime -from typing import List, Set, Optional, Union -from abc import ABC - -from .enums import NodeType, JoinType, SortOrder - -# ============================================================================ -# Base Node Structure -# ============================================================================ - -class Node(ABC): - """Base class for all nodes""" - def __init__(self, type: NodeType, children: Optional[Set['Node']|List['Node']] = None): - self.type = type - self.children = children if children is not None else set() - - def __eq__(self, other): - if not isinstance(other, Node): - return False - if self.type != other.type: - return False - if len(self.children) != len(other.children): - return False - # Compare children - if isinstance(self.children, set) and isinstance(other.children, set): - return self.children == other.children - elif isinstance(self.children, list) and isinstance(other.children, list): - return self.children == other.children - else: - return False - - def __hash__(self): - # Make nodes hashable by using their type and a hash of their children - if isinstance(self.children, set): - # For sets, create a deterministic hash by sorting children by their string representation - children_hash = hash(tuple(sorted(self.children, key=lambda x: str(x)))) - else: - # For lists, just hash the tuple directly - children_hash = hash(tuple(self.children)) - return hash((self.type, children_hash)) - - -# ============================================================================ -# Operand Nodes -# ============================================================================ - -class TableNode(Node): - """Table reference node""" - def __init__(self, _name: str, _alias: Optional[str] = None, **kwargs): - super().__init__(NodeType.TABLE, **kwargs) - self.name = _name - self.alias = _alias - - def __eq__(self, other): - if not isinstance(other, TableNode): - return False - return (super().__eq__(other) and - self.name == other.name and - self.alias == other.alias) - - def __hash__(self): - return hash((super().__hash__(), self.name, self.alias)) - - -# TODO - including query structure arguments (similar to QueryNode) in constructor. -class SubqueryNode(Node): - """Subquery node""" - def __init__(self, query: 'Node', _alias: Optional[str] = None, **kwargs): - super().__init__(NodeType.SUBQUERY, children={query}, **kwargs) - self.alias = _alias - - -class ColumnNode(Node): - """Column reference node""" - def __init__(self, _name: str, _alias: Optional[str] = None, _parent_alias: Optional[str] = None, _parent: Optional[TableNode|SubqueryNode] = None, **kwargs): - super().__init__(NodeType.COLUMN, **kwargs) - self.name = _name - self.alias = _alias - self.parent_alias = _parent_alias - self.parent = _parent - - def __eq__(self, other): - if not isinstance(other, ColumnNode): - return False - return (super().__eq__(other) and - self.name == other.name and - self.alias == other.alias and - self.parent_alias == other.parent_alias) - - def __hash__(self): - return hash((super().__hash__(), self.name, self.alias, self.parent_alias)) - - -class LiteralNode(Node): - """Literal value node""" - def __init__(self, _value: str|int|float|bool|datetime|None, **kwargs): - super().__init__(NodeType.LITERAL, **kwargs) - self.value = _value - - def __eq__(self, other): - if not isinstance(other, LiteralNode): - return False - return (super().__eq__(other) and - self.value == other.value) - - def __hash__(self): - return hash((super().__hash__(), self.value)) - - -class VarNode(Node): - """VarSQL variable node""" - def __init__(self, _name: str, **kwargs): - super().__init__(NodeType.VAR, **kwargs) - self.name = _name - - -class VarSetNode(Node): - """VarSQL variable set node""" - def __init__(self, _name: str, **kwargs): - super().__init__(NodeType.VARSET, **kwargs) - self.name = _name - - -class OperatorNode(Node): - """Operator node""" - def __init__(self, _left: Node, _name: str, _right: Optional[Node] = None, **kwargs): - children = [_left, _right] if _right else [_left] - super().__init__(NodeType.OPERATOR, children=children, **kwargs) - self.name = _name - - def __eq__(self, other): - if not isinstance(other, OperatorNode): - return False - return (super().__eq__(other) and - self.name == other.name) - - def __hash__(self): - return hash((super().__hash__(), self.name)) - - -class FunctionNode(Node): - """Function call node""" - def __init__(self, _name: str, _args: Optional[List[Node]] = None, _alias: Optional[str] = None, **kwargs): - if _args is None: - _args = [] - super().__init__(NodeType.FUNCTION, children=_args, **kwargs) - self.name = _name - self.alias = _alias - - def __eq__(self, other): - if not isinstance(other, FunctionNode): - return False - return (super().__eq__(other) and - self.name == other.name and - self.alias == other.alias) - - def __hash__(self): - return hash((super().__hash__(), self.name, self.alias)) - - -class JoinNode(Node): - """JOIN clause node""" - def __init__(self, _left_table: Union['TableNode', 'JoinNode'], _right_table: 'TableNode', _join_type: JoinType = JoinType.INNER, _on_condition: Optional['Node'] = None, **kwargs): - children = [_left_table, _right_table] - if _on_condition: - children.append(_on_condition) - super().__init__(NodeType.JOIN, children=children, **kwargs) - self.left_table = _left_table - self.right_table = _right_table - self.join_type = _join_type - self.on_condition = _on_condition - - def __eq__(self, other): - if not isinstance(other, JoinNode): - return False - return (super().__eq__(other) and - self.join_type == other.join_type) - - def __hash__(self): - return hash((super().__hash__(), self.join_type)) - -# ============================================================================ -# Query Structure Nodes -# ============================================================================ - -class SelectNode(Node): - """SELECT clause node""" - def __init__(self, _items: List['Node'], **kwargs): - super().__init__(NodeType.SELECT, children=_items, **kwargs) - - -# TODO - confine the valid NodeTypes as children of FromNode -class FromNode(Node): - """FROM clause node""" - def __init__(self, _sources: List['Node'], **kwargs): - super().__init__(NodeType.FROM, children=_sources, **kwargs) - - -class WhereNode(Node): - """WHERE clause node""" - def __init__(self, _predicates: List['Node'], **kwargs): - super().__init__(NodeType.WHERE, children=_predicates, **kwargs) - - -class GroupByNode(Node): - """GROUP BY clause node""" - def __init__(self, _items: List['Node'], **kwargs): - super().__init__(NodeType.GROUP_BY, children=_items, **kwargs) - - -class HavingNode(Node): - """HAVING clause node""" - def __init__(self, _predicates: List['Node'], **kwargs): - super().__init__(NodeType.HAVING, children=_predicates, **kwargs) - - -class OrderByItemNode(Node): - """Single ORDER BY item""" - def __init__(self, _column: Node, _sort: SortOrder = SortOrder.ASC, **kwargs): - super().__init__(NodeType.ORDER_BY_ITEM, children=[_column], **kwargs) - self.sort = _sort - - def __eq__(self, other): - if not isinstance(other, OrderByItemNode): - return False - return (super().__eq__(other) and - self.sort == other.sort) - - def __hash__(self): - return hash((super().__hash__(), self.sort)) - -class OrderByNode(Node): - """ORDER BY clause node""" - def __init__(self, _items: List[OrderByItemNode], **kwargs): - super().__init__(NodeType.ORDER_BY, children=_items, **kwargs) - - -class LimitNode(Node): - """LIMIT clause node""" - def __init__(self, _limit: int, **kwargs): - super().__init__(NodeType.LIMIT, **kwargs) - self.limit = _limit - - def __eq__(self, other): - if not isinstance(other, LimitNode): - return False - return (super().__eq__(other) and - self.limit == other.limit) - - def __hash__(self): - return hash((super().__hash__(), self.limit)) - - -class OffsetNode(Node): - """OFFSET clause node""" - def __init__(self, _offset: int, **kwargs): - super().__init__(NodeType.OFFSET, **kwargs) - self.offset = _offset - - def __eq__(self, other): - if not isinstance(other, OffsetNode): - return False - return (super().__eq__(other) and - self.offset == other.offset) - - def __hash__(self): - return hash((super().__hash__(), self.offset)) - - -class QueryNode(Node): - """Query root node""" - def __init__(self, - _select: Optional['Node'] = None, - _from: Optional['Node'] = None, - _where: Optional['Node'] = None, - _group_by: Optional['Node'] = None, - _having: Optional['Node'] = None, - _order_by: Optional['Node'] = None, - _limit: Optional['Node'] = None, - _offset: Optional['Node'] = None, - **kwargs): - children = [] - if _select: - children.append(_select) - if _from: - children.append(_from) - if _where: - children.append(_where) - if _group_by: - children.append(_group_by) - if _having: - children.append(_having) - if _order_by: - children.append(_order_by) - if _limit: - children.append(_limit) - if _offset: - children.append(_offset) - super().__init__(NodeType.QUERY, children=children, **kwargs) +from datetime import datetime +from typing import List, Set, Optional, Union +from abc import ABC + +from .enums import NodeType, JoinType, SortOrder + +# ============================================================================ +# Base Node Structure +# ============================================================================ + +class Node(ABC): + """Base class for all nodes""" + def __init__(self, type: NodeType, children: Optional[Set['Node']|List['Node']] = None): + self.type = type + self.children = children if children is not None else set() + + def __eq__(self, other): + if not isinstance(other, Node): + return False + if self.type != other.type: + return False + if len(self.children) != len(other.children): + return False + # Compare children + if isinstance(self.children, set) and isinstance(other.children, set): + return self.children == other.children + elif isinstance(self.children, list) and isinstance(other.children, list): + return self.children == other.children + else: + return False + + def __hash__(self): + # Make nodes hashable by using their type and a hash of their children + if isinstance(self.children, set): + # For sets, create a deterministic hash by sorting children by their string representation + children_hash = hash(tuple(sorted(self.children, key=lambda x: str(x)))) + else: + # For lists, just hash the tuple directly + children_hash = hash(tuple(self.children)) + return hash((self.type, children_hash)) + + +# ============================================================================ +# Operand Nodes +# ============================================================================ + +class TableNode(Node): + """Table reference node""" + def __init__(self, _name: str, _alias: Optional[str] = None, **kwargs): + super().__init__(NodeType.TABLE, **kwargs) + self.name = _name + self.alias = _alias + + def __eq__(self, other): + if not isinstance(other, TableNode): + return False + return (super().__eq__(other) and + self.name == other.name and + self.alias == other.alias) + + def __hash__(self): + return hash((super().__hash__(), self.name, self.alias)) + + +class SubqueryNode(Node): + """Subquery node""" + def __init__(self, query: 'Node', _alias: Optional[str] = None, **kwargs): + super().__init__(NodeType.SUBQUERY, children={query}, **kwargs) + self.alias = _alias + + def __eq__(self, other): + if not isinstance(other, SubqueryNode): + return False + return (super().__eq__(other) and + self.alias == other.alias) + + def __hash__(self): + return hash((super().__hash__(), self.alias)) + + +class ColumnNode(Node): + """Column reference node""" + def __init__(self, _name: str, _alias: Optional[str] = None, _parent_alias: Optional[str] = None, _parent: Optional[TableNode|SubqueryNode] = None, **kwargs): + super().__init__(NodeType.COLUMN, **kwargs) + self.name = _name + self.alias = _alias + self.parent_alias = _parent_alias + self.parent = _parent + + def __eq__(self, other): + if not isinstance(other, ColumnNode): + return False + return (super().__eq__(other) and + self.name == other.name and + self.alias == other.alias and + self.parent_alias == other.parent_alias) + + def __hash__(self): + return hash((super().__hash__(), self.name, self.alias, self.parent_alias)) + + +class LiteralNode(Node): + """Literal value node""" + def __init__(self, _value: str|int|float|bool|datetime|None, **kwargs): + super().__init__(NodeType.LITERAL, **kwargs) + self.value = _value + + def __eq__(self, other): + if not isinstance(other, LiteralNode): + return False + return (super().__eq__(other) and + self.value == other.value) + + def __hash__(self): + return hash((super().__hash__(), self.value)) + + +class VarNode(Node): + """VarSQL variable node""" + def __init__(self, _name: str, **kwargs): + super().__init__(NodeType.VAR, **kwargs) + self.name = _name + + +class VarSetNode(Node): + """VarSQL variable set node""" + def __init__(self, _name: str, **kwargs): + super().__init__(NodeType.VARSET, **kwargs) + self.name = _name + + +class OperatorNode(Node): + """Operator node""" + def __init__(self, _left: Node, _name: str, _right: Optional[Node] = None, **kwargs): + children = [_left, _right] if _right else [_left] + super().__init__(NodeType.OPERATOR, children=children, **kwargs) + self.name = _name + + def __eq__(self, other): + if not isinstance(other, OperatorNode): + return False + return (super().__eq__(other) and + self.name == other.name) + + def __hash__(self): + return hash((super().__hash__(), self.name)) + + +class FunctionNode(Node): + """Function call node""" + def __init__(self, _name: str, _args: Optional[List[Node]] = None, _alias: Optional[str] = None, **kwargs): + if _args is None: + _args = [] + super().__init__(NodeType.FUNCTION, children=_args, **kwargs) + self.name = _name + self.alias = _alias + + def __eq__(self, other): + if not isinstance(other, FunctionNode): + return False + return (super().__eq__(other) and + self.name == other.name and + self.alias == other.alias) + + def __hash__(self): + return hash((super().__hash__(), self.name, self.alias)) + + +class JoinNode(Node): + """JOIN clause node""" + def __init__(self, _left_table: Union['TableNode', 'JoinNode', 'SubqueryNode'], _right_table: Union['TableNode', 'SubqueryNode'], _join_type: JoinType = JoinType.INNER, _on_condition: Optional['Node'] = None, **kwargs): + children = [_left_table, _right_table] + if _on_condition: + children.append(_on_condition) + super().__init__(NodeType.JOIN, children=children, **kwargs) + self.left_table = _left_table + self.right_table = _right_table + self.join_type = _join_type + self.on_condition = _on_condition + + def __eq__(self, other): + if not isinstance(other, JoinNode): + return False + return (super().__eq__(other) and + self.join_type == other.join_type) + + def __hash__(self): + return hash((super().__hash__(), self.join_type)) + +# ============================================================================ +# Query Structure Nodes +# ============================================================================ + +class SelectNode(Node): + """SELECT clause node""" + def __init__(self, _items: List['Node'], **kwargs): + super().__init__(NodeType.SELECT, children=_items, **kwargs) + + +# TODO - confine the valid NodeTypes as children of FromNode +class FromNode(Node): + """FROM clause node""" + def __init__(self, _sources: List['Node'], **kwargs): + super().__init__(NodeType.FROM, children=_sources, **kwargs) + + +class WhereNode(Node): + """WHERE clause node""" + def __init__(self, _predicates: List['Node'], **kwargs): + super().__init__(NodeType.WHERE, children=_predicates, **kwargs) + + +class GroupByNode(Node): + """GROUP BY clause node""" + def __init__(self, _items: List['Node'], **kwargs): + super().__init__(NodeType.GROUP_BY, children=_items, **kwargs) + + +class HavingNode(Node): + """HAVING clause node""" + def __init__(self, _predicates: List['Node'], **kwargs): + super().__init__(NodeType.HAVING, children=_predicates, **kwargs) + + +class OrderByItemNode(Node): + """Single ORDER BY item""" + def __init__(self, _column: Node, _sort: SortOrder = SortOrder.ASC, **kwargs): + super().__init__(NodeType.ORDER_BY_ITEM, children=[_column], **kwargs) + self.sort = _sort + + def __eq__(self, other): + if not isinstance(other, OrderByItemNode): + return False + return (super().__eq__(other) and + self.sort == other.sort) + + def __hash__(self): + return hash((super().__hash__(), self.sort)) + +class OrderByNode(Node): + """ORDER BY clause node""" + def __init__(self, _items: List[OrderByItemNode], **kwargs): + super().__init__(NodeType.ORDER_BY, children=_items, **kwargs) + + +class LimitNode(Node): + """LIMIT clause node""" + def __init__(self, _limit: int, **kwargs): + super().__init__(NodeType.LIMIT, **kwargs) + self.limit = _limit + + def __eq__(self, other): + if not isinstance(other, LimitNode): + return False + return (super().__eq__(other) and + self.limit == other.limit) + + def __hash__(self): + return hash((super().__hash__(), self.limit)) + + +class OffsetNode(Node): + """OFFSET clause node""" + def __init__(self, _offset: int, **kwargs): + super().__init__(NodeType.OFFSET, **kwargs) + self.offset = _offset + + def __eq__(self, other): + if not isinstance(other, OffsetNode): + return False + return (super().__eq__(other) and + self.offset == other.offset) + + def __hash__(self): + return hash((super().__hash__(), self.offset)) + + +class QueryNode(Node): + """Query root node""" + def __init__(self, + _select: Optional['Node'] = None, + _from: Optional['Node'] = None, + _where: Optional['Node'] = None, + _group_by: Optional['Node'] = None, + _having: Optional['Node'] = None, + _order_by: Optional['Node'] = None, + _limit: Optional['Node'] = None, + _offset: Optional['Node'] = None, + **kwargs): + children = [] + if _select: + children.append(_select) + if _from: + children.append(_from) + if _where: + children.append(_where) + if _group_by: + children.append(_group_by) + if _having: + children.append(_having) + if _order_by: + children.append(_order_by) + if _limit: + children.append(_limit) + if _offset: + children.append(_offset) + super().__init__(NodeType.QUERY, children=children, **kwargs) \ No newline at end of file diff --git a/core/query_parser.py b/core/query_parser.py index deee3c1..88cb632 100644 --- a/core/query_parser.py +++ b/core/query_parser.py @@ -3,12 +3,20 @@ LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, OrderByNode, OrderByItemNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode, JoinNode ) -# TODO: implement SubqueryNode, VarNode, VarSetNode +# TODO: implement VarNode, VarSetNode from core.ast.enums import JoinType, SortOrder import mo_sql_parsing as mosql import json class QueryParser: + # mo_sql_parsing operator keys -> SQL display name + _OPERATOR_KEY_TO_NAME = { + 'eq': '=', 'neq': '!=', 'ne': '!=', + 'gt': '>', 'gte': '>=', 'lt': '<', 'lte': '<=', + 'and': 'AND', 'or': 'OR', 'in': 'IN', + } + _LIST_OPERATOR_KEYS = frozenset(_OPERATOR_KEY_TO_NAME.keys()) + @staticmethod def normalize_to_list(value): """Normalize mo_sql_parsing output to a list format. @@ -33,57 +41,24 @@ def normalize_to_list(value): ) def parse(self, query: str) -> QueryNode: - # [1] Call mo_sql_parser - # str -> Any (JSON) + # str -> mo_sql_parsing -> QueryNode mosql_ast = mosql.parse(query) - - # [2] Our new code - # Any (JSON) -> AST (QueryNode) - # Aliases dictionary - aliases = {} - - select_clause = None - from_clause = None - where_clause = None - group_by_clause = None - having_clause = None - order_by_clause = None - limit_clause = None - offset_clause = None - - if 'select' in mosql_ast: - select_clause = self.parse_select(self.normalize_to_list(mosql_ast['select']), aliases) - if 'from' in mosql_ast: - from_clause = self.parse_from(self.normalize_to_list(mosql_ast['from']), aliases) - if 'where' in mosql_ast: - where_clause = self.parse_where(mosql_ast['where'], aliases) - if 'groupby' in mosql_ast: - group_by_clause = self.parse_group_by(self.normalize_to_list(mosql_ast['groupby']), aliases) - if 'having' in mosql_ast: - having_clause = self.parse_having(mosql_ast['having'], aliases) - if 'orderby' in mosql_ast: - order_by_clause = self.parse_order_by(self.normalize_to_list(mosql_ast['orderby']), aliases) - if 'limit' in mosql_ast: - limit_clause = LimitNode(mosql_ast['limit']) - if 'offset' in mosql_ast: - offset_clause = OffsetNode(mosql_ast['offset']) - - return QueryNode( - _select=select_clause, - _from=from_clause, - _where=where_clause, - _group_by=group_by_clause, - _having=having_clause, - _order_by=order_by_clause, - _limit=limit_clause, - _offset=offset_clause - ) + return self.parse_query_dict(mosql_ast, aliases={}) def parse_select(self, select_list: list, aliases: dict) -> SelectNode: items = [] for item in select_list: if isinstance(item, dict) and 'value' in item: - expression = self.parse_expression(item['value']) + value = item['value'] + # Check if value is a subquery + if isinstance(value, dict) and 'select' in value: + # This is a subquery in SELECT clause + # Subquery has its own alias scope (no leaking to/from outer query) + subquery_query = self.parse_query_dict(value, aliases={}) + expression = SubqueryNode(subquery_query) + else: + expression = self.parse_expression(value, aliases) + # Handle alias - set for any node that has alias attribute if 'name' in item: alias = item['name'] @@ -94,7 +69,7 @@ def parse_select(self, select_list: list, aliases: dict) -> SelectNode: items.append(expression) else: # Handle direct expression (string, int, etc.) - expression = self.parse_expression(item) + expression = self.parse_expression(item, aliases) items.append(expression) return SelectNode(items) @@ -119,39 +94,73 @@ def parse_from(self, from_list: list, aliases: dict) -> FromNode: if isinstance(join_info, str): table_name = join_info alias = None + right_source = TableNode(table_name, alias) + elif isinstance(join_info, dict): + # Derived table: {'value': {}, 'name': } + value = join_info.get('value') + if isinstance(value, dict) and 'select' in value: + subquery_query = self.parse_query_dict(value, aliases={}) + alias = join_info.get('name') + right_source = SubqueryNode(subquery_query, alias) + elif 'select' in join_info: + # Subquery at top level (alias in 'name' if present) + subquery_query = self.parse_query_dict(join_info, aliases={}) + alias = join_info.get('name') + right_source = SubqueryNode(subquery_query, alias) + else: + table_name = join_info.get('value', join_info) + alias = join_info.get('name') + right_source = TableNode(table_name, alias) else: - table_name = join_info['value'] if isinstance(join_info, dict) else join_info - alias = join_info.get('name') if isinstance(join_info, dict) else None + table_name = join_info + alias = None + right_source = TableNode(table_name, alias) - right_table = TableNode(table_name, alias) - # Track table alias + # Track alias if alias: - aliases[alias] = right_table + aliases[alias] = right_source on_condition = None if 'on' in item: - on_condition = self.parse_expression(item['on']) + on_condition = self.parse_expression(item['on'], aliases) # Create join node - left_source might be a table or a previous join join_type = self.parse_join_type(join_key) - join_node = JoinNode(left_source, right_table, join_type, on_condition) + join_node = JoinNode(left_source, right_source, join_type, on_condition) # The result of this JOIN becomes the new left source for potential next JOIN left_source = join_node elif 'value' in item: - # This is a table reference - table_name = item['value'] + # Check if value is a subquery + value = item['value'] alias = item.get('name') - table_node = TableNode(table_name, alias) - # Track table alias - if alias: - aliases[alias] = table_node - if left_source is None: - # First table becomes the left source - left_source = table_node + if isinstance(value, dict) and 'select' in value: + # This is a subquery in FROM clause + # Subquery has its own alias scope (no leaking to/from outer query) + subquery_query = self.parse_query_dict(value, aliases={}) + subquery_node = SubqueryNode(subquery_query, alias) + # Track subquery alias + if alias: + aliases[alias] = subquery_node + + if left_source is None: + left_source = subquery_node + else: + sources.append(subquery_node) else: - # Multiple tables without explicit JOIN (cross join) - sources.append(table_node) + # This is a table reference + table_name = value + table_node = TableNode(table_name, alias) + # Track table alias + if alias: + aliases[alias] = table_node + + if left_source is None: + # First table becomes the left source + left_source = table_node + else: + # Multiple tables without explicit JOIN (cross join) + sources.append(table_node) elif isinstance(item, str): # Simple string table name table_node = TableNode(item) @@ -160,28 +169,28 @@ def parse_from(self, from_list: list, aliases: dict) -> FromNode: else: sources.append(table_node) - # Add the final left source (which might be a single table or chain of joins) + # Prepend the first/left source so order is preserved if left_source is not None: - sources.append(left_source) - + sources.insert(0, left_source) + return FromNode(sources) def parse_where(self, where_dict: dict, aliases: dict) -> WhereNode: predicates = [] - predicates.append(self.parse_expression(where_dict)) + predicates.append(self.parse_expression(where_dict, aliases)) return WhereNode(predicates) def parse_group_by(self, group_by_list: list, aliases: dict) -> GroupByNode: items = [] for item in group_by_list: if isinstance(item, dict) and 'value' in item: - expr = self.parse_expression(item['value']) + expr = self.parse_expression(item['value'], aliases) # Resolve aliases expr = self.resolve_aliases(expr, aliases) items.append(expr) else: # Handle direct expression (string, int, etc.) - expr = self.parse_expression(item) + expr = self.parse_expression(item, aliases) expr = self.resolve_aliases(expr, aliases) items.append(expr) @@ -189,7 +198,7 @@ def parse_group_by(self, group_by_list: list, aliases: dict) -> GroupByNode: def parse_having(self, having_dict: dict, aliases: dict) -> HavingNode: predicates = [] - expr = self.parse_expression(having_dict) + expr = self.parse_expression(having_dict, aliases) # Check if this expression references an aliased function from SELECT expr = self.resolve_aliases(expr, aliases) @@ -207,7 +216,7 @@ def parse_order_by(self, order_by_list: list, aliases: dict) -> OrderByNode: column = aliases[value] else: # Parse normally for other cases - column = self.parse_expression(value) + column = self.parse_expression(value, aliases) # Get sort order (default is ASC) sort_order = SortOrder.ASC @@ -221,7 +230,7 @@ def parse_order_by(self, order_by_list: list, aliases: dict) -> OrderByNode: items.append(order_by_item) else: # Handle direct expression (string, int, etc.) - column = self.parse_expression(item) + column = self.parse_expression(item, aliases) order_by_item = OrderByItemNode(column, SortOrder.ASC) items.append(order_by_item) @@ -267,7 +276,10 @@ def resolve_aliases(self, expr: Node, aliases: dict) -> Node: else: return expr - def parse_expression(self, expr) -> Node: + def parse_expression(self, expr, aliases: dict = None) -> Node: + if aliases is None: + aliases = {} + if isinstance(expr, str): # Column reference if '.' in expr: @@ -280,10 +292,17 @@ def parse_expression(self, expr) -> Node: if isinstance(expr, list): # List literals (for IN clauses) - parsed = [self.parse_expression(item) for item in expr] + parsed = [self.parse_expression(item, aliases) for item in expr] return parsed if isinstance(expr, dict): + # Check if this is a subquery (has 'select' key) + if 'select' in expr: + # This is a subquery - parse it recursively + # Subquery has its own alias scope (no leaking to/from outer query) + subquery_query = self.parse_query_dict(expr, aliases={}) + return SubqueryNode(subquery_query) + # Special cases first if 'all_columns' in expr: return ColumnNode('*') @@ -300,34 +319,44 @@ def parse_expression(self, expr) -> Node: value = expr[key] op_name = self.normalize_operator_name(key) - - # Pattern 1: Binary/N-ary operator with list of operands + key_lower = key.lower() + + # Pattern 1: List value (either n-ary operator or multi-arg function) if isinstance(value, list): if len(value) == 0: return LiteralNode(None) if len(value) == 1: - return self.parse_expression(value[0]) - - # Parse all operands - operands = [self.parse_expression(v) for v in value] - - # Chain multiple operands with the same operator - result = operands[0] - for operand in operands[1:]: - result = OperatorNode(result, op_name, operand) - return result + return self.parse_expression(value[0], aliases) + + operands = [self.parse_expression(v, aliases) for v in value] + + # SQL operators that mo_sql_parsing represents as key: [args] + if key_lower in QueryParser._LIST_OPERATOR_KEYS: + result = operands[0] + for operand in operands[1:]: + result = OperatorNode(result, op_name, operand) + return result + # Otherwise treat as multi-arg function (e.g. COALESCE, GREATEST) + return FunctionNode(op_name, _args=operands) # Pattern 2: Unary operator if key == 'not': - return OperatorNode(self.parse_expression(value), 'NOT') + return OperatorNode(self.parse_expression(value, aliases), 'NOT') - # Pattern 3: Function call + # Pattern 3: EXISTS operator with subquery + if key == 'exists' and isinstance(value, dict) and 'select' in value: + # Subquery has its own alias scope (no leaking to/from outer query) + subquery_query = self.parse_query_dict(value, aliases={}) + subquery_node = SubqueryNode(subquery_query) + return OperatorNode(subquery_node, 'EXISTS') + + # Pattern 4: Function call # Special case: COUNT(*), SUM(*), etc. if value == '*': return FunctionNode(op_name, _args=[ColumnNode('*')]) # Regular function - args = [self.parse_expression(value)] + args = [self.parse_expression(value, aliases)] return FunctionNode(op_name, _args=args) # No valid key found @@ -336,16 +365,50 @@ def parse_expression(self, expr) -> Node: # Other types return LiteralNode(expr) + def parse_query_dict(self, query_dict: dict, aliases: dict) -> QueryNode: + """Parse a mo_sql_parsing query-dict into a QueryNode. + """ + select_clause = None + from_clause = None + where_clause = None + group_by_clause = None + having_clause = None + order_by_clause = None + limit_clause = None + offset_clause = None + + if 'select' in query_dict: + select_clause = self.parse_select(self.normalize_to_list(query_dict['select']), aliases) + if 'from' in query_dict: + from_clause = self.parse_from(self.normalize_to_list(query_dict['from']), aliases) + if 'where' in query_dict: + where_clause = self.parse_where(query_dict['where'], aliases) + if 'groupby' in query_dict: + group_by_clause = self.parse_group_by(self.normalize_to_list(query_dict['groupby']), aliases) + if 'having' in query_dict: + having_clause = self.parse_having(query_dict['having'], aliases) + if 'orderby' in query_dict: + order_by_clause = self.parse_order_by(self.normalize_to_list(query_dict['orderby']), aliases) + if 'limit' in query_dict: + limit_clause = LimitNode(query_dict['limit']) + if 'offset' in query_dict: + offset_clause = OffsetNode(query_dict['offset']) + + return QueryNode( + _select=select_clause, + _from=from_clause, + _where=where_clause, + _group_by=group_by_clause, + _having=having_clause, + _order_by=order_by_clause, + _limit=limit_clause, + _offset=offset_clause + ) + @staticmethod def normalize_operator_name(key: str) -> str: """Convert mo_sql_parsing operator keys to SQL operator names.""" - mapping = { - 'eq': '=', 'neq': '!=', 'ne': '!=', - 'gt': '>', 'gte': '>=', - 'lt': '<', 'lte': '<=', - 'and': 'AND', 'or': 'OR', - } - return mapping.get(key.lower(), key.upper()) + return QueryParser._OPERATOR_KEY_TO_NAME.get(key.lower(), key.upper()) @staticmethod def parse_join_type(join_key: str) -> JoinType: @@ -363,4 +426,4 @@ def parse_join_type(join_key: str) -> JoinType: elif 'cross' in key_lower: return JoinType.CROSS - return JoinType.INNER # By default \ No newline at end of file + return JoinType.INNER \ No newline at end of file diff --git a/tests/test_query_parser.py b/tests/test_query_parser.py index 1bae504..054148a 100644 --- a/tests/test_query_parser.py +++ b/tests/test_query_parser.py @@ -134,273 +134,6 @@ def test_subquery_parse(): _from=from_clause, _where=where_clause ) - - # qb_ast = parser.parse(sql) - # assert qb_ast == expected_ast - - -def test_parse_1(): - query = get_query(1) - sql = query['pattern'] - - qb_ast = parser.parse(sql) - # assert isinstance(qb_ast, QueryNode) - - # Check SELECT clause - - # select_clause = None - # for child in qb_ast.children: - # if child.type == NodeType.SELECT: - # select_clause = child - # break - - # assert select_clause is not None - # assert len(select_clause.children) == 2 - - # Check FROM clause - # from_clause = None - # for child in qb_ast.children: - # if child.type == NodeType.FROM: - # from_clause = child - # break - - # assert from_clause is not None - # table_node = next(iter(from_clause.children)) - # assert isinstance(table_node, TableNode) - # assert table_node.name == "tweets" - - # Check WHERE clause - # where_clause = None - # for child in qb_ast.children: - # if child.type == NodeType.WHERE: - # where_clause = child - # break - - # assert where_clause is not None - # assert len(where_clause.children) == 1 - - # Check GROUP BY clause - # group_by_clause = None - # for child in qb_ast.children: - # if child.type == NodeType.GROUP_BY: - # group_by_clause = child - # break - - # assert group_by_clause is not None - # assert len(group_by_clause.children) == 1 - - -def test_parse_2(): - query = get_query(6) - sql = query['pattern'] - - qb_ast = parser.parse(sql) - # assert isinstance(qb_ast, QueryNode) - - # Check FROM clause has multiple tables - # from_clause = None - # for child in qb_ast.children: - # if child.type == NodeType.FROM: - # from_clause = child - # break - - # assert from_clause is not None - # assert len(from_clause.children) == 2 - - # Check WHERE clause has multiple conditions - # where_clause = None - # for child in qb_ast.children: - # if child.type == NodeType.WHERE: - # where_clause = child - # break - - # assert where_clause is not None - # condition = next(iter(where_clause.children)) - # assert isinstance(condition, OperatorNode) - - -def test_parse_4(): - query = get_query(12) - sql = query['pattern'] - - qb_ast = parser.parse(sql) - # assert isinstance(qb_ast, QueryNode) - - # Check FROM clause has multiple JOINs - # from_clause = None - # for child in qb_ast.children: - # if child.type == NodeType.FROM: - # from_clause = child - # break - - # assert from_clause is not None - # Check for JOIN nodes in the FROM clause - # join_count = 0 - # for child in from_clause.children: - # if hasattr(child, 'type') and 'JOIN' in str(child.type): - # join_count += 1 - # assert join_count >= 2 - - -def test_parse_5(): - query = get_query(16) - sql = query['pattern'] - - qb_ast = parser.parse(sql) - # assert isinstance(qb_ast, QueryNode) - - # Check SELECT clause has aggregation with subquery - # select_clause = None - # for child in qb_ast.children: - # if child.type == NodeType.SELECT: - # select_clause = child - # break - - # assert select_clause is not None - # assert len(select_clause.children) == 3 - - # Check for MAX function - # for child in select_clause.children: - # if isinstance(child, FunctionNode) and child.name == "MAX": - # assert True - # break - - -def test_parse_6(): - query = get_query(18) - sql = query['pattern'] - - qb_ast = parser.parse(sql) - # assert isinstance(qb_ast, QueryNode) - - # Check SELECT clause has DISTINCT - # select_clause = None - # for child in qb_ast.children: - # if child.type == NodeType.SELECT: - # select_clause = child - # break - - # assert select_clause is not None - # Check for DISTINCT keyword - # assert hasattr(select_clause, 'distinct') and select_clause.distinct - - # Check FROM clause has multiple tables - # from_clause = None - # for child in qb_ast.children: - # if child.type == NodeType.FROM: - # from_clause = child - # break - - # assert from_clause is not None - # assert len(from_clause.children) == 2 - -def test_parse_7(): - query = get_query(25) - sql = query['pattern'] - qb_ast = parser.parse(sql) - # assert isinstance(qb_ast, QueryNode) - - # Check WHERE clause has boolean logic - # where_clause = None - # for child in qb_ast.children: - # if child.type == NodeType.WHERE: - # where_clause = child - # break - - # assert where_clause is not None - # condition = next(iter(where_clause.children)) - # assert isinstance(condition, OperatorNode) - # assert condition.name == "AND" - - -def test_parse_8(): - query = get_query(29) - sql = query['pattern'] - - qb_ast = parser.parse(sql) - # assert isinstance(qb_ast, QueryNode) - - # Check for UNION operation (this query has UNION) - # Check if the query contains UNION - # assert 'UNION' in sql.upper() - - # Check for subqueries in WHERE clause - # where_clause = None - # for child in qb_ast.children: - # if child.type == NodeType.WHERE: - # where_clause = child - # break - - # assert where_clause is not None - - -def test_parse_9(): - query = get_query(31) - sql = query['pattern'] - - qb_ast = parser.parse(sql) - # assert isinstance(qb_ast, QueryNode) - - # Check SELECT clause has complex aggregation - # select_clause = None - # for child in qb_ast.children: - # if child.type == NodeType.SELECT: - # select_clause = child - # break - - # assert select_clause is not None - # assert len(select_clause.children) == 3 - - # Check for CASE statement - # for child in select_clause.children: - # if isinstance(child, FunctionNode) and child.name == "CASE": - # assert True - # break - - # Check GROUP BY clause - # group_by_clause = None - # for child in qb_ast.children: - # if child.type == NodeType.GROUP_BY: - # group_by_clause = child - # break - - # assert group_by_clause is not None - - -def test_parse_10(): - query = get_query(42) - sql = query['pattern'] - - qb_ast = parser.parse(sql) - # assert isinstance(qb_ast, QueryNode) - - # Check SELECT clause - # select_clause = None - # for child in qb_ast.children: - # if child.type == NodeType.SELECT: - # select_clause = child - # break - - # assert select_clause is not None - # assert len(select_clause.children) == 2 - - # Check WHERE clause has complex conditions - # where_clause = None - # for child in qb_ast.children: - # if child.type == NodeType.WHERE: - # where_clause = child - # break - - # assert where_clause is not None - - # Check GROUP BY clause - # group_by_clause = None - # for child in qb_ast.children: - # if child.type == NodeType.GROUP_BY: - # group_by_clause = child - # break - - # assert group_by_clause is not None - # assert len(group_by_clause.children) == 2 \ No newline at end of file + assert qb_ast == expected_ast \ No newline at end of file