From 804300630b70514b674d1bba65ad7b09983bc9ea Mon Sep 17 00:00:00 2001 From: swong3 Date: Tue, 2 Dec 2025 23:35:20 +0000 Subject: [PATCH 1/9] Added heterogeneous training --- .../id_embeddings/heterogeneous_training.py | 937 ++++++++++++++++++ 1 file changed, 937 insertions(+) create mode 100644 examples/id_embeddings/heterogeneous_training.py diff --git a/examples/id_embeddings/heterogeneous_training.py b/examples/id_embeddings/heterogeneous_training.py new file mode 100644 index 000000000..13f7e7578 --- /dev/null +++ b/examples/id_embeddings/heterogeneous_training.py @@ -0,0 +1,937 @@ +""" +This file contains an example for how to run heterogeneous ID embedding training using live subgraph sampling powered by GraphLearn-for-PyTorch (GLT). +While `run_example_training` is coupled with GiGL orchestration, the `_training_process` and `testing_process` functions are generic +and can be used as references for writing training for pipelines not dependent on GiGL orchestration. + +To run this file with GiGL orchestration, set the fields similar to below: + +trainerConfig: + trainerArgs: + log_every_n_batch: "50" + ssl_positive_label_percentage: "0.05" + command: python -m examples.id_embeddings.heterogeneous_training +featureFlags: + should_run_glt_backend: 'True' + +Given a frozen task config with some already populated data preprocessor output, the following training script can be run locally using: +WORLD_SIZE=1 RANK=0 MASTER_ADDR="localhost" MASTER_PORT=20000 python -m examples.id_embeddings.heterogeneous_training --task_config_uri= + +A frozen task config with data preprocessor outputs can be generated by running an e2e pipeline with `stop_after=data_preprocessor` and using the +frozen config generated from the `config_populator` component after the run has completed. +""" + +from __future__ import annotations + +import os + +# Suppress TensorFlow logs +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # isort: skip + +import argparse +import statistics +import time +from collections.abc import Iterator, Mapping +from dataclasses import dataclass +from typing import Any, Literal, Optional, Sequence, Union +from torch_geometric.data import Data, HeteroData + +import torch +import torch.distributed +import torch.multiprocessing as mp +from torch import nn +from torch.distributed.optim import _apply_optimizer_in_backward as apply_optimizer_in_backward +from torch.optim import AdamW +from torchrec.distributed.model_parallel import DistributedModelParallel as DMP +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper +from torchrec.optim.optimizers import in_backward_optimizer_filter +from torchrec.optim.rowwise_adagrad import RowWiseAdagrad + +import gigl.distributed.utils +from gigl.common import Uri, UriFactory +from gigl.common.logger import Logger +from gigl.common.utils.torch_training import is_distributed_available_and_initialized +from gigl.distributed import ( + DistABLPLoader, + DistDataset, + build_dataset_from_task_config_uri, +) +from gigl.distributed.distributed_neighborloader import DistNeighborLoader +from gigl.distributed.utils import get_available_device +from gigl.nn.models import LightGCN +from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.src.common.utils.model import load_state_dict_from_uri, save_state_dict +from gigl.utils.iterator import InfiniteIterator +from gigl.utils.sampling import parse_fanout + +logger = Logger() + + +@dataclass +class DMPConfig: + device: torch.device + world_size: int + local_world_size: int + pg: Optional[torch.distributed.ProcessGroup] = None + compute_device: str = "cuda" # or "cpu" + prefer_sharding_types: Optional[Sequence[str]] = ("table_wise", "row_wise") + compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.FUSED + + +def wrap_with_dmp(model: nn.Module, cfg: DMPConfig) -> nn.Module: + """Wraps `model` with TorchRec DMP (shards EBCs, DP for the rest).""" + dmp_model = DMP(module=model, device=cfg.device) + return dmp_model + + +def unwrap_from_dmp(model: nn.Module) -> nn.Module: + """Return the underlying nn.Module if wrapped by DMP, otherwise the module itself.""" + return getattr(model, "module", model) + + +def _sync_metric_across_processes(metric: torch.Tensor) -> float: + """ + Takes the average of a training metric across multiple processes. Note that this function requires DDP to be initialized. + Args: + metric (torch.Tensor): The metric, expressed as a torch Tensor, which should be synced across multiple processes + Returns: + float: The average of the provided metric across all training processes + """ + assert is_distributed_available_and_initialized(), "DDP is not initialized" + # Make a copy of the local loss tensor + loss_tensor = metric.detach().clone() + torch.distributed.all_reduce(loss_tensor, op=torch.distributed.ReduceOp.SUM) + return loss_tensor.item() / torch.distributed.get_world_size() + + +def _setup_dataloaders( + dataset: DistDataset, + split: Literal["train", "val", "test"], + supervision_edge_type: EdgeType, + num_neighbors: list[int], + sampling_workers_per_process: int, + main_batch_size: int, + random_batch_size: int, + device: torch.device, + sampling_worker_shared_channel_size: str, + process_start_gap_seconds: int, +) -> tuple[DistABLPLoader, DistNeighborLoader]: + """ + Sets up main and random dataloaders for training and testing purposes + Args: + dataset (DistDataset): Loaded Distributed Dataset for training and testing + split (Literal["train", "val", "test"]): The current split which we are loading data for + supervision_edge_type (EdgeType): The supervision edge type to use for training in format query_node -> relation -> labeled_node + num_neighbors: list[int]: Fanout for subgraph sampling, where the ith item corresponds to the number of items to sample for the ith hop + sampling_workers_per_process (int): Number of sampling workers per training/testing process + main_batch_size (int): Batch size for main dataloader with query and labeled nodes + random_batch_size (int): Batch size for random negative dataloader + device (torch.device): Device to put loaded subgraphs on + sampling_worker_shared_channel_size (str): Shared-memory buffer size (bytes) allocated for the channel during sampling + process_start_gap_seconds (int): The amount of time to sleep for initializing each dataloader. For large-scale settings, consider setting this + field to 30-60 seconds to ensure dataloaders don't compete for memory during initialization, causing OOM. + Returns: + DistABLPLoader: Dataloader for loading main batch data with query and labeled nodes + DistNeighborLoader: Dataloader for loading random negative data + """ + rank = torch.distributed.get_rank() + + if split == "train": + main_input_nodes = dataset.train_node_ids + shuffle = True + elif split == "val": + main_input_nodes = dataset.val_node_ids + shuffle = False + else: + main_input_nodes = dataset.test_node_ids + shuffle = False + + query_node_type = supervision_edge_type.src_node_type + labeled_node_type = supervision_edge_type.dst_node_type + + assert isinstance(main_input_nodes, Mapping) + + main_loader = DistABLPLoader( + dataset=dataset, + num_neighbors=num_neighbors, + input_nodes=(query_node_type, main_input_nodes[query_node_type]), + supervision_edge_type=supervision_edge_type, + num_workers=sampling_workers_per_process, + batch_size=main_batch_size, + pin_memory_device=device, + worker_concurrency=sampling_workers_per_process, + channel_size=sampling_worker_shared_channel_size, + # Each train_main_loader will wait for `process_start_gap_seconds` * `local_process_rank` seconds before initializing to reduce peak memory usage. + # This is done so that each process on the current machine which initializes a `main_loader` doesn't compete for memory, causing potential OOM + process_start_gap_seconds=process_start_gap_seconds, + shuffle=shuffle, + ) + + logger.info(f"---Rank {rank} finished setting up main loader") + + # We need to wait for all processes to finish initializing the main_loader before creating the random_negative_loader so that its initialization doesn't compete for memory with the main_loader, causing potential OOM. + torch.distributed.barrier() + + assert isinstance(dataset.node_ids, Mapping) + + random_negative_loader = DistNeighborLoader( + dataset=dataset, + num_neighbors=num_neighbors, + input_nodes=(labeled_node_type, dataset.node_ids[labeled_node_type]), + num_workers=sampling_workers_per_process, + batch_size=random_batch_size, + pin_memory_device=device, + worker_concurrency=sampling_workers_per_process, + channel_size=sampling_worker_shared_channel_size, + process_start_gap_seconds=process_start_gap_seconds, + shuffle=shuffle, + ) + + logger.info(f"--Rank {rank} finished setting up random negative loader") + + # Wait for all processes to finish initializing the random_loader + torch.distributed.barrier() + + return main_loader, random_negative_loader + + +def bpr_loss( + query_emb: torch.Tensor, # [M, D] + pos_emb: torch.Tensor, # [M, D] + neg_emb: torch.Tensor, # [M, D] or [M, K, D] + l2_lambda: float = 0.0, + l2_params: Optional[Sequence[torch.Tensor]] = None, +) -> torch.Tensor: + """Bayesian Personalized Ranking loss with dot-product scores. + + Supports one negative per positive ([M, D]) or K negatives ([M, K, D]). + """ + # s_pos: [M] + s_pos = (query_emb * pos_emb).sum(dim=-1) + + if neg_emb.dim() == 2: # [M, D] + s_neg = (query_emb * neg_emb).sum(dim=-1) + loss = -torch.nn.functional.logsigmoid(s_pos - s_neg) + elif neg_emb.dim() == 3: # [M, K, D] + # Broadcast query: [M, 1, D] + s_neg = (query_emb.unsqueeze(1) * neg_emb).sum(dim=-1) # [M, K] + loss = -torch.nn.functional.logsigmoid(s_pos.unsqueeze(1) - s_neg).mean(dim=1) + else: + raise ValueError("neg_emb must be [M, D] or [M, K, D]") + + loss = loss.mean() + + if l2_lambda > 0.0 and l2_params: + l2 = sum(p.pow(2).sum() for p in l2_params) + loss = loss + l2_lambda * l2 + + return loss + + +def _compute_bpr_batch( + model: nn.Module, + main_data: HeteroData, + random_negative_data: HeteroData, + supervision_edge_type: EdgeType, + device: torch.device, + num_random_negs_per_pos: int = 1, + use_hard_negs: bool = True, + l2_lambda: float = 0.0, + debug_log: bool = False, +) -> tuple[torch.Tensor, dict[str, float]]: + """Compute a BPR batch using LightGCN embeddings and GLT-batched indices for heterogeneous graphs. + + Strategy: one (or K) random negative(s) per positive. If hard negatives exist in the batch, + we concatenate them as additional negatives (weighting equally). + + Returns: + loss: The BPR loss + debug_info: Dictionary with debug statistics (scores, embedding stats, etc.) + """ + logger.info(f"Computing BPR batch") + # Extract relevant node types from the supervision edge + query_node_type = supervision_edge_type.src_node_type + labeled_node_type = supervision_edge_type.dst_node_type + + logger.info(f"Encoding main data") + # Encode - LightGCN returns dict[NodeType, Tensor] for heterogeneous graphs + main_emb = model(data=main_data, device=device) + rand_emb = model(data=random_negative_data, device=device) + + # Debug: collect statistics + debug_info = {} + + logger.info(f"Query indices and positives from the main batch") + # Query indices and positives from the main batch + B = int(main_data[query_node_type].batch_size) + query_idx = torch.arange(B, device=device) # [B] + + logger.info(f"Positives from the main batch") + pos_idx = torch.cat(list(main_data.y_positive.values())).to(device) # [M] + # Repeat queries to align with positives + rep_query_idx = query_idx.repeat_interleave( + torch.tensor([len(v) for v in main_data.y_positive.values()], device=device) + ) # [M] + + logger.info(f"Hard negatives from the main batch") + # Optional hard negatives from the main batch + if use_hard_negs and hasattr(main_data, "y_negative"): + hard_neg_idx = torch.cat(list(main_data.y_negative.values())).to(device) + hard_neg_emb = main_emb[labeled_node_type][hard_neg_idx] # [H, D] + else: + hard_neg_idx = torch.empty(0, dtype=torch.long, device=device) + hard_neg_emb = torch.empty( + 0, main_emb[labeled_node_type].size(1), device=device + ) + + logger.info(f"Random negatives: take the first K*M rows from rand_emb for simplicity") + # Random negatives: take the first K*M rows from rand_emb for simplicity + M = rep_query_idx.numel() + D = main_emb[labeled_node_type].size(1) + + total_needed = M * max(1, num_random_negs_per_pos) + rand_batch_size = int(random_negative_data[labeled_node_type].batch_size) + if rand_batch_size < total_needed: + # Tile if fewer than needed + tile = (total_needed + rand_batch_size - 1) // rand_batch_size + rand_pool = rand_emb[labeled_node_type][:rand_batch_size].repeat(tile, 1)[ + :total_needed + ] + else: + rand_pool = rand_emb[labeled_node_type][:total_needed] + + logger.info(f"Positive and query embeddings") + if num_random_negs_per_pos == 1: + rand_neg_emb = rand_pool # [M, D] + else: + rand_neg_emb = rand_pool.view(M, num_random_negs_per_pos, D) # [M, K, D] + + logger.info(f"Computing scores for debugging") + # Positive and query embeddings + q = main_emb[query_node_type][rep_query_idx] # [M, D] + pos = main_emb[labeled_node_type][pos_idx] # [M, D] + + # If we have hard negatives, merge with random negatives by stacking along K + if hard_neg_emb.numel() > 0: + # Align hard negatives count to M. + if hard_neg_emb.size(0) < M: + ht = (M + hard_neg_emb.size(0) - 1) // hard_neg_emb.size(0) + hard_neg_emb = hard_neg_emb.repeat(ht, 1)[:M] + if rand_neg_emb.dim() == 2: # [M, D] + neg = torch.stack([rand_neg_emb, hard_neg_emb], dim=1) # [M, 2, D] + else: # [M, K, D] + neg = torch.cat( + [rand_neg_emb, hard_neg_emb.unsqueeze(1)], dim=1 + ) # [M, K+1, D] + else: + neg = rand_neg_emb # [M, D] or [M, K, D] + + # Compute scores for debugging + s_pos = (q * pos).sum(dim=-1) # [M] + if neg.dim() == 2: # [M, D] + s_neg = (q * neg).sum(dim=-1) # [M] + else: # [M, K, D] + s_neg = (q.unsqueeze(1) * neg).sum(dim=-1) # [M, K] + + # Collect debug info + if debug_log: + debug_info["pos_score_mean"] = s_pos.mean().item() + debug_info["pos_score_std"] = s_pos.std().item() + debug_info["neg_score_mean"] = s_neg.mean().item() + debug_info["neg_score_std"] = s_neg.std().item() + debug_info["query_emb_mean"] = q.mean().item() + debug_info["query_emb_std"] = q.std().item() + debug_info["query_emb_norm"] = q.norm(dim=-1).mean().item() + debug_info["pos_emb_mean"] = pos.mean().item() + debug_info["pos_emb_std"] = pos.std().item() + debug_info["pos_emb_norm"] = pos.norm(dim=-1).mean().item() + debug_info["neg_emb_mean"] = neg.mean().item() + debug_info["neg_emb_std"] = neg.std().item() + debug_info["neg_emb_norm"] = neg.norm(dim=-1).mean().item() + debug_info["num_positives"] = M + debug_info["num_hard_negs"] = hard_neg_idx.numel() + # Sample some actual node IDs + debug_info["sample_query_ids"] = rep_query_idx[:5].tolist() + debug_info["sample_pos_ids"] = pos_idx[:5].tolist() + + loss = bpr_loss(q, pos, neg, l2_lambda=l2_lambda, l2_params=None) + + if debug_log: + debug_info["bpr_loss"] = loss.item() + + return loss, debug_info + + +def _training_process( + local_rank: int, + local_world_size: int, + machine_rank: int, + machine_world_size: int, + dataset: DistDataset, + supervision_edge_type: EdgeType, + node_type_to_num_nodes: dict[NodeType, int], + master_ip_address: str, + master_default_process_group_port: int, + model_uri: Uri, + num_neighbors: list[int], + sampling_workers_per_process: int, + main_batch_size: int, + random_batch_size: int, + embedding_dim: int, + num_layers: int, + sampling_worker_shared_channel_size: str, + process_start_gap_seconds: int, + log_every_n_batch: int, + learning_rate: float, + weight_decay: float, + num_max_train_batches: int, + num_val_batches: int, + val_every_n_batch: int, + should_skip_training: bool, + num_random_negs_per_pos: int, + l2_lambda: float, +) -> None: + """ + This function is spawned by each machine for training a heterogeneous LightGCN model given some loaded distributed dataset. + Args: + local_rank (int): Process number on the current machine + local_world_size (int): Number of training processes spawned by each machine + machine_rank (int): Rank of the current machine + machine_world_size (int): Total number of machines + dataset (DistDataset): Loaded Distributed Dataset for training + supervision_edge_type (EdgeType): The supervision edge type to use for training in format query_node -> relation -> labeled_node + node_type_to_num_nodes (dict[NodeType, int]): Map from node types to node counts for LightGCN + master_ip_address (str): IP Address of the master worker for distributed communication + master_default_process_group_port (int): Port on the master worker for setting up distributed process group communication + model_uri (Uri): URI Path to save the model to + num_neighbors: list[int]: Fanout for subgraph sampling, where the ith item corresponds to the number of items to sample for the ith hop + sampling_workers_per_process (int): Number of sampling workers per training process + main_batch_size (int): Batch size for main dataloader with query and labeled nodes + random_batch_size (int): Batch size for random negative dataloader + embedding_dim (int): Embedding dimension of the model + num_layers (int): Number of LightGCN layers + sampling_worker_shared_channel_size (str): Shared-memory buffer size (bytes) allocated for the channel during sampling + process_start_gap_seconds (int): The amount of time to sleep for initializing each dataloader. For large-scale settings, consider setting this + field to 30-60 seconds to ensure dataloaders don't compete for memory during initialization, causing OOM. + log_every_n_batch (int): The frequency we should log batch information when training + learning_rate (float): Learning rate for training + weight_decay (float): Weight decay for training + num_max_train_batches (int): The maximum number of batches to train for across all training processes + num_val_batches (int): The number of batches to do validation for across all training processes + val_every_n_batch: (int): The frequency we should log batch information when validating + should_skip_training (bool): Whether training should be skipped and we should only run testing. Assumes model has been uploaded to the model_uri. + num_random_negs_per_pos (int): Number of random negatives per positive + l2_lambda (float): L2 regularization strength + """ + world_size = machine_world_size * local_world_size + rank = machine_rank * local_world_size + local_rank + logger.info( + f"---Current training process rank: {rank}, training process world size: {world_size}" + ) + + torch.distributed.init_process_group( + backend="nccl" if torch.cuda.is_available() else "gloo", + init_method=f"tcp://{master_ip_address}:{master_default_process_group_port}", + world_size=world_size, + rank=rank, + ) + + device = get_available_device(local_process_rank=local_rank) + if torch.cuda.is_available(): + torch.cuda.set_device(device) + logger.info(f"Training process rank {rank} is using device {device}") + + # Build LightGCN model + logger.info(f"---Rank {rank} building LightGCN model") + logger.info(f"node_type_to_num_nodes: {node_type_to_num_nodes}") + logger.info(f"embedding_dim: {embedding_dim}") + logger.info(f"num_layers: {num_layers}") + logger.info(f"device: {device}") + base_model = LightGCN( + node_type_to_num_nodes=node_type_to_num_nodes, + embedding_dim=embedding_dim, + num_layers=num_layers, + device=device, + ) + + logger.info(f"base_model: {base_model}") + + # Apply sparse optimizer to embedding bag collection BEFORE DMP wrapping + # This is the correct TorchRec pattern - optimizer must be applied before sharding + logger.info(f"---Rank {rank} applying sparse optimizer to embedding bag collection (BEFORE DMP)") + sparse_lr = learning_rate + for name, param in base_model._embedding_bag_collection.named_parameters(): + logger.info(f" Applying RowWiseAdagrad to {name}") + apply_optimizer_in_backward( + optimizer_class=RowWiseAdagrad, + params=[param], + optimizer_kwargs={"lr": sparse_lr, "weight_decay": weight_decay}, + ) + logger.info(f"Applied RowWiseAdagrad (lr={sparse_lr}, weight_decay={weight_decay}) to embedding parameters") + + logger.info(f"---Rank {rank} wrapping LightGCN model with DMP") + model = wrap_with_dmp( + base_model, + DMPConfig( + device=device, + world_size=world_size, + local_world_size=local_world_size, + pg=torch.distributed.group.WORLD, + compute_device="cuda" if device.type == "cuda" else "cpu", + ), + ) + logger.info(f"model: {model}") + + # Initialize embeddings AFTER DMP wrapping (DMP reinitializes them, so doing it before is useless!) + # NOTE: Use small scale - toy test showed 50x causes saturation, 0.1x works well + logger.info(f"---Rank {rank} initializing embeddings with scaled Xavier uniform (AFTER DMP)") + unwrapped_model = unwrap_from_dmp(model) + logger.info(f"EmbeddingBagCollection parameters:") + init_count = 0 + EMBEDDING_SCALE = 0.1 # Small scale to avoid saturation (toy test confirmed this works!) + for name, param in unwrapped_model._embedding_bag_collection.named_parameters(): + logger.info(f" Found parameter: {name}, shape: {param.shape}, device: {param.device}") + logger.info(f" BEFORE init - mean={param.mean().item():.6f}, std={param.std().item():.6f}, norm={param.norm().item():.6f}") + + # Xavier uniform: U(-sqrt(6/(fan_in + fan_out)), sqrt(6/(fan_in + fan_out))) + # For embeddings: fan_in = 1, fan_out = embedding_dim + # Then scale up to combat LightGCN's aggressive neighbor averaging + torch.nn.init.xavier_uniform_(param) + param.data *= EMBEDDING_SCALE + init_count += 1 + + logger.info(f" AFTER init (scaled {EMBEDDING_SCALE}x) - mean={param.mean().item():.6f}, std={param.std().item():.6f}, norm={param.norm().item():.6f}") + + logger.info(f"Initialized {init_count} embedding parameters after DMP wrapping with {EMBEDDING_SCALE}x scaling") + + # After DMP, create dense optimizer for non-embedding parameters (e.g., LGConv layers) + logger.info(f"---Rank {rank} creating dense optimizer for non-embedding parameters") + dense_params = dict(in_backward_optimizer_filter(model.named_parameters())) + + if dense_params: + logger.info(f"Found {len(dense_params)} dense parameters") + dense_optimizer = KeyedOptimizerWrapper( + dense_params, + lambda params: AdamW(params, lr=learning_rate, weight_decay=weight_decay), + ) + # Combine fused (sparse) optimizer with dense optimizer + optimizer = CombinedOptimizer([model.fused_optimizer, dense_optimizer]) + logger.info(f"Created CombinedOptimizer with fused (sparse) and dense optimizers") + else: + logger.info("No dense parameters found, using only fused optimizer") + optimizer = CombinedOptimizer([model.fused_optimizer]) + + logger.info(f"optimizer: {optimizer}") + logger.info(f"num_neighbors: {num_neighbors}") + + + if should_skip_training: + logger.info(f"Rank {rank}: Skipping training and loading model from {model_uri}") + state_dict = load_state_dict_from_uri(model_uri) + model.load_state_dict(state_dict) + else: + logger.info(f"Rank {rank}: Setting up training dataloaders") + train_main_loader, train_random_negative_loader = _setup_dataloaders( + dataset=dataset, + split="train", + supervision_edge_type=supervision_edge_type, + num_neighbors=num_neighbors, + sampling_workers_per_process=sampling_workers_per_process, + main_batch_size=main_batch_size, + random_batch_size=random_batch_size, + device=device, + sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, + process_start_gap_seconds=process_start_gap_seconds, + ) + + logger.info(f"Rank {rank}: Setting up validation dataloaders") + val_main_loader, val_random_negative_loader = _setup_dataloaders( + dataset=dataset, + split="val", + supervision_edge_type=supervision_edge_type, + num_neighbors=num_neighbors, + sampling_workers_per_process=sampling_workers_per_process, + main_batch_size=main_batch_size, + random_batch_size=random_batch_size, + device=device, + sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, + process_start_gap_seconds=process_start_gap_seconds, + ) + + logger.info(f"Rank {rank}: Starting training loop") + model.train() + start_time = time.time() + + train_main_iter = InfiniteIterator(train_main_loader) + train_random_neg_iter = InfiniteIterator(train_random_negative_loader) + + val_main_iter = InfiniteIterator(val_main_loader) + val_random_neg_iter = InfiniteIterator(val_random_negative_loader) + + batch_losses = [] + val_losses = [] + + # Track a sample embedding to see if it changes + sample_node_type = list(node_type_to_num_nodes.keys())[0] + sample_node_id = 0 + + logger.info(f"Starting training loop") + + for batch_num in range(num_max_train_batches): + logger.info(f"Training batch {batch_num + 1} of {num_max_train_batches}") + # Enable debug logging for first batch and every log_every_n_batch + debug_this_batch = (batch_num == 0) or ((batch_num + 1) % log_every_n_batch == 0) + + optimizer.zero_grad() + logger.info(f"Zeroing gradients") + main_data = next(train_main_iter) + logger.info(f"Main data: {main_data}") + random_negative_data = next(train_random_neg_iter) + logger.info(f"Random negative data: {random_negative_data}") + # Log sample data for debugging + if debug_this_batch and rank == 0: + logger.info(f"\n{'='*60}") + logger.info(f"DEBUG Batch {batch_num + 1}") + logger.info(f"{'='*60}") + logger.info(f"Main data node types: {main_data.node_types}") + logger.info(f"Main data edge types: {main_data.edge_types}") + for node_type in main_data.node_types: + node_store = main_data[node_type] + batch_size = getattr(node_store, 'batch_size', 'N/A') + n_id_shape = node_store.n_id.shape if hasattr(node_store, 'n_id') else 'N/A' + num_nodes = node_store.num_nodes if hasattr(node_store, 'num_nodes') else 'N/A' + logger.info(f" {node_type}: batch_size={batch_size}, n_id shape={n_id_shape}, num_nodes={num_nodes}") + if hasattr(main_data, 'y_positive'): + logger.info(f"y_positive keys: {list(main_data.y_positive.keys())}") + for k, v in main_data.y_positive.items(): + logger.info(f" {k}: {len(v)} positives, sample IDs: {v[:5].tolist() if len(v) > 0 else []}") + + loss, debug_info = _compute_bpr_batch( + model=model, + main_data=main_data, + random_negative_data=random_negative_data, + supervision_edge_type=supervision_edge_type, + device=device, + num_random_negs_per_pos=num_random_negs_per_pos, + use_hard_negs=True, + l2_lambda=l2_lambda, + debug_log=debug_this_batch, + ) + + if debug_this_batch and rank == 0: + logger.info(f"\nBatch {batch_num + 1} DEBUG INFO:") + for key, value in sorted(debug_info.items()): + logger.info(f" {key}: {value}") + + loss.backward() + + # Check if gradients exist and their magnitudes + # NOTE: With TorchRec's fused optimizer, embedding gradients are not materialized on .grad + # (they're applied directly in backward), so we only check non-embedding params here + if debug_this_batch and rank == 0: + grad_norms = [] + param_norms = [] + for name, param in model.named_parameters(): + if param.grad is not None: + grad_norms.append(param.grad.norm().item()) + param_norms.append(param.norm().item()) + if grad_norms: + logger.info(f" Non-embedding gradient norms - mean: {sum(grad_norms)/len(grad_norms):.6f}, " + f"max: {max(grad_norms):.6f}, min: {min(grad_norms):.6f}") + logger.info(f" Non-embedding param norms - mean: {sum(param_norms)/len(param_norms):.6f}, " + f"max: {max(param_norms):.6f}, min: {min(param_norms):.6f}") + else: + logger.info(" Note: No .grad found on params (expected with TorchRec fused optimizer for embeddings)") + + optimizer.step() + + logger.info(f"Stepped optimizer") + + batch_loss = _sync_metric_across_processes(loss) + batch_losses.append(batch_loss) + + if (batch_num + 1) % log_every_n_batch == 0: + avg_loss = statistics.mean(batch_losses[-log_every_n_batch:]) + elapsed = time.time() - start_time + logger.info( + f"Rank {rank} | Batch {batch_num + 1}/{num_max_train_batches} | " + f"Train Loss: {avg_loss:.4f} | Elapsed: {elapsed:.2f}s" + ) + logger.info(f"{'='*60}\n") + + # # Validation + # if (batch_num + 1) % val_every_n_batch == 0: + # model.eval() + # with torch.no_grad(): + # val_batch_losses = [] + # for _ in range(num_val_batches): + # val_main_data = next(val_main_iter) + # val_random_negative_data = next(val_random_neg_iter) + + # val_loss, _ = _compute_bpr_batch( + # model=model, + # main_data=val_main_data, + # random_negative_data=val_random_negative_data, + # supervision_edge_type=supervision_edge_type, + # device=device, + # num_random_negs_per_pos=num_random_negs_per_pos, + # use_hard_negs=True, + # l2_lambda=l2_lambda, + # debug_log=False, + # ) + + # val_batch_loss = _sync_metric_across_processes(val_loss) + # val_batch_losses.append(val_batch_loss) + + # avg_val_loss = statistics.mean(val_batch_losses) + # val_losses.append(avg_val_loss) + # logger.info( + # f"Rank {rank} | Batch {batch_num + 1} | Val Loss: {avg_val_loss:.4f}" + # ) + + # model.train() + + logger.info(f"Rank {rank}: Training completed. Saving model to {model_uri}") + if rank == 0: + save_state_dict(model.state_dict(), model_uri) + + # Final testing + logger.info(f"Rank {rank}: Setting up test dataloaders") + test_main_loader, test_random_negative_loader = _setup_dataloaders( + dataset=dataset, + split="test", + supervision_edge_type=supervision_edge_type, + num_neighbors=num_neighbors, + sampling_workers_per_process=sampling_workers_per_process, + main_batch_size=main_batch_size, + random_batch_size=random_batch_size, + device=device, + sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, + process_start_gap_seconds=process_start_gap_seconds, + ) + + logger.info(f"Rank {rank}: Running test evaluation") + model.eval() + with torch.no_grad(): + test_losses = [] + test_main_iter = InfiniteIterator(test_main_loader) + test_random_neg_iter = InfiniteIterator(test_random_negative_loader) + + num_test_batches = min(100, num_val_batches) + for _ in range(num_test_batches): + test_main_data = next(test_main_iter) + test_random_negative_data = next(test_random_neg_iter) + + test_loss, _ = _compute_bpr_batch( + model=model, + main_data=test_main_data, + random_negative_data=test_random_negative_data, + supervision_edge_type=supervision_edge_type, + device=device, + num_random_negs_per_pos=num_random_negs_per_pos, + use_hard_negs=True, + l2_lambda=l2_lambda, + debug_log=False, + ) + + test_batch_loss = _sync_metric_across_processes(test_loss) + test_losses.append(test_batch_loss) + + avg_test_loss = statistics.mean(test_losses) + logger.info(f"Rank {rank} | Test Loss: {avg_test_loss:.4f}") + + torch.distributed.destroy_process_group() + + +def _run_example_training( + task_config_uri: str, +): + """ + Runs an example heterogeneous training + testing loop using GiGL Orchestration. + Args: + task_config_uri (str): Path to YAML-serialized GbmlConfig proto. + """ + start_time = time.time() + mp.set_start_method("spawn") + logger.info(f"Starting sub process method: {mp.get_start_method()}") + + gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( + gbml_config_uri=UriFactory.create_uri(task_config_uri) + ) + + # Training Hyperparameters for the training and test processes + trainer_args = dict[Any, Any](gbml_config_pb_wrapper.trainer_config.trainer_args) + + local_world_size = int(trainer_args.get("local_world_size", "1")) + if torch.cuda.is_available(): + if local_world_size > torch.cuda.device_count(): + raise ValueError( + f"Specified a local world_size of {local_world_size} which exceeds the number of devices {torch.cuda.device_count()}" + ) + + # Parse supervision edge type + supervision_edge_types = gbml_config_pb_wrapper.gbml_config_pb.task_metadata.node_anchor_based_link_prediction_task_metadata.supervision_edge_types + assert len(supervision_edge_types) == 1, "Expected exactly one supervision edge type" + supervision_edge_type_pb = supervision_edge_types[0] + supervision_edge_type = EdgeType( + src_node_type=NodeType(supervision_edge_type_pb.src_node_type), + relation=Relation(supervision_edge_type_pb.relation), + dst_node_type=NodeType(supervision_edge_type_pb.dst_node_type), + ) + + # Parses the fanout as a string. For heterogeneous case, fanouts should be specified as a string of a list of integers, such as "[10, 10]". + fanout = trainer_args.get("num_neighbors", "[10, 10]") + num_neighbors = parse_fanout(fanout) + + # While the ideal value for `sampling_workers_per_process` has been identified to be between `2` and `4`, this may need some tuning depending on the + # pipeline. We default this value to `4` here for simplicity. A `sampling_workers_per_process` which is too small may not have enough parallelization for + # sampling, which would slow down training, while a value which is too large may slow down each sampling process due to competing resources, which would also + # then slow down training. + sampling_workers_per_process: int = int( + trainer_args.get("sampling_workers_per_process", "4") + ) + + main_batch_size = int(trainer_args.get("main_batch_size", "16")) + random_batch_size = int(trainer_args.get("random_batch_size", "16")) + + # LightGCN Hyperparameters + embedding_dim = int(trainer_args.get("embedding_dim", "64")) + num_layers = int(trainer_args.get("num_layers", "2")) + + # BPR params + num_random_negs_per_pos = int(trainer_args.get("num_random_negs_per_pos", "1")) + l2_lambda = float(trainer_args.get("l2_lambda", "0.0")) + + # This value represents the the shared-memory buffer size (bytes) allocated for the channel during sampling, and + # is the place to store pre-fetched data, so if it is too small then prefetching is limited, causing sampling slowdown. This parameter is a string + # with `{numeric_value}{storage_size}`, where storage size could be `MB`, `GB`, etc. We default this value to 4GB, + # but in production may need some tuning. + sampling_worker_shared_channel_size: str = trainer_args.get( + "sampling_worker_shared_channel_size", "4GB" + ) + + process_start_gap_seconds = int(trainer_args.get("process_start_gap_seconds", "0")) + log_every_n_batch = int(trainer_args.get("log_every_n_batch", "25")) + + learning_rate = float(trainer_args.get("learning_rate", "0.01")) + weight_decay = float(trainer_args.get("weight_decay", "0.0005")) + num_max_train_batches = int(trainer_args.get("num_max_train_batches", "1000")) + num_val_batches = int(trainer_args.get("num_val_batches", "100")) + val_every_n_batch = int(trainer_args.get("val_every_n_batch", "50")) + + logger.info( + f"Got training args local_world_size={local_world_size}, \ + num_neighbors={num_neighbors}, \ + sampling_workers_per_process={sampling_workers_per_process}, \ + main_batch_size={main_batch_size}, \ + random_batch_size={random_batch_size}, \ + embedding_dim={embedding_dim}, \ + num_layers={num_layers}, \ + num_random_negs_per_pos={num_random_negs_per_pos}, \ + l2_lambda={l2_lambda}, \ + sampling_worker_shared_channel_size={sampling_worker_shared_channel_size}, \ + process_start_gap_seconds={process_start_gap_seconds}, \ + log_every_n_batch={log_every_n_batch}, \ + learning_rate={learning_rate}, \ + weight_decay={weight_decay}, \ + num_max_train_batches={num_max_train_batches}, \ + num_val_batches={num_val_batches}, \ + val_every_n_batch={val_every_n_batch}" + ) + + # This `init_process_group` is only called to get the master_ip_address, master port, and rank/world_size fields which help with partitioning, sampling, + # and distributed training/testing. We can use `gloo` here since these fields we are extracting don't require GPU capabilities provided by `nccl`. + # Note that this init_process_group uses env:// to setup the connection. + # In VAI we create one process per node thus these variables are exposed through env i.e. MASTER_PORT , MASTER_ADDR , WORLD_SIZE , RANK that VAI sets up for us. + # If running locally, these env variables will need to be setup by the user manually. + torch.distributed.init_process_group(backend="gloo") + + master_ip_address = gigl.distributed.utils.get_internal_ip_from_master_node() + machine_rank = torch.distributed.get_rank() + machine_world_size = torch.distributed.get_world_size() + master_default_process_group_port = ( + gigl.distributed.utils.get_free_ports_from_master_node(num_ports=1) + )[0] + # Destroying the process group as one will be re-initialized in the training process using above information + torch.distributed.destroy_process_group() + + logger.info(f"--- Launching data loading process ---") + dataset = build_dataset_from_task_config_uri( + task_config_uri=UriFactory.create_uri(task_config_uri), + is_inference=False, + ) + + # Calculate node_type_to_num_nodes from dataset + node_type_to_num_nodes: dict[NodeType, int] = {} + for node_type_str, node_ids_tensor in dataset.node_ids.items(): + node_type = NodeType(node_type_str) + max_id = int(node_ids_tensor.max().item()) + num_nodes = max_id + 1 + node_type_to_num_nodes[node_type] = num_nodes + logger.info(f"Node type {node_type}: {num_nodes} nodes (max_id={max_id})") + + logger.info( + f"--- Data loading process finished, took {time.time() - start_time:.3f} seconds" + ) + + model_uri = UriFactory.create_uri( + gbml_config_pb_wrapper.gbml_config_pb.shared_config.trained_model_metadata.trained_model_uri + ) + + should_skip_training = gbml_config_pb_wrapper.shared_config.should_skip_training + + logger.info("--- Launching training processes ...\n") + start_time = time.time() + torch.multiprocessing.spawn( + _training_process, + args=( # Corresponding arguments in `_training_process` function + local_world_size, # local_world_size + machine_rank, # machine_rank + machine_world_size, # machine_world_size + dataset, # dataset + supervision_edge_type, # supervision_edge_type + node_type_to_num_nodes, # node_type_to_num_nodes + master_ip_address, # master_ip_address + master_default_process_group_port, # master_default_process_group_port + model_uri, # model_uri + num_neighbors, # num_neighbors + sampling_workers_per_process, # sampling_workers_per_process + main_batch_size, # main_batch_size + random_batch_size, # random_batch_size + embedding_dim, # embedding_dim + num_layers, # num_layers + sampling_worker_shared_channel_size, # sampling_worker_shared_channel_size + process_start_gap_seconds, # process_start_gap_seconds + log_every_n_batch, # log_every_n_batch + learning_rate, # learning_rate + weight_decay, # weight_decay + num_max_train_batches, # num_max_train_batches + num_val_batches, # num_val_batches + val_every_n_batch, # val_every_n_batch + should_skip_training, # should_skip_training + num_random_negs_per_pos, # num_random_negs_per_pos + l2_lambda, # l2_lambda + ), + nprocs=local_world_size, + join=True, + ) + logger.info(f"--- Training finished, took {time.time() - start_time} seconds") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Arguments for distributed model training on VertexAI" + ) + parser.add_argument("--task_config_uri", type=str, help="Gbml config uri") + + # We use parse_known_args instead of parse_args since we only need job_name and task_config_uri for distributed trainer + logger.info(f"Starting heterogeneous training") + args, unused_args = parser.parse_known_args() + logger.info(f"Args: {args}") + logger.info(f"Unused arguments: {unused_args}") + + # We only need `task_config_uri` for running trainer + _run_example_training( + task_config_uri=args.task_config_uri, + ) From df4d9d7a74724f05a72365c3c72dbdfa488a4cfb Mon Sep 17 00:00:00 2001 From: swong3 Date: Wed, 3 Dec 2025 00:05:43 +0000 Subject: [PATCH 2/9] Added heterogeneous training --- examples/id_embeddings/heterogeneous_training.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/id_embeddings/heterogeneous_training.py b/examples/id_embeddings/heterogeneous_training.py index 13f7e7578..f64b8f7cb 100644 --- a/examples/id_embeddings/heterogeneous_training.py +++ b/examples/id_embeddings/heterogeneous_training.py @@ -483,13 +483,12 @@ def _training_process( ) logger.info(f"model: {model}") - # Initialize embeddings AFTER DMP wrapping (DMP reinitializes them, so doing it before is useless!) - # NOTE: Use small scale - toy test showed 50x causes saturation, 0.1x works well + # Initialize embeddings AFTER DMP wrapping (DMP reinitializes them, so doing it before is useless) logger.info(f"---Rank {rank} initializing embeddings with scaled Xavier uniform (AFTER DMP)") unwrapped_model = unwrap_from_dmp(model) logger.info(f"EmbeddingBagCollection parameters:") init_count = 0 - EMBEDDING_SCALE = 0.1 # Small scale to avoid saturation (toy test confirmed this works!) + EMBEDDING_SCALE = 0.1 # Small scale to avoid saturation for name, param in unwrapped_model._embedding_bag_collection.named_parameters(): logger.info(f" Found parameter: {name}, shape: {param.shape}, device: {param.device}") logger.info(f" BEFORE init - mean={param.mean().item():.6f}, std={param.std().item():.6f}, norm={param.norm().item():.6f}") From 673d7671170f7126751456a4aa426bf30a5970a7 Mon Sep 17 00:00:00 2001 From: swong3 Date: Wed, 3 Dec 2025 23:09:16 +0000 Subject: [PATCH 3/9] Updating loop --- examples/id_embeddings/heterogeneous_training.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/id_embeddings/heterogeneous_training.py b/examples/id_embeddings/heterogeneous_training.py index f64b8f7cb..7f3fbdb68 100644 --- a/examples/id_embeddings/heterogeneous_training.py +++ b/examples/id_embeddings/heterogeneous_training.py @@ -585,9 +585,9 @@ def _training_process( optimizer.zero_grad() logger.info(f"Zeroing gradients") main_data = next(train_main_iter) - logger.info(f"Main data: {main_data}") + # logger.info(f"Main data: {main_data}") random_negative_data = next(train_random_neg_iter) - logger.info(f"Random negative data: {random_negative_data}") + # logger.info(f"Random negative data: {random_negative_data}") # Log sample data for debugging if debug_this_batch and rank == 0: logger.info(f"\n{'='*60}") @@ -788,11 +788,11 @@ def _run_example_training( # sampling, which would slow down training, while a value which is too large may slow down each sampling process due to competing resources, which would also # then slow down training. sampling_workers_per_process: int = int( - trainer_args.get("sampling_workers_per_process", "4") + trainer_args.get("sampling_workers_per_process", "8") ) - main_batch_size = int(trainer_args.get("main_batch_size", "16")) - random_batch_size = int(trainer_args.get("random_batch_size", "16")) + main_batch_size = int(trainer_args.get("main_batch_size", "512")) + random_batch_size = int(trainer_args.get("random_batch_size", "512")) # LightGCN Hyperparameters embedding_dim = int(trainer_args.get("embedding_dim", "64")) From 39d91109bdb92f3d5ad4ae943b7d8ed2fca74255 Mon Sep 17 00:00:00 2001 From: swong3 Date: Thu, 11 Dec 2025 17:34:48 +0000 Subject: [PATCH 4/9] Updated training loop --- .../id_embeddings/heterogeneous_training.py | 488 ++++++++++++++---- 1 file changed, 393 insertions(+), 95 deletions(-) diff --git a/examples/id_embeddings/heterogeneous_training.py b/examples/id_embeddings/heterogeneous_training.py index 7f3fbdb68..322d981d6 100644 --- a/examples/id_embeddings/heterogeneous_training.py +++ b/examples/id_embeddings/heterogeneous_training.py @@ -105,6 +105,139 @@ def _sync_metric_across_processes(metric: torch.Tensor) -> float: return loss_tensor.item() / torch.distributed.get_world_size() +# ============================================================================ +# Evaluation Metrics: Recall@K and NDCG@K (LightGCN paper metrics) +# ============================================================================ + + +def recall_at_k( + scores: torch.Tensor, + labels: torch.Tensor, + k_values: list[int], +) -> dict[int, float]: + """ + Computes Recall@K for collaborative filtering. + + Recall@K = (# of relevant items in top-K) / (# of relevant items) + + For link prediction with 1 positive per query, this simplifies to: + Recall@K = 1 if positive is in top-K, else 0 + + Args: + scores: Score tensor of shape [B, 1 + N], where B is batch size, N is number of negatives. + First column contains positive scores, rest are negative scores. + labels: Label tensor of shape [B, 1 + N], where 1 indicates positive, 0 indicates negative. + First column should contain 1 (positive label). + k_values: List of K values to compute Recall@K for (e.g., [10, 20, 50, 100]) + + Returns: + Dictionary mapping K -> Recall@K value + """ + batch_size = scores.size(0) + max_k = max(k_values) + + # Get top-K indices (shape: B x max_k) + topk_indices = torch.topk(scores, k=max_k, dim=1).indices + + # Gather labels for top-K items (shape: B x max_k) + topk_labels = torch.gather(labels, dim=1, index=topk_indices) + + # For each K, check if positive appears in top-K + recalls = {} + for k in k_values: + # Check if any of the top-K items is positive (label == 1) + hits = (topk_labels[:, :k].sum(dim=1) > 0).float() # (B,) + recalls[k] = hits.mean().item() + + return recalls + + +def ndcg_at_k( + scores: torch.Tensor, + labels: torch.Tensor, + k_values: list[int], +) -> dict[int, float]: + """ + Computes Normalized Discounted Cumulative Gain (NDCG@K) for collaborative filtering. + + DCG@K = sum_{i=1}^{K} (2^{rel_i} - 1) / log2(i + 1) + NDCG@K = DCG@K / IDCG@K + + For binary relevance (0 or 1), this simplifies to: + DCG@K = sum_{i=1}^{K} rel_i / log2(i + 1) + + Args: + scores: Score tensor of shape [B, 1 + N], where B is batch size, N is number of negatives. + First column contains positive scores, rest are negative scores. + labels: Label tensor of shape [B, 1 + N], where 1 indicates positive, 0 indicates negative. + First column should contain 1 (positive label). + k_values: List of K values to compute NDCG@K for (e.g., [10, 20, 50, 100]) + + Returns: + Dictionary mapping K -> NDCG@K value + """ + batch_size = scores.size(0) + max_k = max(k_values) + + # Get top-K indices (shape: B x max_k) + topk_indices = torch.topk(scores, k=max_k, dim=1).indices + + # Gather labels for top-K items (shape: B x max_k) + topk_labels = torch.gather(labels, dim=1, index=topk_indices).float() # (B, max_k) + + # Compute position weights: 1 / log2(i + 1) for i in [1, 2, ..., max_k] + positions = torch.arange(1, max_k + 1, device=scores.device).float() + weights = 1.0 / torch.log2(positions + 1.0) # (max_k,) + + ndcgs = {} + for k in k_values: + # DCG@K: sum of (relevance * weight) for top-K + dcg = (topk_labels[:, :k] * weights[:k]).sum(dim=1) # (B,) + + # IDCG@K: ideal DCG (assumes we have 1 relevant item, so IDCG@K = 1 / log2(2) = 1.0) + # Since we have exactly 1 positive per query, the ideal ranking puts it at position 1 + idcg = 1.0 / torch.log2(torch.tensor(2.0, device=scores.device)) # 1.0 + + # NDCG@K = DCG@K / IDCG@K + ndcg = dcg / idcg + ndcgs[k] = ndcg.mean().item() + + return ndcgs + + +def compute_metrics_from_scores( + scores: torch.Tensor, + labels: torch.Tensor, + k_values: list[int] = [10, 20, 50, 100], +) -> dict[str, float]: + """ + Compute all evaluation metrics (Recall@K, NDCG@K) from score and label tensors. + + Args: + scores: Score tensor of shape [B, 1 + N] + labels: Label tensor of shape [B, 1 + N] + k_values: List of K values for evaluation (default: [10, 20, 50, 100] as in LightGCN paper) + + Returns: + Dictionary with metric names and values, e.g.: + { + 'recall@10': 0.123, + 'recall@20': 0.234, + 'ndcg@10': 0.456, + 'ndcg@20': 0.567, + } + """ + recalls = recall_at_k(scores, labels, k_values) + ndcgs = ndcg_at_k(scores, labels, k_values) + + metrics = {} + for k in k_values: + metrics[f'recall@{k}'] = recalls[k] + metrics[f'ndcg@{k}'] = ndcgs[k] + + return metrics + + def _setup_dataloaders( dataset: DistDataset, split: Literal["train", "val", "test"], @@ -137,25 +270,36 @@ def _setup_dataloaders( """ rank = torch.distributed.get_rank() + logger.info(f"Rank {rank} setting up main loader for split {split}") + + logger.info(dataset.graph) if split == "train": - main_input_nodes = dataset.train_node_ids + dsts, srcs, _, _ = dataset.graph[("item", "to_train_gigl_positive", "user")].topo.to_coo() + main_input_nodes = srcs.unique() + logger.info(f"source: {srcs.shape}, destination: {dsts.shape}") + logger.info(main_input_nodes.shape) + logger.info(f"Rank {rank} train_node_ids: {main_input_nodes}") shuffle = True elif split == "val": main_input_nodes = dataset.val_node_ids shuffle = False else: - main_input_nodes = dataset.test_node_ids + dsts, srcs, _, _ = dataset.graph[("item", "to_test_gigl_positive", "user")].topo.to_coo() + main_input_nodes = srcs.unique() + logger.info(f"source: {srcs.shape}, destination: {dsts.shape}") + logger.info(main_input_nodes.shape) + logger.info(f"Rank {rank} test_node_ids: {main_input_nodes}") shuffle = False query_node_type = supervision_edge_type.src_node_type labeled_node_type = supervision_edge_type.dst_node_type - assert isinstance(main_input_nodes, Mapping) + # assert isinstance(main_input_nodes, Mapping) main_loader = DistABLPLoader( dataset=dataset, num_neighbors=num_neighbors, - input_nodes=(query_node_type, main_input_nodes[query_node_type]), + input_nodes=(query_node_type, main_input_nodes), supervision_edge_type=supervision_edge_type, num_workers=sampling_workers_per_process, batch_size=main_batch_size, @@ -249,12 +393,12 @@ def _compute_bpr_batch( loss: The BPR loss debug_info: Dictionary with debug statistics (scores, embedding stats, etc.) """ - logger.info(f"Computing BPR batch") + # logger.info(f"Computing BPR batch") # Extract relevant node types from the supervision edge query_node_type = supervision_edge_type.src_node_type labeled_node_type = supervision_edge_type.dst_node_type - logger.info(f"Encoding main data") + # logger.info(f"Encoding main data") # Encode - LightGCN returns dict[NodeType, Tensor] for heterogeneous graphs main_emb = model(data=main_data, device=device) rand_emb = model(data=random_negative_data, device=device) @@ -262,19 +406,19 @@ def _compute_bpr_batch( # Debug: collect statistics debug_info = {} - logger.info(f"Query indices and positives from the main batch") + # logger.info(f"Query indices and positives from the main batch") # Query indices and positives from the main batch B = int(main_data[query_node_type].batch_size) query_idx = torch.arange(B, device=device) # [B] - logger.info(f"Positives from the main batch") + # logger.info(f"Positives from the main batch") pos_idx = torch.cat(list(main_data.y_positive.values())).to(device) # [M] # Repeat queries to align with positives rep_query_idx = query_idx.repeat_interleave( torch.tensor([len(v) for v in main_data.y_positive.values()], device=device) ) # [M] - logger.info(f"Hard negatives from the main batch") + # logger.info(f"Hard negatives from the main batch") # Optional hard negatives from the main batch if use_hard_negs and hasattr(main_data, "y_negative"): hard_neg_idx = torch.cat(list(main_data.y_negative.values())).to(device) @@ -285,7 +429,7 @@ def _compute_bpr_batch( 0, main_emb[labeled_node_type].size(1), device=device ) - logger.info(f"Random negatives: take the first K*M rows from rand_emb for simplicity") + # logger.info(f"Random negatives: take the first K*M rows from rand_emb for simplicity") # Random negatives: take the first K*M rows from rand_emb for simplicity M = rep_query_idx.numel() D = main_emb[labeled_node_type].size(1) @@ -301,13 +445,13 @@ def _compute_bpr_batch( else: rand_pool = rand_emb[labeled_node_type][:total_needed] - logger.info(f"Positive and query embeddings") + # logger.info(f"Positive and query embeddings") if num_random_negs_per_pos == 1: rand_neg_emb = rand_pool # [M, D] else: rand_neg_emb = rand_pool.view(M, num_random_negs_per_pos, D) # [M, K, D] - logger.info(f"Computing scores for debugging") + # logger.info(f"Computing scores for debugging") # Positive and query embeddings q = main_emb[query_node_type][rep_query_idx] # [M, D] pos = main_emb[labeled_node_type][pos_idx] # [M, D] @@ -363,13 +507,124 @@ def _compute_bpr_batch( return loss, debug_info +def _evaluate_with_metrics( + model: nn.Module, + main_data_loader: Iterator, + random_negative_data_loader: Iterator, + supervision_edge_type: EdgeType, + device: torch.device, + num_batches: int, + num_random_negs_per_pos: int = 1, + use_hard_negs: bool = True, + k_values: list[int] = [10, 20, 50, 100], +) -> dict[str, float]: + """ + Evaluate the model using Recall@K and NDCG@K metrics (as used in LightGCN paper). + + Args: + model: The trained model + main_data_loader: Iterator for main data (positives + hard negatives) + random_negative_data_loader: Iterator for random negatives + supervision_edge_type: The edge type for supervision + device: Device to run evaluation on + num_batches: Number of batches to evaluate on + num_random_negs_per_pos: Number of random negatives per positive + use_hard_negs: Whether to include hard negatives + k_values: List of K values for Recall@K and NDCG@K + + Returns: + Dictionary with metric names and values averaged across all batches + """ + model.eval() + + query_node_type = supervision_edge_type.src_node_type + labeled_node_type = supervision_edge_type.dst_node_type + + # Accumulate metrics across batches + all_metrics = {f'recall@{k}': [] for k in k_values} + all_metrics.update({f'ndcg@{k}': [] for k in k_values}) + + with torch.no_grad(): + for batch_idx in range(num_batches): + main_data = next(main_data_loader) + random_negative_data = next(random_negative_data_loader) + + # Get embeddings + main_emb = model(data=main_data, device=device) + rand_emb = model(data=random_negative_data, device=device) + + # Get query indices and positives + B = int(main_data[query_node_type].batch_size) + query_idx = torch.arange(B, device=device) + + pos_idx = torch.cat(list(main_data.y_positive.values())).to(device) + rep_query_idx = query_idx.repeat_interleave( + torch.tensor([len(v) for v in main_data.y_positive.values()], device=device) + ) + + # Get embeddings + query_emb = main_emb[query_node_type][rep_query_idx] # [M, D] + pos_emb = main_emb[labeled_node_type][pos_idx] # [M, D] + + # Random negatives + M = pos_idx.size(0) + rand_B = int(random_negative_data[labeled_node_type].batch_size) + rand_neg_idx = torch.randint(0, rand_B, (M * num_random_negs_per_pos,), device=device) + rand_neg_emb = rand_emb[labeled_node_type][rand_neg_idx].view(M, num_random_negs_per_pos, -1) # [M, K, D] + + # Hard negatives (if available) + if use_hard_negs and hasattr(main_data, "y_negative"): + hard_neg_idx = torch.cat(list(main_data.y_negative.values())).to(device) + hard_neg_emb = main_emb[labeled_node_type][hard_neg_idx] # [H, D] + else: + hard_neg_emb = torch.empty(0, main_emb[labeled_node_type].size(1), device=device) + + # Compute scores + # Positive scores: [M, 1] + pos_scores = (query_emb * pos_emb).sum(dim=-1, keepdim=True) + + # Random negative scores: [M, K] + rand_neg_scores = torch.bmm( + query_emb.unsqueeze(1), # [M, 1, D] + rand_neg_emb.transpose(1, 2), # [M, D, K] + ).squeeze(1) # [M, K] + + # Hard negative scores: [M, H] + if hard_neg_emb.size(0) > 0: + hard_neg_scores = torch.matmul(query_emb, hard_neg_emb.T) # [M, H] + # Concatenate all negative scores + all_neg_scores = torch.cat([rand_neg_scores, hard_neg_scores], dim=1) # [M, K+H] + else: + all_neg_scores = rand_neg_scores # [M, K] + + # Concatenate positive and negative scores: [M, 1+K+H] + scores = torch.cat([pos_scores, all_neg_scores], dim=1) + + # Create labels: first column is 1 (positive), rest are 0 (negatives) + labels = torch.zeros_like(scores) + labels[:, 0] = 1.0 + + # Compute metrics for this batch + batch_metrics = compute_metrics_from_scores(scores, labels, k_values) + + # Accumulate + for metric_name, metric_value in batch_metrics.items(): + all_metrics[metric_name].append(metric_value) + + # Average across batches + avg_metrics = {name: sum(values) / len(values) for name, values in all_metrics.items()} + + return avg_metrics + + def _training_process( local_rank: int, local_world_size: int, machine_rank: int, machine_world_size: int, dataset: DistDataset, - supervision_edge_type: EdgeType, + train_supervision_edge_type: EdgeType, + test_supervision_edge_type: EdgeType, node_type_to_num_nodes: dict[NodeType, int], master_ip_address: str, master_default_process_group_port: int, @@ -400,7 +655,8 @@ def _training_process( machine_rank (int): Rank of the current machine machine_world_size (int): Total number of machines dataset (DistDataset): Loaded Distributed Dataset for training - supervision_edge_type (EdgeType): The supervision edge type to use for training in format query_node -> relation -> labeled_node + train_supervision_edge_type (EdgeType): The supervision edge type to use for training (e.g., user -> to_train -> item) + test_supervision_edge_type (EdgeType): The supervision edge type to use for testing (e.g., user -> to_test -> item) node_type_to_num_nodes (dict[NodeType, int]): Map from node types to node counts for LightGCN master_ip_address (str): IP Address of the master worker for distributed communication master_default_process_group_port (int): Port on the master worker for setting up distributed process group communication @@ -483,26 +739,26 @@ def _training_process( ) logger.info(f"model: {model}") - # Initialize embeddings AFTER DMP wrapping (DMP reinitializes them, so doing it before is useless) - logger.info(f"---Rank {rank} initializing embeddings with scaled Xavier uniform (AFTER DMP)") - unwrapped_model = unwrap_from_dmp(model) - logger.info(f"EmbeddingBagCollection parameters:") - init_count = 0 - EMBEDDING_SCALE = 0.1 # Small scale to avoid saturation - for name, param in unwrapped_model._embedding_bag_collection.named_parameters(): - logger.info(f" Found parameter: {name}, shape: {param.shape}, device: {param.device}") - logger.info(f" BEFORE init - mean={param.mean().item():.6f}, std={param.std().item():.6f}, norm={param.norm().item():.6f}") + # # Initialize embeddings AFTER DMP wrapping (DMP reinitializes them, so doing it before is useless) + # logger.info(f"---Rank {rank} initializing embeddings with scaled Xavier uniform (AFTER DMP)") + # unwrapped_model = unwrap_from_dmp(model) + # logger.info(f"EmbeddingBagCollection parameters:") + # init_count = 0 + # EMBEDDING_SCALE = 0.1 # Small scale to avoid saturation + # for name, param in unwrapped_model._embedding_bag_collection.named_parameters(): + # logger.info(f" Found parameter: {name}, shape: {param.shape}, device: {param.device}") + # logger.info(f" BEFORE init - mean={param.mean().item():.6f}, std={param.std().item():.6f}, norm={param.norm().item():.6f}") - # Xavier uniform: U(-sqrt(6/(fan_in + fan_out)), sqrt(6/(fan_in + fan_out))) - # For embeddings: fan_in = 1, fan_out = embedding_dim - # Then scale up to combat LightGCN's aggressive neighbor averaging - torch.nn.init.xavier_uniform_(param) - param.data *= EMBEDDING_SCALE - init_count += 1 + # # Xavier uniform: U(-sqrt(6/(fan_in + fan_out)), sqrt(6/(fan_in + fan_out))) + # # For embeddings: fan_in = 1, fan_out = embedding_dim + # # Then scale up to combat LightGCN's aggressive neighbor averaging + # torch.nn.init.xavier_uniform_(param) + # param.data *= EMBEDDING_SCALE + # init_count += 1 - logger.info(f" AFTER init (scaled {EMBEDDING_SCALE}x) - mean={param.mean().item():.6f}, std={param.std().item():.6f}, norm={param.norm().item():.6f}") + # logger.info(f" AFTER init (scaled {EMBEDDING_SCALE}x) - mean={param.mean().item():.6f}, std={param.std().item():.6f}, norm={param.norm().item():.6f}") - logger.info(f"Initialized {init_count} embedding parameters after DMP wrapping with {EMBEDDING_SCALE}x scaling") + # logger.info(f"Initialized {init_count} embedding parameters after DMP wrapping with {EMBEDDING_SCALE}x scaling") # After DMP, create dense optimizer for non-embedding parameters (e.g., LGConv layers) logger.info(f"---Rank {rank} creating dense optimizer for non-embedding parameters") @@ -534,7 +790,7 @@ def _training_process( train_main_loader, train_random_negative_loader = _setup_dataloaders( dataset=dataset, split="train", - supervision_edge_type=supervision_edge_type, + supervision_edge_type=train_supervision_edge_type, # Use TRAIN edges for training num_neighbors=num_neighbors, sampling_workers_per_process=sampling_workers_per_process, main_batch_size=main_batch_size, @@ -545,18 +801,18 @@ def _training_process( ) logger.info(f"Rank {rank}: Setting up validation dataloaders") - val_main_loader, val_random_negative_loader = _setup_dataloaders( - dataset=dataset, - split="val", - supervision_edge_type=supervision_edge_type, - num_neighbors=num_neighbors, - sampling_workers_per_process=sampling_workers_per_process, - main_batch_size=main_batch_size, - random_batch_size=random_batch_size, - device=device, - sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, - process_start_gap_seconds=process_start_gap_seconds, - ) + # val_main_loader, val_random_negative_loader = _setup_dataloaders( + # dataset=dataset, + # split="val", + # supervision_edge_type=train_supervision_edge_type, # Use TRAIN edges for validation + # num_neighbors=num_neighbors, + # sampling_workers_per_process=sampling_workers_per_process, + # main_batch_size=main_batch_size, + # random_batch_size=random_batch_size, + # device=device, + # sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, + # process_start_gap_seconds=process_start_gap_seconds, + # ) logger.info(f"Rank {rank}: Starting training loop") model.train() @@ -565,8 +821,8 @@ def _training_process( train_main_iter = InfiniteIterator(train_main_loader) train_random_neg_iter = InfiniteIterator(train_random_negative_loader) - val_main_iter = InfiniteIterator(val_main_loader) - val_random_neg_iter = InfiniteIterator(val_random_negative_loader) + # val_main_iter = InfiniteIterator(val_main_loader) + # val_random_neg_iter = InfiniteIterator(val_random_negative_loader) batch_losses = [] val_losses = [] @@ -583,7 +839,7 @@ def _training_process( debug_this_batch = (batch_num == 0) or ((batch_num + 1) % log_every_n_batch == 0) optimizer.zero_grad() - logger.info(f"Zeroing gradients") + # logger.info(f"Zeroing gradients") main_data = next(train_main_iter) # logger.info(f"Main data: {main_data}") random_negative_data = next(train_random_neg_iter) @@ -601,16 +857,16 @@ def _training_process( n_id_shape = node_store.n_id.shape if hasattr(node_store, 'n_id') else 'N/A' num_nodes = node_store.num_nodes if hasattr(node_store, 'num_nodes') else 'N/A' logger.info(f" {node_type}: batch_size={batch_size}, n_id shape={n_id_shape}, num_nodes={num_nodes}") - if hasattr(main_data, 'y_positive'): - logger.info(f"y_positive keys: {list(main_data.y_positive.keys())}") - for k, v in main_data.y_positive.items(): - logger.info(f" {k}: {len(v)} positives, sample IDs: {v[:5].tolist() if len(v) > 0 else []}") + # if hasattr(main_data, 'y_positive'): + # logger.info(f"y_positive keys: {list(main_data.y_positive.keys())}") + # for k, v in main_data.y_positive.items(): + # logger.info(f" {k}: {len(v)} positives, sample IDs: {v[:5].tolist() if len(v) > 0 else []}") loss, debug_info = _compute_bpr_batch( model=model, main_data=main_data, random_negative_data=random_negative_data, - supervision_edge_type=supervision_edge_type, + supervision_edge_type=train_supervision_edge_type, # Use TRAIN edges during training device=device, num_random_negs_per_pos=num_random_negs_per_pos, use_hard_negs=True, @@ -645,8 +901,6 @@ def _training_process( optimizer.step() - logger.info(f"Stepped optimizer") - batch_loss = _sync_metric_across_processes(loss) batch_losses.append(batch_loss) @@ -692,15 +946,15 @@ def _training_process( # model.train() logger.info(f"Rank {rank}: Training completed. Saving model to {model_uri}") - if rank == 0: - save_state_dict(model.state_dict(), model_uri) + # if rank == 0: + # save_state_dict(model.state_dict(), model_uri) - # Final testing + # Final testing with Recall@K and NDCG@K metrics logger.info(f"Rank {rank}: Setting up test dataloaders") test_main_loader, test_random_negative_loader = _setup_dataloaders( dataset=dataset, split="test", - supervision_edge_type=supervision_edge_type, + supervision_edge_type=test_supervision_edge_type, # Use TEST edges for evaluation! num_neighbors=num_neighbors, sampling_workers_per_process=sampling_workers_per_process, main_batch_size=main_batch_size, @@ -710,35 +964,35 @@ def _training_process( process_start_gap_seconds=process_start_gap_seconds, ) - logger.info(f"Rank {rank}: Running test evaluation") - model.eval() - with torch.no_grad(): - test_losses = [] - test_main_iter = InfiniteIterator(test_main_loader) - test_random_neg_iter = InfiniteIterator(test_random_negative_loader) + logger.info(f"Rank {rank}: Running test evaluation with Recall@K and NDCG@K metrics") + logger.info(f" Evaluating on TEST edges: {test_supervision_edge_type}") - num_test_batches = min(100, num_val_batches) - for _ in range(num_test_batches): - test_main_data = next(test_main_iter) - test_random_negative_data = next(test_random_neg_iter) - - test_loss, _ = _compute_bpr_batch( - model=model, - main_data=test_main_data, - random_negative_data=test_random_negative_data, - supervision_edge_type=supervision_edge_type, - device=device, - num_random_negs_per_pos=num_random_negs_per_pos, - use_hard_negs=True, - l2_lambda=l2_lambda, - debug_log=False, - ) + # K values to evaluate (as in LightGCN paper, Table 2) + # Paper uses: Recall@20, NDCG@20 for Gowalla + eval_k_values = [10, 20, 50, 100] - test_batch_loss = _sync_metric_across_processes(test_loss) - test_losses.append(test_batch_loss) + test_metrics = _evaluate_with_metrics( + model=model, + main_data_loader=InfiniteIterator(test_main_loader), + random_negative_data_loader=InfiniteIterator(test_random_negative_loader), + supervision_edge_type=test_supervision_edge_type, # Use TEST edges for metric computation + device=device, + num_batches=min(100, num_val_batches), + num_random_negs_per_pos=num_random_negs_per_pos, + use_hard_negs=True, + k_values=eval_k_values, + ) - avg_test_loss = statistics.mean(test_losses) - logger.info(f"Rank {rank} | Test Loss: {avg_test_loss:.4f}") + # Log test metrics + if rank == 0: + logger.info(f"\n{'='*80}") + logger.info("TEST EVALUATION RESULTS (LightGCN Paper Metrics)") + logger.info(f"{'='*80}") + for k in eval_k_values: + recall = test_metrics[f'recall@{k}'] + ndcg = test_metrics[f'ndcg@{k}'] + logger.info(f" Recall@{k:3d}: {recall:.4f} | NDCG@{k:3d}: {ndcg:.4f}") + logger.info(f"{'='*80}\n") torch.distributed.destroy_process_group() @@ -769,19 +1023,55 @@ def _run_example_training( f"Specified a local world_size of {local_world_size} which exceeds the number of devices {torch.cuda.device_count()}" ) - # Parse supervision edge type + # Parse supervision edge types - expecting 2: one for training, one for testing supervision_edge_types = gbml_config_pb_wrapper.gbml_config_pb.task_metadata.node_anchor_based_link_prediction_task_metadata.supervision_edge_types - assert len(supervision_edge_types) == 1, "Expected exactly one supervision edge type" - supervision_edge_type_pb = supervision_edge_types[0] - supervision_edge_type = EdgeType( - src_node_type=NodeType(supervision_edge_type_pb.src_node_type), - relation=Relation(supervision_edge_type_pb.relation), - dst_node_type=NodeType(supervision_edge_type_pb.dst_node_type), - ) + + if len(supervision_edge_types) == 1: + # Legacy behavior: only one edge type specified (for training) + # Create test edge type by replacing "to_train" with "to_test" + supervision_edge_type_pb = supervision_edge_types[0] + train_supervision_edge_type = EdgeType( + src_node_type=NodeType(supervision_edge_type_pb.src_node_type), + relation=Relation(supervision_edge_type_pb.relation), + dst_node_type=NodeType(supervision_edge_type_pb.dst_node_type), + ) + test_relation = supervision_edge_type_pb.relation.replace("to_train", "to_test") + test_supervision_edge_type = EdgeType( + src_node_type=NodeType(supervision_edge_type_pb.src_node_type), + relation=Relation(test_relation), + dst_node_type=NodeType(supervision_edge_type_pb.dst_node_type), + ) + logger.info("Using single supervision edge type (legacy mode)") + elif len(supervision_edge_types) == 2: + # New behavior: two edge types specified (training and testing) + train_edge_type_pb = supervision_edge_types[0] + test_edge_type_pb = supervision_edge_types[1] + + train_supervision_edge_type = EdgeType( + src_node_type=NodeType(train_edge_type_pb.src_node_type), + relation=Relation(train_edge_type_pb.relation), + dst_node_type=NodeType(train_edge_type_pb.dst_node_type), + ) + test_supervision_edge_type = EdgeType( + src_node_type=NodeType(test_edge_type_pb.src_node_type), + relation=Relation(test_edge_type_pb.relation), + dst_node_type=NodeType(test_edge_type_pb.dst_node_type), + ) + logger.info("Using explicit train and test supervision edge types") + else: + raise ValueError(f"Expected 1 or 2 supervision edge types, got {len(supervision_edge_types)}") + + logger.info(f"Train supervision edge type: {train_supervision_edge_type}") + logger.info(f"Test supervision edge type: {test_supervision_edge_type}") # Parses the fanout as a string. For heterogeneous case, fanouts should be specified as a string of a list of integers, such as "[10, 10]". fanout = trainer_args.get("num_neighbors", "[10, 10]") + # for debugging: + # fanout = "{('user', 'to_train', 'item'): [10, 10, 10, 10], ('user', 'to_test', 'item'): [0, 0, 0, 0], ('item', 'to_train', 'user'): [15, 15, 15, 15], ('item', 'to_test', 'user'): [0, 0, 0, 0]}" + logger.info(f"fanout: {fanout}") num_neighbors = parse_fanout(fanout) + logger.info(f"num_neighbors: {num_neighbors}") + # While the ideal value for `sampling_workers_per_process` has been identified to be between `2` and `4`, this may need some tuning depending on the # pipeline. We default this value to `4` here for simplicity. A `sampling_workers_per_process` which is too small may not have enough parallelization for @@ -791,8 +1081,8 @@ def _run_example_training( trainer_args.get("sampling_workers_per_process", "8") ) - main_batch_size = int(trainer_args.get("main_batch_size", "512")) - random_batch_size = int(trainer_args.get("random_batch_size", "512")) + main_batch_size = int(trainer_args.get("main_batch_size", "2048")) + random_batch_size = int(trainer_args.get("random_batch_size", "2048")) # LightGCN Hyperparameters embedding_dim = int(trainer_args.get("embedding_dim", "64")) @@ -812,10 +1102,12 @@ def _run_example_training( process_start_gap_seconds = int(trainer_args.get("process_start_gap_seconds", "0")) log_every_n_batch = int(trainer_args.get("log_every_n_batch", "25")) + log_every_n_batch = 25 learning_rate = float(trainer_args.get("learning_rate", "0.01")) weight_decay = float(trainer_args.get("weight_decay", "0.0005")) num_max_train_batches = int(trainer_args.get("num_max_train_batches", "1000")) + num_max_train_batches = 10 num_val_batches = int(trainer_args.get("num_val_batches", "100")) val_every_n_batch = int(trainer_args.get("val_every_n_batch", "50")) @@ -860,6 +1152,11 @@ def _run_example_training( task_config_uri=UriFactory.create_uri(task_config_uri), is_inference=False, ) + logger.info(f"Dataset: {dataset}") + logger.info(dir(dataset)) + logger.info(f"Dataset.train_node_ids: {dataset.train_node_ids}") + logger.info(f"Dataset.val_node_ids: {dataset.val_node_ids}") + logger.info(f"Dataset.test_node_ids: {dataset.test_node_ids}") # Calculate node_type_to_num_nodes from dataset node_type_to_num_nodes: dict[NodeType, int] = {} @@ -889,7 +1186,8 @@ def _run_example_training( machine_rank, # machine_rank machine_world_size, # machine_world_size dataset, # dataset - supervision_edge_type, # supervision_edge_type + train_supervision_edge_type, # train_supervision_edge_type (user -> to_train -> item) + test_supervision_edge_type, # test_supervision_edge_type (user -> to_test -> item) node_type_to_num_nodes, # node_type_to_num_nodes master_ip_address, # master_ip_address master_default_process_group_port, # master_default_process_group_port From 5df7c4c1a6cc70baf4229f3f5d9b813183269583 Mon Sep 17 00:00:00 2001 From: swong3 Date: Tue, 16 Dec 2025 21:51:04 +0000 Subject: [PATCH 5/9] Added recall metrics --- .../id_embeddings/heterogeneous_training.py | 556 +++++++----------- 1 file changed, 223 insertions(+), 333 deletions(-) diff --git a/examples/id_embeddings/heterogeneous_training.py b/examples/id_embeddings/heterogeneous_training.py index 322d981d6..e4fa95ef4 100644 --- a/examples/id_embeddings/heterogeneous_training.py +++ b/examples/id_embeddings/heterogeneous_training.py @@ -29,6 +29,13 @@ import argparse import statistics +import collections +import matplotlib +matplotlib.use('Agg') # Must be before importing pyplot for headless environments +import matplotlib.pyplot as plt +from google.cloud import storage +import io +from urllib.parse import urlparse import time from collections.abc import Iterator, Mapping from dataclasses import dataclass @@ -78,6 +85,23 @@ class DMPConfig: prefer_sharding_types: Optional[Sequence[str]] = ("table_wise", "row_wise") compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.FUSED +def upload_file_to_gcs(local_path: str, gcs_uri: str) -> None: + """ + Uploads a local file to a GCS URI of the form gs://bucket/path/to/file.png. + """ + if not gcs_uri.startswith("gs://"): + raise ValueError(f"gcs_uri must start with 'gs://', got {gcs_uri}") + + parsed = urlparse(gcs_uri.replace("gs://", "https://", 1)) + bucket_name = parsed.netloc + blob_path = parsed.path.lstrip("/") + + client = storage.Client() # uses default project/credentials on Vertex AI + bucket = client.bucket(bucket_name) + blob = bucket.blob(blob_path) + blob.upload_from_filename(local_path) + + def wrap_with_dmp(model: nn.Module, cfg: DMPConfig) -> nn.Module: """Wraps `model` with TorchRec DMP (shards EBCs, DP for the rest).""" @@ -104,140 +128,6 @@ def _sync_metric_across_processes(metric: torch.Tensor) -> float: torch.distributed.all_reduce(loss_tensor, op=torch.distributed.ReduceOp.SUM) return loss_tensor.item() / torch.distributed.get_world_size() - -# ============================================================================ -# Evaluation Metrics: Recall@K and NDCG@K (LightGCN paper metrics) -# ============================================================================ - - -def recall_at_k( - scores: torch.Tensor, - labels: torch.Tensor, - k_values: list[int], -) -> dict[int, float]: - """ - Computes Recall@K for collaborative filtering. - - Recall@K = (# of relevant items in top-K) / (# of relevant items) - - For link prediction with 1 positive per query, this simplifies to: - Recall@K = 1 if positive is in top-K, else 0 - - Args: - scores: Score tensor of shape [B, 1 + N], where B is batch size, N is number of negatives. - First column contains positive scores, rest are negative scores. - labels: Label tensor of shape [B, 1 + N], where 1 indicates positive, 0 indicates negative. - First column should contain 1 (positive label). - k_values: List of K values to compute Recall@K for (e.g., [10, 20, 50, 100]) - - Returns: - Dictionary mapping K -> Recall@K value - """ - batch_size = scores.size(0) - max_k = max(k_values) - - # Get top-K indices (shape: B x max_k) - topk_indices = torch.topk(scores, k=max_k, dim=1).indices - - # Gather labels for top-K items (shape: B x max_k) - topk_labels = torch.gather(labels, dim=1, index=topk_indices) - - # For each K, check if positive appears in top-K - recalls = {} - for k in k_values: - # Check if any of the top-K items is positive (label == 1) - hits = (topk_labels[:, :k].sum(dim=1) > 0).float() # (B,) - recalls[k] = hits.mean().item() - - return recalls - - -def ndcg_at_k( - scores: torch.Tensor, - labels: torch.Tensor, - k_values: list[int], -) -> dict[int, float]: - """ - Computes Normalized Discounted Cumulative Gain (NDCG@K) for collaborative filtering. - - DCG@K = sum_{i=1}^{K} (2^{rel_i} - 1) / log2(i + 1) - NDCG@K = DCG@K / IDCG@K - - For binary relevance (0 or 1), this simplifies to: - DCG@K = sum_{i=1}^{K} rel_i / log2(i + 1) - - Args: - scores: Score tensor of shape [B, 1 + N], where B is batch size, N is number of negatives. - First column contains positive scores, rest are negative scores. - labels: Label tensor of shape [B, 1 + N], where 1 indicates positive, 0 indicates negative. - First column should contain 1 (positive label). - k_values: List of K values to compute NDCG@K for (e.g., [10, 20, 50, 100]) - - Returns: - Dictionary mapping K -> NDCG@K value - """ - batch_size = scores.size(0) - max_k = max(k_values) - - # Get top-K indices (shape: B x max_k) - topk_indices = torch.topk(scores, k=max_k, dim=1).indices - - # Gather labels for top-K items (shape: B x max_k) - topk_labels = torch.gather(labels, dim=1, index=topk_indices).float() # (B, max_k) - - # Compute position weights: 1 / log2(i + 1) for i in [1, 2, ..., max_k] - positions = torch.arange(1, max_k + 1, device=scores.device).float() - weights = 1.0 / torch.log2(positions + 1.0) # (max_k,) - - ndcgs = {} - for k in k_values: - # DCG@K: sum of (relevance * weight) for top-K - dcg = (topk_labels[:, :k] * weights[:k]).sum(dim=1) # (B,) - - # IDCG@K: ideal DCG (assumes we have 1 relevant item, so IDCG@K = 1 / log2(2) = 1.0) - # Since we have exactly 1 positive per query, the ideal ranking puts it at position 1 - idcg = 1.0 / torch.log2(torch.tensor(2.0, device=scores.device)) # 1.0 - - # NDCG@K = DCG@K / IDCG@K - ndcg = dcg / idcg - ndcgs[k] = ndcg.mean().item() - - return ndcgs - - -def compute_metrics_from_scores( - scores: torch.Tensor, - labels: torch.Tensor, - k_values: list[int] = [10, 20, 50, 100], -) -> dict[str, float]: - """ - Compute all evaluation metrics (Recall@K, NDCG@K) from score and label tensors. - - Args: - scores: Score tensor of shape [B, 1 + N] - labels: Label tensor of shape [B, 1 + N] - k_values: List of K values for evaluation (default: [10, 20, 50, 100] as in LightGCN paper) - - Returns: - Dictionary with metric names and values, e.g.: - { - 'recall@10': 0.123, - 'recall@20': 0.234, - 'ndcg@10': 0.456, - 'ndcg@20': 0.567, - } - """ - recalls = recall_at_k(scores, labels, k_values) - ndcgs = ndcg_at_k(scores, labels, k_values) - - metrics = {} - for k in k_values: - metrics[f'recall@{k}'] = recalls[k] - metrics[f'ndcg@{k}'] = ndcgs[k] - - return metrics - - def _setup_dataloaders( dataset: DistDataset, split: Literal["train", "val", "test"], @@ -506,116 +396,92 @@ def _compute_bpr_batch( return loss, debug_info +@torch.no_grad() +def compute_full_lightgcn_embeddings(model, dataset, node_type_to_num_nodes, device): + model.eval() -def _evaluate_with_metrics( - model: nn.Module, - main_data_loader: Iterator, - random_negative_data_loader: Iterator, - supervision_edge_type: EdgeType, - device: torch.device, - num_batches: int, - num_random_negs_per_pos: int = 1, - use_hard_negs: bool = True, - k_values: list[int] = [10, 20, 50, 100], -) -> dict[str, float]: - """ - Evaluate the model using Recall@K and NDCG@K metrics (as used in LightGCN paper). + from torch_geometric.data import HeteroData - Args: - model: The trained model - main_data_loader: Iterator for main data (positives + hard negatives) - random_negative_data_loader: Iterator for random negatives - supervision_edge_type: The edge type for supervision - device: Device to run evaluation on - num_batches: Number of batches to evaluate on - num_random_negs_per_pos: Number of random negatives per positive - use_hard_negs: Whether to include hard negatives - k_values: List of K values for Recall@K and NDCG@K + data = HeteroData() - Returns: - Dictionary with metric names and values averaged across all batches - """ - model.eval() + num_users = node_type_to_num_nodes[NodeType("user")] + num_items = node_type_to_num_nodes[NodeType("item")] - query_node_type = supervision_edge_type.src_node_type - labeled_node_type = supervision_edge_type.dst_node_type + logger.info(f"num_users: {num_users}, num_items: {num_items}") - # Accumulate metrics across batches - all_metrics = {f'recall@{k}': [] for k in k_values} - all_metrics.update({f'ndcg@{k}': [] for k in k_values}) + data["user"].node = torch.arange(num_users, device=device, dtype=torch.long) + data["user"].batch_size = num_users + data["item"].node = torch.arange(num_items, device=device, dtype=torch.long) + data["item"].batch_size = num_items - with torch.no_grad(): - for batch_idx in range(num_batches): - main_data = next(main_data_loader) - random_negative_data = next(random_negative_data_loader) + logger.info(f"data: {data}") - # Get embeddings - main_emb = model(data=main_data, device=device) - rand_emb = model(data=random_negative_data, device=device) + dsts, srcs, _, _ = dataset.graph[("item", "to_train_gigl_positive", "user")].topo.to_coo() + logger.info(f"dsts: {dsts.shape}, srcs: {srcs.shape}") + edge_ui = torch.stack([srcs.to(device), dsts.to(device)], dim=0) # [2, E] + edge_iu = torch.stack([dsts.to(device), srcs.to(device)], dim=0) # [2, E] + logger.info(f"edge_ui: {edge_ui.shape}, edge_iu: {edge_iu.shape}") - # Get query indices and positives - B = int(main_data[query_node_type].batch_size) - query_idx = torch.arange(B, device=device) + data["user", "to_train", "item"].edge_index = edge_ui + data["item", "to_train", "user"].edge_index = edge_iu - pos_idx = torch.cat(list(main_data.y_positive.values())).to(device) - rep_query_idx = query_idx.repeat_interleave( - torch.tensor([len(v) for v in main_data.y_positive.values()], device=device) - ) + logger.info(f"data: {data}") + + emb_dict = model(data=data, device=device) # dict[NodeType, Tensor] + + logger.info(f"emb_dict: {emb_dict}") + + + user_emb = emb_dict[NodeType("user")] # shape [num_users, D] + item_emb = emb_dict[NodeType("item")] # shape [num_items, D] + logger.info(f"user_emb: {user_emb.shape}, item_emb: {item_emb.shape}") - # Get embeddings - query_emb = main_emb[query_node_type][rep_query_idx] # [M, D] - pos_emb = main_emb[labeled_node_type][pos_idx] # [M, D] + return user_emb, item_emb, srcs, dsts # srcs/dsts = full train edge list - # Random negatives - M = pos_idx.size(0) - rand_B = int(random_negative_data[labeled_node_type].batch_size) - rand_neg_idx = torch.randint(0, rand_B, (M * num_random_negs_per_pos,), device=device) - rand_neg_emb = rand_emb[labeled_node_type][rand_neg_idx].view(M, num_random_negs_per_pos, -1) # [M, K, D] +def build_train_pos_lists(num_users, srcs, dsts): + # srcs, dsts are 1D tensors of same length E + pos_items_per_user = [[] for _ in range(num_users)] + for u, i in zip(srcs.tolist(), dsts.tolist()): + pos_items_per_user[u].append(i) + return pos_items_per_user - # Hard negatives (if available) - if use_hard_negs and hasattr(main_data, "y_negative"): - hard_neg_idx = torch.cat(list(main_data.y_negative.values())).to(device) - hard_neg_emb = main_emb[labeled_node_type][hard_neg_idx] # [H, D] - else: - hard_neg_emb = torch.empty(0, main_emb[labeled_node_type].size(1), device=device) - # Compute scores - # Positive scores: [M, 1] - pos_scores = (query_emb * pos_emb).sum(dim=-1, keepdim=True) +@torch.no_grad() +def compute_full_recall_at_k(user_emb, item_emb, pos_items_per_user, K=20, device=None): + if device is None: + device = user_emb.device - # Random negative scores: [M, K] - rand_neg_scores = torch.bmm( - query_emb.unsqueeze(1), # [M, 1, D] - rand_neg_emb.transpose(1, 2), # [M, D, K] - ).squeeze(1) # [M, K] + num_users = user_emb.size(0) + num_items = item_emb.size(0) - # Hard negative scores: [M, H] - if hard_neg_emb.size(0) > 0: - hard_neg_scores = torch.matmul(query_emb, hard_neg_emb.T) # [M, H] - # Concatenate all negative scores - all_neg_scores = torch.cat([rand_neg_scores, hard_neg_scores], dim=1) # [M, K+H] - else: - all_neg_scores = rand_neg_scores # [M, K] + # logger.info(f"num_users: {num_users}, num_items: {num_items}") + # logger.info(f"pos_items_per_user: {pos_items_per_user}") + # logger.info(f"length of pos_items_per_user: {len(pos_items_per_user)}") - # Concatenate positive and negative scores: [M, 1+K+H] - scores = torch.cat([pos_scores, all_neg_scores], dim=1) + recalls = [] + batch_users = 256 # chunk users to avoid huge score matrix - # Create labels: first column is 1 (positive), rest are 0 (negatives) - labels = torch.zeros_like(scores) - labels[:, 0] = 1.0 + for start in range(0, num_users, batch_users): + end = min(start + batch_users, num_users) + u_batch = torch.arange(start, end, device=device) - # Compute metrics for this batch - batch_metrics = compute_metrics_from_scores(scores, labels, k_values) + # [B, D] x [D, I] -> [B, I] + scores = user_emb[u_batch] @ item_emb.T # full ranking over all items - # Accumulate - for metric_name, metric_value in batch_metrics.items(): - all_metrics[metric_name].append(metric_value) + # Top-K items per user + topk_items = torch.topk(scores, K, dim=1).indices # [B, K] - # Average across batches - avg_metrics = {name: sum(values) / len(values) for name, values in all_metrics.items()} + topk_sets = [set(row.tolist()) for row in topk_items] - return avg_metrics + for local_idx, u in enumerate(u_batch.tolist()): + pos_items = pos_items_per_user[u] + if not pos_items: + continue # skip users with no train positives + hits = sum(1 for i in pos_items if i in topk_sets[local_idx]) + recalls.append(hits / len(pos_items)) + + return sum(recalls) / len(recalls) if recalls else 0.0 def _training_process( local_rank: int, @@ -646,6 +512,7 @@ def _training_process( should_skip_training: bool, num_random_negs_per_pos: int, l2_lambda: float, + plots_output_uri: str, ) -> None: """ This function is spawned by each machine for training a heterogeneous LightGCN model given some loaded distributed dataset. @@ -680,6 +547,7 @@ def _training_process( num_random_negs_per_pos (int): Number of random negatives per positive l2_lambda (float): L2 regularization strength """ + os.environ["LOCAL_WORLD_SIZE"] = str(local_world_size) world_size = machine_world_size * local_world_size rank = machine_rank * local_world_size + local_rank logger.info( @@ -739,27 +607,6 @@ def _training_process( ) logger.info(f"model: {model}") - # # Initialize embeddings AFTER DMP wrapping (DMP reinitializes them, so doing it before is useless) - # logger.info(f"---Rank {rank} initializing embeddings with scaled Xavier uniform (AFTER DMP)") - # unwrapped_model = unwrap_from_dmp(model) - # logger.info(f"EmbeddingBagCollection parameters:") - # init_count = 0 - # EMBEDDING_SCALE = 0.1 # Small scale to avoid saturation - # for name, param in unwrapped_model._embedding_bag_collection.named_parameters(): - # logger.info(f" Found parameter: {name}, shape: {param.shape}, device: {param.device}") - # logger.info(f" BEFORE init - mean={param.mean().item():.6f}, std={param.std().item():.6f}, norm={param.norm().item():.6f}") - - # # Xavier uniform: U(-sqrt(6/(fan_in + fan_out)), sqrt(6/(fan_in + fan_out))) - # # For embeddings: fan_in = 1, fan_out = embedding_dim - # # Then scale up to combat LightGCN's aggressive neighbor averaging - # torch.nn.init.xavier_uniform_(param) - # param.data *= EMBEDDING_SCALE - # init_count += 1 - - # logger.info(f" AFTER init (scaled {EMBEDDING_SCALE}x) - mean={param.mean().item():.6f}, std={param.std().item():.6f}, norm={param.norm().item():.6f}") - - # logger.info(f"Initialized {init_count} embedding parameters after DMP wrapping with {EMBEDDING_SCALE}x scaling") - # After DMP, create dense optimizer for non-embedding parameters (e.g., LGConv layers) logger.info(f"---Rank {rank} creating dense optimizer for non-embedding parameters") dense_params = dict(in_backward_optimizer_filter(model.named_parameters())) @@ -780,6 +627,13 @@ def _training_process( logger.info(f"optimizer: {optimizer}") logger.info(f"num_neighbors: {num_neighbors}") + # Metric tracking (only meaningful on rank 0) + train_batch_indices: list[int] = [] + train_losses: list[float] = [] + + full_eval_batch_indices: list[int] = [] + full_eval_recall20: list[float] = [] + if should_skip_training: logger.info(f"Rank {rank}: Skipping training and loading model from {model_uri}") @@ -800,19 +654,11 @@ def _training_process( process_start_gap_seconds=process_start_gap_seconds, ) - logger.info(f"Rank {rank}: Setting up validation dataloaders") - # val_main_loader, val_random_negative_loader = _setup_dataloaders( - # dataset=dataset, - # split="val", - # supervision_edge_type=train_supervision_edge_type, # Use TRAIN edges for validation - # num_neighbors=num_neighbors, - # sampling_workers_per_process=sampling_workers_per_process, - # main_batch_size=main_batch_size, - # random_batch_size=random_batch_size, - # device=device, - # sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, - # process_start_gap_seconds=process_start_gap_seconds, - # ) + # How often to run *full* Recall@K over all TRAIN edges. + # This is expensive (full graph + full item ranking), so keep it infrequent. + full_recall_eval_every_n_batch = 10 # adjust up/down as you like + full_recall_K = 20 + logger.info(f"Rank {rank}: Starting training loop") model.train() @@ -821,9 +667,6 @@ def _training_process( train_main_iter = InfiniteIterator(train_main_loader) train_random_neg_iter = InfiniteIterator(train_random_negative_loader) - # val_main_iter = InfiniteIterator(val_main_loader) - # val_random_neg_iter = InfiniteIterator(val_random_negative_loader) - batch_losses = [] val_losses = [] @@ -904,6 +747,13 @@ def _training_process( batch_loss = _sync_metric_across_processes(loss) batch_losses.append(batch_loss) + # Track train loss for plotting (rank 0 only) + if rank == 0: + global_step = batch_num + 1 + train_batch_indices.append(global_step) + train_losses.append(batch_loss) + + if (batch_num + 1) % log_every_n_batch == 0: avg_loss = statistics.mean(batch_losses[-log_every_n_batch:]) elapsed = time.time() - start_time @@ -913,86 +763,113 @@ def _training_process( ) logger.info(f"{'='*60}\n") - # # Validation - # if (batch_num + 1) % val_every_n_batch == 0: - # model.eval() - # with torch.no_grad(): - # val_batch_losses = [] - # for _ in range(num_val_batches): - # val_main_data = next(val_main_iter) - # val_random_negative_data = next(val_random_neg_iter) - - # val_loss, _ = _compute_bpr_batch( - # model=model, - # main_data=val_main_data, - # random_negative_data=val_random_negative_data, - # supervision_edge_type=supervision_edge_type, - # device=device, - # num_random_negs_per_pos=num_random_negs_per_pos, - # use_hard_negs=True, - # l2_lambda=l2_lambda, - # debug_log=False, - # ) - - # val_batch_loss = _sync_metric_across_processes(val_loss) - # val_batch_losses.append(val_batch_loss) - - # avg_val_loss = statistics.mean(val_batch_losses) - # val_losses.append(avg_val_loss) - # logger.info( - # f"Rank {rank} | Batch {batch_num + 1} | Val Loss: {avg_val_loss:.4f}" - # ) - - # model.train() + # ------------------------------------------------------------------ + # Periodic *full* Recall@K over TRAIN edges vs batch (rank 0 only) + # ------------------------------------------------------------------ + if (batch_num + 1) % full_recall_eval_every_n_batch == 0 and rank == 0: + global_step = batch_num + 1 + logger.info( + f"Rank {rank}: running FULL TRAIN Recall@{full_recall_K} eval " + f"at batch {global_step}" + ) + + # Compute full LightGCN embeddings for all users/items on TRAIN graph + user_emb, item_emb, srcs, dsts = compute_full_lightgcn_embeddings( + model=model, + dataset=dataset, + node_type_to_num_nodes=node_type_to_num_nodes, + device=device, + ) + pos_items_per_user = build_train_pos_lists( + num_users=node_type_to_num_nodes[NodeType("user")], + srcs=srcs, + dsts=dsts, + ) + + # Compute true Recall@K over all items & all TRAIN edges + full_recall = compute_full_recall_at_k( + user_emb=user_emb, + item_emb=item_emb, + pos_items_per_user=pos_items_per_user, + K=full_recall_K, + device=device, + ) + + full_eval_batch_indices.append(global_step) + full_eval_recall20.append(full_recall) + + logger.info( + f"[FULL TRAIN Recall] batch {global_step}: " + f"Recall@{full_recall_K}={full_recall:.4f}" + ) logger.info(f"Rank {rank}: Training completed. Saving model to {model_uri}") # if rank == 0: # save_state_dict(model.state_dict(), model_uri) - # Final testing with Recall@K and NDCG@K metrics - logger.info(f"Rank {rank}: Setting up test dataloaders") - test_main_loader, test_random_negative_loader = _setup_dataloaders( + user_emb, item_emb, srcs, dsts = compute_full_lightgcn_embeddings( + model=unwrap_from_dmp(model), dataset=dataset, - split="test", - supervision_edge_type=test_supervision_edge_type, # Use TEST edges for evaluation! - num_neighbors=num_neighbors, - sampling_workers_per_process=sampling_workers_per_process, - main_batch_size=main_batch_size, - random_batch_size=random_batch_size, + node_type_to_num_nodes=node_type_to_num_nodes, device=device, - sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, - process_start_gap_seconds=process_start_gap_seconds, ) - - logger.info(f"Rank {rank}: Running test evaluation with Recall@K and NDCG@K metrics") - logger.info(f" Evaluating on TEST edges: {test_supervision_edge_type}") - - # K values to evaluate (as in LightGCN paper, Table 2) - # Paper uses: Recall@20, NDCG@20 for Gowalla - eval_k_values = [10, 20, 50, 100] - - test_metrics = _evaluate_with_metrics( - model=model, - main_data_loader=InfiniteIterator(test_main_loader), - random_negative_data_loader=InfiniteIterator(test_random_negative_loader), - supervision_edge_type=test_supervision_edge_type, # Use TEST edges for metric computation + pos_items_per_user = build_train_pos_lists( + num_users=node_type_to_num_nodes[NodeType("user")], + srcs=srcs, + dsts=dsts, + ) + recall20 = compute_full_recall_at_k( + user_emb=user_emb, + item_emb=item_emb, + pos_items_per_user=pos_items_per_user, + K=20, device=device, - num_batches=min(100, num_val_batches), - num_random_negs_per_pos=num_random_negs_per_pos, - use_hard_negs=True, - k_values=eval_k_values, ) + logger.info(f"Full-train Recall@20 over all items: {recall20:.4f}") - # Log test metrics + # ---------------------------------------------------------------------- + # Offline plotting with matplotlib (only rank 0) + # ---------------------------------------------------------------------- if rank == 0: - logger.info(f"\n{'='*80}") - logger.info("TEST EVALUATION RESULTS (LightGCN Paper Metrics)") - logger.info(f"{'='*80}") - for k in eval_k_values: - recall = test_metrics[f'recall@{k}'] - ndcg = test_metrics[f'ndcg@{k}'] - logger.info(f" Recall@{k:3d}: {recall:.4f} | NDCG@{k:3d}: {ndcg:.4f}") - logger.info(f"{'='*80}\n") + # Train loss vs batch + loss_plot_path = "train_loss_vs_batch.png" + if train_batch_indices and train_losses: + plt.figure() + plt.plot(train_batch_indices, train_losses, marker=".", linewidth=1) + plt.xlabel("Batch") + plt.ylabel("Train BPR loss") + plt.title("Train Loss vs Batch") + plt.grid(True, alpha=0.3) + plt.tight_layout() + plt.savefig(loss_plot_path) + plt.close() + logger.info(f"Saved train loss curve to {loss_plot_path}") + + if plots_output_uri: + gcs_path = plots_output_uri.rstrip("/") + "/train_loss_vs_batch.png" + upload_file_to_gcs(loss_plot_path, gcs_path) + logger.info(f"Uploaded train loss curve to {gcs_path}") + + # Full-train Recall@20 vs batch + recall_plot_path = f"full_train_recall@{full_recall_K}_vs_batch.png" + if full_eval_batch_indices and full_eval_recall20: + plt.figure() + plt.plot(full_eval_batch_indices, full_eval_recall20, marker="o", linewidth=1) + plt.xlabel("Batch") + plt.ylabel(f"Full-train Recall@{full_recall_K}") + plt.title(f"Full-train Recall@{full_recall_K} vs Batch") + plt.grid(True, alpha=0.3) + plt.tight_layout() + plt.savefig(recall_plot_path) + plt.close() + logger.info(f"Saved full-train Recall@{full_recall_K} curve to {recall_plot_path}") + + if plots_output_uri: + gcs_path = plots_output_uri.rstrip("/") + f"/{recall_plot_path}" + upload_file_to_gcs(recall_plot_path, gcs_path) + logger.info(f"Uploaded full-train Recall curve to {gcs_path}") + + torch.distributed.destroy_process_group() @@ -1081,8 +958,8 @@ def _run_example_training( trainer_args.get("sampling_workers_per_process", "8") ) - main_batch_size = int(trainer_args.get("main_batch_size", "2048")) - random_batch_size = int(trainer_args.get("random_batch_size", "2048")) + main_batch_size = int(trainer_args.get("main_batch_size", "512")) + random_batch_size = int(trainer_args.get("random_batch_size", "512")) # LightGCN Hyperparameters embedding_dim = int(trainer_args.get("embedding_dim", "64")) @@ -1092,6 +969,9 @@ def _run_example_training( num_random_negs_per_pos = int(trainer_args.get("num_random_negs_per_pos", "1")) l2_lambda = float(trainer_args.get("l2_lambda", "0.0")) + plots_output_uri = trainer_args.get("plots_output_uri", "") # optional + + # This value represents the the shared-memory buffer size (bytes) allocated for the channel during sampling, and # is the place to store pre-fetched data, so if it is too small then prefetching is limited, causing sampling slowdown. This parameter is a string # with `{numeric_value}{storage_size}`, where storage size could be `MB`, `GB`, etc. We default this value to 4GB, @@ -1102,12 +982,12 @@ def _run_example_training( process_start_gap_seconds = int(trainer_args.get("process_start_gap_seconds", "0")) log_every_n_batch = int(trainer_args.get("log_every_n_batch", "25")) - log_every_n_batch = 25 + log_every_n_batch = 100 learning_rate = float(trainer_args.get("learning_rate", "0.01")) weight_decay = float(trainer_args.get("weight_decay", "0.0005")) num_max_train_batches = int(trainer_args.get("num_max_train_batches", "1000")) - num_max_train_batches = 10 + num_max_train_batches = 100 num_val_batches = int(trainer_args.get("num_val_batches", "100")) val_every_n_batch = int(trainer_args.get("val_every_n_batch", "50")) @@ -1144,8 +1024,6 @@ def _run_example_training( master_default_process_group_port = ( gigl.distributed.utils.get_free_ports_from_master_node(num_ports=1) )[0] - # Destroying the process group as one will be re-initialized in the training process using above information - torch.distributed.destroy_process_group() logger.info(f"--- Launching data loading process ---") dataset = build_dataset_from_task_config_uri( @@ -1166,6 +1044,17 @@ def _run_example_training( num_nodes = max_id + 1 node_type_to_num_nodes[node_type] = num_nodes logger.info(f"Node type {node_type}: {num_nodes} nodes (max_id={max_id})") + output_list = [None for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather_object(output_list, node_type_to_num_nodes) + logger.info(f"output_list: {output_list}") + node_type_to_num_nodes = collections.defaultdict(int) + for d in output_list: + for node_type, num_nodes in d.items(): + node_type_to_num_nodes[node_type] = max(node_type_to_num_nodes[node_type], num_nodes) + logger.info(f"node_type_to_num_nodes: {node_type_to_num_nodes}") + # Destroying the process group as one will be re-initialized in the training process using above information + + torch.distributed.destroy_process_group() logger.info( f"--- Data loading process finished, took {time.time() - start_time:.3f} seconds" @@ -1209,6 +1098,7 @@ def _run_example_training( should_skip_training, # should_skip_training num_random_negs_per_pos, # num_random_negs_per_pos l2_lambda, # l2_lambda + plots_output_uri, # plots_output_uri ), nprocs=local_world_size, join=True, From 248da83a920084daf87620368e9ceaaab976abf1 Mon Sep 17 00:00:00 2001 From: swong3 Date: Thu, 18 Dec 2025 19:44:11 +0000 Subject: [PATCH 6/9] fix --- .../id_embeddings/heterogeneous_training.py | 291 ++++++++++-------- 1 file changed, 157 insertions(+), 134 deletions(-) diff --git a/examples/id_embeddings/heterogeneous_training.py b/examples/id_embeddings/heterogeneous_training.py index e4fa95ef4..8a8459487 100644 --- a/examples/id_embeddings/heterogeneous_training.py +++ b/examples/id_embeddings/heterogeneous_training.py @@ -33,9 +33,6 @@ import matplotlib matplotlib.use('Agg') # Must be before importing pyplot for headless environments import matplotlib.pyplot as plt -from google.cloud import storage -import io -from urllib.parse import urlparse import time from collections.abc import Iterator, Mapping from dataclasses import dataclass @@ -55,8 +52,9 @@ from torchrec.optim.rowwise_adagrad import RowWiseAdagrad import gigl.distributed.utils -from gigl.common import Uri, UriFactory +from gigl.common import Uri, UriFactory, GcsUri, LocalUri from gigl.common.logger import Logger +from gigl.common.utils.gcs import GcsUtils from gigl.common.utils.torch_training import is_distributed_available_and_initialized from gigl.distributed import ( DistABLPLoader, @@ -85,21 +83,31 @@ class DMPConfig: prefer_sharding_types: Optional[Sequence[str]] = ("table_wise", "row_wise") compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.FUSED -def upload_file_to_gcs(local_path: str, gcs_uri: str) -> None: +def upload_file_to_gcs(local_path: str, gcs_uri: str, project: Optional[str] = None) -> None: """ Uploads a local file to a GCS URI of the form gs://bucket/path/to/file.png. + + Args: + local_path (str): Path to the local file to upload. + gcs_uri (str): GCS URI destination (e.g., gs://bucket/path/to/file.png). + project (Optional[str]): GCP project ID. If None, uses default credentials. """ if not gcs_uri.startswith("gs://"): raise ValueError(f"gcs_uri must start with 'gs://', got {gcs_uri}") - parsed = urlparse(gcs_uri.replace("gs://", "https://", 1)) - bucket_name = parsed.netloc - blob_path = parsed.path.lstrip("/") - - client = storage.Client() # uses default project/credentials on Vertex AI - bucket = client.bucket(bucket_name) - blob = bucket.blob(blob_path) - blob.upload_from_filename(local_path) + # Use GiGL's GcsUtils for proper GCS handling + gcs_utils = GcsUtils(project=project) + local_uri = LocalUri(local_path) + print(local_uri) + gcs_uri_obj = GcsUri(gcs_uri) + print(gcs_uri_obj) + + logger.info(f"Uploading {local_path} to {gcs_uri}") + gcs_utils.upload_files_to_gcs( + local_file_path_to_gcs_path_map={local_uri: gcs_uri_obj}, + parallel=False, + ) + logger.info(f"Successfully uploaded to {gcs_uri}") @@ -407,34 +415,34 @@ def compute_full_lightgcn_embeddings(model, dataset, node_type_to_num_nodes, dev num_users = node_type_to_num_nodes[NodeType("user")] num_items = node_type_to_num_nodes[NodeType("item")] - logger.info(f"num_users: {num_users}, num_items: {num_items}") + # logger.info(f"num_users: {num_users}, num_items: {num_items}") data["user"].node = torch.arange(num_users, device=device, dtype=torch.long) data["user"].batch_size = num_users data["item"].node = torch.arange(num_items, device=device, dtype=torch.long) data["item"].batch_size = num_items - logger.info(f"data: {data}") + # logger.info(f"data: {data}") dsts, srcs, _, _ = dataset.graph[("item", "to_train_gigl_positive", "user")].topo.to_coo() - logger.info(f"dsts: {dsts.shape}, srcs: {srcs.shape}") + # logger.info(f"dsts: {dsts.shape}, srcs: {srcs.shape}") edge_ui = torch.stack([srcs.to(device), dsts.to(device)], dim=0) # [2, E] edge_iu = torch.stack([dsts.to(device), srcs.to(device)], dim=0) # [2, E] - logger.info(f"edge_ui: {edge_ui.shape}, edge_iu: {edge_iu.shape}") + # logger.info(f"edge_ui: {edge_ui.shape}, edge_iu: {edge_iu.shape}") data["user", "to_train", "item"].edge_index = edge_ui data["item", "to_train", "user"].edge_index = edge_iu - logger.info(f"data: {data}") + # logger.info(f"data: {data}") emb_dict = model(data=data, device=device) # dict[NodeType, Tensor] - logger.info(f"emb_dict: {emb_dict}") + # logger.info(f"emb_dict: {emb_dict}") user_emb = emb_dict[NodeType("user")] # shape [num_users, D] item_emb = emb_dict[NodeType("item")] # shape [num_items, D] - logger.info(f"user_emb: {user_emb.shape}, item_emb: {item_emb.shape}") + # logger.info(f"user_emb: {user_emb.shape}, item_emb: {item_emb.shape}") return user_emb, item_emb, srcs, dsts # srcs/dsts = full train edge list @@ -457,6 +465,7 @@ def compute_full_recall_at_k(user_emb, item_emb, pos_items_per_user, K=20, devic # logger.info(f"num_users: {num_users}, num_items: {num_items}") # logger.info(f"pos_items_per_user: {pos_items_per_user}") # logger.info(f"length of pos_items_per_user: {len(pos_items_per_user)}") + # logger.info(pos_items_per_user[:2]) recalls = [] batch_users = 256 # chunk users to avoid huge score matrix @@ -467,20 +476,28 @@ def compute_full_recall_at_k(user_emb, item_emb, pos_items_per_user, K=20, devic # [B, D] x [D, I] -> [B, I] scores = user_emb[u_batch] @ item_emb.T # full ranking over all items + # logger.info(f"scores: {scores.shape}") # Top-K items per user topk_items = torch.topk(scores, K, dim=1).indices # [B, K] - + # logger.info(topk_items.shape) + # logger.info(f"topk_items: {topk_items}") topk_sets = [set(row.tolist()) for row in topk_items] + # logger.info(f"topk_sets: {topk_sets}") for local_idx, u in enumerate(u_batch.tolist()): + # logger.info(f"local_idx: {local_idx}, u: {u}") pos_items = pos_items_per_user[u] + # logger.info(f"pos_items: {pos_items}") if not pos_items: continue # skip users with no train positives + # logger.info(f"topk_sets[local_idx]: {topk_sets[local_idx]}") hits = sum(1 for i in pos_items if i in topk_sets[local_idx]) + # logger.info(f"hits: {hits}") recalls.append(hits / len(pos_items)) + # logger.info(f"recalls: {recalls}") return sum(recalls) / len(recalls) if recalls else 0.0 def _training_process( @@ -656,7 +673,7 @@ def _training_process( # How often to run *full* Recall@K over all TRAIN edges. # This is expensive (full graph + full item ranking), so keep it infrequent. - full_recall_eval_every_n_batch = 10 # adjust up/down as you like + full_recall_eval_every_n_batch = 100 # adjust up/down as you like full_recall_K = 20 @@ -682,7 +699,7 @@ def _training_process( debug_this_batch = (batch_num == 0) or ((batch_num + 1) % log_every_n_batch == 0) optimizer.zero_grad() - # logger.info(f"Zeroing gradients") + logger.info(f"Zeroing gradients") main_data = next(train_main_iter) # logger.info(f"Main data: {main_data}") random_negative_data = next(train_random_neg_iter) @@ -704,7 +721,7 @@ def _training_process( # logger.info(f"y_positive keys: {list(main_data.y_positive.keys())}") # for k, v in main_data.y_positive.items(): # logger.info(f" {k}: {len(v)} positives, sample IDs: {v[:5].tolist() if len(v) > 0 else []}") - + logger.info(f"Computing BPR batch") loss, debug_info = _compute_bpr_batch( model=model, main_data=main_data, @@ -723,29 +740,30 @@ def _training_process( logger.info(f" {key}: {value}") loss.backward() - + logger.info(f"Backward pass completed") # Check if gradients exist and their magnitudes # NOTE: With TorchRec's fused optimizer, embedding gradients are not materialized on .grad # (they're applied directly in backward), so we only check non-embedding params here - if debug_this_batch and rank == 0: - grad_norms = [] - param_norms = [] - for name, param in model.named_parameters(): - if param.grad is not None: - grad_norms.append(param.grad.norm().item()) - param_norms.append(param.norm().item()) - if grad_norms: - logger.info(f" Non-embedding gradient norms - mean: {sum(grad_norms)/len(grad_norms):.6f}, " - f"max: {max(grad_norms):.6f}, min: {min(grad_norms):.6f}") - logger.info(f" Non-embedding param norms - mean: {sum(param_norms)/len(param_norms):.6f}, " - f"max: {max(param_norms):.6f}, min: {min(param_norms):.6f}") - else: - logger.info(" Note: No .grad found on params (expected with TorchRec fused optimizer for embeddings)") + # if debug_this_batch and rank == 0: + # grad_norms = [] + # param_norms = [] + # for name, param in model.named_parameters(): + # if param.grad is not None: + # grad_norms.append(param.grad.norm().item()) + # param_norms.append(param.norm().item()) + # if grad_norms: + # logger.info(f" Non-embedding gradient norms - mean: {sum(grad_norms)/len(grad_norms):.6f}, " + # f"max: {max(grad_norms):.6f}, min: {min(grad_norms):.6f}") + # logger.info(f" Non-embedding param norms - mean: {sum(param_norms)/len(param_norms):.6f}, " + # f"max: {max(param_norms):.6f}, min: {min(param_norms):.6f}") + # else: + # logger.info(" Note: No .grad found on params (expected with TorchRec fused optimizer for embeddings)") optimizer.step() - + logger.info(f"Step completed") batch_loss = _sync_metric_across_processes(loss) batch_losses.append(batch_loss) + logger.info(f"Batch loss: {batch_loss}") # Track train loss for plotting (rank 0 only) if rank == 0: @@ -766,108 +784,112 @@ def _training_process( # ------------------------------------------------------------------ # Periodic *full* Recall@K over TRAIN edges vs batch (rank 0 only) # ------------------------------------------------------------------ - if (batch_num + 1) % full_recall_eval_every_n_batch == 0 and rank == 0: - global_step = batch_num + 1 - logger.info( - f"Rank {rank}: running FULL TRAIN Recall@{full_recall_K} eval " - f"at batch {global_step}" - ) - - # Compute full LightGCN embeddings for all users/items on TRAIN graph - user_emb, item_emb, srcs, dsts = compute_full_lightgcn_embeddings( - model=model, - dataset=dataset, - node_type_to_num_nodes=node_type_to_num_nodes, - device=device, - ) - pos_items_per_user = build_train_pos_lists( - num_users=node_type_to_num_nodes[NodeType("user")], - srcs=srcs, - dsts=dsts, - ) - - # Compute true Recall@K over all items & all TRAIN edges - full_recall = compute_full_recall_at_k( - user_emb=user_emb, - item_emb=item_emb, - pos_items_per_user=pos_items_per_user, - K=full_recall_K, - device=device, - ) - - full_eval_batch_indices.append(global_step) - full_eval_recall20.append(full_recall) - - logger.info( - f"[FULL TRAIN Recall] batch {global_step}: " - f"Recall@{full_recall_K}={full_recall:.4f}" - ) + # if (batch_num + 1) % full_recall_eval_every_n_batch == 0 and rank == 0: + # global_step = batch_num + 1 + # logger.info( + # f"Rank {rank}: running FULL TRAIN Recall@{full_recall_K} eval " + # f"at batch {global_step}" + # ) + + # # Compute full LightGCN embeddings for all users/items on TRAIN graph + # user_emb, item_emb, srcs, dsts = compute_full_lightgcn_embeddings( + # model=model, + # dataset=dataset, + # node_type_to_num_nodes=node_type_to_num_nodes, + # device=device, + # ) + # pos_items_per_user = build_train_pos_lists( + # num_users=node_type_to_num_nodes[NodeType("user")], + # srcs=srcs, + # dsts=dsts, + # ) + + # # Compute true Recall@K over all items & all TRAIN edges + # full_recall = compute_full_recall_at_k( + # user_emb=user_emb, + # item_emb=item_emb, + # pos_items_per_user=pos_items_per_user, + # K=full_recall_K, + # device=device, + # ) + + # full_eval_batch_indices.append(global_step) + # full_eval_recall20.append(full_recall) + + # logger.info( + # f"[FULL TRAIN Recall] batch {global_step}: " + # f"Recall@{full_recall_K}={full_recall:.4f}" + # ) logger.info(f"Rank {rank}: Training completed. Saving model to {model_uri}") # if rank == 0: # save_state_dict(model.state_dict(), model_uri) - user_emb, item_emb, srcs, dsts = compute_full_lightgcn_embeddings( - model=unwrap_from_dmp(model), - dataset=dataset, - node_type_to_num_nodes=node_type_to_num_nodes, - device=device, - ) - pos_items_per_user = build_train_pos_lists( - num_users=node_type_to_num_nodes[NodeType("user")], - srcs=srcs, - dsts=dsts, - ) - recall20 = compute_full_recall_at_k( - user_emb=user_emb, - item_emb=item_emb, - pos_items_per_user=pos_items_per_user, - K=20, - device=device, - ) - logger.info(f"Full-train Recall@20 over all items: {recall20:.4f}") + if rank == 0: + logger.info(f"Computing full LightGCN embeddings after training") + user_emb, item_emb, srcs, dsts = compute_full_lightgcn_embeddings( + model=unwrap_from_dmp(model), + dataset=dataset, + node_type_to_num_nodes=node_type_to_num_nodes, + device=device, + ) + logger.info(f"Building train pos lists after training") + pos_items_per_user = build_train_pos_lists( + num_users=node_type_to_num_nodes[NodeType("user")], + srcs=srcs, + dsts=dsts, + ) + logger.info(f"Computing full Recall@20 after training") + recall20 = compute_full_recall_at_k( + user_emb=user_emb, + item_emb=item_emb, + pos_items_per_user=pos_items_per_user, + K=20, + device=device, + ) + logger.info(f"Full-train Recall@20 over all items: {recall20:.4f}") # ---------------------------------------------------------------------- # Offline plotting with matplotlib (only rank 0) # ---------------------------------------------------------------------- - if rank == 0: - # Train loss vs batch - loss_plot_path = "train_loss_vs_batch.png" - if train_batch_indices and train_losses: - plt.figure() - plt.plot(train_batch_indices, train_losses, marker=".", linewidth=1) - plt.xlabel("Batch") - plt.ylabel("Train BPR loss") - plt.title("Train Loss vs Batch") - plt.grid(True, alpha=0.3) - plt.tight_layout() - plt.savefig(loss_plot_path) - plt.close() - logger.info(f"Saved train loss curve to {loss_plot_path}") - - if plots_output_uri: - gcs_path = plots_output_uri.rstrip("/") + "/train_loss_vs_batch.png" - upload_file_to_gcs(loss_plot_path, gcs_path) - logger.info(f"Uploaded train loss curve to {gcs_path}") - - # Full-train Recall@20 vs batch - recall_plot_path = f"full_train_recall@{full_recall_K}_vs_batch.png" - if full_eval_batch_indices and full_eval_recall20: - plt.figure() - plt.plot(full_eval_batch_indices, full_eval_recall20, marker="o", linewidth=1) - plt.xlabel("Batch") - plt.ylabel(f"Full-train Recall@{full_recall_K}") - plt.title(f"Full-train Recall@{full_recall_K} vs Batch") - plt.grid(True, alpha=0.3) - plt.tight_layout() - plt.savefig(recall_plot_path) - plt.close() - logger.info(f"Saved full-train Recall@{full_recall_K} curve to {recall_plot_path}") - - if plots_output_uri: - gcs_path = plots_output_uri.rstrip("/") + f"/{recall_plot_path}" - upload_file_to_gcs(recall_plot_path, gcs_path) - logger.info(f"Uploaded full-train Recall curve to {gcs_path}") + # if rank == 0: + # # Train loss vs batch + # loss_plot_path = "train_loss_vs_batch.png" + # if train_batch_indices and train_losses: + # plt.figure() + # plt.plot(train_batch_indices, train_losses, marker=".", linewidth=1) + # plt.xlabel("Batch") + # plt.ylabel("Train BPR loss") + # plt.title("Train Loss vs Batch") + # plt.grid(True, alpha=0.3) + # plt.tight_layout() + # plt.savefig(loss_plot_path) + # plt.close() + # logger.info(f"Saved train loss curve to {loss_plot_path}") + + # if plots_output_uri: + # gcs_path = plots_output_uri.rstrip("/") + "/train_loss_vs_batch.png" + # upload_file_to_gcs(loss_plot_path, gcs_path) + # logger.info(f"Uploaded train loss curve to {gcs_path}") + + # # Full-train Recall@20 vs batch + # recall_plot_path = f"full_train_recall@{full_recall_K}_vs_batch.png" + # if full_eval_batch_indices and full_eval_recall20: + # plt.figure() + # plt.plot(full_eval_batch_indices, full_eval_recall20, marker="o", linewidth=1) + # plt.xlabel("Batch") + # plt.ylabel(f"Full-train Recall@{full_recall_K}") + # plt.title(f"Full-train Recall@{full_recall_K} vs Batch") + # plt.grid(True, alpha=0.3) + # plt.tight_layout() + # plt.savefig(recall_plot_path) + # plt.close() + # logger.info(f"Saved full-train Recall@{full_recall_K} curve to {recall_plot_path}") + + # if plots_output_uri: + # gcs_path = plots_output_uri.rstrip("/") + f"/{recall_plot_path}" + # upload_file_to_gcs(recall_plot_path, gcs_path) + # logger.info(f"Uploaded full-train Recall curve to {gcs_path}") @@ -970,6 +992,7 @@ def _run_example_training( l2_lambda = float(trainer_args.get("l2_lambda", "0.0")) plots_output_uri = trainer_args.get("plots_output_uri", "") # optional + plots_output_uri = "gs://gigl-dev-temp-assets/swong3" # This value represents the the shared-memory buffer size (bytes) allocated for the channel during sampling, and @@ -987,7 +1010,7 @@ def _run_example_training( learning_rate = float(trainer_args.get("learning_rate", "0.01")) weight_decay = float(trainer_args.get("weight_decay", "0.0005")) num_max_train_batches = int(trainer_args.get("num_max_train_batches", "1000")) - num_max_train_batches = 100 + num_max_train_batches = 10000 num_val_batches = int(trainer_args.get("num_val_batches", "100")) val_every_n_batch = int(trainer_args.get("val_every_n_batch", "50")) From 0952da6716589869cb06420198266c22c9b642fc Mon Sep 17 00:00:00 2001 From: swong3 Date: Thu, 18 Dec 2025 22:42:12 +0000 Subject: [PATCH 7/9] Fix recall distr issues --- .../id_embeddings/heterogeneous_training.py | 233 +++++++++--------- 1 file changed, 113 insertions(+), 120 deletions(-) diff --git a/examples/id_embeddings/heterogeneous_training.py b/examples/id_embeddings/heterogeneous_training.py index 8a8459487..b1c2d6b45 100644 --- a/examples/id_embeddings/heterogeneous_training.py +++ b/examples/id_embeddings/heterogeneous_training.py @@ -650,6 +650,7 @@ def _training_process( full_eval_batch_indices: list[int] = [] full_eval_recall20: list[float] = [] + full_recall_K = 20 if should_skip_training: @@ -674,7 +675,6 @@ def _training_process( # How often to run *full* Recall@K over all TRAIN edges. # This is expensive (full graph + full item ranking), so keep it infrequent. full_recall_eval_every_n_batch = 100 # adjust up/down as you like - full_recall_K = 20 logger.info(f"Rank {rank}: Starting training loop") @@ -685,11 +685,6 @@ def _training_process( train_random_neg_iter = InfiniteIterator(train_random_negative_loader) batch_losses = [] - val_losses = [] - - # Track a sample embedding to see if it changes - sample_node_type = list(node_type_to_num_nodes.keys())[0] - sample_node_id = 0 logger.info(f"Starting training loop") @@ -699,7 +694,7 @@ def _training_process( debug_this_batch = (batch_num == 0) or ((batch_num + 1) % log_every_n_batch == 0) optimizer.zero_grad() - logger.info(f"Zeroing gradients") + # logger.info(f"Zeroing gradients") main_data = next(train_main_iter) # logger.info(f"Main data: {main_data}") random_negative_data = next(train_random_neg_iter) @@ -721,7 +716,7 @@ def _training_process( # logger.info(f"y_positive keys: {list(main_data.y_positive.keys())}") # for k, v in main_data.y_positive.items(): # logger.info(f" {k}: {len(v)} positives, sample IDs: {v[:5].tolist() if len(v) > 0 else []}") - logger.info(f"Computing BPR batch") + # logger.info(f"Computing BPR batch") loss, debug_info = _compute_bpr_batch( model=model, main_data=main_data, @@ -740,30 +735,12 @@ def _training_process( logger.info(f" {key}: {value}") loss.backward() - logger.info(f"Backward pass completed") - # Check if gradients exist and their magnitudes - # NOTE: With TorchRec's fused optimizer, embedding gradients are not materialized on .grad - # (they're applied directly in backward), so we only check non-embedding params here - # if debug_this_batch and rank == 0: - # grad_norms = [] - # param_norms = [] - # for name, param in model.named_parameters(): - # if param.grad is not None: - # grad_norms.append(param.grad.norm().item()) - # param_norms.append(param.norm().item()) - # if grad_norms: - # logger.info(f" Non-embedding gradient norms - mean: {sum(grad_norms)/len(grad_norms):.6f}, " - # f"max: {max(grad_norms):.6f}, min: {min(grad_norms):.6f}") - # logger.info(f" Non-embedding param norms - mean: {sum(param_norms)/len(param_norms):.6f}, " - # f"max: {max(param_norms):.6f}, min: {min(param_norms):.6f}") - # else: - # logger.info(" Note: No .grad found on params (expected with TorchRec fused optimizer for embeddings)") optimizer.step() - logger.info(f"Step completed") + # logger.info(f"Step completed") batch_loss = _sync_metric_across_processes(loss) batch_losses.append(batch_loss) - logger.info(f"Batch loss: {batch_loss}") + # logger.info(f"Batch loss: {batch_loss}") # Track train loss for plotting (rank 0 only) if rank == 0: @@ -784,112 +761,128 @@ def _training_process( # ------------------------------------------------------------------ # Periodic *full* Recall@K over TRAIN edges vs batch (rank 0 only) # ------------------------------------------------------------------ - # if (batch_num + 1) % full_recall_eval_every_n_batch == 0 and rank == 0: - # global_step = batch_num + 1 - # logger.info( - # f"Rank {rank}: running FULL TRAIN Recall@{full_recall_K} eval " - # f"at batch {global_step}" - # ) - - # # Compute full LightGCN embeddings for all users/items on TRAIN graph - # user_emb, item_emb, srcs, dsts = compute_full_lightgcn_embeddings( - # model=model, - # dataset=dataset, - # node_type_to_num_nodes=node_type_to_num_nodes, - # device=device, - # ) - # pos_items_per_user = build_train_pos_lists( - # num_users=node_type_to_num_nodes[NodeType("user")], - # srcs=srcs, - # dsts=dsts, - # ) - - # # Compute true Recall@K over all items & all TRAIN edges - # full_recall = compute_full_recall_at_k( - # user_emb=user_emb, - # item_emb=item_emb, - # pos_items_per_user=pos_items_per_user, - # K=full_recall_K, - # device=device, - # ) - - # full_eval_batch_indices.append(global_step) - # full_eval_recall20.append(full_recall) - - # logger.info( - # f"[FULL TRAIN Recall] batch {global_step}: " - # f"Recall@{full_recall_K}={full_recall:.4f}" - # ) + if (batch_num + 1) % full_recall_eval_every_n_batch == 0: + global_step = batch_num + 1 + logger.info( + f"Rank {rank}: running FULL TRAIN Recall@{full_recall_K} eval " + f"at batch {global_step}" + ) + + torch.distributed.barrier() + + model.eval() + + + # Compute full LightGCN embeddings for all users/items on TRAIN graph + user_emb, item_emb, srcs, dsts = compute_full_lightgcn_embeddings( + model=model, + dataset=dataset, + node_type_to_num_nodes=node_type_to_num_nodes, + device=device, + ) + + if rank == 0: + pos_items_per_user = build_train_pos_lists( + num_users=node_type_to_num_nodes[NodeType("user")], + srcs=srcs, + dsts=dsts, + ) + + # Compute true Recall@K over all items & all TRAIN edges + full_recall = compute_full_recall_at_k( + user_emb=user_emb, + item_emb=item_emb, + pos_items_per_user=pos_items_per_user, + K=full_recall_K, + device=device, + ) + + full_eval_batch_indices.append(global_step) + full_eval_recall20.append(full_recall) + + logger.info( + f"[FULL TRAIN Recall] batch {global_step}: " + f"Recall@{full_recall_K}={full_recall:.4f}" + ) + + torch.distributed.barrier() + model.train() logger.info(f"Rank {rank}: Training completed. Saving model to {model_uri}") # if rank == 0: # save_state_dict(model.state_dict(), model_uri) - if rank == 0: - logger.info(f"Computing full LightGCN embeddings after training") + logger.info(f"Rank {rank}: training done, syncing before final recall") + torch.distributed.barrier() + + model.eval() user_emb, item_emb, srcs, dsts = compute_full_lightgcn_embeddings( - model=unwrap_from_dmp(model), + model=model, dataset=dataset, node_type_to_num_nodes=node_type_to_num_nodes, device=device, ) - logger.info(f"Building train pos lists after training") - pos_items_per_user = build_train_pos_lists( - num_users=node_type_to_num_nodes[NodeType("user")], - srcs=srcs, - dsts=dsts, - ) - logger.info(f"Computing full Recall@20 after training") - recall20 = compute_full_recall_at_k( - user_emb=user_emb, - item_emb=item_emb, - pos_items_per_user=pos_items_per_user, - K=20, - device=device, - ) - logger.info(f"Full-train Recall@20 over all items: {recall20:.4f}") + + if rank == 0: + logger.info(f"Building train pos lists after training") + pos_items_per_user = build_train_pos_lists( + num_users=node_type_to_num_nodes[NodeType("user")], + srcs=srcs, + dsts=dsts, + ) + logger.info(f"Computing full Recall@20 after training") + recall20 = compute_full_recall_at_k( + user_emb=user_emb, + item_emb=item_emb, + pos_items_per_user=pos_items_per_user, + K=20, + device=device, + ) + logger.info(f"Full-train Recall@20 over all items: {recall20:.4f}") + + torch.distributed.barrier() # ---------------------------------------------------------------------- # Offline plotting with matplotlib (only rank 0) # ---------------------------------------------------------------------- - # if rank == 0: - # # Train loss vs batch - # loss_plot_path = "train_loss_vs_batch.png" - # if train_batch_indices and train_losses: - # plt.figure() - # plt.plot(train_batch_indices, train_losses, marker=".", linewidth=1) - # plt.xlabel("Batch") - # plt.ylabel("Train BPR loss") - # plt.title("Train Loss vs Batch") - # plt.grid(True, alpha=0.3) - # plt.tight_layout() - # plt.savefig(loss_plot_path) - # plt.close() - # logger.info(f"Saved train loss curve to {loss_plot_path}") - - # if plots_output_uri: - # gcs_path = plots_output_uri.rstrip("/") + "/train_loss_vs_batch.png" - # upload_file_to_gcs(loss_plot_path, gcs_path) - # logger.info(f"Uploaded train loss curve to {gcs_path}") - - # # Full-train Recall@20 vs batch - # recall_plot_path = f"full_train_recall@{full_recall_K}_vs_batch.png" - # if full_eval_batch_indices and full_eval_recall20: - # plt.figure() - # plt.plot(full_eval_batch_indices, full_eval_recall20, marker="o", linewidth=1) - # plt.xlabel("Batch") - # plt.ylabel(f"Full-train Recall@{full_recall_K}") - # plt.title(f"Full-train Recall@{full_recall_K} vs Batch") - # plt.grid(True, alpha=0.3) - # plt.tight_layout() - # plt.savefig(recall_plot_path) - # plt.close() - # logger.info(f"Saved full-train Recall@{full_recall_K} curve to {recall_plot_path}") - - # if plots_output_uri: - # gcs_path = plots_output_uri.rstrip("/") + f"/{recall_plot_path}" - # upload_file_to_gcs(recall_plot_path, gcs_path) - # logger.info(f"Uploaded full-train Recall curve to {gcs_path}") + if rank == 0: + # Train loss vs batch + loss_plot_path = "train_loss_vs_batch.png" + if train_batch_indices and train_losses: + plt.figure() + plt.plot(train_batch_indices, train_losses, marker=".", linewidth=1) + plt.xlabel("Batch") + plt.ylabel("Train BPR loss") + plt.title("Train Loss vs Batch") + plt.grid(True, alpha=0.3) + plt.tight_layout() + plt.savefig(loss_plot_path) + plt.close() + logger.info(f"Saved train loss curve to {loss_plot_path}") + + if plots_output_uri: + gcs_path = plots_output_uri.rstrip("/") + "/train_loss_vs_batch.png" + upload_file_to_gcs(loss_plot_path, gcs_path) + logger.info(f"Uploaded train loss curve to {gcs_path}") + + # Full-train Recall@20 vs batch + recall_plot_path = f"full_train_recall@{full_recall_K}_vs_batch.png" + if full_eval_batch_indices and full_eval_recall20: + plt.figure() + plt.plot(full_eval_batch_indices, full_eval_recall20, marker="o", linewidth=1) + plt.xlabel("Batch") + plt.ylabel(f"Full-train Recall@{full_recall_K}") + plt.title(f"Full-train Recall@{full_recall_K} vs Batch") + plt.grid(True, alpha=0.3) + plt.tight_layout() + plt.savefig(recall_plot_path) + plt.close() + logger.info(f"Saved full-train Recall@{full_recall_K} curve to {recall_plot_path}") + + if plots_output_uri: + gcs_path = plots_output_uri.rstrip("/") + f"/{recall_plot_path}" + upload_file_to_gcs(recall_plot_path, gcs_path) + logger.info(f"Uploaded full-train Recall curve to {gcs_path}") From 504407ac9201c50be02b59fce2b565643abbff9a Mon Sep 17 00:00:00 2001 From: swong3 Date: Thu, 18 Dec 2025 23:23:12 +0000 Subject: [PATCH 8/9] Changed default train batches --- examples/id_embeddings/heterogeneous_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/id_embeddings/heterogeneous_training.py b/examples/id_embeddings/heterogeneous_training.py index b1c2d6b45..ffa341b24 100644 --- a/examples/id_embeddings/heterogeneous_training.py +++ b/examples/id_embeddings/heterogeneous_training.py @@ -1003,7 +1003,7 @@ def _run_example_training( learning_rate = float(trainer_args.get("learning_rate", "0.01")) weight_decay = float(trainer_args.get("weight_decay", "0.0005")) num_max_train_batches = int(trainer_args.get("num_max_train_batches", "1000")) - num_max_train_batches = 10000 + num_max_train_batches = 1000 num_val_batches = int(trainer_args.get("num_val_batches", "100")) val_every_n_batch = int(trainer_args.get("val_every_n_batch", "50")) From 10a49c984a33c9975f740aba0795d8592c2b2e21 Mon Sep 17 00:00:00 2001 From: swong3 Date: Fri, 19 Dec 2025 19:32:24 +0000 Subject: [PATCH 9/9] Changed default train batches --- examples/id_embeddings/heterogeneous_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/id_embeddings/heterogeneous_training.py b/examples/id_embeddings/heterogeneous_training.py index ffa341b24..8db686102 100644 --- a/examples/id_embeddings/heterogeneous_training.py +++ b/examples/id_embeddings/heterogeneous_training.py @@ -1003,7 +1003,7 @@ def _run_example_training( learning_rate = float(trainer_args.get("learning_rate", "0.01")) weight_decay = float(trainer_args.get("weight_decay", "0.0005")) num_max_train_batches = int(trainer_args.get("num_max_train_batches", "1000")) - num_max_train_batches = 1000 + num_max_train_batches = 5000 num_val_batches = int(trainer_args.get("num_val_batches", "100")) val_every_n_batch = int(trainer_args.get("val_every_n_batch", "50"))