diff --git a/CHANGELOG.md b/CHANGELOG.md index d69236e6..0a458048 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,8 @@ # Changelog ## [Unreleased] +### Added +- PREWHERE clause support for efficient pre-filtering in ClickHouse queries. Core SQL: `.prewhere()` method, ORM: `.prefilter()` and `.prefilter_by()` methods. Follows the same pattern as the FINAL clause implementation. ## [0.3.2] - 2024-06-12 ### Added diff --git a/clickhouse_sqlalchemy/drivers/compilers/sqlcompiler.py b/clickhouse_sqlalchemy/drivers/compilers/sqlcompiler.py index efe3dab7..febdb7cb 100644 --- a/clickhouse_sqlalchemy/drivers/compilers/sqlcompiler.py +++ b/clickhouse_sqlalchemy/drivers/compilers/sqlcompiler.py @@ -304,6 +304,11 @@ def _compose_select_body( if final_clause is not None: text += self.final_clause() + prewhere_clause = getattr(select, '_prewhere_clause', None) + + if prewhere_clause is not None: + text += self.prewhere_clause(select, **kwargs) + if select._where_criteria: t = self._generate_delimited_and_list( select._where_criteria, from_linter=from_linter, **kwargs @@ -346,6 +351,11 @@ def sample_clause(self, select, **kw): def final_clause(self): return " \nFINAL" + def prewhere_clause(self, select, **kw): + if select._prewhere_clause is None: + return "" + return " \nPREWHERE " + self.process(select._prewhere_clause, **kw) + def group_by_clause(self, select, **kw): text = "" diff --git a/clickhouse_sqlalchemy/ext/clauses.py b/clickhouse_sqlalchemy/ext/clauses.py index d0d6cda7..b9ae1e72 100644 --- a/clickhouse_sqlalchemy/ext/clauses.py +++ b/clickhouse_sqlalchemy/ext/clauses.py @@ -1,5 +1,5 @@ from sqlalchemy import exc -from sqlalchemy.sql import type_api, roles +from sqlalchemy.sql import type_api, roles, and_ from sqlalchemy.sql.elements import ( BindParameter, ColumnElement, @@ -30,6 +30,15 @@ def sample_clause(element): return SampleParam(None, element, unique=True) +def prewhere_clause(*clauses): + if not clauses: + return None + elif len(clauses) == 1: + return clauses[0] + else: + return and_(*clauses) + + class LimitByClause: def __init__(self, by_clauses, limit, offset): diff --git a/clickhouse_sqlalchemy/orm/query.py b/clickhouse_sqlalchemy/orm/query.py index 1b6b51ec..43ab251c 100644 --- a/clickhouse_sqlalchemy/orm/query.py +++ b/clickhouse_sqlalchemy/orm/query.py @@ -1,6 +1,7 @@ from functools import partial from sqlalchemy import exc +from sqlalchemy.sql import and_ from sqlalchemy.sql.base import _generative from sqlalchemy.orm.query import Query as BaseQuery @@ -9,6 +10,7 @@ LeftArrayJoin, LimitByClause, sample_clause, + prewhere_clause, ) @@ -23,6 +25,7 @@ def _compile_state_factory(orig_compile_state_factory, query, statement, new_stmt._sample_clause = sample_clause(query._sample) new_stmt._limit_by_clause = query._limit_by new_stmt._array_join = query._array_join + new_stmt._prewhere_clause = query._prewhere return rv @@ -34,6 +37,7 @@ class Query(BaseQuery): _sample = None _limit_by = None _array_join = None + _prewhere = None def _statement_20(self, *args, **kwargs): statement = super(Query, self)._statement_20(*args, **kwargs) @@ -106,6 +110,30 @@ def final(self): self._final = True return self + @_generative + def prefilter(self, *clauses): + self._prewhere = prewhere_clause(*clauses) + return self + + @_generative + def prefilter_by(self, **kwargs): + clauses = [] + + for key, value in kwargs.items(): + entity = self._entities[0] + if hasattr(entity, 'entity'): + entity = entity.entity + + if hasattr(entity, key): + column = getattr(entity, key) + clauses.append(column == value) + + if len(clauses) == 1: + self._prewhere = clauses[0] + else: + self._prewhere = and_(*clauses) + return self + @_generative def sample(self, sample): self._sample = sample diff --git a/clickhouse_sqlalchemy/sql/selectable.py b/clickhouse_sqlalchemy/sql/selectable.py index 2336c981..3585d735 100644 --- a/clickhouse_sqlalchemy/sql/selectable.py +++ b/clickhouse_sqlalchemy/sql/selectable.py @@ -8,6 +8,7 @@ LeftArrayJoin, LimitByClause, sample_clause, + prewhere_clause, ) @@ -22,6 +23,7 @@ class Select(StandardSelect): _sample_clause = None _limit_by_clause = None _array_join = None + _prewhere_clause = None @_generative def with_cube(self): @@ -43,6 +45,11 @@ def final(self): self._final_clause = True return self + @_generative + def prewhere(self, *clauses): + self._prewhere_clause = prewhere_clause(*clauses) + return self + @_generative def sample(self, sample): self._sample_clause = sample_clause(sample) diff --git a/docs/features.rst b/docs/features.rst index 20dce7e8..eb45ae71 100644 --- a/docs/features.rst +++ b/docs/features.rst @@ -933,6 +933,45 @@ becomes (respectively) SELECT ... FROM ... GROUP BY ... WITH ROLLUP SELECT ... FROM ... GROUP BY ... WITH TOTALS +PREWHERE ++++++++++ + +PREWHERE clause allows efficient pre-filtering of data before reading from disk. + + .. code-block:: python + + session.query(table.c.x).prefilter(table.c.x > 10) + +or + + .. code-block:: python + + select([table.c.x]).prewhere(table.c.x > 10) + +becomes + + .. code-block:: sql + + SELECT ... FROM ... PREWHERE x > 10 + +PREWHERE can be combined with WHERE clause for additional filtering: + + .. code-block:: python + + session.query(table.c.x).prefilter(table.c.x > 10).filter(table.c.x < 100) + +or + + .. code-block:: python + + select([table.c.x]).prewhere(table.c.x > 10).where(table.c.x < 100) + +becomes + + .. code-block:: sql + + SELECT ... FROM ... PREWHERE x > 10 WHERE x < 100 + FINAL +++++ diff --git a/tests/orm/test_select.py b/tests/orm/test_select.py index c21cbdb4..a11ccff5 100644 --- a/tests/orm/test_select.py +++ b/tests/orm/test_select.py @@ -215,6 +215,104 @@ def test_final(self): 'SELECT t1.x AS t1_x FROM t1 FINAL GROUP BY t1.x' ) + def test_prefilter(self): + table = self._make_table() + + query = self.session.query(table.c.x).prefilter(table.c.x > 10) + self.assertEqual( + self.compile(query), + 'SELECT t1.x AS t1_x FROM t1 PREWHERE t1.x > %(param_1)s' + ) + self.assertEqual( + self.compile(query, literal_binds=True), + 'SELECT t1.x AS t1_x FROM t1 PREWHERE t1.x > 10' + ) + + def test_prefilter_multiple_clauses(self): + table = self._make_table() + + query = self.session.query(table.c.x).prefilter( + table.c.x > 10, table.c.x < 100 + ) + self.assertEqual( + self.compile(query), + 'SELECT t1.x AS t1_x FROM t1 PREWHERE t1.x > %(param_1)s AND ' + 't1.x < %(param_2)s' + ) + self.assertEqual( + self.compile(query, literal_binds=True), + 'SELECT t1.x AS t1_x FROM t1 PREWHERE t1.x > 10 AND t1.x < 100' + ) + + def test_prefilter_by(self): + table = self._make_table() + + query = self.session.query(table.c.x).prefilter_by(x=10) + self.assertEqual( + self.compile(query), + 'SELECT t1.x AS t1_x FROM t1 PREWHERE t1.x = %(param_1)s' + ) + self.assertEqual( + self.compile(query, literal_binds=True), + 'SELECT t1.x AS t1_x FROM t1 PREWHERE t1.x = 10' + ) + + def test_prefilter_with_filter(self): + table = self._make_table() + + query = self.session.query(table.c.x).prefilter( + table.c.x > 10 + ).filter(table.c.x < 100) + self.assertEqual( + self.compile(query), + 'SELECT t1.x AS t1_x FROM t1 PREWHERE t1.x > %(param_1)s WHERE ' + 't1.x < %(param_2)s' + ) + self.assertEqual( + self.compile(query, literal_binds=True), + 'SELECT t1.x AS t1_x FROM t1 PREWHERE t1.x > 10 WHERE t1.x < 100' + ) + + def test_prefilter_with_final(self): + table = self._make_table() + + query = self.session.query(table.c.x).prefilter(table.c.x > 10).final() + self.assertEqual( + self.compile(query), + 'SELECT t1.x AS t1_x FROM t1 FINAL PREWHERE t1.x > %(param_1)s' + ) + self.assertEqual( + self.compile(query, literal_binds=True), + 'SELECT t1.x AS t1_x FROM t1 FINAL PREWHERE t1.x > 10' + ) + + def test_prefilter_empty_clauses(self): + table = self._make_table() + + query = self.session.query(table.c.x).prefilter() + self.assertEqual( + self.compile(query), + 'SELECT t1.x AS t1_x FROM t1' + ) + + def test_prefilter_by_empty_kwargs(self): + table = self._make_table() + + query = self.session.query(table.c.x).prefilter_by() + self.assertEqual( + self.compile(query), + 'SELECT t1.x AS t1_x FROM t1' + ) + + def test_prefilter_invalid_clause_type(self): + table = self._make_table() + + query = self.session.query(table.c.x).prefilter("invalid_string") + self.assertEqual( + self.compile(query), + 'SELECT t1.x AS t1_x FROM t1 PREWHERE %(param_1)s' + ) + def test_limit_by(self): table = self._make_table() diff --git a/tests/sql/test_selectable.py b/tests/sql/test_selectable.py index 975853d5..bfb58fae 100644 --- a/tests/sql/test_selectable.py +++ b/tests/sql/test_selectable.py @@ -84,6 +84,82 @@ def test_final(self): 'SELECT t1.x FROM t1 FINAL GROUP BY t1.x' ) + def test_prewhere(self): + table = self._make_table() + + query = select(table.c.x).prewhere(table.c.x > 10) + self.assertEqual( + self.compile(query), + 'SELECT t1.x FROM t1 PREWHERE t1.x > %(param_1)s' + ) + self.assertEqual( + self.compile(query, literal_binds=True), + 'SELECT t1.x FROM t1 PREWHERE t1.x > 10' + ) + + def test_prewhere_multiple_clauses(self): + table = self._make_table() + + query = select(table.c.x).prewhere( + table.c.x > 10, table.c.x < 100 + ) + self.assertEqual( + self.compile(query), + 'SELECT t1.x FROM t1 PREWHERE t1.x > %(param_1)s AND ' + 't1.x < %(param_2)s' + ) + self.assertEqual( + self.compile(query, literal_binds=True), + 'SELECT t1.x FROM t1 PREWHERE t1.x > 10 AND t1.x < 100' + ) + + def test_prewhere_with_where(self): + table = self._make_table() + + query = select(table.c.x).prewhere( + table.c.x > 10 + ).where(table.c.x < 100) + self.assertEqual( + self.compile(query), + 'SELECT t1.x FROM t1 PREWHERE t1.x > %(param_1)s WHERE ' + 't1.x < %(param_2)s' + ) + self.assertEqual( + self.compile(query, literal_binds=True), + 'SELECT t1.x FROM t1 PREWHERE t1.x > 10 WHERE t1.x < 100' + ) + + def test_prewhere_with_final(self): + table = self._make_table() + + query = select(table.c.x).prewhere(table.c.x > 10).final() + self.assertEqual( + self.compile(query), + 'SELECT t1.x FROM t1 FINAL PREWHERE t1.x > %(param_1)s' + ) + self.assertEqual( + self.compile(query, literal_binds=True), + 'SELECT t1.x FROM t1 FINAL PREWHERE t1.x > 10' + ) + + def test_prewhere_empty_clauses(self): + table = self._make_table() + + query = select(table.c.x).prewhere() + self.assertEqual( + self.compile(query), + 'SELECT t1.x FROM t1' + ) + + def test_prewhere_invalid_clause_type(self): + table = self._make_table() + + query = select(table.c.x).prewhere("invalid_string") + self.assertEqual( + self.compile(query), + 'SELECT t1.x FROM t1 PREWHERE %(param_1)s' + ) + def test_limit_by(self): table = self._make_table()