Skip to content

Commit 72fbe2b

Browse files
authored
Merge pull request #63 from bluelabsio/project_dataset_table_parsing
Allow use of schema argument for project and dataset
2 parents 8c1450f + 12fb2f3 commit 72fbe2b

File tree

3 files changed

+80
-10
lines changed

3 files changed

+80
-10
lines changed

README.rst

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,14 @@ To specify location of your datasets pass ``location`` to ``create_engine()``:
5454
Table names
5555
___________
5656

57-
To query tables from non-default projects, use the following format for the table name: ``project.dataset.table``, e.g.:
57+
To query tables from non-default projects or datasets, use the following format for the SQLAlchemy schema name: ``[project.]dataset``, e.g.:
5858

5959
.. code-block:: python
6060
61-
sample_table = Table('bigquery-public-data.samples.natality')
61+
# If neither dataset nor project are the default
62+
sample_table_1 = Table('natality', schema='bigquery-public-data.samples')
63+
# If just dataset is not the default
64+
sample_table_2 = Table('natality', schema='bigquery-public-data')
6265
6366
Batch size
6467
__________
@@ -85,7 +88,7 @@ When using a default dataset, don't include the dataset name in the table name,
8588
8689
table = Table('table_name')
8790
88-
Note that specyfing a default dataset doesn't restrict execution of queries to that particular dataset when using raw queries, e.g.:
91+
Note that specifying a default dataset doesn't restrict execution of queries to that particular dataset when using raw queries, e.g.:
8992

9093
.. code-block:: python
9194

pybigquery/sqlalchemy_bigquery.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -400,22 +400,51 @@ def _split_table_name(full_table_name):
400400
dataset, table_name = table_name_split
401401
elif len(table_name_split) == 3:
402402
project, dataset, table_name = table_name_split
403+
else:
404+
raise ValueError("Did not understand table_name: {}".format(full_table_name))
403405

404406
return (project, dataset, table_name)
405407

408+
def _table_reference(self, provided_schema_name, provided_table_name,
409+
client_project):
410+
project_id_from_table, dataset_id_from_table, table_id = self._split_table_name(provided_table_name)
411+
project_id_from_schema = None
412+
dataset_id_from_schema = None
413+
if provided_schema_name is not None:
414+
provided_schema_name_split = provided_schema_name.split('.')
415+
if len(provided_schema_name_split) == 0:
416+
pass
417+
elif len(provided_schema_name_split) == 1:
418+
if dataset_id_from_table:
419+
project_id_from_schema = provided_schema_name_split[0]
420+
else:
421+
dataset_id_from_schema = provided_schema_name_split[0]
422+
elif len(provided_schema_name_split) == 2:
423+
project_id_from_schema = provided_schema_name_split[0]
424+
dataset_id_from_schema = provided_schema_name_split[1]
425+
else:
426+
raise ValueError("Did not understand schema: {}".format(provided_schema_name))
427+
if (dataset_id_from_schema and dataset_id_from_table and
428+
dataset_id_from_schema != dataset_id_from_table):
429+
raise ValueError("dataset_id specified in schema and table_name disagree: got {} in schema, and {} in table_name".format(dataset_id_from_schema, dataset_id_from_table))
430+
if (project_id_from_schema and project_id_from_table and
431+
project_id_from_schema != project_id_from_table):
432+
raise ValueError("project_id specified in schema and table_name disagree: got {} in schema, and {} in table_name".format(project_id_from_schema, project_id_from_table))
433+
project_id = project_id_from_schema or project_id_from_table or client_project
434+
dataset_id = dataset_id_from_schema or dataset_id_from_table or self.dataset_id
435+
436+
table_ref = TableReference.from_string("{}.{}.{}".format(
437+
project_id, dataset_id, table_id
438+
))
439+
return table_ref
440+
406441
def _get_table(self, connection, table_name, schema=None):
407442
if isinstance(connection, Engine):
408443
connection = connection.connect()
409444

410445
client = connection.connection._client
411446

412-
project_id, dataset_id, table_id = self._split_table_name(table_name)
413-
project_id = project_id or client.project
414-
dataset_id = dataset_id or schema or self.dataset_id
415-
416-
table_ref = TableReference.from_string("{}.{}.{}".format(
417-
project_id, dataset_id, table_id
418-
))
447+
table_ref = self._table_reference(schema, table_name, client.project)
419448
try:
420449
table = client.get_table(table_ref)
421450
except NotFound:

test/test_sqlalchemy_bigquery.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from google.api_core.exceptions import BadRequest
55
from pybigquery.api import ApiClient
6+
from pybigquery.sqlalchemy_bigquery import BigQueryDialect
67
from sqlalchemy.engine import create_engine
78
from sqlalchemy.schema import Table, MetaData, Column
89
from sqlalchemy.ext.declarative import declarative_base
@@ -102,6 +103,11 @@ def engine():
102103
return engine
103104

104105

106+
@pytest.fixture(scope='session')
107+
def dialect():
108+
return BigQueryDialect()
109+
110+
105111
@pytest.fixture(scope='session')
106112
def engine_using_test_dataset():
107113
engine = create_engine('bigquery:///test_pybigquery', echo=True)
@@ -532,6 +538,38 @@ def test_get_columns(inspector, inspector_using_test_dataset):
532538
assert col['type'].__class__.__name__ == sample_col['type'].__class__.__name__
533539

534540

541+
@pytest.mark.parametrize('provided_schema_name,provided_table_name,client_project',
542+
[
543+
('dataset', 'table', 'project'),
544+
(None, 'dataset.table', 'project'),
545+
(None, 'project.dataset.table', 'other_project'),
546+
('project', 'dataset.table', 'other_project'),
547+
('project.dataset', 'table', 'other_project'),
548+
])
549+
def test_table_reference(dialect, provided_schema_name,
550+
provided_table_name, client_project):
551+
ref = dialect._table_reference(provided_schema_name,
552+
provided_table_name,
553+
client_project)
554+
assert ref.table_id == 'table'
555+
assert ref.dataset_id == 'dataset'
556+
assert ref.project == 'project'
557+
558+
@pytest.mark.parametrize('provided_schema_name,provided_table_name,client_project',
559+
[
560+
('project.dataset', 'other_dataset.table', 'project'),
561+
('project.dataset', 'other_project.dataset.table', 'project'),
562+
('project.dataset.something_else', 'table', 'project'),
563+
(None, 'project.dataset.table.something_else', 'project'),
564+
])
565+
def test_invalid_table_reference(dialect, provided_schema_name,
566+
provided_table_name, client_project):
567+
with pytest.raises(ValueError):
568+
dialect._table_reference(provided_schema_name,
569+
provided_table_name,
570+
client_project)
571+
572+
535573
def test_has_table(engine, engine_using_test_dataset):
536574
assert engine.has_table('sample', 'test_pybigquery') is True
537575
assert engine.has_table('test_pybigquery.sample') is True

0 commit comments

Comments
 (0)