Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Changelog

## [Unreleased]
### Added
- PREWHERE clause support for efficient pre-filtering in ClickHouse queries. Core SQL: `.prewhere()` method, ORM: `.prefilter()` and `.prefilter_by()` methods. Follows the same pattern as the FINAL clause implementation.

## [0.3.2] - 2024-06-12
### Added
Expand Down
10 changes: 10 additions & 0 deletions clickhouse_sqlalchemy/drivers/compilers/sqlcompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,11 @@ def _compose_select_body(
if final_clause is not None:
text += self.final_clause()

prewhere_clause = getattr(select, '_prewhere_clause', None)

if prewhere_clause is not None:
text += self.prewhere_clause(select, **kwargs)

if select._where_criteria:
t = self._generate_delimited_and_list(
select._where_criteria, from_linter=from_linter, **kwargs
Expand Down Expand Up @@ -346,6 +351,11 @@ def sample_clause(self, select, **kw):
def final_clause(self):
return " \nFINAL"

def prewhere_clause(self, select, **kw):
if select._prewhere_clause is None:
return ""
return " \nPREWHERE " + self.process(select._prewhere_clause, **kw)

def group_by_clause(self, select, **kw):
text = ""

Expand Down
11 changes: 10 additions & 1 deletion clickhouse_sqlalchemy/ext/clauses.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from sqlalchemy import exc
from sqlalchemy.sql import type_api, roles
from sqlalchemy.sql import type_api, roles, and_
from sqlalchemy.sql.elements import (
BindParameter,
ColumnElement,
Expand Down Expand Up @@ -30,6 +30,15 @@ def sample_clause(element):
return SampleParam(None, element, unique=True)


def prewhere_clause(*clauses):
if not clauses:
return None
elif len(clauses) == 1:
return clauses[0]
else:
return and_(*clauses)


class LimitByClause:

def __init__(self, by_clauses, limit, offset):
Expand Down
28 changes: 28 additions & 0 deletions clickhouse_sqlalchemy/orm/query.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import partial

from sqlalchemy import exc
from sqlalchemy.sql import and_
from sqlalchemy.sql.base import _generative
from sqlalchemy.orm.query import Query as BaseQuery

Expand All @@ -9,6 +10,7 @@
LeftArrayJoin,
LimitByClause,
sample_clause,
prewhere_clause,
)


Expand All @@ -23,6 +25,7 @@ def _compile_state_factory(orig_compile_state_factory, query, statement,
new_stmt._sample_clause = sample_clause(query._sample)
new_stmt._limit_by_clause = query._limit_by
new_stmt._array_join = query._array_join
new_stmt._prewhere_clause = query._prewhere
return rv


Expand All @@ -34,6 +37,7 @@ class Query(BaseQuery):
_sample = None
_limit_by = None
_array_join = None
_prewhere = None

def _statement_20(self, *args, **kwargs):
statement = super(Query, self)._statement_20(*args, **kwargs)
Expand Down Expand Up @@ -106,6 +110,30 @@ def final(self):
self._final = True
return self

@_generative
def prefilter(self, *clauses):
self._prewhere = prewhere_clause(*clauses)
return self

@_generative
def prefilter_by(self, **kwargs):
clauses = []

for key, value in kwargs.items():
entity = self._entities[0]
if hasattr(entity, 'entity'):
entity = entity.entity

if hasattr(entity, key):
column = getattr(entity, key)
clauses.append(column == value)

if len(clauses) == 1:
self._prewhere = clauses[0]
else:
self._prewhere = and_(*clauses)
return self

@_generative
def sample(self, sample):
self._sample = sample
Expand Down
7 changes: 7 additions & 0 deletions clickhouse_sqlalchemy/sql/selectable.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
LeftArrayJoin,
LimitByClause,
sample_clause,
prewhere_clause,
)


Expand All @@ -22,6 +23,7 @@ class Select(StandardSelect):
_sample_clause = None
_limit_by_clause = None
_array_join = None
_prewhere_clause = None

@_generative
def with_cube(self):
Expand All @@ -43,6 +45,11 @@ def final(self):
self._final_clause = True
return self

@_generative
def prewhere(self, *clauses):
self._prewhere_clause = prewhere_clause(*clauses)
return self

@_generative
def sample(self, sample):
self._sample_clause = sample_clause(sample)
Expand Down
39 changes: 39 additions & 0 deletions docs/features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,45 @@ becomes (respectively)
SELECT ... FROM ... GROUP BY ... WITH ROLLUP
SELECT ... FROM ... GROUP BY ... WITH TOTALS

PREWHERE
+++++++++

PREWHERE clause allows efficient pre-filtering of data before reading from disk.

.. code-block:: python

session.query(table.c.x).prefilter(table.c.x > 10)

or

.. code-block:: python

select([table.c.x]).prewhere(table.c.x > 10)

becomes

.. code-block:: sql

SELECT ... FROM ... PREWHERE x > 10

PREWHERE can be combined with WHERE clause for additional filtering:

.. code-block:: python

session.query(table.c.x).prefilter(table.c.x > 10).filter(table.c.x < 100)

or

.. code-block:: python

select([table.c.x]).prewhere(table.c.x > 10).where(table.c.x < 100)

becomes

.. code-block:: sql

SELECT ... FROM ... PREWHERE x > 10 WHERE x < 100

FINAL
+++++

Expand Down
98 changes: 98 additions & 0 deletions tests/orm/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,104 @@ def test_final(self):
'SELECT t1.x AS t1_x FROM t1 FINAL GROUP BY t1.x'
)

def test_prefilter(self):
table = self._make_table()

query = self.session.query(table.c.x).prefilter(table.c.x > 10)
self.assertEqual(
self.compile(query),
'SELECT t1.x AS t1_x FROM t1 PREWHERE t1.x > %(param_1)s'
)
self.assertEqual(
self.compile(query, literal_binds=True),
'SELECT t1.x AS t1_x FROM t1 PREWHERE t1.x > 10'
)

def test_prefilter_multiple_clauses(self):
table = self._make_table()

query = self.session.query(table.c.x).prefilter(
table.c.x > 10, table.c.x < 100
)
self.assertEqual(
self.compile(query),
'SELECT t1.x AS t1_x FROM t1 PREWHERE t1.x > %(param_1)s AND '
't1.x < %(param_2)s'
)
self.assertEqual(
self.compile(query, literal_binds=True),
'SELECT t1.x AS t1_x FROM t1 PREWHERE t1.x > 10 AND t1.x < 100'
)

def test_prefilter_by(self):
table = self._make_table()

query = self.session.query(table.c.x).prefilter_by(x=10)
self.assertEqual(
self.compile(query),
'SELECT t1.x AS t1_x FROM t1 PREWHERE t1.x = %(param_1)s'
)
self.assertEqual(
self.compile(query, literal_binds=True),
'SELECT t1.x AS t1_x FROM t1 PREWHERE t1.x = 10'
)

def test_prefilter_with_filter(self):
table = self._make_table()

query = self.session.query(table.c.x).prefilter(
table.c.x > 10
).filter(table.c.x < 100)
self.assertEqual(
self.compile(query),
'SELECT t1.x AS t1_x FROM t1 PREWHERE t1.x > %(param_1)s WHERE '
't1.x < %(param_2)s'
)
self.assertEqual(
self.compile(query, literal_binds=True),
'SELECT t1.x AS t1_x FROM t1 PREWHERE t1.x > 10 WHERE t1.x < 100'
)

def test_prefilter_with_final(self):
table = self._make_table()

query = self.session.query(table.c.x).prefilter(table.c.x > 10).final()
self.assertEqual(
self.compile(query),
'SELECT t1.x AS t1_x FROM t1 FINAL PREWHERE t1.x > %(param_1)s'
)
self.assertEqual(
self.compile(query, literal_binds=True),
'SELECT t1.x AS t1_x FROM t1 FINAL PREWHERE t1.x > 10'
)

def test_prefilter_empty_clauses(self):
table = self._make_table()

query = self.session.query(table.c.x).prefilter()
self.assertEqual(
self.compile(query),
'SELECT t1.x AS t1_x FROM t1'
)

def test_prefilter_by_empty_kwargs(self):
table = self._make_table()

query = self.session.query(table.c.x).prefilter_by()
self.assertEqual(
self.compile(query),
'SELECT t1.x AS t1_x FROM t1'
)

def test_prefilter_invalid_clause_type(self):
table = self._make_table()

query = self.session.query(table.c.x).prefilter("invalid_string")
self.assertEqual(
self.compile(query),
'SELECT t1.x AS t1_x FROM t1 PREWHERE %(param_1)s'
)

def test_limit_by(self):
table = self._make_table()

Expand Down
76 changes: 76 additions & 0 deletions tests/sql/test_selectable.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,82 @@ def test_final(self):
'SELECT t1.x FROM t1 FINAL GROUP BY t1.x'
)

def test_prewhere(self):
table = self._make_table()

query = select(table.c.x).prewhere(table.c.x > 10)
self.assertEqual(
self.compile(query),
'SELECT t1.x FROM t1 PREWHERE t1.x > %(param_1)s'
)
self.assertEqual(
self.compile(query, literal_binds=True),
'SELECT t1.x FROM t1 PREWHERE t1.x > 10'
)

def test_prewhere_multiple_clauses(self):
table = self._make_table()

query = select(table.c.x).prewhere(
table.c.x > 10, table.c.x < 100
)
self.assertEqual(
self.compile(query),
'SELECT t1.x FROM t1 PREWHERE t1.x > %(param_1)s AND '
't1.x < %(param_2)s'
)
self.assertEqual(
self.compile(query, literal_binds=True),
'SELECT t1.x FROM t1 PREWHERE t1.x > 10 AND t1.x < 100'
)

def test_prewhere_with_where(self):
table = self._make_table()

query = select(table.c.x).prewhere(
table.c.x > 10
).where(table.c.x < 100)
self.assertEqual(
self.compile(query),
'SELECT t1.x FROM t1 PREWHERE t1.x > %(param_1)s WHERE '
't1.x < %(param_2)s'
)
self.assertEqual(
self.compile(query, literal_binds=True),
'SELECT t1.x FROM t1 PREWHERE t1.x > 10 WHERE t1.x < 100'
)

def test_prewhere_with_final(self):
table = self._make_table()

query = select(table.c.x).prewhere(table.c.x > 10).final()
self.assertEqual(
self.compile(query),
'SELECT t1.x FROM t1 FINAL PREWHERE t1.x > %(param_1)s'
)
self.assertEqual(
self.compile(query, literal_binds=True),
'SELECT t1.x FROM t1 FINAL PREWHERE t1.x > 10'
)

def test_prewhere_empty_clauses(self):
table = self._make_table()

query = select(table.c.x).prewhere()
self.assertEqual(
self.compile(query),
'SELECT t1.x FROM t1'
)

def test_prewhere_invalid_clause_type(self):
table = self._make_table()

query = select(table.c.x).prewhere("invalid_string")
self.assertEqual(
self.compile(query),
'SELECT t1.x FROM t1 PREWHERE %(param_1)s'
)

def test_limit_by(self):
table = self._make_table()

Expand Down