diff --git a/bigframes/session/loader.py b/bigframes/session/loader.py index 9c18d727c8..c7b433b554 100644 --- a/bigframes/session/loader.py +++ b/bigframes/session/loader.py @@ -49,6 +49,7 @@ import google.cloud.bigquery as bigquery import google.cloud.bigquery.table from google.cloud.bigquery_storage_v1 import types as bq_storage_types +from google.cloud.bigquery_storage_v1 import writer as bq_storage_writer import pandas import pyarrow as pa @@ -488,31 +489,29 @@ def stream_worker(work: Iterator[pa.RecordBatch]) -> str: stream = self._write_client.create_write_stream( parent=parent, write_stream=requested_stream ) - stream_name = stream.name - - def request_generator(): - current_offset = 0 - for batch in work: - request = bq_storage_types.AppendRowsRequest( - write_stream=stream.name, offset=current_offset - ) + base_request = bq_storage_types.AppendRowsRequest( + write_stream=stream.name, + ) + base_request.arrow_rows.writer_schema.serialized_schema = serialized_schema - request.arrow_rows.writer_schema.serialized_schema = ( - serialized_schema - ) - request.arrow_rows.rows.serialized_record_batch = ( - batch.serialize().to_pybytes() - ) + stream_manager = bq_storage_writer.AppendRowsStream( + client=self._write_client, initial_request_template=base_request + ) + stream_name = stream.name + current_offset = 0 + futures: list[bq_storage_writer.AppendRowsFuture] = [] + for batch in work: + request = bq_storage_types.AppendRowsRequest(offset=current_offset) + request.arrow_rows.rows.serialized_record_batch = ( + batch.serialize().to_pybytes() + ) - yield request - current_offset += batch.num_rows + futures.append(stream_manager.send(request)) + current_offset += batch.num_rows + for future in futures: + future.result() - responses = self._write_client.append_rows(requests=request_generator()) - for resp in responses: - if resp.row_errors: - raise ValueError( - f"Errors in stream {stream_name}: {resp.row_errors}" - ) + stream_manager.close() self._write_client.finalize_write_stream(name=stream_name) return stream_name