Skip to content

Commit 5cfc280

Browse files
committed
feat: grouping sets, rollup and cube compatibility
1 parent fcd5755 commit 5cfc280

File tree

2 files changed

+95
-3
lines changed

2 files changed

+95
-3
lines changed

sqlalchemy_bigquery/base.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import sqlalchemy
3939
import sqlalchemy.sql.expression
4040
import sqlalchemy.sql.functions
41+
from sqlalchemy.sql.functions import rollup, cube, grouping_sets
4142
import sqlalchemy.sql.sqltypes
4243
import sqlalchemy.sql.type_api
4344
from sqlalchemy.exc import NoSuchTableError, NoSuchColumnError
@@ -340,9 +341,36 @@ def visit_label(self, *args, within_group_by=False, **kwargs):
340341
return super(BigQueryCompiler, self).visit_label(*args, **kwargs)
341342

342343
def group_by_clause(self, select, **kw):
343-
return super(BigQueryCompiler, self).group_by_clause(
344-
select, **kw, within_group_by=True
345-
)
344+
grouping_sets_exprs = []
345+
rollup_exprs = []
346+
cube_exprs = []
347+
348+
# Traverse select statement to extract grouping sets, rollup, and cube expressions
349+
for expr in select._group_by_clause:
350+
if isinstance(expr, grouping_sets):
351+
grouping_sets_exprs.append(
352+
self.process(expr.clauses)
353+
) # Assuming SQLAlchemy syntax
354+
elif isinstance(expr, rollup): # Assuming SQLAlchemy syntax
355+
rollup_exprs.append(self.process(expr.clauses))
356+
elif isinstance(expr, cube): # Assuming SQLAlchemy syntax
357+
cube_exprs.append(self.process(expr.clauses))
358+
else:
359+
# Handle regular group by expressions
360+
pass
361+
362+
clause = super(BigQueryCompiler, self).group_by_clause(select, **kw)
363+
364+
if grouping_sets_exprs:
365+
clause = (
366+
f"GROUP BY {clause} GROUPING SETS ({', '.join(grouping_sets_exprs)})"
367+
)
368+
if rollup_exprs:
369+
clause = f"GROUP BY {clause} ROLLUP ({', '.join(rollup_exprs)})"
370+
if cube_exprs:
371+
clause = f"GROUP BY {clause} CUBE ({', '.join(cube_exprs)})"
372+
373+
return clause
346374

347375
############################################################################
348376
# Handle parameters in in

tests/unit/test_compiler.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from .conftest import setup_table
2424
from .conftest import sqlalchemy_1_4_or_higher, sqlalchemy_before_1_4
25+
from sqlalchemy.sql.functions import rollup, cube, grouping_sets
2526

2627

2728
def test_constraints_are_ignored(faux_conn, metadata):
@@ -278,3 +279,66 @@ def test_no_implicit_join_for_inner_unnest_no_table2_column(faux_conn, metadata)
278279
)
279280
found_outer_sql = q.compile(faux_conn).string
280281
assert found_outer_sql == expected_outer_sql
282+
283+
284+
def test_grouping_sets(faux_conn, metadata):
285+
table = setup_table(
286+
faux_conn,
287+
"table1",
288+
metadata,
289+
sqlalchemy.Column("foo", sqlalchemy.Integer),
290+
sqlalchemy.Column("bar", sqlalchemy.Integer),
291+
)
292+
293+
q = sqlalchemy.select(table.c.foo, table.c.bar).group_by(
294+
grouping_sets(table.c.foo, table.c.bar)
295+
)
296+
297+
expected_sql = (
298+
"SELECT `table1`.`foo`, `table1`.`bar` \n"
299+
"FROM `table1` GROUP BY GROUPING SETS ((`table1`.`foo`), (`table1`.`bar`))"
300+
)
301+
found_sql = q.compile(faux_conn).string
302+
assert found_sql == expected_sql
303+
304+
305+
def test_rollup(faux_conn, metadata):
306+
table = setup_table(
307+
faux_conn,
308+
"table1",
309+
metadata,
310+
sqlalchemy.Column("foo", sqlalchemy.Integer),
311+
sqlalchemy.Column("bar", sqlalchemy.Integer),
312+
)
313+
314+
q = sqlalchemy.select(table.c.foo, table.c.bar).group_by(
315+
rollup(table.c.foo, table.c.bar)
316+
)
317+
318+
expected_sql = (
319+
"SELECT `table1`.`foo`, `table1`.`bar` \n"
320+
"FROM `table1` GROUP BY ROLLUP(`table1`.`foo`, `table1`.`bar`)"
321+
)
322+
found_sql = q.compile(faux_conn).string
323+
assert found_sql == expected_sql
324+
325+
326+
def test_cube(faux_conn, metadata):
327+
table = setup_table(
328+
faux_conn,
329+
"table1",
330+
metadata,
331+
sqlalchemy.Column("foo", sqlalchemy.Integer),
332+
sqlalchemy.Column("bar", sqlalchemy.Integer),
333+
)
334+
335+
q = sqlalchemy.select(table.c.foo, table.c.bar).group_by(
336+
cube(table.c.foo, table.c.bar)
337+
)
338+
339+
expected_sql = (
340+
"SELECT `table1`.`foo`, `table1`.`bar` \n"
341+
"FROM `table1` GROUP BY CUBE(`table1`.`foo`, `table1`.`bar`)"
342+
)
343+
found_sql = q.compile(faux_conn).string
344+
assert found_sql == expected_sql

0 commit comments

Comments
 (0)