Skip to content

Commit bd38a5e

Browse files
committed
fixed render label as label assignment
1 parent 68afc39 commit bd38a5e

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

sqlalchemy_bigquery/base.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -332,11 +332,25 @@ def visit_column(
332332

333333
return self.preparer.quote(tablename) + "." + name
334334

335-
def visit_label(self, *args, **kwargs):
335+
def visit_label(self, *args, within_group_by=False, **kwargs):
336+
# Use labels in GROUP BY clause.
337+
#
338+
# Flag set in the group_by_clause method. Works around missing
339+
# equivalent to supports_simple_order_by_label for group by.
340+
if within_group_by:
341+
if all(
342+
keyword not in str(args[0])
343+
for keyword in ("GROUPING SETS", "ROLLUP", "CUBE")
344+
):
345+
kwargs["render_label_as_label"] = args[0]
336346
return super(BigQueryCompiler, self).visit_label(*args, **kwargs)
337347

338-
def group_by_clause(self, select, **kwargs):
339-
return super(BigQueryCompiler, self).group_by_clause(select, **kwargs)
348+
def group_by_clause(self, select, **kw):
349+
return super(BigQueryCompiler, self).group_by_clause(
350+
select,
351+
**kw,
352+
within_group_by=True,
353+
)
340354

341355
############################################################################
342356
# Handle parameters in in

tests/unit/test_compiler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def test_grouping_sets(faux_conn, metadata):
288288
"table1",
289289
metadata,
290290
sqlalchemy.Column("foo", sqlalchemy.Integer),
291-
sqlalchemy.Column("bar", sqlalchemy.Integer),
291+
sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)),
292292
)
293293

294294
q = sqlalchemy.select(table.c.foo, table.c.bar).group_by(
@@ -309,7 +309,7 @@ def test_rollup(faux_conn, metadata):
309309
"table1",
310310
metadata,
311311
sqlalchemy.Column("foo", sqlalchemy.Integer),
312-
sqlalchemy.Column("bar", sqlalchemy.Integer),
312+
sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)),
313313
)
314314

315315
q = sqlalchemy.select(table.c.foo, table.c.bar).group_by(
@@ -330,7 +330,7 @@ def test_cube(faux_conn, metadata):
330330
"table1",
331331
metadata,
332332
sqlalchemy.Column("foo", sqlalchemy.Integer),
333-
sqlalchemy.Column("bar", sqlalchemy.Integer),
333+
sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)),
334334
)
335335

336336
q = sqlalchemy.select(table.c.foo, table.c.bar).group_by(

0 commit comments

Comments
 (0)