Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 51 additions & 19 deletions python/gigl/common/data/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
import time
from pathlib import Path
from typing import Final, Optional, Sequence
from typing import Final, Optional, Sequence, Iterable

import fastavro
import fastavro.types
Expand Down Expand Up @@ -52,12 +52,13 @@
]


class EmbeddingExporter:
class GcsExporter:
def __init__(
self,
export_dir: Uri,
file_prefix: Optional[str] = None,
min_shard_size_threshold_bytes: int = 0,
avro_schema: fastavro.types.Schema = AVRO_SCHEMA,
):
"""
Initializes an EmbeddingExporter instance.
Expand Down Expand Up @@ -104,6 +105,7 @@ def __init__(
self._file_utils = FileLoader()
self._prefix = file_prefix
self._min_shard_size_threshold_bytes = min_shard_size_threshold_bytes
self._avro_schema = avro_schema

if isinstance(
self._base_export_uri, LocalUri
Expand All @@ -113,11 +115,9 @@ def __init__(
)
Path(self._base_export_uri.uri).mkdir(parents=True, exist_ok=True)

def add_embedding(
def add_records(
self,
id_batch: torch.Tensor,
embedding_batch: torch.Tensor,
embedding_type: str,
records: Iterable[dict],
):
"""
Adds to the in-memory buffer the integer IDs and their corresponding embeddings.
Expand All @@ -132,23 +132,11 @@ def add_embedding(
# and Python list(s). This is faster than converting torch tensors
# directly to Python int(s) and Python list(s), as Numpy's implementation
# is more efficient.
ids = id_batch.numpy()
embeddings = embedding_batch.numpy()
self._num_records_written += len(ids)

batched_records = (
{
"node_id": int(node_id),
"node_type": embedding_type,
"emb": embedding.tolist(),
}
for node_id, embedding in zip(ids, embeddings)
)

# fastavro.writer can accept the generator directly.
# Doing this appends to self._buffer, so in order to *read* all data from the buffer
# we *must* call self._buffer.seek(0) before reading.
fastavro.writer(self._buffer, AVRO_SCHEMA, batched_records)
fastavro.writer(self._buffer, self._avro_schema, records)
self._write_time += time.perf_counter() - start

if (
Expand Down Expand Up @@ -225,13 +213,57 @@ def __exit__(self, exc_type, exc_value, traceback):
self._in_context = False


class EmbeddingExporter(GcsExporter):
def __init__(
self,
export_dir: Uri,
file_prefix: Optional[str] = None,
min_shard_size_threshold_bytes: int = 0,
):
super().__init__(export_dir, file_prefix, min_shard_size_threshold_bytes, AVRO_SCHEMA)

def add_embedding(
self,
id_batch: torch.Tensor,
embedding_batch: torch.Tensor,
embedding_type: str,
):
"""
Adds to the in-memory buffer the integer IDs and their corresponding embeddings.

Args:
id_batch (torch.Tensor): A torch.Tensor containing integer IDs.
embedding_batch (torch.Tensor): A torch.Tensor containing embeddings corresponding to the integer IDs in `id_batch`.
embedding_type (str): A tag for the type of the embeddings, e.g., 'user', 'content', etc.
"""
# Convert torch tensors to NumPy arrays, and then to Python int(s)
# and Python list(s). This is faster than converting torch tensors
# directly to Python int(s) and Python list(s), as Numpy's implementation
# is more efficient.
ids = id_batch.numpy()
embeddings = embedding_batch.numpy()
self._num_records_written += len(ids)

batched_records = (
{
"node_id": int(node_id),
"node_type": embedding_type,
"emb": embedding.tolist(),
}
for node_id, embedding in zip(ids, embeddings)
)

self.add_records(batched_records)


# TODO(kmonte): We should migrate this over to `BqUtils.load_files_to_bq` once that is implemented.
def load_embeddings_to_bigquery(
gcs_folder: GcsUri,
project_id: str,
dataset_id: str,
table_id: str,
should_run_async: bool = False,
schema: Sequence[bigquery.SchemaField] = EMBEDDING_BIGQUERY_SCHEMA,
) -> LoadJob:
"""
Loads multiple Avro files containing GNN embeddings from GCS into BigQuery.
Expand Down