From dfd967683acf0774206cb1067aea09857dbcbed5 Mon Sep 17 00:00:00 2001 From: nshah Date: Thu, 11 Sep 2025 17:39:37 +0000 Subject: [PATCH 01/10] add edge_dataset creation utils --- .../lib/data/edge_dataset.py | 343 ++++++++++++++++++ 1 file changed, 343 insertions(+) create mode 100644 python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py diff --git a/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py b/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py new file mode 100644 index 000000000..65b534d55 --- /dev/null +++ b/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py @@ -0,0 +1,343 @@ +from enum import Enum +from typing import Dict, List + +import applied_tasks.knowledge_graph_embedding.lib.constants.gcs as gcs_constants +import torch.distributed as dist +from applied_tasks.knowledge_graph_embedding.common.graph_dataset import ( + CONDENSED_EDGE_TYPE_FIELD, + DST_FIELD, + SRC_FIELD, + BigQueryHeterogeneousGraphIterableDataset, + GcsJSONLHeterogeneousGraphIterableDataset, + GcsParquetHeterogeneousGraphIterableDataset, +) +from torch.utils.data import IterableDataset + +from gigl.common.logger import Logger +from gigl.common.types.uri.gcs_uri import GcsUri +from gigl.common.utils.gcs import GcsUtils +from gigl.common.utils.torch_training import is_distributed_available_and_initialized +from gigl.distributed.dist_context import DistributedContext +from gigl.env.pipelines_config import get_resource_config +from gigl.src.common.types import AppliedTaskIdentifier +from gigl.src.common.types.dataset_split import DatasetSplit +from gigl.src.common.types.graph_data import EdgeType +from gigl.src.common.types.pb_wrappers.graph_metadata import GraphMetadataPbWrapper +from gigl.src.common.utils.bq import BqUtils +from gigl.src.data_preprocessor.lib.enumerate.utils import EnumeratorEdgeTypeMetadata +from gigl.src.data_preprocessor.lib.ingest.bigquery import BigqueryEdgeDataReference + +logger = Logger() + + +def _build_intermediate_edges_table( + enumerated_edge_metadata: List[EnumeratorEdgeTypeMetadata], + applied_task_identifier: AppliedTaskIdentifier, + output_bq_dataset: str, + graph_metadata: GraphMetadataPbWrapper, + bq_utils: BqUtils, + split_columns: List[str] = list(), + train_split_clause: str = "rand_split BETWEEN 0 AND 0.8", + val_split_clause: str = "rand_split BETWEEN 0.8 AND 0.9", + test_split_clause: str = "rand_split BETWEEN 0.9 AND 1", +) -> str: + """ + Build an intermediate edges table by unioning multiple edge tables with split metadata. + + This function creates a BigQuery table that combines all edge tables from the enumerated + edge metadata into a single intermediate table. Each edge is mapped to a condensed edge + type and includes split information (either from provided split columns or random splits). + + Args: + enumerated_edge_metadata: List of metadata objects containing edge table references + and identifiers for source and destination nodes. + applied_task_identifier: Unique identifier for the current applied task, used in + the intermediate table name. + output_bq_dataset: BigQuery dataset where the intermediate table will be created. + graph_metadata: Wrapper containing graph metadata including edge type mappings. + bq_utils: BigQuery utilities instance for executing queries. + split_columns: Optional list of column names to use for data splitting. If empty, + a random split column will be generated. + train_split_clause: SQL WHERE clause defining the training split condition. + val_split_clause: SQL WHERE clause defining the validation split condition. + test_split_clause: SQL WHERE clause defining the test split condition. + + Returns: + str: The fully qualified BigQuery table path of the created intermediate edges table. + """ + # Create an intermediate edges table with some split-related metadata. + has_split_columns = len(split_columns) > 0 + split_column_selector = ( + ", ".join(split_columns) if has_split_columns else "RAND() AS rand_split" + ) + if has_split_columns: + logger.info(f"Using split columns: {split_columns}") + else: + logger.info("No split columns provided. Using random transductive split.") + + logger.info( + f"Using train/val/test clauses: '{train_split_clause}', '{val_split_clause}', '{test_split_clause}'" + ) + + edge_table_queries: List[str] = list() + for edge_metadata in enumerated_edge_metadata: + enumerated_reference = edge_metadata.enumerated_edge_data_reference + edge_table = BqUtils.format_bq_path(bq_path=enumerated_reference.reference_uri) + condensed_edge_type = graph_metadata.edge_type_to_condensed_edge_type_map[ + enumerated_reference.edge_type + ] + edge_table_query = f""" + SELECT + {enumerated_reference.src_identifier} AS {SRC_FIELD}, + {enumerated_reference.dst_identifier} AS {DST_FIELD}, + {condensed_edge_type} AS {CONDENSED_EDGE_TYPE_FIELD}, + {split_column_selector} + FROM + `{edge_table}` + """ + edge_table_queries.append(edge_table_query) + + union_edges_query = " UNION ALL ".join(edge_table_queries) + logger.info(f"Will write train/val/test datasets to BQ dataset {output_bq_dataset}") + intermediate_edges_table = BqUtils.join_path( + BqUtils.format_bq_path(output_bq_dataset), + f"intermediate_{applied_task_identifier}", + ) + bq_utils.run_query( + query=union_edges_query, + destination=intermediate_edges_table, + write_disposition="WRITE_TRUNCATE", + labels={}, + ) + + return intermediate_edges_table + + +class EdgeDatasetFormat(str, Enum): + """ + Enumeration of supported edge dataset output formats. + + This enum defines the different formats in which edge datasets can be stored + and accessed. Each format has different performance characteristics and use cases: + + - GCS_JSONL: Stores data as JSONL (JSON Lines) files in Google Cloud Storage. + Good for debugging and human-readable data inspection. + - GCS_PARQUET: Stores data as Parquet files in Google Cloud Storage. + Optimized for analytical workloads with efficient compression and columnar storage. + - BIGQUERY: Keeps data in BigQuery tables for direct querying. + Best for large-scale datasets that benefit from BigQuery's distributed processing. + """ + GCS_JSONL = "JSONL" + GCS_PARQUET = "PARQUET" + BIGQUERY = "BIGQUERY" + + +def build_edge_datasets( + distributed_context: DistributedContext, + enumerated_edge_metadata: List[EnumeratorEdgeTypeMetadata], + applied_task_identifier: AppliedTaskIdentifier, + output_bq_dataset: str, + graph_metadata: GraphMetadataPbWrapper, + split_columns: List[str] = list(), + train_split_clause: str = "rand_split BETWEEN 0 AND 0.8", + val_split_clause: str = "rand_split BETWEEN 0.8 AND 0.9", + test_split_clause: str = "rand_split BETWEEN 0.9 AND 1", + format: EdgeDatasetFormat = EdgeDatasetFormat.GCS_PARQUET, +) -> Dict[DatasetSplit, IterableDataset]: + """ + Build edge datasets for training, validation, and testing. This function + reads edge data from BigQuery, filters it based on the provided split clauses, + and writes the filtered data to either BigQuery or GCS in the specified format. + + This function is designed to work in a distributed environment, where + multiple processes may be running in parallel. It ensures that the resources + are created only once and that all processes wait for each other to finish + before proceeding. It uses PyTorch's distributed package to manage the + distributed context. It also handles the initialization and destruction of + the distributed process group if necessary. + + Args: + distributed_context: The distributed context for the current process. + enumerated_edge_metadata: Metadata for the edges to be processed. + applied_task_identifier: Identifier for the applied task. + output_bq_dataset: BigQuery dataset to write the output to. + graph_metadata: Metadata for the graph. + split_columns: List of columns to use for splitting the data. + train_split_clause: SQL clause for training data split. + val_split_clause: SQL clause for validation data split. + test_split_clause: SQL clause for testing data split. + format: Format of the output datasets (GCS or BigQuery). + project: GCP project ID. + """ + + # Only init torch distributed if not already initialized + we_initialized = False + if not is_distributed_available_and_initialized(): + logger.info( + f"Building edge datasets -- Initializing torch distributed for {distributed_context.global_rank}..." + ) + dist.init_process_group( + backend="cpu:gloo,cuda:nccl", + world_size=distributed_context.global_world_size, + rank=distributed_context.global_rank, + init_method=f"tcp://{distributed_context.main_worker_ip_address}:23456", + ) + logger.info( + f"Using backend: {dist.get_backend()} for distributed dataset building." + ) + we_initialized = True + + bq_utils = BqUtils(project=get_resource_config().project) + gcs_utils = GcsUtils(project=get_resource_config().project) + + MIXED_EDGE_TYPE = EdgeType("mixed", "mixed", "mixed") + heterogeneous_kwargs = { + "src_field": SRC_FIELD, + "dst_field": DST_FIELD, + "condensed_edge_type_field": CONDENSED_EDGE_TYPE_FIELD, + } + + split_info = [ + (DatasetSplit.TRAIN, train_split_clause), + (DatasetSplit.VAL, val_split_clause), + (DatasetSplit.TEST, test_split_clause), + ] + + def create_resources() -> None: + """ + Create the required resources for edge datasets. + + This nested function handles the creation of all necessary resources for the edge + datasets. It first builds an intermediate edges table by combining all edge tables, + then creates separate train/validation/test tables by filtering the intermediate + table according to the provided split clauses. If the output format is GCS-based + (JSONL or PARQUET), it also exports the BigQuery tables to GCS. + + The function creates: + 1. An intermediate edges table containing all edges with split metadata + 2. Separate BigQuery tables for train/validation/test splits + 3. GCS exports of the split tables (if format is GCS_JSONL or GCS_PARQUET) + """ + + intermediate_edges_table = _build_intermediate_edges_table( + enumerated_edge_metadata=enumerated_edge_metadata, + applied_task_identifier=applied_task_identifier, + output_bq_dataset=output_bq_dataset, + graph_metadata=graph_metadata, + bq_utils=bq_utils, + split_columns=split_columns, + train_split_clause=train_split_clause, + val_split_clause=val_split_clause, + test_split_clause=test_split_clause, + ) + for split, split_clause in split_info: + table_reference = BigqueryEdgeDataReference( + reference_uri=BqUtils.join_path( + BqUtils.format_bq_path(output_bq_dataset), + f"{split.value}_edges_{applied_task_identifier}", + ), + src_identifier=SRC_FIELD, + dst_identifier=DST_FIELD, + edge_type=MIXED_EDGE_TYPE, + ) + random_column_field = "row_id" + maybe_extra_field_selector = ( + f", RAND() as {random_column_field}" + if format == EdgeDatasetFormat.BIGQUERY + else "" + ) + query = f"SELECT * {maybe_extra_field_selector} FROM `{intermediate_edges_table}` WHERE {split_clause} ORDER BY RAND()" + + bq_utils.run_query( + query=query, + destination=table_reference.reference_uri, + write_disposition="WRITE_TRUNCATE", + labels=dict(), + ) + if format in (EdgeDatasetFormat.GCS_JSONL, EdgeDatasetFormat.GCS_PARQUET): + gcs_target_path = GcsUri.join( + gcs_constants.get_edge_dataset_output_path( + applied_task_identifier=applied_task_identifier, + ), + f"{split.value}_edges", + ) + destination_glob_path = GcsUri.join(gcs_target_path, "shard-*") + bq_utils.export_to_gcs( + bq_table_path=table_reference.reference_uri, + destination_gcs_uri=destination_glob_path, + destination_format="NEWLINE_DELIMITED_JSON" + if format == EdgeDatasetFormat.GCS_JSONL + else "PARQUET", + ) + + def instantiate_datasets() -> Dict[DatasetSplit, IterableDataset]: + """ + Instantiate and return the edge datasets for each data split. + + This nested function creates IterableDataset instances for train, validation, + and test splits. The type of dataset created depends on the specified format: + - BIGQUERY: Creates BigQueryHeterogeneousGraphIterableDataset instances that + read directly from BigQuery tables + - GCS_JSONL/GCS_PARQUET: Creates GcsJSONLHeterogeneousGraphIterableDataset or + GcsParquetHeterogeneousGraphIterableDataset instances that read from GCS files + + For GCS-based datasets, the function lists all shard files at the expected + GCS path and passes them to the dataset constructor. + + Returns: + Dict[DatasetSplit, IterableDataset]: A dictionary mapping each data split + (TRAIN, VAL, TEST) to its corresponding IterableDataset instance. + """ + + datasets: dict = dict() + for split, _ in split_info: + table_reference = BigqueryEdgeDataReference( + reference_uri=BqUtils.join_path( + BqUtils.format_bq_path(output_bq_dataset), + f"{split.value}_edges_{applied_task_identifier}", + ), + src_identifier=SRC_FIELD, + dst_identifier=DST_FIELD, + edge_type=MIXED_EDGE_TYPE, + ) + random_column_field = "row_id" + if format == EdgeDatasetFormat.BIGQUERY: + datasets[split] = BigQueryHeterogeneousGraphIterableDataset( + table=table_reference.reference_uri, + random_column=random_column_field, + project=get_resource_config().project, + **heterogeneous_kwargs, + ) + elif format in (EdgeDatasetFormat.GCS_JSONL, EdgeDatasetFormat.GCS_PARQUET): + gcs_target_path = GcsUri.join( + gcs_constants.get_edge_dataset_output_path( + applied_task_identifier=applied_task_identifier, + ), + f"{split.value}_edges", + ) + files_at_glob_path = gcs_utils.list_uris_with_gcs_path_pattern( + gcs_path=gcs_target_path, pattern=".*shard-\d+" + ) + dataset_cls = { + EdgeDatasetFormat.GCS_JSONL: GcsJSONLHeterogeneousGraphIterableDataset, + EdgeDatasetFormat.GCS_PARQUET: GcsParquetHeterogeneousGraphIterableDataset, + }[format] + datasets[split] = dataset_cls( + file_uris=files_at_glob_path, **heterogeneous_kwargs + ) + return datasets + + # Rank 0 will create the resources, and all ranks will wait for it to finish. + # This is to ensure that resource creation doesn't happen across multiple ranks, + # since this will create redundant resources and potentially cause issues. + if distributed_context.global_rank == 0: + create_resources() + dist.barrier() # Ensure all ranks have created the resources + datasets = instantiate_datasets() + if we_initialized: + logger.info( + f"Finished building edge datasets -- tearing down torch distributed for {distributed_context.global_rank}..." + ) + dist.destroy_process_group() + + return datasets From 9385ea27fbb3e169bb55c44dc45fe4042d87987e Mon Sep 17 00:00:00 2001 From: nshah Date: Thu, 11 Sep 2025 20:47:28 +0000 Subject: [PATCH 02/10] bump --- .../knowledge_graph_embedding/lib/data/edge_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py b/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py index 65b534d55..0a379d397 100644 --- a/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py +++ b/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py @@ -1,9 +1,9 @@ from enum import Enum from typing import Dict, List -import applied_tasks.knowledge_graph_embedding.lib.constants.gcs as gcs_constants +import gigl.experimental.knowledge_graph_embedding.lib.constants.gcs as gcs_constants import torch.distributed as dist -from applied_tasks.knowledge_graph_embedding.common.graph_dataset import ( +from gigl.experimental.knowledge_graph_embedding.common.graph_dataset import ( CONDENSED_EDGE_TYPE_FIELD, DST_FIELD, SRC_FIELD, From e0f2079145ab16fbac4c674b729401d217beb6d7 Mon Sep 17 00:00:00 2001 From: nshah Date: Thu, 11 Sep 2025 20:54:02 +0000 Subject: [PATCH 03/10] refactoring --- .../lib/data/edge_dataset.py | 629 ++++++++++-------- 1 file changed, 366 insertions(+), 263 deletions(-) diff --git a/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py b/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py index 0a379d397..05e3e5978 100644 --- a/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py +++ b/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py @@ -1,24 +1,26 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass from enum import Enum -from typing import Dict, List +from typing import Callable, Dict, List -import gigl.experimental.knowledge_graph_embedding.lib.constants.gcs as gcs_constants import torch.distributed as dist -from gigl.experimental.knowledge_graph_embedding.common.graph_dataset import ( - CONDENSED_EDGE_TYPE_FIELD, - DST_FIELD, - SRC_FIELD, - BigQueryHeterogeneousGraphIterableDataset, - GcsJSONLHeterogeneousGraphIterableDataset, - GcsParquetHeterogeneousGraphIterableDataset, -) from torch.utils.data import IterableDataset +import gigl.experimental.knowledge_graph_embedding.lib.constants.gcs as gcs_constants from gigl.common.logger import Logger from gigl.common.types.uri.gcs_uri import GcsUri from gigl.common.utils.gcs import GcsUtils from gigl.common.utils.torch_training import is_distributed_available_and_initialized from gigl.distributed.dist_context import DistributedContext from gigl.env.pipelines_config import get_resource_config +from gigl.experimental.knowledge_graph_embedding.common.graph_dataset import ( + CONDENSED_EDGE_TYPE_FIELD, + DST_FIELD, + SRC_FIELD, + BigQueryHeterogeneousGraphIterableDataset, + GcsJSONLHeterogeneousGraphIterableDataset, + GcsParquetHeterogeneousGraphIterableDataset, +) from gigl.src.common.types import AppliedTaskIdentifier from gigl.src.common.types.dataset_split import DatasetSplit from gigl.src.common.types.graph_data import EdgeType @@ -30,89 +32,6 @@ logger = Logger() -def _build_intermediate_edges_table( - enumerated_edge_metadata: List[EnumeratorEdgeTypeMetadata], - applied_task_identifier: AppliedTaskIdentifier, - output_bq_dataset: str, - graph_metadata: GraphMetadataPbWrapper, - bq_utils: BqUtils, - split_columns: List[str] = list(), - train_split_clause: str = "rand_split BETWEEN 0 AND 0.8", - val_split_clause: str = "rand_split BETWEEN 0.8 AND 0.9", - test_split_clause: str = "rand_split BETWEEN 0.9 AND 1", -) -> str: - """ - Build an intermediate edges table by unioning multiple edge tables with split metadata. - - This function creates a BigQuery table that combines all edge tables from the enumerated - edge metadata into a single intermediate table. Each edge is mapped to a condensed edge - type and includes split information (either from provided split columns or random splits). - - Args: - enumerated_edge_metadata: List of metadata objects containing edge table references - and identifiers for source and destination nodes. - applied_task_identifier: Unique identifier for the current applied task, used in - the intermediate table name. - output_bq_dataset: BigQuery dataset where the intermediate table will be created. - graph_metadata: Wrapper containing graph metadata including edge type mappings. - bq_utils: BigQuery utilities instance for executing queries. - split_columns: Optional list of column names to use for data splitting. If empty, - a random split column will be generated. - train_split_clause: SQL WHERE clause defining the training split condition. - val_split_clause: SQL WHERE clause defining the validation split condition. - test_split_clause: SQL WHERE clause defining the test split condition. - - Returns: - str: The fully qualified BigQuery table path of the created intermediate edges table. - """ - # Create an intermediate edges table with some split-related metadata. - has_split_columns = len(split_columns) > 0 - split_column_selector = ( - ", ".join(split_columns) if has_split_columns else "RAND() AS rand_split" - ) - if has_split_columns: - logger.info(f"Using split columns: {split_columns}") - else: - logger.info("No split columns provided. Using random transductive split.") - - logger.info( - f"Using train/val/test clauses: '{train_split_clause}', '{val_split_clause}', '{test_split_clause}'" - ) - - edge_table_queries: List[str] = list() - for edge_metadata in enumerated_edge_metadata: - enumerated_reference = edge_metadata.enumerated_edge_data_reference - edge_table = BqUtils.format_bq_path(bq_path=enumerated_reference.reference_uri) - condensed_edge_type = graph_metadata.edge_type_to_condensed_edge_type_map[ - enumerated_reference.edge_type - ] - edge_table_query = f""" - SELECT - {enumerated_reference.src_identifier} AS {SRC_FIELD}, - {enumerated_reference.dst_identifier} AS {DST_FIELD}, - {condensed_edge_type} AS {CONDENSED_EDGE_TYPE_FIELD}, - {split_column_selector} - FROM - `{edge_table}` - """ - edge_table_queries.append(edge_table_query) - - union_edges_query = " UNION ALL ".join(edge_table_queries) - logger.info(f"Will write train/val/test datasets to BQ dataset {output_bq_dataset}") - intermediate_edges_table = BqUtils.join_path( - BqUtils.format_bq_path(output_bq_dataset), - f"intermediate_{applied_task_identifier}", - ) - bq_utils.run_query( - query=union_edges_query, - destination=intermediate_edges_table, - write_disposition="WRITE_TRUNCATE", - labels={}, - ) - - return intermediate_edges_table - - class EdgeDatasetFormat(str, Enum): """ Enumeration of supported edge dataset output formats. @@ -127,11 +46,339 @@ class EdgeDatasetFormat(str, Enum): - BIGQUERY: Keeps data in BigQuery tables for direct querying. Best for large-scale datasets that benefit from BigQuery's distributed processing. """ + GCS_JSONL = "JSONL" GCS_PARQUET = "PARQUET" BIGQUERY = "BIGQUERY" +@dataclass +class SplitConfiguration: + """Configuration for dataset splitting parameters.""" + + split_columns: List[str] + train_split_clause: str + val_split_clause: str + test_split_clause: str + + +@dataclass +class EdgeDatasetConfig: + """Configuration for edge dataset building parameters.""" + + distributed_context: DistributedContext + enumerated_edge_metadata: List[EnumeratorEdgeTypeMetadata] + applied_task_identifier: AppliedTaskIdentifier + output_bq_dataset: str + graph_metadata: GraphMetadataPbWrapper + format: EdgeDatasetFormat + split_config: SplitConfiguration + + +class DistributedEdgeDatasetCoordinator: + """Handles distributed coordination for edge dataset building.""" + + def __init__(self, distributed_context: DistributedContext): + self.distributed_context = distributed_context + self.we_initialized_dist = False + + def __enter__(self): + """Initialize distributed context if needed.""" + if not is_distributed_available_and_initialized(): + logger.info( + f"Building edge datasets -- Initializing torch distributed for {self.distributed_context.global_rank}..." + ) + dist.init_process_group( + backend="cpu:gloo,cuda:nccl", + world_size=self.distributed_context.global_world_size, + rank=self.distributed_context.global_rank, + init_method=f"tcp://{self.distributed_context.main_worker_ip_address}:23456", + ) + logger.info( + f"Using backend: {dist.get_backend()} for distributed dataset building." + ) + self.we_initialized_dist = True + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Cleanup distributed context if we initialized it.""" + if self.we_initialized_dist: + logger.info( + f"Finished building edge datasets -- tearing down torch distributed for {self.distributed_context.global_rank}..." + ) + dist.destroy_process_group() + + def coordinate_resource_creation(self, creation_func: Callable[[], None]) -> None: + """Coordinate resource creation across distributed ranks.""" + # Rank 0 will create the resources, and all ranks will wait for it to finish. + # This is to ensure that resource creation doesn't happen across multiple ranks, + # since this will create redundant resources and potentially cause issues. + if self.distributed_context.global_rank == 0: + creation_func() + dist.barrier() # Ensure all ranks have created the resources + + +class EdgeDatasetResourceBuilder: + """Handles creation of edge dataset resources (BigQuery tables and GCS exports).""" + + def __init__(self, config: EdgeDatasetConfig): + self.config = config + self.bq_utils = BqUtils(project=get_resource_config().project) + self.gcs_utils = GcsUtils(project=get_resource_config().project) + self.MIXED_EDGE_TYPE = EdgeType("mixed", "mixed", "mixed") + self.split_info = [ + (DatasetSplit.TRAIN, config.split_config.train_split_clause), + (DatasetSplit.VAL, config.split_config.val_split_clause), + (DatasetSplit.TEST, config.split_config.test_split_clause), + ] + + def create_all_resources(self) -> None: + """Create all required resources for edge datasets.""" + intermediate_table = self._build_intermediate_edges_table() + self._create_split_tables(intermediate_table) + self._export_to_gcs_if_needed() + + def _build_intermediate_edges_table(self) -> str: + """Build an intermediate edges table by unioning multiple edge tables with split metadata.""" + # Create an intermediate edges table with some split-related metadata. + has_split_columns = len(self.config.split_config.split_columns) > 0 + split_column_selector = ( + ", ".join(self.config.split_config.split_columns) + if has_split_columns + else "RAND() AS rand_split" + ) + + if has_split_columns: + logger.info( + f"Using split columns: {self.config.split_config.split_columns}" + ) + else: + logger.info("No split columns provided. Using random transductive split.") + + logger.info( + f"Using train/val/test clauses: '{self.config.split_config.train_split_clause}', " + f"'{self.config.split_config.val_split_clause}', '{self.config.split_config.test_split_clause}'" + ) + + edge_table_queries: List[str] = [] + for edge_metadata in self.config.enumerated_edge_metadata: + enumerated_reference = edge_metadata.enumerated_edge_data_reference + edge_table = BqUtils.format_bq_path( + bq_path=enumerated_reference.reference_uri + ) + condensed_edge_type = ( + self.config.graph_metadata.edge_type_to_condensed_edge_type_map[ + enumerated_reference.edge_type + ] + ) + edge_table_query = f""" + SELECT + {enumerated_reference.src_identifier} AS {SRC_FIELD}, + {enumerated_reference.dst_identifier} AS {DST_FIELD}, + {condensed_edge_type} AS {CONDENSED_EDGE_TYPE_FIELD}, + {split_column_selector} + FROM + `{edge_table}` + """ + edge_table_queries.append(edge_table_query) + + union_edges_query = " UNION ALL ".join(edge_table_queries) + logger.info( + f"Will write train/val/test datasets to BQ dataset {self.config.output_bq_dataset}" + ) + + intermediate_edges_table = BqUtils.join_path( + BqUtils.format_bq_path(self.config.output_bq_dataset), + f"intermediate_{self.config.applied_task_identifier}", + ) + + self.bq_utils.run_query( + query=union_edges_query, + destination=intermediate_edges_table, + write_disposition="WRITE_TRUNCATE", + labels={}, + ) + + return intermediate_edges_table + + def _create_split_tables(self, intermediate_table: str) -> None: + """Create separate BigQuery tables for train/validation/test splits.""" + for split, split_clause in self.split_info: + table_reference = self._create_table_reference(split) + + random_column_field = "row_id" + maybe_extra_field_selector = ( + f", RAND() as {random_column_field}" + if self.config.format == EdgeDatasetFormat.BIGQUERY + else "" + ) + + query = f"SELECT * {maybe_extra_field_selector} FROM `{intermediate_table}` WHERE {split_clause} ORDER BY RAND()" + + self.bq_utils.run_query( + query=query, + destination=table_reference.reference_uri, + write_disposition="WRITE_TRUNCATE", + labels=dict(), + ) + + def _export_to_gcs_if_needed(self) -> None: + """Export BigQuery tables to GCS if the format requires it.""" + if self.config.format not in ( + EdgeDatasetFormat.GCS_JSONL, + EdgeDatasetFormat.GCS_PARQUET, + ): + return + + for split, _ in self.split_info: + table_reference = self._create_table_reference(split) + + gcs_target_path = GcsUri.join( + gcs_constants.get_edge_dataset_output_path( + applied_task_identifier=self.config.applied_task_identifier, + ), + f"{split.value}_edges", + ) + destination_glob_path = GcsUri.join(gcs_target_path, "shard-*") + + self.bq_utils.export_to_gcs( + bq_table_path=table_reference.reference_uri, + destination_gcs_uri=destination_glob_path, + destination_format="NEWLINE_DELIMITED_JSON" + if self.config.format == EdgeDatasetFormat.GCS_JSONL + else "PARQUET", + ) + + def _create_table_reference(self, split: DatasetSplit) -> BigqueryEdgeDataReference: + """Create a BigQuery table reference for a given data split.""" + return BigqueryEdgeDataReference( + reference_uri=BqUtils.join_path( + BqUtils.format_bq_path(self.config.output_bq_dataset), + f"{split.value}_edges_{self.config.applied_task_identifier}", + ), + src_identifier=SRC_FIELD, + dst_identifier=DST_FIELD, + edge_type=self.MIXED_EDGE_TYPE, + ) + + +class EdgeDatasetStrategy(ABC): + """Abstract strategy for creating different types of edge datasets.""" + + @abstractmethod + def create_dataset( + self, + config: EdgeDatasetConfig, + split: DatasetSplit, + table_reference: BigqueryEdgeDataReference, + **kwargs, + ) -> IterableDataset: + """Create a dataset for the given split.""" + + +class BigQueryDatasetStrategy(EdgeDatasetStrategy): + """Strategy for creating BigQuery-based edge datasets.""" + + def create_dataset( + self, + config: EdgeDatasetConfig, + split: DatasetSplit, + table_reference: BigqueryEdgeDataReference, + **kwargs, + ) -> IterableDataset: + random_column_field = "row_id" + return BigQueryHeterogeneousGraphIterableDataset( + table=table_reference.reference_uri, + random_column=random_column_field, + project=get_resource_config().project, + **kwargs, + ) + + +class GcsDatasetStrategy(EdgeDatasetStrategy): + """Strategy for creating GCS-based edge datasets (JSONL or Parquet).""" + + def __init__(self, format_type: EdgeDatasetFormat): + self.format_type = format_type + self.gcs_utils = GcsUtils(project=get_resource_config().project) + + def create_dataset( + self, + config: EdgeDatasetConfig, + split: DatasetSplit, + table_reference: BigqueryEdgeDataReference, + **kwargs, + ) -> IterableDataset: + gcs_target_path = GcsUri.join( + gcs_constants.get_edge_dataset_output_path( + applied_task_identifier=config.applied_task_identifier, + ), + f"{split.value}_edges", + ) + files_at_glob_path = self.gcs_utils.list_uris_with_gcs_path_pattern( + gcs_path=gcs_target_path, pattern=".*shard-\d+" + ) + + dataset_cls = { + EdgeDatasetFormat.GCS_JSONL: GcsJSONLHeterogeneousGraphIterableDataset, + EdgeDatasetFormat.GCS_PARQUET: GcsParquetHeterogeneousGraphIterableDataset, + }[self.format_type] + + return dataset_cls(file_uris=files_at_glob_path, **kwargs) + + +class EdgeDatasetFactory: + """Factory for creating edge datasets using appropriate strategies.""" + + def __init__(self, config: EdgeDatasetConfig): + self.config = config + self.strategy_map = { + EdgeDatasetFormat.BIGQUERY: BigQueryDatasetStrategy(), + EdgeDatasetFormat.GCS_JSONL: GcsDatasetStrategy( + EdgeDatasetFormat.GCS_JSONL + ), + EdgeDatasetFormat.GCS_PARQUET: GcsDatasetStrategy( + EdgeDatasetFormat.GCS_PARQUET + ), + } + self.heterogeneous_kwargs = { + "src_field": SRC_FIELD, + "dst_field": DST_FIELD, + "condensed_edge_type_field": CONDENSED_EDGE_TYPE_FIELD, + } + self.MIXED_EDGE_TYPE = EdgeType("mixed", "mixed", "mixed") + self.split_info = [ + (DatasetSplit.TRAIN, config.split_config.train_split_clause), + (DatasetSplit.VAL, config.split_config.val_split_clause), + (DatasetSplit.TEST, config.split_config.test_split_clause), + ] + + def create_datasets(self) -> Dict[DatasetSplit, IterableDataset]: + """Create and return the edge datasets for each data split.""" + strategy = self.strategy_map[self.config.format] + datasets: Dict[DatasetSplit, IterableDataset] = {} + + for split, _ in self.split_info: + table_reference = BigqueryEdgeDataReference( + reference_uri=BqUtils.join_path( + BqUtils.format_bq_path(self.config.output_bq_dataset), + f"{split.value}_edges_{self.config.applied_task_identifier}", + ), + src_identifier=SRC_FIELD, + dst_identifier=DST_FIELD, + edge_type=self.MIXED_EDGE_TYPE, + ) + + datasets[split] = strategy.create_dataset( + config=self.config, + split=split, + table_reference=table_reference, + **self.heterogeneous_kwargs, + ) + + return datasets + + def build_edge_datasets( distributed_context: DistributedContext, enumerated_edge_metadata: List[EnumeratorEdgeTypeMetadata], @@ -149,8 +396,8 @@ def build_edge_datasets( reads edge data from BigQuery, filters it based on the provided split clauses, and writes the filtered data to either BigQuery or GCS in the specified format. - This function is designed to work in a distributed environment, where - multiple processes may be running in parallel. It ensures that the resources + This function is designed to work in a distributed environment (e.g. at start of training), + where multiple processes may be running in parallel. It ensures that the resources are created only once and that all processes wait for each other to finish before proceeding. It uses PyTorch's distributed package to manage the distributed context. It also handles the initialization and destruction of @@ -170,174 +417,30 @@ def build_edge_datasets( project: GCP project ID. """ - # Only init torch distributed if not already initialized - we_initialized = False - if not is_distributed_available_and_initialized(): - logger.info( - f"Building edge datasets -- Initializing torch distributed for {distributed_context.global_rank}..." - ) - dist.init_process_group( - backend="cpu:gloo,cuda:nccl", - world_size=distributed_context.global_world_size, - rank=distributed_context.global_rank, - init_method=f"tcp://{distributed_context.main_worker_ip_address}:23456", - ) - logger.info( - f"Using backend: {dist.get_backend()} for distributed dataset building." - ) - we_initialized = True - - bq_utils = BqUtils(project=get_resource_config().project) - gcs_utils = GcsUtils(project=get_resource_config().project) - - MIXED_EDGE_TYPE = EdgeType("mixed", "mixed", "mixed") - heterogeneous_kwargs = { - "src_field": SRC_FIELD, - "dst_field": DST_FIELD, - "condensed_edge_type_field": CONDENSED_EDGE_TYPE_FIELD, - } - - split_info = [ - (DatasetSplit.TRAIN, train_split_clause), - (DatasetSplit.VAL, val_split_clause), - (DatasetSplit.TEST, test_split_clause), - ] - - def create_resources() -> None: - """ - Create the required resources for edge datasets. - - This nested function handles the creation of all necessary resources for the edge - datasets. It first builds an intermediate edges table by combining all edge tables, - then creates separate train/validation/test tables by filtering the intermediate - table according to the provided split clauses. If the output format is GCS-based - (JSONL or PARQUET), it also exports the BigQuery tables to GCS. - - The function creates: - 1. An intermediate edges table containing all edges with split metadata - 2. Separate BigQuery tables for train/validation/test splits - 3. GCS exports of the split tables (if format is GCS_JSONL or GCS_PARQUET) - """ - - intermediate_edges_table = _build_intermediate_edges_table( - enumerated_edge_metadata=enumerated_edge_metadata, - applied_task_identifier=applied_task_identifier, - output_bq_dataset=output_bq_dataset, - graph_metadata=graph_metadata, - bq_utils=bq_utils, - split_columns=split_columns, - train_split_clause=train_split_clause, - val_split_clause=val_split_clause, - test_split_clause=test_split_clause, - ) - for split, split_clause in split_info: - table_reference = BigqueryEdgeDataReference( - reference_uri=BqUtils.join_path( - BqUtils.format_bq_path(output_bq_dataset), - f"{split.value}_edges_{applied_task_identifier}", - ), - src_identifier=SRC_FIELD, - dst_identifier=DST_FIELD, - edge_type=MIXED_EDGE_TYPE, - ) - random_column_field = "row_id" - maybe_extra_field_selector = ( - f", RAND() as {random_column_field}" - if format == EdgeDatasetFormat.BIGQUERY - else "" - ) - query = f"SELECT * {maybe_extra_field_selector} FROM `{intermediate_edges_table}` WHERE {split_clause} ORDER BY RAND()" + # Create configuration objects + split_config = SplitConfiguration( + split_columns=split_columns, + train_split_clause=train_split_clause, + val_split_clause=val_split_clause, + test_split_clause=test_split_clause, + ) - bq_utils.run_query( - query=query, - destination=table_reference.reference_uri, - write_disposition="WRITE_TRUNCATE", - labels=dict(), - ) - if format in (EdgeDatasetFormat.GCS_JSONL, EdgeDatasetFormat.GCS_PARQUET): - gcs_target_path = GcsUri.join( - gcs_constants.get_edge_dataset_output_path( - applied_task_identifier=applied_task_identifier, - ), - f"{split.value}_edges", - ) - destination_glob_path = GcsUri.join(gcs_target_path, "shard-*") - bq_utils.export_to_gcs( - bq_table_path=table_reference.reference_uri, - destination_gcs_uri=destination_glob_path, - destination_format="NEWLINE_DELIMITED_JSON" - if format == EdgeDatasetFormat.GCS_JSONL - else "PARQUET", - ) - - def instantiate_datasets() -> Dict[DatasetSplit, IterableDataset]: - """ - Instantiate and return the edge datasets for each data split. - - This nested function creates IterableDataset instances for train, validation, - and test splits. The type of dataset created depends on the specified format: - - BIGQUERY: Creates BigQueryHeterogeneousGraphIterableDataset instances that - read directly from BigQuery tables - - GCS_JSONL/GCS_PARQUET: Creates GcsJSONLHeterogeneousGraphIterableDataset or - GcsParquetHeterogeneousGraphIterableDataset instances that read from GCS files - - For GCS-based datasets, the function lists all shard files at the expected - GCS path and passes them to the dataset constructor. - - Returns: - Dict[DatasetSplit, IterableDataset]: A dictionary mapping each data split - (TRAIN, VAL, TEST) to its corresponding IterableDataset instance. - """ - - datasets: dict = dict() - for split, _ in split_info: - table_reference = BigqueryEdgeDataReference( - reference_uri=BqUtils.join_path( - BqUtils.format_bq_path(output_bq_dataset), - f"{split.value}_edges_{applied_task_identifier}", - ), - src_identifier=SRC_FIELD, - dst_identifier=DST_FIELD, - edge_type=MIXED_EDGE_TYPE, - ) - random_column_field = "row_id" - if format == EdgeDatasetFormat.BIGQUERY: - datasets[split] = BigQueryHeterogeneousGraphIterableDataset( - table=table_reference.reference_uri, - random_column=random_column_field, - project=get_resource_config().project, - **heterogeneous_kwargs, - ) - elif format in (EdgeDatasetFormat.GCS_JSONL, EdgeDatasetFormat.GCS_PARQUET): - gcs_target_path = GcsUri.join( - gcs_constants.get_edge_dataset_output_path( - applied_task_identifier=applied_task_identifier, - ), - f"{split.value}_edges", - ) - files_at_glob_path = gcs_utils.list_uris_with_gcs_path_pattern( - gcs_path=gcs_target_path, pattern=".*shard-\d+" - ) - dataset_cls = { - EdgeDatasetFormat.GCS_JSONL: GcsJSONLHeterogeneousGraphIterableDataset, - EdgeDatasetFormat.GCS_PARQUET: GcsParquetHeterogeneousGraphIterableDataset, - }[format] - datasets[split] = dataset_cls( - file_uris=files_at_glob_path, **heterogeneous_kwargs - ) - return datasets + config = EdgeDatasetConfig( + distributed_context=distributed_context, + enumerated_edge_metadata=enumerated_edge_metadata, + applied_task_identifier=applied_task_identifier, + output_bq_dataset=output_bq_dataset, + graph_metadata=graph_metadata, + format=format, + split_config=split_config, + ) - # Rank 0 will create the resources, and all ranks will wait for it to finish. - # This is to ensure that resource creation doesn't happen across multiple ranks, - # since this will create redundant resources and potentially cause issues. - if distributed_context.global_rank == 0: - create_resources() - dist.barrier() # Ensure all ranks have created the resources - datasets = instantiate_datasets() - if we_initialized: - logger.info( - f"Finished building edge datasets -- tearing down torch distributed for {distributed_context.global_rank}..." - ) - dist.destroy_process_group() + # Use context manager for distributed coordination + with DistributedEdgeDatasetCoordinator(distributed_context) as coordinator: + # Create resources using the builder + resource_builder = EdgeDatasetResourceBuilder(config) + coordinator.coordinate_resource_creation(resource_builder.create_all_resources) - return datasets + # Create and return datasets using the factory + factory = EdgeDatasetFactory(config) + return factory.create_datasets() From 251641f3ac5bef7da378f65ae101482fb9ca3d17 Mon Sep 17 00:00:00 2001 From: nshah Date: Fri, 12 Sep 2025 20:41:52 +0000 Subject: [PATCH 04/10] protocols --- .../lib/data/edge_dataset.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py b/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py index 05e3e5978..9daaa14c6 100644 --- a/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py +++ b/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py @@ -1,7 +1,6 @@ -from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from typing import Callable, Dict, List +from typing import Callable, Dict, List, Protocol import torch.distributed as dist from torch.utils.data import IterableDataset @@ -262,10 +261,9 @@ def _create_table_reference(self, split: DatasetSplit) -> BigqueryEdgeDataRefere ) -class EdgeDatasetStrategy(ABC): - """Abstract strategy for creating different types of edge datasets.""" +class EdgeDatasetStrategy(Protocol): + """Protocol for creating different types of edge datasets.""" - @abstractmethod def create_dataset( self, config: EdgeDatasetConfig, @@ -274,9 +272,10 @@ def create_dataset( **kwargs, ) -> IterableDataset: """Create a dataset for the given split.""" + ... -class BigQueryDatasetStrategy(EdgeDatasetStrategy): +class BigQueryDatasetStrategy: """Strategy for creating BigQuery-based edge datasets.""" def create_dataset( @@ -295,7 +294,7 @@ def create_dataset( ) -class GcsDatasetStrategy(EdgeDatasetStrategy): +class GcsDatasetStrategy: """Strategy for creating GCS-based edge datasets (JSONL or Parquet).""" def __init__(self, format_type: EdgeDatasetFormat): From 4945d9e19d4f9426e3895fab509d15a3202df696 Mon Sep 17 00:00:00 2001 From: nshah Date: Mon, 15 Sep 2025 21:36:16 +0000 Subject: [PATCH 05/10] comment --- .../lib/data/edge_dataset.py | 93 ++++++++++--------- 1 file changed, 48 insertions(+), 45 deletions(-) diff --git a/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py b/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py index 9daaa14c6..a5d97e301 100644 --- a/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py +++ b/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from enum import Enum -from typing import Callable, Dict, List, Protocol +from typing import Callable, Dict, List, Protocol, Tuple import torch.distributed as dist from torch.utils.data import IterableDataset @@ -60,6 +60,13 @@ class SplitConfiguration: val_split_clause: str test_split_clause: str + def clause_per_split(self) -> List[Tuple[DatasetSplit, str]]: + return [ + (DatasetSplit.TRAIN, self.train_split_clause), + (DatasetSplit.VAL, self.val_split_clause), + (DatasetSplit.TEST, self.test_split_clause), + ] + @dataclass class EdgeDatasetConfig: @@ -74,6 +81,23 @@ class EdgeDatasetConfig: split_config: SplitConfiguration +def _get_BigqueryEdgeDataReference_for_split( + output_bq_dataset: str, + applied_task_identifier: AppliedTaskIdentifier, + split: DatasetSplit, +) -> BigqueryEdgeDataReference: + """Get a BigQuery edge data reference for a given split.""" + return BigqueryEdgeDataReference( + reference_uri=BqUtils.join_path( + BqUtils.format_bq_path(output_bq_dataset), + f"{split.value}_edges_{applied_task_identifier}", + ), + src_identifier=SRC_FIELD, + dst_identifier=DST_FIELD, + edge_type=EdgeType("mixed", "mixed", "mixed"), + ) + + class DistributedEdgeDatasetCoordinator: """Handles distributed coordination for edge dataset building.""" @@ -124,18 +148,13 @@ def __init__(self, config: EdgeDatasetConfig): self.config = config self.bq_utils = BqUtils(project=get_resource_config().project) self.gcs_utils = GcsUtils(project=get_resource_config().project) - self.MIXED_EDGE_TYPE = EdgeType("mixed", "mixed", "mixed") - self.split_info = [ - (DatasetSplit.TRAIN, config.split_config.train_split_clause), - (DatasetSplit.VAL, config.split_config.val_split_clause), - (DatasetSplit.TEST, config.split_config.test_split_clause), - ] + self.split_info = config.split_config.clause_per_split() def create_all_resources(self) -> None: """Create all required resources for edge datasets.""" intermediate_table = self._build_intermediate_edges_table() - self._create_split_tables(intermediate_table) - self._export_to_gcs_if_needed() + self._create_tables_for_each_split(intermediate_table) + self._export_split_tables_to_gcs_if_needed() def _build_intermediate_edges_table(self) -> str: """Build an intermediate edges table by unioning multiple edge tables with split metadata.""" @@ -200,10 +219,14 @@ def _build_intermediate_edges_table(self) -> str: return intermediate_edges_table - def _create_split_tables(self, intermediate_table: str) -> None: + def _create_tables_for_each_split(self, intermediate_table: str) -> None: """Create separate BigQuery tables for train/validation/test splits.""" for split, split_clause in self.split_info: - table_reference = self._create_table_reference(split) + table_reference = _get_BigqueryEdgeDataReference_for_split( + self.config.output_bq_dataset, + self.config.applied_task_identifier, + split, + ) random_column_field = "row_id" maybe_extra_field_selector = ( @@ -221,7 +244,7 @@ def _create_split_tables(self, intermediate_table: str) -> None: labels=dict(), ) - def _export_to_gcs_if_needed(self) -> None: + def _export_split_tables_to_gcs_if_needed(self) -> None: """Export BigQuery tables to GCS if the format requires it.""" if self.config.format not in ( EdgeDatasetFormat.GCS_JSONL, @@ -230,7 +253,11 @@ def _export_to_gcs_if_needed(self) -> None: return for split, _ in self.split_info: - table_reference = self._create_table_reference(split) + table_reference = _get_BigqueryEdgeDataReference_for_split( + self.config.output_bq_dataset, + self.config.applied_task_identifier, + split, + ) gcs_target_path = GcsUri.join( gcs_constants.get_edge_dataset_output_path( @@ -248,18 +275,6 @@ def _export_to_gcs_if_needed(self) -> None: else "PARQUET", ) - def _create_table_reference(self, split: DatasetSplit) -> BigqueryEdgeDataReference: - """Create a BigQuery table reference for a given data split.""" - return BigqueryEdgeDataReference( - reference_uri=BqUtils.join_path( - BqUtils.format_bq_path(self.config.output_bq_dataset), - f"{split.value}_edges_{self.config.applied_task_identifier}", - ), - src_identifier=SRC_FIELD, - dst_identifier=DST_FIELD, - edge_type=self.MIXED_EDGE_TYPE, - ) - class EdgeDatasetStrategy(Protocol): """Protocol for creating different types of edge datasets.""" @@ -268,7 +283,6 @@ def create_dataset( self, config: EdgeDatasetConfig, split: DatasetSplit, - table_reference: BigqueryEdgeDataReference, **kwargs, ) -> IterableDataset: """Create a dataset for the given split.""" @@ -282,9 +296,15 @@ def create_dataset( self, config: EdgeDatasetConfig, split: DatasetSplit, - table_reference: BigqueryEdgeDataReference, **kwargs, ) -> IterableDataset: + # Create table reference specific to BigQuery strategy + table_reference = _get_BigqueryEdgeDataReference_for_split( + config.output_bq_dataset, + config.applied_task_identifier, + split, + ) + random_column_field = "row_id" return BigQueryHeterogeneousGraphIterableDataset( table=table_reference.reference_uri, @@ -305,7 +325,6 @@ def create_dataset( self, config: EdgeDatasetConfig, split: DatasetSplit, - table_reference: BigqueryEdgeDataReference, **kwargs, ) -> IterableDataset: gcs_target_path = GcsUri.join( @@ -345,12 +364,7 @@ def __init__(self, config: EdgeDatasetConfig): "dst_field": DST_FIELD, "condensed_edge_type_field": CONDENSED_EDGE_TYPE_FIELD, } - self.MIXED_EDGE_TYPE = EdgeType("mixed", "mixed", "mixed") - self.split_info = [ - (DatasetSplit.TRAIN, config.split_config.train_split_clause), - (DatasetSplit.VAL, config.split_config.val_split_clause), - (DatasetSplit.TEST, config.split_config.test_split_clause), - ] + self.split_info = config.split_config.clause_per_split() def create_datasets(self) -> Dict[DatasetSplit, IterableDataset]: """Create and return the edge datasets for each data split.""" @@ -358,20 +372,9 @@ def create_datasets(self) -> Dict[DatasetSplit, IterableDataset]: datasets: Dict[DatasetSplit, IterableDataset] = {} for split, _ in self.split_info: - table_reference = BigqueryEdgeDataReference( - reference_uri=BqUtils.join_path( - BqUtils.format_bq_path(self.config.output_bq_dataset), - f"{split.value}_edges_{self.config.applied_task_identifier}", - ), - src_identifier=SRC_FIELD, - dst_identifier=DST_FIELD, - edge_type=self.MIXED_EDGE_TYPE, - ) - datasets[split] = strategy.create_dataset( config=self.config, split=split, - table_reference=table_reference, **self.heterogeneous_kwargs, ) From e44b55e2e178b8ccac0f71d36f8e504fe5a35485 Mon Sep 17 00:00:00 2001 From: nshah Date: Mon, 15 Sep 2025 23:38:35 +0000 Subject: [PATCH 06/10] stash --- .../lib/data/edge_dataset.py | 164 ++++++++---------- 1 file changed, 75 insertions(+), 89 deletions(-) diff --git a/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py b/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py index a5d97e301..38e9a81c2 100644 --- a/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py +++ b/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from enum import Enum -from typing import Callable, Dict, List, Protocol, Tuple +from typing import Dict, List, Optional, Protocol, Tuple import torch.distributed as dist from torch.utils.data import IterableDataset @@ -52,8 +52,8 @@ class EdgeDatasetFormat(str, Enum): @dataclass -class SplitConfiguration: - """Configuration for dataset splitting parameters.""" +class PerSplitFilteredEdgeBigqueryMetadata: + """Configuration parameters to filter BigQuery tables by split (train/val/test).""" split_columns: List[str] train_split_clause: str @@ -69,16 +69,16 @@ def clause_per_split(self) -> List[Tuple[DatasetSplit, str]]: @dataclass -class EdgeDatasetConfig: - """Configuration for edge dataset building parameters.""" +class PerSplitFilteredEdgeDatasetConfig: + """Configuration parameters to build filtered datasets by split (train/val/test).""" distributed_context: DistributedContext enumerated_edge_metadata: List[EnumeratorEdgeTypeMetadata] applied_task_identifier: AppliedTaskIdentifier output_bq_dataset: str graph_metadata: GraphMetadataPbWrapper - format: EdgeDatasetFormat - split_config: SplitConfiguration + split_dataset_format: EdgeDatasetFormat + split_config: PerSplitFilteredEdgeBigqueryMetadata def _get_BigqueryEdgeDataReference_for_split( @@ -98,53 +98,10 @@ def _get_BigqueryEdgeDataReference_for_split( ) -class DistributedEdgeDatasetCoordinator: - """Handles distributed coordination for edge dataset building.""" - - def __init__(self, distributed_context: DistributedContext): - self.distributed_context = distributed_context - self.we_initialized_dist = False - - def __enter__(self): - """Initialize distributed context if needed.""" - if not is_distributed_available_and_initialized(): - logger.info( - f"Building edge datasets -- Initializing torch distributed for {self.distributed_context.global_rank}..." - ) - dist.init_process_group( - backend="cpu:gloo,cuda:nccl", - world_size=self.distributed_context.global_world_size, - rank=self.distributed_context.global_rank, - init_method=f"tcp://{self.distributed_context.main_worker_ip_address}:23456", - ) - logger.info( - f"Using backend: {dist.get_backend()} for distributed dataset building." - ) - self.we_initialized_dist = True - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """Cleanup distributed context if we initialized it.""" - if self.we_initialized_dist: - logger.info( - f"Finished building edge datasets -- tearing down torch distributed for {self.distributed_context.global_rank}..." - ) - dist.destroy_process_group() - - def coordinate_resource_creation(self, creation_func: Callable[[], None]) -> None: - """Coordinate resource creation across distributed ranks.""" - # Rank 0 will create the resources, and all ranks will wait for it to finish. - # This is to ensure that resource creation doesn't happen across multiple ranks, - # since this will create redundant resources and potentially cause issues. - if self.distributed_context.global_rank == 0: - creation_func() - dist.barrier() # Ensure all ranks have created the resources - - -class EdgeDatasetResourceBuilder: +class PerSplitFilteredEdgeDatasetBuilder: """Handles creation of edge dataset resources (BigQuery tables and GCS exports).""" - def __init__(self, config: EdgeDatasetConfig): + def __init__(self, config: PerSplitFilteredEdgeDatasetConfig): self.config = config self.bq_utils = BqUtils(project=get_resource_config().project) self.gcs_utils = GcsUtils(project=get_resource_config().project) @@ -152,11 +109,11 @@ def __init__(self, config: EdgeDatasetConfig): def create_all_resources(self) -> None: """Create all required resources for edge datasets.""" - intermediate_table = self._build_intermediate_edges_table() - self._create_tables_for_each_split(intermediate_table) - self._export_split_tables_to_gcs_if_needed() + intermediate_table = self._build_intermediate_all_edges_table() + self._create_filtered_tables_for_each_split(intermediate_table) + self._export_filtered_tables_to_gcs_if_needed() - def _build_intermediate_edges_table(self) -> str: + def _build_intermediate_all_edges_table(self) -> str: """Build an intermediate edges table by unioning multiple edge tables with split metadata.""" # Create an intermediate edges table with some split-related metadata. has_split_columns = len(self.config.split_config.split_columns) > 0 @@ -219,7 +176,7 @@ def _build_intermediate_edges_table(self) -> str: return intermediate_edges_table - def _create_tables_for_each_split(self, intermediate_table: str) -> None: + def _create_filtered_tables_for_each_split(self, intermediate_table: str) -> None: """Create separate BigQuery tables for train/validation/test splits.""" for split, split_clause in self.split_info: table_reference = _get_BigqueryEdgeDataReference_for_split( @@ -231,7 +188,7 @@ def _create_tables_for_each_split(self, intermediate_table: str) -> None: random_column_field = "row_id" maybe_extra_field_selector = ( f", RAND() as {random_column_field}" - if self.config.format == EdgeDatasetFormat.BIGQUERY + if self.config.split_dataset_format == EdgeDatasetFormat.BIGQUERY else "" ) @@ -244,9 +201,9 @@ def _create_tables_for_each_split(self, intermediate_table: str) -> None: labels=dict(), ) - def _export_split_tables_to_gcs_if_needed(self) -> None: + def _export_filtered_tables_to_gcs_if_needed(self) -> None: """Export BigQuery tables to GCS if the format requires it.""" - if self.config.format not in ( + if self.config.split_dataset_format not in ( EdgeDatasetFormat.GCS_JSONL, EdgeDatasetFormat.GCS_PARQUET, ): @@ -271,17 +228,17 @@ def _export_split_tables_to_gcs_if_needed(self) -> None: bq_table_path=table_reference.reference_uri, destination_gcs_uri=destination_glob_path, destination_format="NEWLINE_DELIMITED_JSON" - if self.config.format == EdgeDatasetFormat.GCS_JSONL + if self.config.split_dataset_format == EdgeDatasetFormat.GCS_JSONL else "PARQUET", ) -class EdgeDatasetStrategy(Protocol): - """Protocol for creating different types of edge datasets.""" +class PerSplitIterableDatasetStrategy(Protocol): + """Protocol for creating different types of iterable datasets with filtered datasets for each split.""" def create_dataset( self, - config: EdgeDatasetConfig, + config: PerSplitFilteredEdgeDatasetConfig, split: DatasetSplit, **kwargs, ) -> IterableDataset: @@ -289,12 +246,12 @@ def create_dataset( ... -class BigQueryDatasetStrategy: - """Strategy for creating BigQuery-based edge datasets.""" +class PerSplitIterableDatasetBigqueryStrategy: + """Strategy for creating BigQuery-based iterable datasets with filtered datasets for each split.""" def create_dataset( self, - config: EdgeDatasetConfig, + config: PerSplitFilteredEdgeDatasetConfig, split: DatasetSplit, **kwargs, ) -> IterableDataset: @@ -314,7 +271,7 @@ def create_dataset( ) -class GcsDatasetStrategy: +class PerSplitIterableDatasetGcsStrategy: """Strategy for creating GCS-based edge datasets (JSONL or Parquet).""" def __init__(self, format_type: EdgeDatasetFormat): @@ -323,7 +280,7 @@ def __init__(self, format_type: EdgeDatasetFormat): def create_dataset( self, - config: EdgeDatasetConfig, + config: PerSplitFilteredEdgeDatasetConfig, split: DatasetSplit, **kwargs, ) -> IterableDataset: @@ -345,17 +302,17 @@ def create_dataset( return dataset_cls(file_uris=files_at_glob_path, **kwargs) -class EdgeDatasetFactory: - """Factory for creating edge datasets using appropriate strategies.""" +class PerSplitIterableDatasetFactory: + """Factory for creating per-split edge datasets using appropriate strategies.""" - def __init__(self, config: EdgeDatasetConfig): + def __init__(self, config: PerSplitFilteredEdgeDatasetConfig): self.config = config self.strategy_map = { - EdgeDatasetFormat.BIGQUERY: BigQueryDatasetStrategy(), - EdgeDatasetFormat.GCS_JSONL: GcsDatasetStrategy( + EdgeDatasetFormat.BIGQUERY: PerSplitIterableDatasetBigqueryStrategy(), + EdgeDatasetFormat.GCS_JSONL: PerSplitIterableDatasetGcsStrategy( EdgeDatasetFormat.GCS_JSONL ), - EdgeDatasetFormat.GCS_PARQUET: GcsDatasetStrategy( + EdgeDatasetFormat.GCS_PARQUET: PerSplitIterableDatasetGcsStrategy( EdgeDatasetFormat.GCS_PARQUET ), } @@ -368,7 +325,7 @@ def __init__(self, config: EdgeDatasetConfig): def create_datasets(self) -> Dict[DatasetSplit, IterableDataset]: """Create and return the edge datasets for each data split.""" - strategy = self.strategy_map[self.config.format] + strategy = self.strategy_map[self.config.split_dataset_format] datasets: Dict[DatasetSplit, IterableDataset] = {} for split, _ in self.split_info: @@ -387,7 +344,7 @@ def build_edge_datasets( applied_task_identifier: AppliedTaskIdentifier, output_bq_dataset: str, graph_metadata: GraphMetadataPbWrapper, - split_columns: List[str] = list(), + split_columns: Optional[List[str]] = None, train_split_clause: str = "rand_split BETWEEN 0 AND 0.8", val_split_clause: str = "rand_split BETWEEN 0.8 AND 0.9", test_split_clause: str = "rand_split BETWEEN 0.9 AND 1", @@ -419,30 +376,59 @@ def build_edge_datasets( project: GCP project ID. """ + if split_columns is None: + split_columns = list() + # Create configuration objects - split_config = SplitConfiguration( + bq_split_metadata = PerSplitFilteredEdgeBigqueryMetadata( split_columns=split_columns, train_split_clause=train_split_clause, val_split_clause=val_split_clause, test_split_clause=test_split_clause, ) - config = EdgeDatasetConfig( + config = PerSplitFilteredEdgeDatasetConfig( distributed_context=distributed_context, enumerated_edge_metadata=enumerated_edge_metadata, applied_task_identifier=applied_task_identifier, output_bq_dataset=output_bq_dataset, graph_metadata=graph_metadata, - format=format, - split_config=split_config, + split_dataset_format=format, + split_config=bq_split_metadata, ) - # Use context manager for distributed coordination - with DistributedEdgeDatasetCoordinator(distributed_context) as coordinator: - # Create resources using the builder - resource_builder = EdgeDatasetResourceBuilder(config) - coordinator.coordinate_resource_creation(resource_builder.create_all_resources) - - # Create and return datasets using the factory - factory = EdgeDatasetFactory(config) + # Handle distributed initialization if needed + we_initialized_dist = False + if not is_distributed_available_and_initialized(): + logger.info( + f"Building edge datasets -- Initializing torch distributed for {distributed_context.global_rank}..." + ) + dist.init_process_group( + backend="cpu:gloo,cuda:nccl", + world_size=distributed_context.global_world_size, + rank=distributed_context.global_rank, + init_method=f"tcp://{distributed_context.main_worker_ip_address}:23456", + ) + logger.info( + f"Using backend: {dist.get_backend()} for distributed dataset building." + ) + we_initialized_dist = True + + try: + # Run BQ / GCS operations to create filtered datasets + # Only rank 0 creates these datasets to avoid duplicate operations, all ranks wait for completion + split_data_builder = PerSplitFilteredEdgeDatasetBuilder(config) + if distributed_context.global_rank == 0: + split_data_builder.create_all_resources() + dist.barrier() # Ensure all ranks wait for resource creation to complete + + # Create and return torch IterableDatasets for each split using the factory + factory = PerSplitIterableDatasetFactory(config) return factory.create_datasets() + finally: + # Cleanup distributed context if we initialized it + if we_initialized_dist: + logger.info( + f"Finished building edge datasets -- tearing down torch distributed for {distributed_context.global_rank}..." + ) + dist.destroy_process_group() From 4d83c6fcca04dcf35845fb278155df4b2ab07b76 Mon Sep 17 00:00:00 2001 From: nshah Date: Tue, 16 Sep 2025 00:43:49 +0000 Subject: [PATCH 07/10] drop unneeded port --- .../knowledge_graph_embedding/lib/data/edge_dataset.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py b/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py index 38e9a81c2..a3b4f8e72 100644 --- a/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py +++ b/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py @@ -22,7 +22,7 @@ ) from gigl.src.common.types import AppliedTaskIdentifier from gigl.src.common.types.dataset_split import DatasetSplit -from gigl.src.common.types.graph_data import EdgeType +from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation from gigl.src.common.types.pb_wrappers.graph_metadata import GraphMetadataPbWrapper from gigl.src.common.utils.bq import BqUtils from gigl.src.data_preprocessor.lib.enumerate.utils import EnumeratorEdgeTypeMetadata @@ -94,7 +94,7 @@ def _get_BigqueryEdgeDataReference_for_split( ), src_identifier=SRC_FIELD, dst_identifier=DST_FIELD, - edge_type=EdgeType("mixed", "mixed", "mixed"), + edge_type=EdgeType(NodeType("mixed"), Relation("mixed"), NodeType("mixed")), ) @@ -307,7 +307,7 @@ class PerSplitIterableDatasetFactory: def __init__(self, config: PerSplitFilteredEdgeDatasetConfig): self.config = config - self.strategy_map = { + self.strategy_map: Dict[EdgeDatasetFormat, PerSplitIterableDatasetStrategy] = { EdgeDatasetFormat.BIGQUERY: PerSplitIterableDatasetBigqueryStrategy(), EdgeDatasetFormat.GCS_JSONL: PerSplitIterableDatasetGcsStrategy( EdgeDatasetFormat.GCS_JSONL @@ -407,7 +407,6 @@ def build_edge_datasets( backend="cpu:gloo,cuda:nccl", world_size=distributed_context.global_world_size, rank=distributed_context.global_rank, - init_method=f"tcp://{distributed_context.main_worker_ip_address}:23456", ) logger.info( f"Using backend: {dist.get_backend()} for distributed dataset building." From 0a51b9b0f28adb5101f2e5b8748ee7fb7015add1 Mon Sep 17 00:00:00 2001 From: nshah Date: Wed, 24 Sep 2025 21:30:16 +0000 Subject: [PATCH 08/10] comments --- .../lib/data/edge_dataset.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py b/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py index a3b4f8e72..2eee6c431 100644 --- a/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py +++ b/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py @@ -31,6 +31,9 @@ logger = Logger() +_ROW_ID_FIELD = "row_id" + + class EdgeDatasetFormat(str, Enum): """ Enumeration of supported edge dataset output formats. @@ -87,6 +90,9 @@ def _get_BigqueryEdgeDataReference_for_split( split: DatasetSplit, ) -> BigqueryEdgeDataReference: """Get a BigQuery edge data reference for a given split.""" + + # This table contains edges of multiple edge types, which are reflected in a condensed_edge_type + # field. Hence, we use a mixed edge type in the reference. return BigqueryEdgeDataReference( reference_uri=BqUtils.join_path( BqUtils.format_bq_path(output_bq_dataset), @@ -185,9 +191,8 @@ def _create_filtered_tables_for_each_split(self, intermediate_table: str) -> Non split, ) - random_column_field = "row_id" maybe_extra_field_selector = ( - f", RAND() as {random_column_field}" + f", RAND() as {_ROW_ID_FIELD}" if self.config.split_dataset_format == EdgeDatasetFormat.BIGQUERY else "" ) @@ -262,10 +267,9 @@ def create_dataset( split, ) - random_column_field = "row_id" return BigQueryHeterogeneousGraphIterableDataset( table=table_reference.reference_uri, - random_column=random_column_field, + random_column=_ROW_ID_FIELD, project=get_resource_config().project, **kwargs, ) @@ -345,9 +349,9 @@ def build_edge_datasets( output_bq_dataset: str, graph_metadata: GraphMetadataPbWrapper, split_columns: Optional[List[str]] = None, - train_split_clause: str = "rand_split BETWEEN 0 AND 0.8", - val_split_clause: str = "rand_split BETWEEN 0.8 AND 0.9", - test_split_clause: str = "rand_split BETWEEN 0.9 AND 1", + train_split_clause: str = "rand_split >= 0 AND rand_split < 0.8", + val_split_clause: str = "rand_split >= 0.8 AND rand_split < 0.9", + test_split_clause: str = "rand_split >= 0.9 AND rand_split <= 1", format: EdgeDatasetFormat = EdgeDatasetFormat.GCS_PARQUET, ) -> Dict[DatasetSplit, IterableDataset]: """ @@ -404,7 +408,7 @@ def build_edge_datasets( f"Building edge datasets -- Initializing torch distributed for {distributed_context.global_rank}..." ) dist.init_process_group( - backend="cpu:gloo,cuda:nccl", + backend="gloo", world_size=distributed_context.global_world_size, rank=distributed_context.global_rank, ) From 31b2ebcec2db8a848b1f425d3ccac878e3389adc Mon Sep 17 00:00:00 2001 From: nshah Date: Wed, 24 Sep 2025 21:45:35 +0000 Subject: [PATCH 09/10] typing --- .../lib/data/edge_dataset.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py b/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py index 2eee6c431..9480813bc 100644 --- a/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py +++ b/python/gigl/experimental/knowledge_graph_embedding/lib/data/edge_dataset.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from enum import Enum -from typing import Dict, List, Optional, Protocol, Tuple +from typing import Optional, Protocol import torch.distributed as dist from torch.utils.data import IterableDataset @@ -58,12 +58,12 @@ class EdgeDatasetFormat(str, Enum): class PerSplitFilteredEdgeBigqueryMetadata: """Configuration parameters to filter BigQuery tables by split (train/val/test).""" - split_columns: List[str] + split_columns: list[str] train_split_clause: str val_split_clause: str test_split_clause: str - def clause_per_split(self) -> List[Tuple[DatasetSplit, str]]: + def clause_per_split(self) -> list[tuple[DatasetSplit, str]]: return [ (DatasetSplit.TRAIN, self.train_split_clause), (DatasetSplit.VAL, self.val_split_clause), @@ -76,7 +76,7 @@ class PerSplitFilteredEdgeDatasetConfig: """Configuration parameters to build filtered datasets by split (train/val/test).""" distributed_context: DistributedContext - enumerated_edge_metadata: List[EnumeratorEdgeTypeMetadata] + enumerated_edge_metadata: list[EnumeratorEdgeTypeMetadata] applied_task_identifier: AppliedTaskIdentifier output_bq_dataset: str graph_metadata: GraphMetadataPbWrapper @@ -141,7 +141,7 @@ def _build_intermediate_all_edges_table(self) -> str: f"'{self.config.split_config.val_split_clause}', '{self.config.split_config.test_split_clause}'" ) - edge_table_queries: List[str] = [] + edge_table_queries: list[str] = [] for edge_metadata in self.config.enumerated_edge_metadata: enumerated_reference = edge_metadata.enumerated_edge_data_reference edge_table = BqUtils.format_bq_path( @@ -311,7 +311,7 @@ class PerSplitIterableDatasetFactory: def __init__(self, config: PerSplitFilteredEdgeDatasetConfig): self.config = config - self.strategy_map: Dict[EdgeDatasetFormat, PerSplitIterableDatasetStrategy] = { + self.strategy_map: dict[EdgeDatasetFormat, PerSplitIterableDatasetStrategy] = { EdgeDatasetFormat.BIGQUERY: PerSplitIterableDatasetBigqueryStrategy(), EdgeDatasetFormat.GCS_JSONL: PerSplitIterableDatasetGcsStrategy( EdgeDatasetFormat.GCS_JSONL @@ -327,10 +327,10 @@ def __init__(self, config: PerSplitFilteredEdgeDatasetConfig): } self.split_info = config.split_config.clause_per_split() - def create_datasets(self) -> Dict[DatasetSplit, IterableDataset]: + def create_datasets(self) -> dict[DatasetSplit, IterableDataset]: """Create and return the edge datasets for each data split.""" strategy = self.strategy_map[self.config.split_dataset_format] - datasets: Dict[DatasetSplit, IterableDataset] = {} + datasets: dict[DatasetSplit, IterableDataset] = {} for split, _ in self.split_info: datasets[split] = strategy.create_dataset( @@ -344,16 +344,16 @@ def create_datasets(self) -> Dict[DatasetSplit, IterableDataset]: def build_edge_datasets( distributed_context: DistributedContext, - enumerated_edge_metadata: List[EnumeratorEdgeTypeMetadata], + enumerated_edge_metadata: list[EnumeratorEdgeTypeMetadata], applied_task_identifier: AppliedTaskIdentifier, output_bq_dataset: str, graph_metadata: GraphMetadataPbWrapper, - split_columns: Optional[List[str]] = None, + split_columns: Optional[list[str]] = None, train_split_clause: str = "rand_split >= 0 AND rand_split < 0.8", val_split_clause: str = "rand_split >= 0.8 AND rand_split < 0.9", test_split_clause: str = "rand_split >= 0.9 AND rand_split <= 1", format: EdgeDatasetFormat = EdgeDatasetFormat.GCS_PARQUET, -) -> Dict[DatasetSplit, IterableDataset]: +) -> dict[DatasetSplit, IterableDataset]: """ Build edge datasets for training, validation, and testing. This function reads edge data from BigQuery, filters it based on the provided split clauses, From ec23ab5a34b7565cdad86d9775a30d509cc51794 Mon Sep 17 00:00:00 2001 From: nshah Date: Sat, 27 Sep 2025 01:15:38 +0000 Subject: [PATCH 10/10] embedding export and unenumeration utils --- .../knowledge_graph_embedding/lib/infer.py | 450 ++++++++++++++++++ 1 file changed, 450 insertions(+) create mode 100644 python/gigl/experimental/knowledge_graph_embedding/lib/infer.py diff --git a/python/gigl/experimental/knowledge_graph_embedding/lib/infer.py b/python/gigl/experimental/knowledge_graph_embedding/lib/infer.py new file mode 100644 index 000000000..1fa3ea5ff --- /dev/null +++ b/python/gigl/experimental/knowledge_graph_embedding/lib/infer.py @@ -0,0 +1,450 @@ +import math +from typing import Union + +import gigl.experimental.knowledge_graph_embedding.lib.constants.bq as bq_constants +import gigl.experimental.knowledge_graph_embedding.lib.constants.gcs as gcs_constants +import torch +from gigl.experimental.knowledge_graph_embedding.lib.config import ( + HeterogeneousGraphSparseEmbeddingConfig, +) +from gigl.experimental.knowledge_graph_embedding.lib.data.edge_dataset import ( + AppliedTaskIdentifier, +) +from gigl.experimental.knowledge_graph_embedding.lib.data.node_batch import NodeBatch +from gigl.experimental.knowledge_graph_embedding.lib.model.heterogeneous_graph_model import ( + HeterogeneousGraphSparseEmbeddingModelAndLoss, + ModelPhase, +) +from google.cloud import bigquery +from torchrec.distributed import DistributedModelParallel, TrainPipelineSparseDist + +import gigl.src.data_preprocessor.lib.enumerate.queries as enumeration_queries +from gigl.common.data import export +from gigl.common.logger import Logger +from gigl.env.pipelines_config import get_resource_config +from gigl.src.common.types.graph_data import ( + CondensedEdgeType, + CondensedNodeType, + EdgeType, + NodeType, +) +from gigl.src.common.types.pb_wrappers.graph_metadata import GraphMetadataPbWrapper +from gigl.src.common.utils.bq import BqUtils +from gigl.src.data_preprocessor.lib.enumerate.utils import EnumeratorNodeTypeMetadata + +logger = Logger() + + +def infer_and_export_node_embeddings( + is_src: bool, + condensed_edge_type: CondensedEdgeType, + edge_type: EdgeType, + pipeline: TrainPipelineSparseDist, + applied_task_identifier: AppliedTaskIdentifier, + rank_prefix_str: str, + rank: int, + world_size: int, + device: torch.device, + kge_config: HeterogeneousGraphSparseEmbeddingConfig, + graph_metadata: GraphMetadataPbWrapper, + condensed_node_type_to_vocab_size_map: dict[CondensedNodeType, int], +): + """Infer and export node embeddings for either source or destination nodes of a given edge type. + + This function handles the complete inference pipeline for node embeddings, including: + - Setting the appropriate model phase (source or destination inference) + - Determining the correct node type and vocabulary size + - Creating the output directory path + - Initializing the embedding exporter + - Creating and processing the inference data loader + - Batch processing and exporting embeddings to GCS + + The function is designed to work in a distributed training setup where each process + handles a portion of the nodes based on its rank. The embeddings are sharded and + written to separate files per rank to enable parallel processing. + + Args: + is_src (bool): If True, process source nodes; if False, process destination nodes. + condensed_edge_type (CondensedEdgeType): The condensed representation of the edge type. + edge_type (EdgeType): The full edge type containing source and destination node types. + pipeline (TrainPipelineSparseDist): The distributed training pipeline for inference. + applied_task_identifier (AppliedTaskIdentifier): Identifier for the applied task. + rank_prefix_str (str): A prefix string for logging, typically includes the rank of the process. + rank (int): The rank of the current process in distributed training. + world_size (int): The total number of processes in distributed training. + device (torch.device): The device to run the inference on. + kge_config (HeterogeneousGraphSparseEmbeddingConfig): The configuration for the KGE model. + graph_metadata (GraphMetadataPbWrapper): Metadata about the graph, including edge types and node types. + condensed_node_type_to_vocab_size_map (dict[CondensedNodeType, int]): A mapping from condensed node types to + their vocabulary sizes. + + Returns: + None: Embeddings are directly exported to GCS via the exporter. + """ + # Determine node type string for logging and model phase selection + node_type_str = "src" if is_src else "dst" + phase = ModelPhase.INFERENCE_SRC if is_src else ModelPhase.INFERENCE_DST + + # Set the model phase for inference (source or destination) + pipeline._model.module.set_phase(phase) + logger.info( + f"{rank_prefix_str} Set model phase to {pipeline._model.module.phase} for inference." + ) + + # Extract the condensed node types for both source and destination from the edge type + ( + src_condensed_node_type, + dst_condensed_node_type, + ) = graph_metadata.condensed_edge_type_to_condensed_node_types[ + condensed_edge_type + ] + + # Select the appropriate condensed node type based on whether we're processing src or dst + condensed_node_type = ( + src_condensed_node_type if is_src else dst_condensed_node_type + ) + + # Get the vocabulary size for the selected node type + vocab_size = condensed_node_type_to_vocab_size_map[condensed_node_type] + + # Get the actual node type (not condensed) for this inference + node_type = edge_type.src_node_type if is_src else edge_type.dst_node_type + + # Determine the appropriate GCS output path based on node type + if is_src: + embedding_dir = gcs_constants.get_embedding_output_path_for_src_node( + applied_task_identifier=applied_task_identifier, + edge_type=edge_type, + ) + else: + embedding_dir = gcs_constants.get_embedding_output_path_for_dst_node( + applied_task_identifier=applied_task_identifier, + edge_type=edge_type, + ) + + # Initialize the embedding exporter with rank-specific file naming + # This ensures each distributed process writes to its own file + exporter = export.EmbeddingExporter( + export_dir=embedding_dir, + file_prefix=f"{rank}_of_{world_size}_embeddings_", + min_shard_size_threshold_bytes=1_000_000_000, # 1GB threshold for sharding + ) + + # Calculate the range of nodes this rank should process + # Each rank gets approximately vocab_size / world_size nodes + nodes_per_rank = math.ceil(vocab_size / world_size) + dataset = range( + rank * nodes_per_rank, min((rank + 1) * nodes_per_rank, vocab_size) + ) + + logger.info( + f"Rank {rank_prefix_str} processing nodes {dataset} with length {len(dataset)} of {node_type_str} nodes of type {node_type}." + ) + + # Create the data loader for this rank's subset of nodes + inference_loader = NodeBatch.build_data_loader( + dataset=dataset, + condensed_node_type=condensed_node_type, + condensed_edge_type=condensed_edge_type, + graph_metadata=graph_metadata, + sampling_config=kge_config.training.sampling, + dataloader_config=kge_config.training.dataloader, + pin_memory=device.type == "cuda", # Use pinned memory for GPU acceleration + ) + inference_iter = iter(inference_loader) + + # Process inference batches until all nodes are processed + while True: + try: + # Run inference on the next batch and get node IDs and their embeddings + ( + node_ids, + node_embeddings, + ) = pipeline.progress(inference_iter) + + # Export the embeddings to the exporter (which will handle GCS upload) + # Move tensors to CPU to free GPU memory + exporter.add_embedding( + id_batch=node_ids.cpu(), + embedding_batch=node_embeddings.cpu(), + embedding_type=node_type, + ) + except StopIteration: + # All batches processed - flush remaining embeddings and log completion + exporter.flush_embeddings() + logger.info( + f"{rank_prefix_str} Finished inference for {node_type_str} nodes of type {node_type} " + + f"(for edge type {edge_type.src_node_type}-{edge_type.relation}-{edge_type.dst_node_type}). " + + f"Embeddings written to {embedding_dir}." + ) + break + + +def infer_and_export_embeddings( + applied_task_identifier: AppliedTaskIdentifier, + rank_prefix_str: str, + rank: int, + world_size: int, + device: torch.device, + kge_config: HeterogeneousGraphSparseEmbeddingConfig, + model_and_loss: Union[ + DistributedModelParallel, HeterogeneousGraphSparseEmbeddingModelAndLoss + ], + optimizer: torch.optim.Optimizer, + graph_metadata: GraphMetadataPbWrapper, + condensed_node_type_to_vocab_size_map: dict[CondensedNodeType, int], +): + """Run inference to generate source and destination node embeddings for all edge types. + + This function iterates over each edge type in the graph metadata, infers embeddings for + both source and destination nodes, and exports them to GCS. It operates within a + distributed training setup where each process handles a portion of the node embeddings + based on its rank. The embeddings are saved in a structured manner in GCS, with each + process writing its portion to separate files. + + Args: + applied_task_identifier (AppliedTaskIdentifier): Identifier for the applied task. + rank_prefix_str (str): A prefix string for logging, typically includes the rank of the process. + rank (int): The rank of the current process in distributed training. + world_size (int): The total number of processes in distributed training. + device (torch.device): The device to run the inference on. + kge_config (HeterogeneousGraphSparseEmbeddingConfig): The configuration for the KGE model. + model_and_loss (Union[DistributedModelParallel, HeterogeneousGraphSparseEmbeddingModelAndLoss]): The model and loss function to use for inference. + optimizer (torch.optim.Optimizer): The optimizer used during training, needed for pipeline initialization. + graph_metadata (GraphMetadataPbWrapper): Metadata about the graph, including edge types and node types. + condensed_node_type_to_vocab_size_map (dict[CondensedNodeType, int]): A mapping from condensed node types to + their vocabulary sizes. + + Returns: + None: Embeddings are directly exported to GCS for each edge type. + """ + + logger.info( + f"{rank_prefix_str} Running inference to predict src and dst node embeddings for each edge type." + ) + + # Initialize the distributed training pipeline for inference + pipeline = TrainPipelineSparseDist( + model=model_and_loss, optimizer=optimizer, device=device + ) + logger.info(f"{rank_prefix_str} Initialized TrainPipelineSparseDist for inference.") + + # Run inference in no_grad context to save memory and improve performance + with torch.no_grad(): + # Set model to evaluation mode for inference + pipeline._model.eval() + + # Process each edge type in the graph metadata + for ( + condensed_edge_type, + edge_type, + ) in sorted(graph_metadata.condensed_edge_type_to_edge_type_map.items()): + logger.info( + f"""{rank_prefix_str} Running inference for edge type {edge_type} on + src node type {edge_type.src_node_type} and dst node type {edge_type.dst_node_type}.""" + ) + + # Process source nodes for this edge type + infer_and_export_node_embeddings( + is_src=True, + condensed_edge_type=condensed_edge_type, + edge_type=edge_type, + pipeline=pipeline, + applied_task_identifier=applied_task_identifier, + rank_prefix_str=rank_prefix_str, + rank=rank, + world_size=world_size, + device=device, + kge_config=kge_config, + graph_metadata=graph_metadata, + condensed_node_type_to_vocab_size_map=condensed_node_type_to_vocab_size_map, + ) + + # Process destination nodes for this edge type + infer_and_export_node_embeddings( + is_src=False, + condensed_edge_type=condensed_edge_type, + edge_type=edge_type, + pipeline=pipeline, + applied_task_identifier=applied_task_identifier, + rank_prefix_str=rank_prefix_str, + rank=rank, + world_size=world_size, + device=device, + kge_config=kge_config, + graph_metadata=graph_metadata, + condensed_node_type_to_vocab_size_map=condensed_node_type_to_vocab_size_map, + ) + + logger.info(f"Finished writing all embeddings.") + + +def upload_embeddings_to_bigquery( + applied_task_identifier: AppliedTaskIdentifier, + graph_metadata: GraphMetadataPbWrapper, + enumerated_node_metadata: list[EnumeratorNodeTypeMetadata], +): + """Upload node embeddings from GCS to BigQuery for all edge types. + + This function iterates over each edge type in the graph metadata and loads the + previously inferred embeddings from GCS into BigQuery tables. It creates both + enumerated and unenumerated embedding tables for source and destination nodes + of each edge type. + + Args: + applied_task_identifier (AppliedTaskIdentifier): Identifier for the applied task. + graph_metadata (GraphMetadataPbWrapper): Metadata about the graph, including edge types and node types. + enumerated_node_metadata (list[EnumeratorNodeTypeMetadata]): Metadata for enumerated node types, used to map + node types to their corresponding BigQuery tables. + + Returns: + None: Embeddings are uploaded to BigQuery tables. + """ + + node_type_to_enumerated_metadata_tables: dict[NodeType, str] = { + node_type_metadata.enumerated_node_data_reference.node_type: node_type_metadata.bq_unique_node_ids_enumerated_table_name + for node_type_metadata in enumerated_node_metadata + } + + logger.info(f"Loading embeddings to BigQuery.") + + for edge_type in graph_metadata.edge_types: + edge_type_src_node_embedding_dir = ( + gcs_constants.get_embedding_output_path_for_src_node( + applied_task_identifier=applied_task_identifier, + edge_type=edge_type, + ) + ) + edge_type_dst_node_embedding_dir = ( + gcs_constants.get_embedding_output_path_for_dst_node( + applied_task_identifier=applied_task_identifier, + edge_type=edge_type, + ) + ) + + # Load src node embeddings to BigQuery. + enum_src_node_embedding_table, unenum_src_node_embedding_table = ( + bq_constants.get_src_node_embedding_table_for_edge_type( + applied_task_identifier=applied_task_identifier, + edge_type=edge_type, + is_enumerated=True, + ), + bq_constants.get_src_node_embedding_table_for_edge_type( + applied_task_identifier=applied_task_identifier, + edge_type=edge_type, + is_enumerated=False, + ), + ) + project_id, dataset_id, table_id = BqUtils.parse_bq_table_path( + enum_src_node_embedding_table + ) + export.load_embeddings_to_bigquery( + gcs_folder=edge_type_src_node_embedding_dir, + project_id=project_id, + dataset_id=dataset_id, + table_id=table_id, + ) + logger.info( + f"Finished writing enumerated src node embeddings to BigQuery table `{enum_src_node_embedding_table}` for edge type {edge_type}." + ) + unenumerate_embeddings_table( + enumerated_embeddings_table=enum_src_node_embedding_table, + embeddings_table_node_id_field=export._NODE_ID_KEY, + unenumerated_embeddings_table=unenum_src_node_embedding_table, + enumerator_mapping_table=node_type_to_enumerated_metadata_tables[ + edge_type.src_node_type + ], + ) + logger.info( + f"Finished unenumerating src node embedings and wrote them to `{unenum_src_node_embedding_table}` using mapping `{node_type_to_enumerated_metadata_tables[edge_type.src_node_type]}`." + ) + + # Load dst node embeddings to BigQuery. + enum_dst_node_embedding_table, unenum_dst_node_embedding_table = ( + bq_constants.get_dst_node_embedding_table_for_edge_type( + applied_task_identifier=applied_task_identifier, + edge_type=edge_type, + is_enumerated=True, + ), + bq_constants.get_dst_node_embedding_table_for_edge_type( + applied_task_identifier=applied_task_identifier, + edge_type=edge_type, + is_enumerated=False, + ), + ) + project_id, dataset_id, table_id = BqUtils.parse_bq_table_path( + enum_dst_node_embedding_table + ) + export.load_embeddings_to_bigquery( + gcs_folder=edge_type_dst_node_embedding_dir, + project_id=project_id, + dataset_id=dataset_id, + table_id=table_id, + ) + logger.info( + f"Finished writing enumerated dst node embeddings to BigQuery table `{enum_dst_node_embedding_table}` for edge type {edge_type}." + ) + unenumerate_embeddings_table( + enumerated_embeddings_table=enum_dst_node_embedding_table, + embeddings_table_node_id_field=export._NODE_ID_KEY, + unenumerated_embeddings_table=unenum_dst_node_embedding_table, + enumerator_mapping_table=node_type_to_enumerated_metadata_tables[ + edge_type.dst_node_type + ], + ) + logger.info( + f"Finished unenumerating dst node embeddings and wrote them to `{unenum_dst_node_embedding_table}` using mapping `{node_type_to_enumerated_metadata_tables[edge_type.dst_node_type]}`." + ) + + +def unenumerate_embeddings_table( + enumerated_embeddings_table: str, + embeddings_table_node_id_field: str, + unenumerated_embeddings_table: str, + enumerator_mapping_table: str, +): + """Convert enumerated embeddings back to their original node IDs. + + This function transforms embeddings from an enumerated embeddings table to an + unenumerated embeddings table by joining with a mapping table. The resulting + table will have the original node IDs as keys and the embeddings as values. + + Args: + enumerated_embeddings_table (str): The BigQuery table containing enumerated embeddings. + embeddings_table_node_id_field (str): The field in the enumerated embeddings table + that contains node IDs. + unenumerated_embeddings_table (str): The destination BigQuery table for unenumerated + embeddings. + enumerator_mapping_table (str): The BigQuery table containing the mapping from + enumerated to original node IDs. + + Returns: + None: Results are written directly to the destination BigQuery table. + """ + + UNENUMERATION_QUERY = """ + SELECT + mapping.{original_node_id_field}, + * EXCEPT({node_id_field}, {enumerated_int_id_field}) + FROM + `{enumerated_assets_table}` enumerated_assets + INNER JOIN + `{mapping_table}` mapping + ON + mapping.int_id = enumerated_assets.{node_id_field} + QUALIFY RANK() OVER (PARTITION BY mapping.{original_node_id_field} ORDER BY RAND()) = 1 + """ + + bq_utils = BqUtils(project=get_resource_config().project) + bq_utils.run_query( + query=UNENUMERATION_QUERY.format( + enumerated_assets_table=enumerated_embeddings_table, + mapping_table=enumerator_mapping_table, + node_id_field=embeddings_table_node_id_field, + original_node_id_field=enumeration_queries.DEFAULT_ORIGINAL_NODE_ID_FIELD, + enumerated_int_id_field=enumeration_queries.DEFAULT_ENUMERATED_NODE_ID_FIELD, + ), + labels=dict(), + destination=unenumerated_embeddings_table, + write_disposition=bigquery.job.WriteDisposition.WRITE_TRUNCATE, + )