Skip to content

Commit a3cacd3

Browse files
authored
chore: add more grouping sets/rollup/cube tests (#1029)
* chore: add more tests for grouping functions fix * reformatted tests
1 parent 87a75dc commit a3cacd3

File tree

1 file changed

+71
-47
lines changed

1 file changed

+71
-47
lines changed

tests/unit/test_compiler.py

Lines changed: 71 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,23 @@
2828
from sqlalchemy.sql.functions import rollup, cube, grouping_sets
2929

3030

31+
@pytest.fixture
32+
def table(faux_conn, metadata):
33+
# Fixture to create a sample table for testing
34+
35+
table = setup_table(
36+
faux_conn,
37+
"table1",
38+
metadata,
39+
sqlalchemy.Column("foo", sqlalchemy.Integer),
40+
sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)),
41+
)
42+
43+
yield table
44+
45+
table.drop(faux_conn)
46+
47+
3148
def test_constraints_are_ignored(faux_conn, metadata):
3249
sqlalchemy.Table(
3350
"ref",
@@ -282,85 +299,92 @@ def test_no_implicit_join_for_inner_unnest_no_table2_column(faux_conn, metadata)
282299
assert found_outer_sql == expected_outer_sql
283300

284301

285-
def test_grouping_sets(faux_conn, metadata):
286-
table = setup_table(
287-
faux_conn,
288-
"table1",
289-
metadata,
290-
sqlalchemy.Column("foo", sqlalchemy.Integer),
291-
sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)),
302+
grouping_ops = (
303+
"grouping_op, grouping_op_func",
304+
[("GROUPING SETS", grouping_sets), ("ROLLUP", rollup), ("CUBE", cube)],
305+
)
306+
307+
308+
@pytest.mark.parametrize(*grouping_ops)
309+
def test_grouping_ops_vs_single_column(faux_conn, table, grouping_op, grouping_op_func):
310+
# Tests each of the grouping ops against a single column
311+
312+
q = sqlalchemy.select(table.c.foo).group_by(grouping_op_func(table.c.foo))
313+
found_sql = q.compile(faux_conn).string
314+
315+
expected_sql = (
316+
f"SELECT `table1`.`foo` \n"
317+
f"FROM `table1` GROUP BY {grouping_op}(`table1`.`foo`)"
292318
)
293319

320+
assert found_sql == expected_sql
321+
322+
323+
@pytest.mark.parametrize(*grouping_ops)
324+
def test_grouping_ops_vs_multi_columns(faux_conn, table, grouping_op, grouping_op_func):
325+
# Tests each of the grouping ops against multiple columns
326+
294327
q = sqlalchemy.select(table.c.foo, table.c.bar).group_by(
295-
grouping_sets(table.c.foo, table.c.bar)
328+
grouping_op_func(table.c.foo, table.c.bar)
296329
)
330+
found_sql = q.compile(faux_conn).string
297331

298332
expected_sql = (
299-
"SELECT `table1`.`foo`, `table1`.`bar` \n"
300-
"FROM `table1` GROUP BY GROUPING SETS(`table1`.`foo`, `table1`.`bar`)"
333+
f"SELECT `table1`.`foo`, `table1`.`bar` \n"
334+
f"FROM `table1` GROUP BY {grouping_op}(`table1`.`foo`, `table1`.`bar`)"
301335
)
302-
found_sql = q.compile(faux_conn).string
336+
303337
assert found_sql == expected_sql
304338

305339

306-
def test_rollup(faux_conn, metadata):
307-
table = setup_table(
308-
faux_conn,
309-
"table1",
310-
metadata,
311-
sqlalchemy.Column("foo", sqlalchemy.Integer),
312-
sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)),
313-
)
340+
@pytest.mark.parametrize(*grouping_ops)
341+
def test_grouping_op_with_grouping_op(faux_conn, table, grouping_op, grouping_op_func):
342+
# Tests multiple grouping ops in a single statement
314343

315344
q = sqlalchemy.select(table.c.foo, table.c.bar).group_by(
316-
rollup(table.c.foo, table.c.bar)
345+
grouping_op_func(table.c.foo, table.c.bar), grouping_op_func(table.c.foo)
317346
)
347+
found_sql = q.compile(faux_conn).string
318348

319349
expected_sql = (
320-
"SELECT `table1`.`foo`, `table1`.`bar` \n"
321-
"FROM `table1` GROUP BY ROLLUP(`table1`.`foo`, `table1`.`bar`)"
350+
f"SELECT `table1`.`foo`, `table1`.`bar` \n"
351+
f"FROM `table1` GROUP BY {grouping_op}(`table1`.`foo`, `table1`.`bar`), {grouping_op}(`table1`.`foo`)"
322352
)
323-
found_sql = q.compile(faux_conn).string
353+
324354
assert found_sql == expected_sql
325355

326356

327-
def test_cube(faux_conn, metadata):
328-
table = setup_table(
329-
faux_conn,
330-
"table1",
331-
metadata,
332-
sqlalchemy.Column("foo", sqlalchemy.Integer),
333-
sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)),
334-
)
357+
@pytest.mark.parametrize(*grouping_ops)
358+
def test_grouping_ops_vs_group_by(faux_conn, table, grouping_op, grouping_op_func):
359+
# Tests grouping op against regular group by statement
335360

336361
q = sqlalchemy.select(table.c.foo, table.c.bar).group_by(
337-
cube(table.c.foo, table.c.bar)
362+
table.c.foo, grouping_op_func(table.c.bar)
338363
)
364+
found_sql = q.compile(faux_conn).string
339365

340366
expected_sql = (
341-
"SELECT `table1`.`foo`, `table1`.`bar` \n"
342-
"FROM `table1` GROUP BY CUBE(`table1`.`foo`, `table1`.`bar`)"
367+
f"SELECT `table1`.`foo`, `table1`.`bar` \n"
368+
f"FROM `table1` GROUP BY `table1`.`foo`, {grouping_op}(`table1`.`bar`)"
343369
)
344-
found_sql = q.compile(faux_conn).string
370+
345371
assert found_sql == expected_sql
346372

347373

348-
def test_multiple_grouping_sets(faux_conn, metadata):
349-
table = setup_table(
350-
faux_conn,
351-
"table1",
352-
metadata,
353-
sqlalchemy.Column("foo", sqlalchemy.Integer),
354-
sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)),
355-
)
374+
@pytest.mark.parametrize(*grouping_ops)
375+
def test_complex_grouping_ops_vs_nested_grouping_ops(
376+
faux_conn, table, grouping_op, grouping_op_func
377+
):
378+
# Tests grouping ops nested within grouping ops
356379

357380
q = sqlalchemy.select(table.c.foo, table.c.bar).group_by(
358-
grouping_sets(table.c.foo, table.c.bar), grouping_sets(table.c.foo)
381+
grouping_sets(table.c.foo, grouping_op_func(table.c.bar))
359382
)
383+
found_sql = q.compile(faux_conn).string
360384

361385
expected_sql = (
362-
"SELECT `table1`.`foo`, `table1`.`bar` \n"
363-
"FROM `table1` GROUP BY GROUPING SETS(`table1`.`foo`, `table1`.`bar`), GROUPING SETS(`table1`.`foo`)"
386+
f"SELECT `table1`.`foo`, `table1`.`bar` \n"
387+
f"FROM `table1` GROUP BY GROUPING SETS(`table1`.`foo`, {grouping_op}(`table1`.`bar`))"
364388
)
365-
found_sql = q.compile(faux_conn).string
389+
366390
assert found_sql == expected_sql

0 commit comments

Comments
 (0)