Skip to content

Read pre-sorted/shuffled dataframe #21

@bnaul

Description

@bnaul

One thing that didn't make it in from the original gist was the partition_field keyword. The motivating idea of that feature there was that sometimes there's a natural index column for your data, but ORDER BY in BigQuery does not scale well and ddf.set_index() requires a shuffle if the data is not sorted. In the example of a date-partitioned table, pre-indexing the dataframe by that date field would speed up a lot of aggregations by date, which seems like a pretty common use case.

But now that I think about it, there's actually nothing particular to partitioned tables in the original logic: if you know the divisions of the dataframe index (or can compute them), then the same read logic should work regardless of whether the table is partitioned or not:

import dask
import dask.dataframe as dd
from dask_bigquery.core import bigquery_clients
from google.cloud import bigquery, bigquery_storage

table = "model-159019.sample_data.states"  # copy of `bigquery-public-data.geo_unit_boundaries.states`
npartitions = 10

# Compute divisions
bq = bigquery.Client()
divisions = bq.query(f"SELECT APPROX_QUANTILES(state, {npartitions}) FROM {table}").result().to_dataframe().squeeze().tolist()
print(f"{divisions=}")

def read_rows(project_id, dataset_id, table_id, row_filter="", fields=()):
    """cf https://github.com/coiled/dask-bigquery/blob/main/dask_bigquery/core.py#L123-L134"""
    with bigquery_clients(project_id) as (_, bqs_client):
        parent = f"projects/{project_id}"
        session = bqs_client.create_read_session(
            parent=parent,
            read_session=bigquery_storage.types.ReadSession(
                data_format=bigquery_storage.types.DataFormat.ARROW,
                read_options=bigquery_storage.types.ReadSession.TableReadOptions(
                    row_restriction=row_filter, selected_fields=fields
                ),
                table=bigquery.Table(f"{project_id}.{dataset_id}.{table_id}").to_bqstorage(),
            ),
            max_stream_count=1,
        )
        reader = bqs_client.read_rows(session.streams[0].name)
        return reader.rows(session).to_dataframe()


@dask.delayed
def _read_partition(table, column, lower, upper):
    row_filter = f"{column} >= '{lower}'"
    if upper is not None:
        row_filter += f" AND {column} < '{upper}'"
    rows = read_rows(*table.split("."), row_filter=row_filter)
    return rows.set_index(column).sort_index()

ddf = dd.from_delayed(
    [_read_partition(table, "state", lower, upper if upper != divisions[-1] else None) for lower, upper in zip(divisions, divisions[1:])],
    divisions=divisions
)
print(f"{ddf.index=}")
print(f"{len(ddf)=}")
print(f'total_rows={bigquery.Client().get_table("model-159019.sample_data.states").num_rows}')

Output:

divisions=['AK', 'CA', 'GA', 'IL', 'MD', 'MP', 'NH', 'OK', 'SC', 'VI', 'WY']

ddf.index=Dask Index Structure:
npartitions=10
AK    object
CA       ...
       ...
VI       ...
WY       ...
Name: state, dtype: object
Dask Name: from-delayed, 30 tasks

len(ddf)=56
total_rows=56

I don't think there's any question that this is a useful bit of functionality...but it's not totally clear to me what the API should look like (part of read_gbq? its own function?). I can't think of any analogous dask patterns but if anyone knows of something similar that could be a place to borrow ideas from. pd.read_gbq has an index_col parameter, but in this case we'd probably want to support setting divisions or npartitions as well.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions