diff --git a/cosmotech/coal/store/store.py b/cosmotech/coal/store/store.py index d940fffa..059fdf9f 100644 --- a/cosmotech/coal/store/store.py +++ b/cosmotech/coal/store/store.py @@ -72,7 +72,7 @@ def add_table(self, table_name: str, data=pyarrow.Table, replace: bool = False): rows = curs.adbc_ingest(table_name, data, "replace" if replace else "create_append") LOGGER.debug(T("coal.common.data_transfer.rows_inserted").format(rows=rows, table_name=table_name)) - def execute_query(self, sql_query: str, parameters: list = (None,)) -> pyarrow.Table: + def execute_query(self, sql_query: str, parameters: list = None) -> pyarrow.Table: batch_size = 1024 batch_size_increment = 1024 while True: diff --git a/tests/integration/coal/test_store/test_store_store.py b/tests/integration/coal/test_store/test_integration_store_store.py similarity index 62% rename from tests/integration/coal/test_store/test_store_store.py rename to tests/integration/coal/test_store/test_integration_store_store.py index 126685e7..9409d7b2 100644 --- a/tests/integration/coal/test_store/test_store_store.py +++ b/tests/integration/coal/test_store/test_integration_store_store.py @@ -18,7 +18,7 @@ def store(): store.reset() -class TestStore: +class TestIntegrationStore: """Tests for the store class.""" def test_get_table(self, store): @@ -66,3 +66,37 @@ def test_add_get_table_with_upper_and_lower_case(self, store): assert upper_result assert UPPER_result assert upper_result == UPPER_result + + def test_execute_query_without_parameters(self, store): + """Test execute_query with a plain SQL query and no parameters""" + + # Arrange + table_name = "items" + table = pa.Table.from_arrays([pa.array([1, 2, 3]), pa.array(["a", "b", "c"])], names=["id", "name"]) + store.add_table(table_name, table) + + # Act + result = store.execute_query(f'SELECT * FROM "{table_name}" ORDER BY id') + + # Assert + assert result is not None + assert result.num_rows == 3 + assert result.column("id").to_pylist() == [1, 2, 3] + assert result.column("name").to_pylist() == ["a", "b", "c"] + + def test_execute_query_with_parameters(self, store): + """Test execute_query with a parameterized SQL query""" + + # Arrange + table_name = "items" + table = pa.Table.from_arrays([pa.array([1, 2, 3]), pa.array(["a", "b", "c"])], names=["id", "name"]) + store.add_table(table_name, table) + + # Act + result = store.execute_query(f'SELECT * FROM "{table_name}" WHERE id = ?', parameters=[2]) + + # Assert + assert result is not None + assert result.num_rows == 1 + assert result.column("id").to_pylist() == [2] + assert result.column("name").to_pylist() == ["b"] diff --git a/tests/unit/coal/test_store/test_store_store.py b/tests/unit/coal/test_store/test_store_store.py index 8e37f5b6..de1fb24e 100644 --- a/tests/unit/coal/test_store/test_store_store.py +++ b/tests/unit/coal/test_store/test_store_store.py @@ -270,7 +270,7 @@ def test_execute_query(self, mock_connect): # Assert mock_connect.assert_called_once() mock_cursor.adbc_statement.set_options.assert_called_once_with(**{"adbc.sqlite.query.batch_rows": "1024"}) - mock_cursor.execute.assert_called_once_with(sql_query, (None,)) + mock_cursor.execute.assert_called_once_with(sql_query, None) mock_cursor.fetch_arrow_table.assert_called_once() assert result == expected_table @@ -306,7 +306,7 @@ def test_execute_query_with_oserror(self, mock_connect): # First call with batch_size = 1024, second with batch_size = 2048 mock_cursor.adbc_statement.set_options.assert_any_call(**{"adbc.sqlite.query.batch_rows": "1024"}) mock_cursor.adbc_statement.set_options.assert_any_call(**{"adbc.sqlite.query.batch_rows": "2048"}) - mock_cursor.execute.assert_called_once_with(sql_query, (None,)) + mock_cursor.execute.assert_called_once_with(sql_query, None) mock_cursor.fetch_arrow_table.assert_called_once() assert result == expected_table