|
28 | 28 | from sqlalchemy.sql.functions import rollup, cube, grouping_sets |
29 | 29 |
|
30 | 30 |
|
| 31 | +@pytest.fixture |
| 32 | +def table(faux_conn, metadata): |
| 33 | + # Fixture to create a sample table for testing |
| 34 | + |
| 35 | + table = setup_table( |
| 36 | + faux_conn, |
| 37 | + "table1", |
| 38 | + metadata, |
| 39 | + sqlalchemy.Column("foo", sqlalchemy.Integer), |
| 40 | + sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)), |
| 41 | + ) |
| 42 | + |
| 43 | + yield table |
| 44 | + |
| 45 | + table.drop(faux_conn) |
| 46 | + |
| 47 | + |
31 | 48 | def test_constraints_are_ignored(faux_conn, metadata): |
32 | 49 | sqlalchemy.Table( |
33 | 50 | "ref", |
@@ -282,85 +299,92 @@ def test_no_implicit_join_for_inner_unnest_no_table2_column(faux_conn, metadata) |
282 | 299 | assert found_outer_sql == expected_outer_sql |
283 | 300 |
|
284 | 301 |
|
285 | | -def test_grouping_sets(faux_conn, metadata): |
286 | | - table = setup_table( |
287 | | - faux_conn, |
288 | | - "table1", |
289 | | - metadata, |
290 | | - sqlalchemy.Column("foo", sqlalchemy.Integer), |
291 | | - sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)), |
| 302 | +grouping_ops = ( |
| 303 | + "grouping_op, grouping_op_func", |
| 304 | + [("GROUPING SETS", grouping_sets), ("ROLLUP", rollup), ("CUBE", cube)], |
| 305 | +) |
| 306 | + |
| 307 | + |
| 308 | +@pytest.mark.parametrize(*grouping_ops) |
| 309 | +def test_grouping_ops_vs_single_column(faux_conn, table, grouping_op, grouping_op_func): |
| 310 | + # Tests each of the grouping ops against a single column |
| 311 | + |
| 312 | + q = sqlalchemy.select(table.c.foo).group_by(grouping_op_func(table.c.foo)) |
| 313 | + found_sql = q.compile(faux_conn).string |
| 314 | + |
| 315 | + expected_sql = ( |
| 316 | + f"SELECT `table1`.`foo` \n" |
| 317 | + f"FROM `table1` GROUP BY {grouping_op}(`table1`.`foo`)" |
292 | 318 | ) |
293 | 319 |
|
| 320 | + assert found_sql == expected_sql |
| 321 | + |
| 322 | + |
| 323 | +@pytest.mark.parametrize(*grouping_ops) |
| 324 | +def test_grouping_ops_vs_multi_columns(faux_conn, table, grouping_op, grouping_op_func): |
| 325 | + # Tests each of the grouping ops against multiple columns |
| 326 | + |
294 | 327 | q = sqlalchemy.select(table.c.foo, table.c.bar).group_by( |
295 | | - grouping_sets(table.c.foo, table.c.bar) |
| 328 | + grouping_op_func(table.c.foo, table.c.bar) |
296 | 329 | ) |
| 330 | + found_sql = q.compile(faux_conn).string |
297 | 331 |
|
298 | 332 | expected_sql = ( |
299 | | - "SELECT `table1`.`foo`, `table1`.`bar` \n" |
300 | | - "FROM `table1` GROUP BY GROUPING SETS(`table1`.`foo`, `table1`.`bar`)" |
| 333 | + f"SELECT `table1`.`foo`, `table1`.`bar` \n" |
| 334 | + f"FROM `table1` GROUP BY {grouping_op}(`table1`.`foo`, `table1`.`bar`)" |
301 | 335 | ) |
302 | | - found_sql = q.compile(faux_conn).string |
| 336 | + |
303 | 337 | assert found_sql == expected_sql |
304 | 338 |
|
305 | 339 |
|
306 | | -def test_rollup(faux_conn, metadata): |
307 | | - table = setup_table( |
308 | | - faux_conn, |
309 | | - "table1", |
310 | | - metadata, |
311 | | - sqlalchemy.Column("foo", sqlalchemy.Integer), |
312 | | - sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)), |
313 | | - ) |
| 340 | +@pytest.mark.parametrize(*grouping_ops) |
| 341 | +def test_grouping_op_with_grouping_op(faux_conn, table, grouping_op, grouping_op_func): |
| 342 | + # Tests multiple grouping ops in a single statement |
314 | 343 |
|
315 | 344 | q = sqlalchemy.select(table.c.foo, table.c.bar).group_by( |
316 | | - rollup(table.c.foo, table.c.bar) |
| 345 | + grouping_op_func(table.c.foo, table.c.bar), grouping_op_func(table.c.foo) |
317 | 346 | ) |
| 347 | + found_sql = q.compile(faux_conn).string |
318 | 348 |
|
319 | 349 | expected_sql = ( |
320 | | - "SELECT `table1`.`foo`, `table1`.`bar` \n" |
321 | | - "FROM `table1` GROUP BY ROLLUP(`table1`.`foo`, `table1`.`bar`)" |
| 350 | + f"SELECT `table1`.`foo`, `table1`.`bar` \n" |
| 351 | + f"FROM `table1` GROUP BY {grouping_op}(`table1`.`foo`, `table1`.`bar`), {grouping_op}(`table1`.`foo`)" |
322 | 352 | ) |
323 | | - found_sql = q.compile(faux_conn).string |
| 353 | + |
324 | 354 | assert found_sql == expected_sql |
325 | 355 |
|
326 | 356 |
|
327 | | -def test_cube(faux_conn, metadata): |
328 | | - table = setup_table( |
329 | | - faux_conn, |
330 | | - "table1", |
331 | | - metadata, |
332 | | - sqlalchemy.Column("foo", sqlalchemy.Integer), |
333 | | - sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)), |
334 | | - ) |
| 357 | +@pytest.mark.parametrize(*grouping_ops) |
| 358 | +def test_grouping_ops_vs_group_by(faux_conn, table, grouping_op, grouping_op_func): |
| 359 | + # Tests grouping op against regular group by statement |
335 | 360 |
|
336 | 361 | q = sqlalchemy.select(table.c.foo, table.c.bar).group_by( |
337 | | - cube(table.c.foo, table.c.bar) |
| 362 | + table.c.foo, grouping_op_func(table.c.bar) |
338 | 363 | ) |
| 364 | + found_sql = q.compile(faux_conn).string |
339 | 365 |
|
340 | 366 | expected_sql = ( |
341 | | - "SELECT `table1`.`foo`, `table1`.`bar` \n" |
342 | | - "FROM `table1` GROUP BY CUBE(`table1`.`foo`, `table1`.`bar`)" |
| 367 | + f"SELECT `table1`.`foo`, `table1`.`bar` \n" |
| 368 | + f"FROM `table1` GROUP BY `table1`.`foo`, {grouping_op}(`table1`.`bar`)" |
343 | 369 | ) |
344 | | - found_sql = q.compile(faux_conn).string |
| 370 | + |
345 | 371 | assert found_sql == expected_sql |
346 | 372 |
|
347 | 373 |
|
348 | | -def test_multiple_grouping_sets(faux_conn, metadata): |
349 | | - table = setup_table( |
350 | | - faux_conn, |
351 | | - "table1", |
352 | | - metadata, |
353 | | - sqlalchemy.Column("foo", sqlalchemy.Integer), |
354 | | - sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)), |
355 | | - ) |
| 374 | +@pytest.mark.parametrize(*grouping_ops) |
| 375 | +def test_complex_grouping_ops_vs_nested_grouping_ops( |
| 376 | + faux_conn, table, grouping_op, grouping_op_func |
| 377 | +): |
| 378 | + # Tests grouping ops nested within grouping ops |
356 | 379 |
|
357 | 380 | q = sqlalchemy.select(table.c.foo, table.c.bar).group_by( |
358 | | - grouping_sets(table.c.foo, table.c.bar), grouping_sets(table.c.foo) |
| 381 | + grouping_sets(table.c.foo, grouping_op_func(table.c.bar)) |
359 | 382 | ) |
| 383 | + found_sql = q.compile(faux_conn).string |
360 | 384 |
|
361 | 385 | expected_sql = ( |
362 | | - "SELECT `table1`.`foo`, `table1`.`bar` \n" |
363 | | - "FROM `table1` GROUP BY GROUPING SETS(`table1`.`foo`, `table1`.`bar`), GROUPING SETS(`table1`.`foo`)" |
| 386 | + f"SELECT `table1`.`foo`, `table1`.`bar` \n" |
| 387 | + f"FROM `table1` GROUP BY GROUPING SETS(`table1`.`foo`, {grouping_op}(`table1`.`bar`))" |
364 | 388 | ) |
365 | | - found_sql = q.compile(faux_conn).string |
| 389 | + |
366 | 390 | assert found_sql == expected_sql |
0 commit comments