diff --git a/data_check/processors/bigquery.py b/data_check/processors/bigquery.py index 0e1556e..fac581e 100644 --- a/data_check/processors/bigquery.py +++ b/data_check/processors/bigquery.py @@ -48,8 +48,31 @@ def with_statement_query_sampled(self) -> Select: return self.with_statement_query def check_input_is_sql(self, value: str) -> bool: - """Check if the input is a SQL query""" - return " select " in (" " + value).lower() and "from " in value.lower() + """Check if the input is a SQL query using sqlglot parser + + Returns True if the input is a valid SQL SELECT query. + Returns False if the input is a table reference (table, dataset.table, or project.dataset.table). + """ + from sqlglot import parse_one + + stripped_value = value.strip() + + try: + # Try to parse the input as SQL + parsed = parse_one(stripped_value, dialect=self.dialect, error_level=None) + + # If it's a Select expression with a FROM clause, it's valid SQL + if isinstance(parsed, Select): + return True + + # If parsing returns a non-Select type (Column, Sub, Identifier, etc.), + # it's a table reference - accept it + return False + + except Exception: + # If parsing fails completely, assume it's a table reference + # This handles edge cases where the table path might have special characters + return False def get_sql_exp_from_tablename(self, tablename: str) -> Select: return select("*").from_(tablename, dialect=self.dialect) diff --git a/tests/processors/test_bigquery.py b/tests/processors/test_bigquery.py index f7afcec..b1481af 100644 --- a/tests/processors/test_bigquery.py +++ b/tests/processors/test_bigquery.py @@ -148,3 +148,58 @@ def test_run_query_check_primary_keys_unique(): result = processor.run_query_check_primary_keys_unique(table="table1") assert result == (True, "") + + +def test_check_input_is_sql_with_complex_query(): + """Test that complex SQL queries with CTEs and subqueries are correctly detected as SQL""" + complex_query = """ + SELECT + * + FROM ( + SELECT + *, + ROW_NUMBER() OVER ( + PARTITION BY + page, query, timestamp, clicks, impressions, + ctr, country, device, position, page_content + ORDER BY timestamp DESC + ) AS row_num + FROM gorgias-growth-production.dreamdata_new.google_search + ) + WHERE row_num = 1; + """ + + table_path = "gorgias-growth-production.dreamdata_new.google_search" + + with patch('data_check.processors.bigquery.QueryBigQuery') as mock_client: + client_instance = Mock() + mock_client.return_value = client_instance + processor = BigQueryProcessor(complex_query, table_path) + + # Verify the complex query is correctly identified as SQL + assert processor.use_sql_query1 is True + assert processor._table1 is None + + # Verify the table path is correctly identified as a table + assert processor.use_sql_query2 is False + assert processor._table2 == table_path + + +def test_check_input_is_sql_method(): + """Test the check_input_is_sql method directly with various inputs""" + with patch('data_check.processors.bigquery.QueryBigQuery') as mock_client: + client_instance = Mock() + mock_client.return_value = client_instance + # Create a processor just to access the method + processor = BigQueryProcessor("dataset.table1", "dataset.table2") + + # Test SQL queries - should return True + assert processor.check_input_is_sql("SELECT * FROM table") is True + assert processor.check_input_is_sql("SELECT * FROM dataset.table") is True + assert processor.check_input_is_sql("WITH cte AS (SELECT 1) SELECT * FROM cte") is True + + # Test table paths - should return False + assert processor.check_input_is_sql("project.dataset.table") is False + assert processor.check_input_is_sql("dataset.table") is False + assert processor.check_input_is_sql("`project.dataset.table`") is False + assert processor.check_input_is_sql("simple_table") is False