Skip to content

Commit e7135eb

Browse files
committed
Merge remote-tracking branch 'upstream/master' into ISSUE-60_implement_get_view_names
2 parents f2f3752 + 9f47d0b commit e7135eb

File tree

2 files changed

+56
-5
lines changed

2 files changed

+56
-5
lines changed

pybigquery/sqlalchemy_bigquery.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,19 @@ def visit_column(self, column, add_to_result_map=None,
185185
self.preparer.quote(tablename) + \
186186
"." + name
187187

188-
def visit_label(self, *args, **kwargs):
189-
# Use labels in GROUP BY clause
190-
if len(kwargs) == 0 or len(kwargs) == 1:
188+
def visit_label(self, *args, within_group_by=False, **kwargs):
189+
# Use labels in GROUP BY clause.
190+
#
191+
# Flag set in the group_by_clause method. Works around missing
192+
# equivalent to supports_simple_order_by_label for group by.
193+
if within_group_by:
191194
kwargs['render_label_as_label'] = args[0]
192-
result = super(BigQueryCompiler, self).visit_label(*args, **kwargs)
193-
return result
195+
return super(BigQueryCompiler, self).visit_label(*args, **kwargs)
196+
197+
def group_by_clause(self, select, **kw):
198+
return super(BigQueryCompiler, self).group_by_clause(
199+
select, **kw, within_group_by=True
200+
)
194201

195202

196203
class BigQueryTypeCompiler(GenericTypeCompiler):
@@ -207,6 +214,9 @@ def visit_text(self, type_, **kw):
207214
def visit_string(self, type_, **kw):
208215
return 'STRING'
209216

217+
def visit_ARRAY(self, type_, **kw):
218+
return "ARRAY<{}>".format(self.process(type_.item_type, **kw))
219+
210220
def visit_BINARY(self, type_, **kw):
211221
return 'BYTES'
212222

test/test_sqlalchemy_bigquery.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,14 @@ def query():
162162
def query(table):
163163
col1 = literal_column("TIMESTAMP_TRUNC(timestamp, DAY)").label("timestamp_label")
164164
col2 = func.sum(table.c.integer)
165+
# Test rendering of nested labels. Full expression should render in SELECT, but
166+
# ORDER/GROUP BY should use label only.
167+
col3 = func.sum(func.sum(table.c.integer.label("inner")).label("outer")).over().label('outer')
165168
query = (
166169
select([
167170
col1,
168171
col2,
172+
col3,
169173
])
170174
.where(col1 < '2017-01-01 00:00:00')
171175
.group_by(col1)
@@ -299,6 +303,33 @@ def test_group_by(session, table, session_using_test_dataset, table_using_test_d
299303
assert len(result) > 0
300304

301305

306+
def test_nested_labels(engine, table):
307+
col = table.c.integer
308+
exprs = [
309+
sqlalchemy.func.sum(
310+
sqlalchemy.func.sum(col.label("inner")
311+
).label("outer")).over(),
312+
sqlalchemy.func.sum(
313+
sqlalchemy.case([[
314+
sqlalchemy.literal(True),
315+
col.label("inner"),
316+
]]).label("outer")
317+
),
318+
sqlalchemy.func.sum(
319+
sqlalchemy.func.sum(
320+
sqlalchemy.case([[
321+
sqlalchemy.literal(True), col.label("inner")
322+
]]).label("middle")
323+
).label("outer")
324+
).over(),
325+
]
326+
for expr in exprs:
327+
sql = str(expr.compile(engine))
328+
assert "inner" not in sql
329+
assert "middle" not in sql
330+
assert "outer" not in sql
331+
332+
302333
def test_session_query(session, table, session_using_test_dataset, table_using_test_dataset):
303334
for session, table in [(session, table), (session_using_test_dataset, table_using_test_dataset)]:
304335
col_concat = func.concat(table.c.string).label('concat')
@@ -360,6 +391,16 @@ def test_compiled_query_literal_binds(engine, engine_using_test_dataset, table,
360391
assert len(result) > 0
361392

362393

394+
@pytest.mark.parametrize(["column", "processed"], [
395+
(types.String(), "STRING"),
396+
(types.NUMERIC(), "NUMERIC"),
397+
(types.ARRAY(types.String), "ARRAY<STRING>"),
398+
])
399+
def test_compile_types(engine, column, processed):
400+
result = engine.dialect.type_compiler.process(column)
401+
assert result == processed
402+
403+
363404
def test_joins(session, table, table_one_row):
364405
result = (session.query(table.c.string, func.count(table_one_row.c.integer))
365406
.join(table_one_row, table_one_row.c.string == table.c.string)

0 commit comments

Comments
 (0)