From bb6759f77887d4158c4073d1758a135fdc6f7ce6 Mon Sep 17 00:00:00 2001 From: Shubham Vij Date: Wed, 22 Oct 2025 18:40:30 -0700 Subject: [PATCH 1/2] test Co-authored-by: Shubham Vij --- python/gigl/common/data/export.py | 75 +++++++++++++++++++++++-------- 1 file changed, 56 insertions(+), 19 deletions(-) diff --git a/python/gigl/common/data/export.py b/python/gigl/common/data/export.py index 2d69d91fd..f41481cca 100644 --- a/python/gigl/common/data/export.py +++ b/python/gigl/common/data/export.py @@ -10,7 +10,7 @@ import os import time from pathlib import Path -from typing import Final, Optional, Sequence +from typing import Final, Optional, Sequence, Iterable import fastavro import fastavro.types @@ -52,12 +52,18 @@ ] -class EmbeddingExporter: + + + + + +class GcsExporter: def __init__( self, export_dir: Uri, file_prefix: Optional[str] = None, min_shard_size_threshold_bytes: int = 0, + avro_schema: fastavro.types.Schema = AVRO_SCHEMA, ): """ Initializes an EmbeddingExporter instance. @@ -104,6 +110,7 @@ def __init__( self._file_utils = FileLoader() self._prefix = file_prefix self._min_shard_size_threshold_bytes = min_shard_size_threshold_bytes + self._avro_schema = avro_schema if isinstance( self._base_export_uri, LocalUri @@ -113,11 +120,9 @@ def __init__( ) Path(self._base_export_uri.uri).mkdir(parents=True, exist_ok=True) - def add_embedding( + def add_records( self, - id_batch: torch.Tensor, - embedding_batch: torch.Tensor, - embedding_type: str, + records: Iterable[dict], ): """ Adds to the in-memory buffer the integer IDs and their corresponding embeddings. @@ -132,23 +137,11 @@ def add_embedding( # and Python list(s). This is faster than converting torch tensors # directly to Python int(s) and Python list(s), as Numpy's implementation # is more efficient. - ids = id_batch.numpy() - embeddings = embedding_batch.numpy() - self._num_records_written += len(ids) - - batched_records = ( - { - "node_id": int(node_id), - "node_type": embedding_type, - "emb": embedding.tolist(), - } - for node_id, embedding in zip(ids, embeddings) - ) # fastavro.writer can accept the generator directly. # Doing this appends to self._buffer, so in order to *read* all data from the buffer # we *must* call self._buffer.seek(0) before reading. - fastavro.writer(self._buffer, AVRO_SCHEMA, batched_records) + fastavro.writer(self._buffer, self._avro_schema, records) self._write_time += time.perf_counter() - start if ( @@ -225,6 +218,49 @@ def __exit__(self, exc_type, exc_value, traceback): self._in_context = False +class EmbeddingExporter(GcsExporter): + def __init__( + self, + export_dir: Uri, + file_prefix: Optional[str] = None, + min_shard_size_threshold_bytes: int = 0, + ): + super().__init__(export_dir, file_prefix, min_shard_size_threshold_bytes, AVRO_SCHEMA) + + def add_embedding( + self, + id_batch: torch.Tensor, + embedding_batch: torch.Tensor, + embedding_type: str, + ): + """ + Adds to the in-memory buffer the integer IDs and their corresponding embeddings. + + Args: + id_batch (torch.Tensor): A torch.Tensor containing integer IDs. + embedding_batch (torch.Tensor): A torch.Tensor containing embeddings corresponding to the integer IDs in `id_batch`. + embedding_type (str): A tag for the type of the embeddings, e.g., 'user', 'content', etc. + """ + # Convert torch tensors to NumPy arrays, and then to Python int(s) + # and Python list(s). This is faster than converting torch tensors + # directly to Python int(s) and Python list(s), as Numpy's implementation + # is more efficient. + ids = id_batch.numpy() + embeddings = embedding_batch.numpy() + self._num_records_written += len(ids) + + batched_records = ( + { + "node_id": int(node_id), + "node_type": embedding_type, + "emb": embedding.tolist(), + } + for node_id, embedding in zip(ids, embeddings) + ) + + self.add_records(batched_records) + + # TODO(kmonte): We should migrate this over to `BqUtils.load_files_to_bq` once that is implemented. def load_embeddings_to_bigquery( gcs_folder: GcsUri, @@ -232,6 +268,7 @@ def load_embeddings_to_bigquery( dataset_id: str, table_id: str, should_run_async: bool = False, + schema: Sequence[bigquery.SchemaField] = EMBEDDING_BIGQUERY_SCHEMA, ) -> LoadJob: """ Loads multiple Avro files containing GNN embeddings from GCS into BigQuery. From 9b1c5ef549cee2877b2af878440e632dc4c58cae Mon Sep 17 00:00:00 2001 From: Shubham Vij Date: Wed, 22 Oct 2025 18:42:08 -0700 Subject: [PATCH 2/2] test Co-authored-by: Shubham Vij --- python/gigl/common/data/export.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/python/gigl/common/data/export.py b/python/gigl/common/data/export.py index f41481cca..8866cfa24 100644 --- a/python/gigl/common/data/export.py +++ b/python/gigl/common/data/export.py @@ -52,11 +52,6 @@ ] - - - - - class GcsExporter: def __init__( self,