diff --git a/core/query_formatter.py b/core/query_formatter.py index 437f316..b592324 100644 --- a/core/query_formatter.py +++ b/core/query_formatter.py @@ -222,6 +222,9 @@ def format_expression(node: Node): return node.name elif node.type == NodeType.LITERAL: + if isinstance(node.value, str): + return {'literal': node.value} + return node.value elif node.type == NodeType.FUNCTION: @@ -230,6 +233,10 @@ def format_expression(node: Node): args = [format_expression(arg) for arg in node.children] return {func_name: args[0] if len(args) == 1 else args} + elif node.type == NodeType.SUBQUERY: + subquery_node = list(node.children)[0] + return ast_to_json(subquery_node) + elif node.type == NodeType.OPERATOR: # format: {'operator': [left, right]} op_map = { diff --git a/tests/test_query_formatter.py b/tests/test_query_formatter.py index 2fe2c56..f3f3dae 100644 --- a/tests/test_query_formatter.py +++ b/tests/test_query_formatter.py @@ -5,17 +5,10 @@ OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode, JoinNode ) from core.ast.enums import JoinType, SortOrder -from re import sub +from mo_sql_parsing import parse formatter = QueryFormatter() -def normalize_sql(s): - """Remove extra whitespace and normalize SQL string to be used in comparisons""" - s = s.strip() - s = sub(r'\s+', ' ', s) - - return s - def test_basic_format(): # Construct expected AST # Tables @@ -82,7 +75,7 @@ def test_basic_format(): sql = formatter.format(ast) sql = sql.strip() - assert normalize_sql(sql) == normalize_sql(expected_sql) + assert parse(sql) == parse(expected_sql) def test_subquery_format(): @@ -142,9 +135,9 @@ def test_subquery_format(): ) AND 1 = 1 """ - # expected_sql = expected_sql.strip() + expected_sql = expected_sql.strip() - # sql = formatter.format(ast) - # sql = sql.strip() + sql = formatter.format(ast) + sql = sql.strip() - # assert normalize_sql(sql) == normalize_sql(expected_sql) \ No newline at end of file + assert parse(sql) == parse(expected_sql) \ No newline at end of file