Skip to content

Commit 9f47d0b

Browse files
authored
Merge pull request #47 from JacobHayes/fix-nested-labels
Fix visit_label override to handle nested labels
2 parents db971fb + efd11fe commit 9f47d0b

File tree

2 files changed

+43
-5
lines changed

2 files changed

+43
-5
lines changed

pybigquery/sqlalchemy_bigquery.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,19 @@ def visit_column(self, column, add_to_result_map=None,
183183
self.preparer.quote(tablename) + \
184184
"." + name
185185

186-
def visit_label(self, *args, **kwargs):
187-
# Use labels in GROUP BY clause
188-
if len(kwargs) == 0 or len(kwargs) == 1:
186+
def visit_label(self, *args, within_group_by=False, **kwargs):
187+
# Use labels in GROUP BY clause.
188+
#
189+
# Flag set in the group_by_clause method. Works around missing
190+
# equivalent to supports_simple_order_by_label for group by.
191+
if within_group_by:
189192
kwargs['render_label_as_label'] = args[0]
190-
result = super(BigQueryCompiler, self).visit_label(*args, **kwargs)
191-
return result
193+
return super(BigQueryCompiler, self).visit_label(*args, **kwargs)
194+
195+
def group_by_clause(self, select, **kw):
196+
return super(BigQueryCompiler, self).group_by_clause(
197+
select, **kw, within_group_by=True
198+
)
192199

193200

194201
class BigQueryTypeCompiler(GenericTypeCompiler):

test/test_sqlalchemy_bigquery.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,14 @@ def query():
162162
def query(table):
163163
col1 = literal_column("TIMESTAMP_TRUNC(timestamp, DAY)").label("timestamp_label")
164164
col2 = func.sum(table.c.integer)
165+
# Test rendering of nested labels. Full expression should render in SELECT, but
166+
# ORDER/GROUP BY should use label only.
167+
col3 = func.sum(func.sum(table.c.integer.label("inner")).label("outer")).over().label('outer')
165168
query = (
166169
select([
167170
col1,
168171
col2,
172+
col3,
169173
])
170174
.where(col1 < '2017-01-01 00:00:00')
171175
.group_by(col1)
@@ -297,6 +301,33 @@ def test_group_by(session, table, session_using_test_dataset, table_using_test_d
297301
assert len(result) > 0
298302

299303

304+
def test_nested_labels(engine, table):
305+
col = table.c.integer
306+
exprs = [
307+
sqlalchemy.func.sum(
308+
sqlalchemy.func.sum(col.label("inner")
309+
).label("outer")).over(),
310+
sqlalchemy.func.sum(
311+
sqlalchemy.case([[
312+
sqlalchemy.literal(True),
313+
col.label("inner"),
314+
]]).label("outer")
315+
),
316+
sqlalchemy.func.sum(
317+
sqlalchemy.func.sum(
318+
sqlalchemy.case([[
319+
sqlalchemy.literal(True), col.label("inner")
320+
]]).label("middle")
321+
).label("outer")
322+
).over(),
323+
]
324+
for expr in exprs:
325+
sql = str(expr.compile(engine))
326+
assert "inner" not in sql
327+
assert "middle" not in sql
328+
assert "outer" not in sql
329+
330+
300331
def test_session_query(session, table, session_using_test_dataset, table_using_test_dataset):
301332
for session, table in [(session, table), (session_using_test_dataset, table_using_test_dataset)]:
302333
col_concat = func.concat(table.c.string).label('concat')

0 commit comments

Comments
 (0)