From f82f432e0ab12bb9e27deeb7065f2aa69dacf848 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 28 Aug 2025 18:04:04 +0000 Subject: [PATCH 01/11] initial work --- examples/server_client/__init__.py | 0 examples/server_client/same_machine.py | 164 +++++++++++++++++++++++++ 2 files changed, 164 insertions(+) create mode 100644 examples/server_client/__init__.py create mode 100644 examples/server_client/same_machine.py diff --git a/examples/server_client/__init__.py b/examples/server_client/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/server_client/same_machine.py b/examples/server_client/same_machine.py new file mode 100644 index 000000000..daa34d6c2 --- /dev/null +++ b/examples/server_client/same_machine.py @@ -0,0 +1,164 @@ +import os + +# Suppress TensorFlow logs +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # isort: skip + +import time +import uuid +from pathlib import Path + +import graphlearn_torch as glt +import torch + +import gigl.distributed as gd +from gigl.common.logger import Logger +from gigl.distributed.utils import get_free_port +from gigl.types.graph import to_homogeneous + +logger = Logger() + + +def run_server( + server_rank: int, + num_servers: int, + num_clients: int, + host: str, + port: int, + output_dir: str, +) -> None: + dataset = gd.build_dataset_from_task_config_uri( + task_config_uri="gs://public-gigl/mocked_assets/2024-07-15--21-30-07-UTC/cora_homogeneous_node_anchor_edge_features_user_defined_labels/frozen_gbml_config.yaml", + is_inference=True, + _tfrecord_uri_pattern=".*tfrecord", + ) + logger.info( + f"Dumping {to_homogeneous(dataset.node_ids).numel()} node_ids to {output_dir}/node_ids.pt" + ) + torch.save(to_homogeneous(dataset.node_ids), f"{output_dir}/node_ids.pt") + logger.info(f"Initializing server") + glt.distributed.init_server( + num_servers=num_servers, + server_rank=server_rank, + dataset=dataset, + master_addr=host, + master_port=port, + num_clients=num_clients, + ) + + logger.info(f"Waiting for server rank {server_rank} to exit") + glt.distributed.wait_and_shutdown_server() + logger.info(f"Server rank {server_rank} exited") + + +def run_client( + client_rank: int, + num_clients: int, + num_servers: int, + host: str, + port: int, + output_dir: str, +) -> None: + glt.distributed.init_client( + num_servers=num_servers, + num_clients=num_clients, + client_rank=client_rank, + master_addr=host, + master_port=port, + ) + current_ctx = glt.distributed.get_context() + current_device = torch.device(current_ctx.rank % torch.cuda.device_count()) + logger.info(f"Client rank {client_rank} initialized on device {current_device}") + + logger.info(f"Loading node_ids from {output_dir}/node_ids.pt") + node_ids = torch.load(f"{output_dir}/node_ids.pt") + logger.info(f"Loaded {node_ids.numel()} node_ids") + num_workers = 4 + + loader = glt.distributed.DistNeighborLoader( + data=None, + num_neighbors=[2, 2], + input_nodes=f"{output_dir}/node_ids.pt", + worker_options=glt.distributed.RemoteDistSamplingWorkerOptions( + server_rank=0, + num_workers=num_workers, + worker_devices=[torch.device("cpu") for i in range(num_workers)], + master_addr=host, + master_port=get_free_port(), + ), + ) + + for batch in loader: + logger.info(f"Batch: {batch}") + + logger.info(f"Shutting down client") + glt.distributed.shutdown_client() + logger.info(f"Client rank {client_rank} exited") + + +def main(): + output_dir = f"/tmp/gigl/server_client/output/{uuid.uuid4()}" + Path(output_dir).mkdir(parents=True, exist_ok=True) + + num_servers = 1 + num_clients = 2 + server_processes = [] + mp_context = torch.multiprocessing.get_context("spawn") + server_client_port = get_free_port() + + for server_rank in range(num_servers): + server_process = mp_context.Process( + target=run_server, + args=( + server_rank, + num_servers, + num_clients, + "localhost", + server_client_port, + output_dir, + ), + ) + server_processes.append(server_process) + + for server_process in server_processes: + server_process.start() + + output_file = Path(f"{output_dir}/node_ids.pt") + + while not output_file.exists(): + time.sleep(5) + logger.info( + f"Waiting for server rank {server_rank} to dump node_ids to {output_dir}/node_ids.pt" + ) + + client_processes = [] + + for client_rank in range(num_clients): + client_process = mp_context.Process( + target=run_client, + args=( + client_rank, + num_clients, + num_servers, + "localhost", + server_client_port, + output_dir, + ), + ) + client_processes.append(client_process) + + for client_process in client_processes: + client_process.start() + + logger.info(f"Waiting for client processes to exit") + for client_process in client_processes: + client_process.join() + + logger.info(f"Waiting for server processes to exit") + for server_process in server_processes: + server_process.join() + + logger.info(f"All processes exited") + + +if __name__ == "__main__": + main() From 667759abdd2e079de83f4a6740f76428416b65bc Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 28 Aug 2025 19:46:00 +0000 Subject: [PATCH 02/11] start on loader --- examples/server_client/same_machine.py | 21 ++++++++++++++----- .../distributed/distributed_neighborloader.py | 10 ++++++--- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/examples/server_client/same_machine.py b/examples/server_client/same_machine.py index daa34d6c2..055d2f3b2 100644 --- a/examples/server_client/same_machine.py +++ b/examples/server_client/same_machine.py @@ -3,6 +3,7 @@ # Suppress TensorFlow logs os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # isort: skip +import argparse import time import uuid from pathlib import Path @@ -85,6 +86,7 @@ def run_client( master_addr=host, master_port=get_free_port(), ), + to_device=current_device, ) for batch in loader: @@ -96,11 +98,20 @@ def run_client( def main(): - output_dir = f"/tmp/gigl/server_client/output/{uuid.uuid4()}" + parser = argparse.ArgumentParser() + parser.add_argument("--num_servers", type=int, default=1) + parser.add_argument("--num_clients", type=int, default=2) + parser.add_argument("--output_dir", type=str, default=f"/tmp/gigl/server_client/output/{uuid.uuid4()}") + parser.add_argument("--host", type=str, default="localhost") + args = parser.parse_args() + logger.info(f"Arguments: {args}") + + # Parse arguments + num_servers = args.num_servers + num_clients = args.num_clients + output_dir = args.output_dir Path(output_dir).mkdir(parents=True, exist_ok=True) - num_servers = 1 - num_clients = 2 server_processes = [] mp_context = torch.multiprocessing.get_context("spawn") server_client_port = get_free_port() @@ -112,7 +123,7 @@ def main(): server_rank, num_servers, num_clients, - "localhost", + args.host, server_client_port, output_dir, ), @@ -139,7 +150,7 @@ def main(): client_rank, num_clients, num_servers, - "localhost", + args.host, server_client_port, output_dir, ), diff --git a/python/gigl/distributed/distributed_neighborloader.py b/python/gigl/distributed/distributed_neighborloader.py index 0ecc8e743..5ece8c92f 100644 --- a/python/gigl/distributed/distributed_neighborloader.py +++ b/python/gigl/distributed/distributed_neighborloader.py @@ -9,6 +9,7 @@ from torch_geometric.typing import EdgeType import gigl.distributed.utils +from gigl.common import Uri from gigl.common.logger import Logger from gigl.distributed.constants import DEFAULT_MASTER_INFERENCE_PORT from gigl.distributed.dist_context import DistributedContext @@ -37,10 +38,10 @@ class DistNeighborLoader(DistLoader): def __init__( self, - dataset: DistLinkPredictionDataset, + dataset: Optional[DistLinkPredictionDataset], num_neighbors: Union[list[int], dict[EdgeType, list[int]]], input_nodes: Optional[ - Union[torch.Tensor, Tuple[NodeType, torch.Tensor]] + Union[torch.Tensor, Tuple[NodeType, torch.Tensor], list[Uri]] ] = None, num_workers: int = 1, batch_size: int = 1, @@ -193,6 +194,8 @@ def __init__( ) if input_nodes is None: + if dataset is None: + raise ValueError("Dataset must be provided if input_nodes are not provided.") if dataset.node_ids is None: raise ValueError( "Dataset must have node ids if input_nodes are not provided." @@ -223,11 +226,12 @@ def __init__( ) else: node_type = None - else: + elif isinstance(input_nodes, tuple): node_type, node_ids = input_nodes assert isinstance( dataset.node_ids, abc.Mapping ), "Dataset must be heterogeneous if provided input nodes are a tuple." + elif isinstance(input_nodes, list): num_neighbors = patch_fanout_for_sampling( dataset.get_edge_types(), num_neighbors From eb0cd89b912497e80503c3035fb4fd8f4328a987 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 29 Aug 2025 17:56:42 +0000 Subject: [PATCH 03/11] wip --- examples/server_client/same_machine.py | 6 +- .../distributed/distributed_neighborloader.py | 83 +++++++++++-------- python/gigl/distributed/sampler.py | 39 ++++++++- 3 files changed, 87 insertions(+), 41 deletions(-) diff --git a/examples/server_client/same_machine.py b/examples/server_client/same_machine.py index 055d2f3b2..f48f20148 100644 --- a/examples/server_client/same_machine.py +++ b/examples/server_client/same_machine.py @@ -103,6 +103,7 @@ def main(): parser.add_argument("--num_clients", type=int, default=2) parser.add_argument("--output_dir", type=str, default=f"/tmp/gigl/server_client/output/{uuid.uuid4()}") parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=get_free_port()) args = parser.parse_args() logger.info(f"Arguments: {args}") @@ -114,7 +115,6 @@ def main(): server_processes = [] mp_context = torch.multiprocessing.get_context("spawn") - server_client_port = get_free_port() for server_rank in range(num_servers): server_process = mp_context.Process( @@ -124,7 +124,7 @@ def main(): num_servers, num_clients, args.host, - server_client_port, + args.port, output_dir, ), ) @@ -151,7 +151,7 @@ def main(): num_clients, num_servers, args.host, - server_client_port, + args.port, output_dir, ), ) diff --git a/python/gigl/distributed/distributed_neighborloader.py b/python/gigl/distributed/distributed_neighborloader.py index 5ece8c92f..8c649cfa1 100644 --- a/python/gigl/distributed/distributed_neighborloader.py +++ b/python/gigl/distributed/distributed_neighborloader.py @@ -21,6 +21,7 @@ shard_nodes_by_process, strip_label_edges, ) +from gigl.distributed.sampler import RemoteUriSamplerInput from gigl.src.common.types.graph_data import ( NodeType, # TODO (mkolodner-sc): Change to use torch_geometric.typing ) @@ -41,7 +42,7 @@ def __init__( dataset: Optional[DistLinkPredictionDataset], num_neighbors: Union[list[int], dict[EdgeType, list[int]]], input_nodes: Optional[ - Union[torch.Tensor, Tuple[NodeType, torch.Tensor], list[Uri]] + Union[torch.Tensor, Tuple[NodeType, torch.Tensor], Uri] ] = None, num_workers: int = 1, batch_size: int = 1, @@ -208,45 +209,55 @@ def __init__( # Determines if the node ids passed in are heterogeneous or homogeneous. self._is_labeled_heterogeneous = False - if isinstance(input_nodes, torch.Tensor): - node_ids = input_nodes - - # If the dataset is heterogeneous, we may be in the "labeled homogeneous" setting, - # if so, then we should use DEFAULT_HOMOGENEOUS_NODE_TYPE. - if isinstance(dataset.node_ids, abc.Mapping): - if ( - len(dataset.node_ids) == 1 - and DEFAULT_HOMOGENEOUS_NODE_TYPE in dataset.node_ids - ): - node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE - self._is_labeled_heterogeneous = True + if dataset is None: + if not isinstance(input_nodes, tuple): + raise ValueError("input_nodes must be a tuple if dataset is None.") + node_type, uri = input_nodes + if not isinstance(uri, Uri): + raise ValueError(f"uri must be a Uri, received {uri} of type {type(uri)}") + if not isinstance(node_type, NodeType): + raise ValueError(f"node_type must be a NodeType, received {node_type} of type {type(node_type)}") + input_data = RemoteUriSamplerInput(uri, node_type) + else: + if isinstance(input_nodes, torch.Tensor): + node_ids = input_nodes + + # If the dataset is heterogeneous, we may be in the "labeled homogeneous" setting, + # if so, then we should use DEFAULT_HOMOGENEOUS_NODE_TYPE. + if isinstance(dataset.node_ids, abc.Mapping): + if ( + len(dataset.node_ids) == 1 + and DEFAULT_HOMOGENEOUS_NODE_TYPE in dataset.node_ids + ): + node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE + self._is_labeled_heterogeneous = True + else: + raise ValueError( + f"For heterogeneous datasets, input_nodes must be a tuple of (node_type, node_ids) OR if it is a labeled homogeneous dataset, input_nodes may be a torch.Tensor. Received node types: {dataset.node_ids.keys()}" + ) else: - raise ValueError( - f"For heterogeneous datasets, input_nodes must be a tuple of (node_type, node_ids) OR if it is a labeled homogeneous dataset, input_nodes may be a torch.Tensor. Received node types: {dataset.node_ids.keys()}" - ) - else: - node_type = None - elif isinstance(input_nodes, tuple): - node_type, node_ids = input_nodes - assert isinstance( - dataset.node_ids, abc.Mapping - ), "Dataset must be heterogeneous if provided input nodes are a tuple." - elif isinstance(input_nodes, list): - - num_neighbors = patch_fanout_for_sampling( - dataset.get_edge_types(), num_neighbors - ) + node_type = None + elif isinstance(input_nodes, tuple): + node_type, node_ids = input_nodes + if not isinstance(node_type, NodeType): + raise ValueError(f"node_type must be a NodeType, received {node_type} of type {type(node_type)}") + if not isinstance(node_ids, torch.Tensor): + raise ValueError(f"node_ids must be a torch.Tensor, received {node_ids} of type {type(node_ids)}") + + num_neighbors = patch_fanout_for_sampling( + dataset.get_edge_types(), num_neighbors + ) - curr_process_nodes = shard_nodes_by_process( - input_nodes=node_ids, - local_process_rank=local_rank, - local_process_world_size=local_world_size, - ) + curr_process_nodes = shard_nodes_by_process( + input_nodes=node_ids, + local_process_rank=local_rank, + local_process_world_size=local_world_size, + ) - self._node_feature_info = dataset.node_feature_info - self._edge_feature_info = dataset.edge_feature_info + self._node_feature_info = dataset.node_feature_info + self._edge_feature_info = dataset.edge_feature_info - input_data = NodeSamplerInput(node=curr_process_nodes, input_type=node_type) + input_data = NodeSamplerInput(node=curr_process_nodes, input_type=node_type) # Sets up processes and torch device for initializing the GLT DistNeighborLoader, setting up RPC and worker groups to minimize # the memory overhead and CPU contention. diff --git a/python/gigl/distributed/sampler.py b/python/gigl/distributed/sampler.py index 09b03a03c..237ae933b 100644 --- a/python/gigl/distributed/sampler.py +++ b/python/gigl/distributed/sampler.py @@ -1,9 +1,44 @@ from typing import Any, Optional, Union +from dataclasses import asdict, dataclass +import ast + import torch -from graphlearn_torch.sampler import NodeSamplerInput +from graphlearn_torch.sampler import NodeSamplerInput, RemoteSamplerInput + +from gigl.common import Uri +from gigl.types.graph import FeatureInfo +from gigl.src.common.types.graph_data import NodeType, EdgeType +from gigl.src.common.utils.file_loader import FileLoader + + +@dataclass +class RemoteNodeInfo: + node_type: NodeType + edge_types: list[tuple[NodeType, NodeType, NodeType]] + num_nodes: int + node_feature_info: Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]] + edge_feature_info: Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]] + + def dump(self) -> str: + return str(asdict(self)) + + @classmethod + def load(cls, uri: Uri) -> "RemoteNodeInfo": + with FileLoader().load_to_temp_file(uri) as temp_file: + s = temp_file.read() + return cls(**ast.literal_eval(s)) + +class RemoteUriSamplerInput(RemoteSamplerInput): + def __init__(self, uri: Uri, input_type: Optional[Union[str, NodeType]]): + self._uri = uri + self._input_type = input_type -from gigl.src.common.types.graph_data import NodeType + def to_local_sampler_input(self, dataset, **kwargs) -> NodeSamplerInput: + file_loader = FileLoader() + with file_loader.load_to_temp_file(self._uri) as temp_file: + tensor = torch.load(temp_file) + return NodeSamplerInput(node=tensor, input_type=self._input_type) class ABLPNodeSamplerInput(NodeSamplerInput): From f14fb73ff40241e9febc8923815175889120bda1 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 3 Sep 2025 15:57:23 +0000 Subject: [PATCH 04/11] bleg --- examples/server_client/same_machine.py | 58 ++++++++++++++----- .../distributed/distributed_neighborloader.py | 29 ++++++---- python/gigl/distributed/sampler.py | 21 +++++-- 3 files changed, 77 insertions(+), 31 deletions(-) diff --git a/examples/server_client/same_machine.py b/examples/server_client/same_machine.py index f48f20148..3e67bd025 100644 --- a/examples/server_client/same_machine.py +++ b/examples/server_client/same_machine.py @@ -12,8 +12,10 @@ import torch import gigl.distributed as gd +from gigl.common import UriFactory from gigl.common.logger import Logger from gigl.distributed.utils import get_free_port +from gigl.distributed.sampler import RemoteNodeInfo from gigl.types.graph import to_homogeneous logger = Logger() @@ -32,10 +34,23 @@ def run_server( is_inference=True, _tfrecord_uri_pattern=".*tfrecord", ) + node_id_uri = f"{output_dir}/node_ids.pt" logger.info( - f"Dumping {to_homogeneous(dataset.node_ids).numel()} node_ids to {output_dir}/node_ids.pt" + f"Dumping {to_homogeneous(dataset.node_ids).numel()} node_ids to {node_id_uri}" ) - torch.save(to_homogeneous(dataset.node_ids), f"{output_dir}/node_ids.pt") + torch.save(to_homogeneous(dataset.node_ids), node_id_uri) + remote_node_info = RemoteNodeInfo( + node_type=None, + edge_types=dataset.get_edge_types(), + node_tensor_uri=node_id_uri, + node_feature_info=dataset.node_feature_info, + edge_feature_info=dataset.edge_feature_info, + num_partitions=dataset.num_partitions, + edge_dir=dataset.edge_dir, + ) + with open(f"{output_dir}/remote_node_info.pyast", "w") as f: + f.write(remote_node_info.dump()) + print(f"Wrote remote node info to {output_dir}/remote_node_info.pyast") logger.info(f"Initializing server") glt.distributed.init_server( num_servers=num_servers, @@ -75,22 +90,35 @@ def run_client( logger.info(f"Loaded {node_ids.numel()} node_ids") num_workers = 4 - loader = glt.distributed.DistNeighborLoader( - data=None, + # loader = glt.distributed.DistNeighborLoader( + # data=None, + # num_neighbors=[2, 2], + # input_nodes=f"{output_dir}/node_ids.pt", + # worker_options=glt.distributed.RemoteDistSamplingWorkerOptions( + # server_rank=0, + # num_workers=num_workers, + # worker_devices=[torch.device("cpu") for i in range(num_workers)], + # master_addr=host, + # master_port=get_free_port(), + # ), + # to_device=current_device, + # ) + torch.distributed.init_process_group(backend="gloo") + gigl_loader = gd.DistNeighborLoader( + dataset=None, num_neighbors=[2, 2], - input_nodes=f"{output_dir}/node_ids.pt", - worker_options=glt.distributed.RemoteDistSamplingWorkerOptions( - server_rank=0, - num_workers=num_workers, - worker_devices=[torch.device("cpu") for i in range(num_workers)], - master_addr=host, - master_port=get_free_port(), - ), - to_device=current_device, + input_nodes=UriFactory.create_uri(f"{output_dir}/remote_node_info.pyast"), + num_workers=num_workers, + batch_size=1, + pin_memory_device=current_device, + worker_concurrency=num_workers, ) - for batch in loader: - logger.info(f"Batch: {batch}") + # for batch in loader: + # logger.info(f"Batch: {batch}") + + for batch in gigl_loader: + logger.info(f"Gigl Batch: {batch}") logger.info(f"Shutting down client") glt.distributed.shutdown_client() diff --git a/python/gigl/distributed/distributed_neighborloader.py b/python/gigl/distributed/distributed_neighborloader.py index 8c649cfa1..c180bf43d 100644 --- a/python/gigl/distributed/distributed_neighborloader.py +++ b/python/gigl/distributed/distributed_neighborloader.py @@ -9,11 +9,12 @@ from torch_geometric.typing import EdgeType import gigl.distributed.utils -from gigl.common import Uri +from gigl.common import Uri, UriFactory from gigl.common.logger import Logger from gigl.distributed.constants import DEFAULT_MASTER_INFERENCE_PORT from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_link_prediction_dataset import DistLinkPredictionDataset +from gigl.distributed.sampler import RemoteNodeInfo from gigl.distributed.utils.neighborloader import ( labeled_to_homogeneous, patch_fanout_for_sampling, @@ -210,14 +211,16 @@ def __init__( # Determines if the node ids passed in are heterogeneous or homogeneous. self._is_labeled_heterogeneous = False if dataset is None: - if not isinstance(input_nodes, tuple): - raise ValueError("input_nodes must be a tuple if dataset is None.") - node_type, uri = input_nodes - if not isinstance(uri, Uri): - raise ValueError(f"uri must be a Uri, received {uri} of type {type(uri)}") - if not isinstance(node_type, NodeType): - raise ValueError(f"node_type must be a NodeType, received {node_type} of type {type(node_type)}") - input_data = RemoteUriSamplerInput(uri, node_type) + if not isinstance(input_nodes, Uri): + raise ValueError("input_nodes must be a Uri if dataset is None.") + uri = input_nodes + remote_node_info = RemoteNodeInfo.load(uri) + + input_data = RemoteUriSamplerInput(UriFactory.create_uri(remote_node_info.node_tensor_uri), remote_node_info.node_type or DEFAULT_HOMOGENEOUS_NODE_TYPE) + self._node_feature_info = remote_node_info.node_feature_info + self._edge_feature_info = remote_node_info.edge_feature_info + num_partitions = remote_node_info.num_partitions + edge_dir = remote_node_info.edge_dir else: if isinstance(input_nodes, torch.Tensor): node_ids = input_nodes @@ -239,7 +242,7 @@ def __init__( node_type = None elif isinstance(input_nodes, tuple): node_type, node_ids = input_nodes - if not isinstance(node_type, NodeType): + if not isinstance(node_type, str): raise ValueError(f"node_type must be a NodeType, received {node_type} of type {type(node_type)}") if not isinstance(node_ids, torch.Tensor): raise ValueError(f"node_ids must be a torch.Tensor, received {node_ids} of type {type(node_ids)}") @@ -256,6 +259,8 @@ def __init__( self._node_feature_info = dataset.node_feature_info self._edge_feature_info = dataset.edge_feature_info + num_partitions = dataset.num_partitions + edge_dir = dataset.edge_dir input_data = NodeSamplerInput(node=curr_process_nodes, input_type=node_type) @@ -314,7 +319,7 @@ def __init__( master_port=dist_sampling_port_for_current_rank, # Load testing show that when num_rpc_threads exceed 16, the performance # will degrade. - num_rpc_threads=min(dataset.num_partitions, 16), + num_rpc_threads=min(num_partitions, 16), rpc_timeout=600, channel_size=channel_size, pin_memory=device.type == "cuda", @@ -330,7 +335,7 @@ def __init__( collect_features=True, with_neg=False, with_weight=False, - edge_dir=dataset.edge_dir, + edge_dir=edge_dir, seed=None, # it's actually optional - None means random. ) diff --git a/python/gigl/distributed/sampler.py b/python/gigl/distributed/sampler.py index 237ae933b..2d4c67ed9 100644 --- a/python/gigl/distributed/sampler.py +++ b/python/gigl/distributed/sampler.py @@ -1,6 +1,7 @@ from typing import Any, Optional, Union from dataclasses import asdict, dataclass import ast +import json import torch @@ -10,23 +11,35 @@ from gigl.types.graph import FeatureInfo from gigl.src.common.types.graph_data import NodeType, EdgeType from gigl.src.common.utils.file_loader import FileLoader +from gigl.common.logger import Logger + +logger = Logger() @dataclass class RemoteNodeInfo: - node_type: NodeType + node_type: Optional[NodeType] edge_types: list[tuple[NodeType, NodeType, NodeType]] - num_nodes: int + node_tensor_uri: str node_feature_info: Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]] edge_feature_info: Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]] + num_partitions: int + edge_dir: str def dump(self) -> str: + print(f"{asdict(self)=}") + print(f"{json.dumps(asdict(self))=}") return str(asdict(self)) @classmethod def load(cls, uri: Uri) -> "RemoteNodeInfo": - with FileLoader().load_to_temp_file(uri) as temp_file: - s = temp_file.read() + logger.info(f"{uri=}") + tf = FileLoader().load_to_temp_file(uri, should_create_symlinks_if_possible=False) + with open(tf.name, "r") as f: + s = f.read() + logger.info(f"{s=}") + tf.close() + logger.info(f"{s=}") return cls(**ast.literal_eval(s)) class RemoteUriSamplerInput(RemoteSamplerInput): From be6518152c1039c739ffc35db4c53069ed328286 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 3 Sep 2025 18:51:06 +0000 Subject: [PATCH 05/11] local works --- examples/server_client/client.py | 97 +++++++++ examples/server_client/same_machine.py | 117 +--------- examples/server_client/server.py | 86 ++++++++ .../distributed/distributed_neighborloader.py | 137 +++++++----- python/gigl/distributed/sampler.py | 199 ++++++++++++++++-- test.py | 47 +++++ 6 files changed, 499 insertions(+), 184 deletions(-) create mode 100644 examples/server_client/client.py create mode 100644 examples/server_client/server.py create mode 100644 test.py diff --git a/examples/server_client/client.py b/examples/server_client/client.py new file mode 100644 index 000000000..2c57e65e9 --- /dev/null +++ b/examples/server_client/client.py @@ -0,0 +1,97 @@ +import os + +# Suppress TensorFlow logs +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # isort: skip + +import argparse +import uuid + +import graphlearn_torch as glt +import torch + +import gigl.distributed as gd +from gigl.common import UriFactory +from gigl.common.logger import Logger +from gigl.distributed.utils import get_free_port + +logger = Logger() + + +def run_client( + client_rank: int, + num_clients: int, + num_servers: int, + host: str, + port: int, + output_dir: str, +) -> None: + glt.distributed.init_client( + num_servers=num_servers, + num_clients=num_clients, + client_rank=client_rank, + master_addr=host, + master_port=port, + ) + current_ctx = glt.distributed.get_context() + print("Current context: ", current_ctx) + current_device = torch.device(current_ctx.rank % torch.cuda.device_count()) + logger.info(f"Client rank {client_rank} initialized on device {current_device}") + + logger.info(f"Loading node_ids from {output_dir}/node_ids.pt") + node_ids = torch.load(f"{output_dir}/node_ids.pt") + logger.info(f"Loaded {node_ids.numel()} node_ids") + num_workers = 4 + + # loader = glt.distributed.DistNeighborLoader( + # data=None, + # num_neighbors=[2, 2], + # input_nodes=f"{output_dir}/node_ids.pt", + # worker_options=glt.distributed.RemoteDistSamplingWorkerOptions( + # server_rank=0, + # num_workers=num_workers, + # worker_devices=[torch.device("cpu") for i in range(num_workers)], + # master_addr=host, + # master_port=get_free_port(), + # ), + # to_device=current_device, + # ) + torch.distributed.init_process_group( + backend="gloo", + world_size=1, + rank=0, + init_method=f"tcp://{host}:{get_free_port()}", + group_name="gigl_comms", + ) + gigl_loader = gd.DistNeighborLoader( + dataset=None, + num_neighbors=[2, 2], + input_nodes=UriFactory.create_uri(f"{output_dir}/remote_node_info.pyast"), + num_workers=num_workers, + batch_size=1, + pin_memory_device=current_device, + worker_concurrency=num_workers, + ) + + # for batch in loader: + # logger.info(f"Batch: {batch}") + + for batch in gigl_loader: + logger.info(f"Gigl Batch: {batch}") + + logger.info(f"Shutting down client") + glt.distributed.shutdown_client() + logger.info(f"Client rank {client_rank} exited") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=get_free_port()) + parser.add_argument( + "--output_dir", + type=str, + default=f"/tmp/gigl/server_client/output/{uuid.uuid4()}", + ) + args = parser.parse_args() + logger.info(f"Arguments: {args}") + run_client(0, 1, 1, args.host, args.port, args.output_dir) diff --git a/examples/server_client/same_machine.py b/examples/server_client/same_machine.py index 3e67bd025..216819eff 100644 --- a/examples/server_client/same_machine.py +++ b/examples/server_client/same_machine.py @@ -8,128 +8,25 @@ import uuid from pathlib import Path -import graphlearn_torch as glt import torch +from examples.server_client.client import run_client +from examples.server_client.server import run_server -import gigl.distributed as gd -from gigl.common import UriFactory from gigl.common.logger import Logger from gigl.distributed.utils import get_free_port -from gigl.distributed.sampler import RemoteNodeInfo -from gigl.types.graph import to_homogeneous logger = Logger() -def run_server( - server_rank: int, - num_servers: int, - num_clients: int, - host: str, - port: int, - output_dir: str, -) -> None: - dataset = gd.build_dataset_from_task_config_uri( - task_config_uri="gs://public-gigl/mocked_assets/2024-07-15--21-30-07-UTC/cora_homogeneous_node_anchor_edge_features_user_defined_labels/frozen_gbml_config.yaml", - is_inference=True, - _tfrecord_uri_pattern=".*tfrecord", - ) - node_id_uri = f"{output_dir}/node_ids.pt" - logger.info( - f"Dumping {to_homogeneous(dataset.node_ids).numel()} node_ids to {node_id_uri}" - ) - torch.save(to_homogeneous(dataset.node_ids), node_id_uri) - remote_node_info = RemoteNodeInfo( - node_type=None, - edge_types=dataset.get_edge_types(), - node_tensor_uri=node_id_uri, - node_feature_info=dataset.node_feature_info, - edge_feature_info=dataset.edge_feature_info, - num_partitions=dataset.num_partitions, - edge_dir=dataset.edge_dir, - ) - with open(f"{output_dir}/remote_node_info.pyast", "w") as f: - f.write(remote_node_info.dump()) - print(f"Wrote remote node info to {output_dir}/remote_node_info.pyast") - logger.info(f"Initializing server") - glt.distributed.init_server( - num_servers=num_servers, - server_rank=server_rank, - dataset=dataset, - master_addr=host, - master_port=port, - num_clients=num_clients, - ) - - logger.info(f"Waiting for server rank {server_rank} to exit") - glt.distributed.wait_and_shutdown_server() - logger.info(f"Server rank {server_rank} exited") - - -def run_client( - client_rank: int, - num_clients: int, - num_servers: int, - host: str, - port: int, - output_dir: str, -) -> None: - glt.distributed.init_client( - num_servers=num_servers, - num_clients=num_clients, - client_rank=client_rank, - master_addr=host, - master_port=port, - ) - current_ctx = glt.distributed.get_context() - current_device = torch.device(current_ctx.rank % torch.cuda.device_count()) - logger.info(f"Client rank {client_rank} initialized on device {current_device}") - - logger.info(f"Loading node_ids from {output_dir}/node_ids.pt") - node_ids = torch.load(f"{output_dir}/node_ids.pt") - logger.info(f"Loaded {node_ids.numel()} node_ids") - num_workers = 4 - - # loader = glt.distributed.DistNeighborLoader( - # data=None, - # num_neighbors=[2, 2], - # input_nodes=f"{output_dir}/node_ids.pt", - # worker_options=glt.distributed.RemoteDistSamplingWorkerOptions( - # server_rank=0, - # num_workers=num_workers, - # worker_devices=[torch.device("cpu") for i in range(num_workers)], - # master_addr=host, - # master_port=get_free_port(), - # ), - # to_device=current_device, - # ) - torch.distributed.init_process_group(backend="gloo") - gigl_loader = gd.DistNeighborLoader( - dataset=None, - num_neighbors=[2, 2], - input_nodes=UriFactory.create_uri(f"{output_dir}/remote_node_info.pyast"), - num_workers=num_workers, - batch_size=1, - pin_memory_device=current_device, - worker_concurrency=num_workers, - ) - - # for batch in loader: - # logger.info(f"Batch: {batch}") - - for batch in gigl_loader: - logger.info(f"Gigl Batch: {batch}") - - logger.info(f"Shutting down client") - glt.distributed.shutdown_client() - logger.info(f"Client rank {client_rank} exited") - - def main(): parser = argparse.ArgumentParser() parser.add_argument("--num_servers", type=int, default=1) parser.add_argument("--num_clients", type=int, default=2) - parser.add_argument("--output_dir", type=str, default=f"/tmp/gigl/server_client/output/{uuid.uuid4()}") + parser.add_argument( + "--output_dir", + type=str, + default=f"/tmp/gigl/server_client/output/{uuid.uuid4()}", + ) parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=get_free_port()) args = parser.parse_args() diff --git a/examples/server_client/server.py b/examples/server_client/server.py new file mode 100644 index 000000000..33be382f2 --- /dev/null +++ b/examples/server_client/server.py @@ -0,0 +1,86 @@ +import os + +# Suppress TensorFlow logs +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # isort: skip + +import argparse +import io +import uuid + +import graphlearn_torch as glt +import torch + +import gigl.distributed as gd +from gigl.common import UriFactory +from gigl.common.logger import Logger +from gigl.distributed.sampler import RemoteNodeInfo +from gigl.distributed.utils import get_free_port +from gigl.src.common.utils.file_loader import FileLoader +from gigl.types.graph import to_homogeneous + +logger = Logger() + + +def run_server( + server_rank: int, + num_servers: int, + num_clients: int, + host: str, + port: int, + output_dir: str, +) -> None: + dataset = gd.build_dataset_from_task_config_uri( + task_config_uri="gs://public-gigl/mocked_assets/2024-07-15--21-30-07-UTC/cora_homogeneous_node_anchor_edge_features_user_defined_labels/frozen_gbml_config.yaml", + is_inference=True, + _tfrecord_uri_pattern=".*tfrecord", + ) + node_id_uri = f"{output_dir}/node_ids.pt" + logger.info( + f"Dumping {to_homogeneous(dataset.node_ids).numel()} node_ids to {node_id_uri}" + ) + bytes_io = io.BytesIO() + torch.save(to_homogeneous(dataset.node_ids), bytes_io) + bytes_io.seek(0) + FileLoader().load_from_filelike(UriFactory.create_uri(node_id_uri), bytes_io) + bytes_io.close() + + remote_node_info = RemoteNodeInfo( + node_type=None, + edge_types=dataset.get_edge_types(), + node_tensor_uri=node_id_uri, + node_feature_info=dataset.node_feature_info, + edge_feature_info=dataset.edge_feature_info, + num_partitions=dataset.num_partitions, + edge_dir=dataset.edge_dir, + master_port=get_free_port(), + ) + with open(f"{output_dir}/remote_node_info.pyast", "w") as f: + f.write(remote_node_info.dump()) + print(f"Wrote remote node info to {output_dir}/remote_node_info.pyast") + logger.info(f"Initializing server") + glt.distributed.init_server( + num_servers=num_servers, + server_rank=server_rank, + dataset=dataset, + master_addr=host, + master_port=port, + num_clients=num_clients, + ) + + logger.info(f"Waiting for server rank {server_rank} to exit") + glt.distributed.wait_and_shutdown_server() + logger.info(f"Server rank {server_rank} exited") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=get_free_port()) + parser.add_argument( + "--output_dir", + type=str, + default=f"/tmp/gigl/server_client/output/{uuid.uuid4()}", + ) + args = parser.parse_args() + logger.info(f"Arguments: {args}") + run_server(0, 1, 1, args.host, args.port, args.output_dir) diff --git a/python/gigl/distributed/distributed_neighborloader.py b/python/gigl/distributed/distributed_neighborloader.py index c180bf43d..c2f6e474d 100644 --- a/python/gigl/distributed/distributed_neighborloader.py +++ b/python/gigl/distributed/distributed_neighborloader.py @@ -3,7 +3,11 @@ import torch from graphlearn_torch.channel import SampleMessage -from graphlearn_torch.distributed import DistLoader, MpDistSamplingWorkerOptions +from graphlearn_torch.distributed import ( + DistLoader, + MpDistSamplingWorkerOptions, + RemoteDistSamplingWorkerOptions, +) from graphlearn_torch.sampler import NodeSamplerInput, SamplingConfig, SamplingType from torch_geometric.data import Data, HeteroData from torch_geometric.typing import EdgeType @@ -14,7 +18,7 @@ from gigl.distributed.constants import DEFAULT_MASTER_INFERENCE_PORT from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_link_prediction_dataset import DistLinkPredictionDataset -from gigl.distributed.sampler import RemoteNodeInfo +from gigl.distributed.sampler import RemoteNodeInfo, RemoteUriSamplerInput from gigl.distributed.utils.neighborloader import ( labeled_to_homogeneous, patch_fanout_for_sampling, @@ -22,7 +26,6 @@ shard_nodes_by_process, strip_label_edges, ) -from gigl.distributed.sampler import RemoteUriSamplerInput from gigl.src.common.types.graph_data import ( NodeType, # TODO (mkolodner-sc): Change to use torch_geometric.typing ) @@ -197,7 +200,9 @@ def __init__( if input_nodes is None: if dataset is None: - raise ValueError("Dataset must be provided if input_nodes are not provided.") + raise ValueError( + "Dataset must be provided if input_nodes are not provided." + ) if dataset.node_ids is None: raise ValueError( "Dataset must have node ids if input_nodes are not provided." @@ -216,11 +221,21 @@ def __init__( uri = input_nodes remote_node_info = RemoteNodeInfo.load(uri) - input_data = RemoteUriSamplerInput(UriFactory.create_uri(remote_node_info.node_tensor_uri), remote_node_info.node_type or DEFAULT_HOMOGENEOUS_NODE_TYPE) + input_data = RemoteUriSamplerInput( + UriFactory.create_uri(remote_node_info.node_tensor_uri), + remote_node_info.node_type or DEFAULT_HOMOGENEOUS_NODE_TYPE, + ) self._node_feature_info = remote_node_info.node_feature_info self._edge_feature_info = remote_node_info.edge_feature_info num_partitions = remote_node_info.num_partitions edge_dir = remote_node_info.edge_dir + worker_options = RemoteDistSamplingWorkerOptions( + server_rank=0, + num_workers=num_workers, + worker_devices=[torch.device("cpu") for i in range(num_workers)], + master_addr=master_ip_address, + master_port=remote_node_info.master_port, + ) else: if isinstance(input_nodes, torch.Tensor): node_ids = input_nodes @@ -243,9 +258,13 @@ def __init__( elif isinstance(input_nodes, tuple): node_type, node_ids = input_nodes if not isinstance(node_type, str): - raise ValueError(f"node_type must be a NodeType, received {node_type} of type {type(node_type)}") + raise ValueError( + f"node_type must be a NodeType, received {node_type} of type {type(node_type)}" + ) if not isinstance(node_ids, torch.Tensor): - raise ValueError(f"node_ids must be a torch.Tensor, received {node_ids} of type {type(node_ids)}") + raise ValueError( + f"node_ids must be a torch.Tensor, received {node_ids} of type {type(node_ids)}" + ) num_neighbors = patch_fanout_for_sampling( dataset.get_edge_types(), num_neighbors @@ -263,41 +282,62 @@ def __init__( edge_dir = dataset.edge_dir input_data = NodeSamplerInput(node=curr_process_nodes, input_type=node_type) + worker_options = MpDistSamplingWorkerOptions( + num_workers=num_workers, + worker_devices=[torch.device("cpu") for _ in range(num_workers)], + worker_concurrency=worker_concurrency, + # Each worker will spawn several sampling workers, and all sampling workers spawned by workers in one group + # need to be connected. Thus, we need master ip address and master port to + # initate the connection. + # Note that different groups of workers are independent, and thus + # the sampling processes in different groups should be independent, and should + # use different master ports. + master_addr=master_ip_address, + master_port=dist_sampling_port_for_current_rank, + # Load testing show that when num_rpc_threads exceed 16, the performance + # will degrade. + num_rpc_threads=min(num_partitions, 16), + rpc_timeout=600, + channel_size=channel_size, + pin_memory=device.type == "cuda", + ) - # Sets up processes and torch device for initializing the GLT DistNeighborLoader, setting up RPC and worker groups to minimize - # the memory overhead and CPU contention. - logger.info( - f"Initializing neighbor loader worker in process: {local_rank}/{local_world_size} using device: {device}" - ) - should_use_cpu_workers = device.type == "cpu" - if should_use_cpu_workers and num_cpu_threads is None: + # Sets up processes and torch device for initializing the GLT DistNeighborLoader, setting up RPC and worker groups to minimize + # the memory overhead and CPU contention. logger.info( - "Using CPU workers, but found num_cpu_threads to be None. " - f"Will default setting num_cpu_threads to {DEFAULT_NUM_CPU_THREADS}." + f"Initializing neighbor loader worker in process: {local_rank}/{local_world_size} using device: {device}" ) - num_cpu_threads = DEFAULT_NUM_CPU_THREADS + should_use_cpu_workers = device.type == "cpu" + if should_use_cpu_workers and num_cpu_threads is None: + logger.info( + "Using CPU workers, but found num_cpu_threads to be None. " + f"Will default setting num_cpu_threads to {DEFAULT_NUM_CPU_THREADS}." + ) + num_cpu_threads = DEFAULT_NUM_CPU_THREADS - neighbor_loader_ports = gigl.distributed.utils.get_free_ports_from_master_node( - num_ports=local_world_size - ) - neighbor_loader_port_for_current_rank = neighbor_loader_ports[local_rank] - - gigl.distributed.utils.init_neighbor_loader_worker( - master_ip_address=master_ip_address, - local_process_rank=local_rank, - local_process_world_size=local_world_size, - rank=node_rank, - world_size=node_world_size, - master_worker_port=neighbor_loader_port_for_current_rank, - device=device, - should_use_cpu_workers=should_use_cpu_workers, - # Lever to explore tuning for CPU based inference - num_cpu_threads=num_cpu_threads, - process_start_gap_seconds=process_start_gap_seconds, - ) - logger.info( - f"Finished initializing neighbor loader worker: {local_rank}/{local_world_size}" - ) + neighbor_loader_ports = ( + gigl.distributed.utils.get_free_ports_from_master_node( + num_ports=local_world_size + ) + ) + neighbor_loader_port_for_current_rank = neighbor_loader_ports[local_rank] + + gigl.distributed.utils.init_neighbor_loader_worker( + master_ip_address=master_ip_address, + local_process_rank=local_rank, + local_process_world_size=local_world_size, + rank=node_rank, + world_size=node_world_size, + master_worker_port=neighbor_loader_port_for_current_rank, + device=device, + should_use_cpu_workers=should_use_cpu_workers, + # Lever to explore tuning for CPU based inference + num_cpu_threads=num_cpu_threads, + process_start_gap_seconds=process_start_gap_seconds, + ) + logger.info( + f"Finished initializing neighbor loader worker: {local_rank}/{local_world_size}" + ) # Sets up worker options for the dataloader dist_sampling_ports = gigl.distributed.utils.get_free_ports_from_master_node( @@ -305,26 +345,6 @@ def __init__( ) dist_sampling_port_for_current_rank = dist_sampling_ports[local_rank] - worker_options = MpDistSamplingWorkerOptions( - num_workers=num_workers, - worker_devices=[torch.device("cpu") for _ in range(num_workers)], - worker_concurrency=worker_concurrency, - # Each worker will spawn several sampling workers, and all sampling workers spawned by workers in one group - # need to be connected. Thus, we need master ip address and master port to - # initate the connection. - # Note that different groups of workers are independent, and thus - # the sampling processes in different groups should be independent, and should - # use different master ports. - master_addr=master_ip_address, - master_port=dist_sampling_port_for_current_rank, - # Load testing show that when num_rpc_threads exceed 16, the performance - # will degrade. - num_rpc_threads=min(num_partitions, 16), - rpc_timeout=600, - channel_size=channel_size, - pin_memory=device.type == "cuda", - ) - sampling_config = SamplingConfig( sampling_type=SamplingType.NODE, num_neighbors=num_neighbors, @@ -345,6 +365,7 @@ def __init__( ) torch.distributed.destroy_process_group() + print(f"Using worker options: {worker_options}") super().__init__(dataset, input_data, sampling_config, device, worker_options) def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: diff --git a/python/gigl/distributed/sampler.py b/python/gigl/distributed/sampler.py index 2d4c67ed9..ff976b55f 100644 --- a/python/gigl/distributed/sampler.py +++ b/python/gigl/distributed/sampler.py @@ -1,17 +1,16 @@ -from typing import Any, Optional, Union -from dataclasses import asdict, dataclass import ast import json - +from dataclasses import dataclass +from typing import Any, Optional, Union import torch from graphlearn_torch.sampler import NodeSamplerInput, RemoteSamplerInput from gigl.common import Uri -from gigl.types.graph import FeatureInfo -from gigl.src.common.types.graph_data import NodeType, EdgeType -from gigl.src.common.utils.file_loader import FileLoader from gigl.common.logger import Logger +from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation +from gigl.src.common.utils.file_loader import FileLoader +from gigl.types.graph import FeatureInfo logger = Logger() @@ -19,38 +18,206 @@ @dataclass class RemoteNodeInfo: node_type: Optional[NodeType] - edge_types: list[tuple[NodeType, NodeType, NodeType]] + edge_types: Optional[list[tuple[NodeType, NodeType, NodeType]]] node_tensor_uri: str node_feature_info: Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]] edge_feature_info: Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]] num_partitions: int edge_dir: str + master_port: int + + def serialize(self) -> str: + """Serialize the RemoteNodeInfo to a JSON string.""" + out_dict = {} + + # Handle node_type (str or None) + out_dict["node_type"] = self.node_type + + # Handle edge_types (list of EdgeType tuples -> list of lists) + if self.edge_types is not None: + out_dict["edge_types"] = [list(edge_type) for edge_type in self.edge_types] + else: + out_dict["edge_types"] = None + + # Handle simple fields + out_dict["node_tensor_uri"] = self.node_tensor_uri + out_dict["num_partitions"] = self.num_partitions + out_dict["edge_dir"] = self.edge_dir + out_dict["master_port"] = self.master_port + + def serialize_feature_info(feature_info: FeatureInfo) -> dict: + """Serialize FeatureInfo with proper torch.dtype handling.""" + return {"dim": feature_info.dim, "dtype": str(feature_info.dtype)} + + # Handle node_feature_info (FeatureInfo, dict, or None) + if self.node_feature_info is None: + out_dict["node_feature_info"] = None + elif isinstance(self.node_feature_info, dict): + out_dict["node_feature_info"] = { + k: serialize_feature_info(v) for k, v in self.node_feature_info.items() + } + else: + out_dict["node_feature_info"] = serialize_feature_info( + self.node_feature_info + ) + + # Handle edge_feature_info (FeatureInfo, dict, or None) + if self.edge_feature_info is None: + out_dict["edge_feature_info"] = None + elif isinstance(self.edge_feature_info, dict): + out_dict["edge_feature_info"] = { + str(list(k)): serialize_feature_info(v) + for k, v in self.edge_feature_info.items() + } + else: + out_dict["edge_feature_info"] = serialize_feature_info( + self.edge_feature_info + ) + + return json.dumps(out_dict, indent=2) def dump(self) -> str: - print(f"{asdict(self)=}") - print(f"{json.dumps(asdict(self))=}") - return str(asdict(self)) + """Legacy method name for backward compatibility.""" + return self.serialize() + + def save(self, uri: Uri) -> None: + """Save RemoteNodeInfo to a URI.""" + json_str = self.serialize() + file_loader = FileLoader() + with file_loader.save_to_temp_file(json_str.encode(), uri) as temp_file: + pass # File is saved when context manager exits + + @classmethod + def deserialize(cls, json_str: str) -> "RemoteNodeInfo": + """Deserialize a JSON string to a RemoteNodeInfo instance.""" + try: + data = json.loads(json_str) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON string: {e}") + + # Validate required fields + required_fields = [ + "edge_types", + "node_tensor_uri", + "num_partitions", + "edge_dir", + "master_port", + ] + for field in required_fields: + if field not in data: + raise ValueError(f"Missing required field: {field}") + + def deserialize_feature_info(feature_info_data: dict) -> FeatureInfo: + """Deserialize FeatureInfo with proper torch.dtype handling.""" + if ( + not isinstance(feature_info_data, dict) + or "dim" not in feature_info_data + or "dtype" not in feature_info_data + ): + raise ValueError(f"Invalid FeatureInfo data: {feature_info_data}") + + dtype_str = feature_info_data["dtype"] + # Convert string representation back to torch.dtype + try: + dtype = getattr(torch, dtype_str.split(".")[-1]) + except AttributeError: + raise ValueError(f"Invalid torch dtype: {dtype_str}") + + return FeatureInfo(dim=feature_info_data["dim"], dtype=dtype) + + # Handle edge_types conversion from list of lists back to list of EdgeType tuples + edge_types = [] + if data["edge_types"] is not None: + for edge_type_list in data["edge_types"]: + edge_type = EdgeType( + src_node_type=NodeType(edge_type_list[0]), + relation=Relation(edge_type_list[1]), + dst_node_type=NodeType(edge_type_list[2]), + ) + edge_types.append(edge_type) + else: + edge_types = None + + # Handle node_feature_info deserialization + node_feature_info = None + if data["node_feature_info"] is not None: + if ( + isinstance(data["node_feature_info"], dict) + and "dim" in data["node_feature_info"] + ): + # Single FeatureInfo + node_feature_info = deserialize_feature_info(data["node_feature_info"]) + else: + # Dict of NodeType -> FeatureInfo + node_feature_info = { + NodeType(k): deserialize_feature_info(v) + for k, v in data["node_feature_info"].items() + } + + # Handle edge_feature_info deserialization + edge_feature_info = None + if data["edge_feature_info"] is not None: + if ( + isinstance(data["edge_feature_info"], dict) + and "dim" in data["edge_feature_info"] + ): + # Single FeatureInfo + edge_feature_info = deserialize_feature_info(data["edge_feature_info"]) + else: + # Dict of EdgeType -> FeatureInfo + edge_feature_info = {} + for k, v in data["edge_feature_info"].items(): + # Parse the string representation back to list then to EdgeType + edge_type_list = ast.literal_eval(k) + if isinstance(edge_type_list, list) and len(edge_type_list) == 3: + edge_type = EdgeType( + src_node_type=NodeType(edge_type_list[0]), + relation=Relation(edge_type_list[1]), + dst_node_type=NodeType(edge_type_list[2]), + ) + edge_feature_info[edge_type] = deserialize_feature_info(v) + + return cls( + node_type=NodeType(data["node_type"]) + if data["node_type"] is not None + else None, + edge_types=edge_types, + node_tensor_uri=data["node_tensor_uri"], + node_feature_info=node_feature_info, + edge_feature_info=edge_feature_info, + num_partitions=data["num_partitions"], + edge_dir=data["edge_dir"], + master_port=data["master_port"], + ) @classmethod def load(cls, uri: Uri) -> "RemoteNodeInfo": + """Load RemoteNodeInfo from a URI.""" logger.info(f"{uri=}") - tf = FileLoader().load_to_temp_file(uri, should_create_symlinks_if_possible=False) + tf = FileLoader().load_to_temp_file( + uri, should_create_symlinks_if_possible=False + ) with open(tf.name, "r") as f: - s = f.read() - logger.info(f"{s=}") + json_str = f.read() + logger.info(f"Loaded JSON: {json_str}") tf.close() - logger.info(f"{s=}") - return cls(**ast.literal_eval(s)) + return cls.deserialize(json_str) + class RemoteUriSamplerInput(RemoteSamplerInput): def __init__(self, uri: Uri, input_type: Optional[Union[str, NodeType]]): self._uri = uri self._input_type = input_type + @property + def input_type(self) -> Optional[Union[str, NodeType]]: + return self._input_type + def to_local_sampler_input(self, dataset, **kwargs) -> NodeSamplerInput: file_loader = FileLoader() with file_loader.load_to_temp_file(self._uri) as temp_file: - tensor = torch.load(temp_file) + tensor = torch.load(temp_file.name) + print(f"Loaded tensor: {tensor.shape}") return NodeSamplerInput(node=tensor, input_type=self._input_type) diff --git a/test.py b/test.py new file mode 100644 index 000000000..ca7c5b5f7 --- /dev/null +++ b/test.py @@ -0,0 +1,47 @@ +from gigl.common import Uri, UriFactory +from gigl.distributed.sampler import RemoteUriSamplerInput, RemoteNodeInfo +from gigl.src.common.utils.file_loader import FileLoader +import tempfile +import torch +from gigl.types.graph import FeatureInfo +import pydantic + + +# class RNI(pydantic.BaseModel): +# node_type: str +# edge_types: list[tuple[str, str, str]] +# node_tensor_uri: str +# num_partitions: int +# node_feature_info: FeatureInfo +# edge_dir: str + +# @pydantic.field_serializer('node_feature_info') +# def serialize_node_feature_info(self, node_feature_info: FeatureInfo) -> dict: +# return {"dim": node_feature_info.dim, "dtype": str(node_feature_info.dtype)} + + +uri = UriFactory.create_uri("gs://gigl-test/test.pt") +print(uri) + +rni = RemoteNodeInfo( + node_type="user", + edge_types=[("user", "to", "item")], + node_tensor_uri=uri.uri, + num_partitions=1, + node_feature_info=FeatureInfo(dim=1, dtype=torch.float32), + edge_feature_info={("user", "to", "item"): FeatureInfo(dim=2, dtype=torch.float32)}, + edge_dir="out", +) +print(f"{rni=}") +dumped = rni.dump() +print(f"{dumped=}") +with tempfile.NamedTemporaryFile("r+t") as temp_file: + temp_file.write(dumped) + temp_file.flush() + temp_file.seek(0) + print(f"{temp_file.name=}") + #print(f"{temp_file.read()=}") + u = UriFactory.create_uri(temp_file.name) + print(f"{u=}, {type(u)=}") + loaded = RemoteNodeInfo.load(u) + print(f"{loaded=}") From 1d8bbe9c180d3b1a82b363d4e6e4dfb7969b9bc6 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 3 Sep 2025 22:35:27 +0000 Subject: [PATCH 06/11] idk doesn't work --- examples/server_client/client.py | 14 ++--- examples/server_client/same_machine.py | 10 +-- examples/server_client/server.py | 62 ++++++++++++------- python/gigl/common/services/vertex_ai.py | 2 + .../distributed/distributed_neighborloader.py | 4 +- python/gigl/distributed/sampler.py | 12 ++-- .../common/services/vertex_ai_test.py | 6 +- 7 files changed, 66 insertions(+), 44 deletions(-) diff --git a/examples/server_client/client.py b/examples/server_client/client.py index 2c57e65e9..57717265e 100644 --- a/examples/server_client/client.py +++ b/examples/server_client/client.py @@ -37,9 +37,9 @@ def run_client( current_device = torch.device(current_ctx.rank % torch.cuda.device_count()) logger.info(f"Client rank {client_rank} initialized on device {current_device}") - logger.info(f"Loading node_ids from {output_dir}/node_ids.pt") - node_ids = torch.load(f"{output_dir}/node_ids.pt") - logger.info(f"Loaded {node_ids.numel()} node_ids") + # logger.info(f"Loading node_ids from {output_dir}/node_ids.pt") + # node_ids = torch.load(f"{output_dir}/node_ids.pt") + # logger.info(f"Loaded {node_ids.numel()} node_ids") num_workers = 4 # loader = glt.distributed.DistNeighborLoader( @@ -57,10 +57,10 @@ def run_client( # ) torch.distributed.init_process_group( backend="gloo", - world_size=1, - rank=0, - init_method=f"tcp://{host}:{get_free_port()}", - group_name="gigl_comms", + world_size=num_clients, + rank=client_rank, + group_name="gigl_loader_comms", + init_method=f"tcp://{host}:{42132}", ) gigl_loader = gd.DistNeighborLoader( dataset=None, diff --git a/examples/server_client/same_machine.py b/examples/server_client/same_machine.py index 216819eff..63c6a7575 100644 --- a/examples/server_client/same_machine.py +++ b/examples/server_client/same_machine.py @@ -60,11 +60,11 @@ def main(): output_file = Path(f"{output_dir}/node_ids.pt") - while not output_file.exists(): - time.sleep(5) - logger.info( - f"Waiting for server rank {server_rank} to dump node_ids to {output_dir}/node_ids.pt" - ) + # while not output_file.exists(): + # time.sleep(5) + # logger.info( + # f"Waiting for server rank {server_rank} to dump node_ids to {output_dir}/node_ids.pt" + # ) client_processes = [] diff --git a/examples/server_client/server.py b/examples/server_client/server.py index 33be382f2..522123e29 100644 --- a/examples/server_client/server.py +++ b/examples/server_client/server.py @@ -29,35 +29,51 @@ def run_server( port: int, output_dir: str, ) -> None: + + torch.distributed.init_process_group(backend="gloo", world_size=num_servers, rank=server_rank, init_method=f"tcp://{host}:{43002}", group_name="gigl_server_comms") dataset = gd.build_dataset_from_task_config_uri( task_config_uri="gs://public-gigl/mocked_assets/2024-07-15--21-30-07-UTC/cora_homogeneous_node_anchor_edge_features_user_defined_labels/frozen_gbml_config.yaml", is_inference=True, _tfrecord_uri_pattern=".*tfrecord", ) node_id_uri = f"{output_dir}/node_ids.pt" - logger.info( - f"Dumping {to_homogeneous(dataset.node_ids).numel()} node_ids to {node_id_uri}" - ) - bytes_io = io.BytesIO() - torch.save(to_homogeneous(dataset.node_ids), bytes_io) - bytes_io.seek(0) - FileLoader().load_from_filelike(UriFactory.create_uri(node_id_uri), bytes_io) - bytes_io.close() + if server_rank == 0: + node_tensor_uris = [] + node_pb = to_homogeneous(dataset.node_pb) + if isinstance(node_pb, torch.Tensor): + total_node_ids = node_pb.numel() + elif isinstance(node_pb, glt.partition.RangePartitionBook): + total_node_ids = node_pb.partition_bounds[-1] + else: + raise ValueError(f"Unsupported node partition book type: {type(node_pb)}") - remote_node_info = RemoteNodeInfo( - node_type=None, - edge_types=dataset.get_edge_types(), - node_tensor_uri=node_id_uri, - node_feature_info=dataset.node_feature_info, - edge_feature_info=dataset.edge_feature_info, - num_partitions=dataset.num_partitions, - edge_dir=dataset.edge_dir, - master_port=get_free_port(), - ) - with open(f"{output_dir}/remote_node_info.pyast", "w") as f: - f.write(remote_node_info.dump()) - print(f"Wrote remote node info to {output_dir}/remote_node_info.pyast") - logger.info(f"Initializing server") + for client_rank in range(num_clients): + node_tensor_uri = f"{output_dir}/node_ids_{client_rank}.pt" + node_tensor_uris.append(node_tensor_uri) + logger.info( + f"Dumping {total_node_ids // num_clients} node_ids to {node_tensor_uri}" + ) + bytes_io = io.BytesIO() + torch.save(torch.arange(start=client_rank * (total_node_ids // num_clients), end=(client_rank + 1) * (total_node_ids // num_clients)), bytes_io) + bytes_io.seek(0) + FileLoader().load_from_filelike(UriFactory.create_uri(node_tensor_uri), bytes_io) + bytes_io.close() + + remote_node_info = RemoteNodeInfo( + node_type=None, + edge_types=dataset.get_edge_types(), + node_tensor_uris=node_tensor_uris, + node_feature_info=dataset.node_feature_info, + edge_feature_info=dataset.edge_feature_info, + num_partitions=dataset.num_partitions, + edge_dir=dataset.edge_dir, + master_port=get_free_port(), + num_servers=num_servers, + ) + with open(f"{output_dir}/remote_node_info.pyast", "w") as f: + f.write(remote_node_info.dump()) + print(f"Wrote remote node info to {output_dir}/remote_node_info.pyast") + logger.info(f"Initializing serve {server_rank} / {num_servers}") glt.distributed.init_server( num_servers=num_servers, server_rank=server_rank, @@ -67,7 +83,7 @@ def run_server( num_clients=num_clients, ) - logger.info(f"Waiting for server rank {server_rank} to exit") + logger.info(f"Waiting for server rank {server_rank} / {num_servers} to exit") glt.distributed.wait_and_shutdown_server() logger.info(f"Server rank {server_rank} exited") diff --git a/python/gigl/common/services/vertex_ai.py b/python/gigl/common/services/vertex_ai.py index 04e285c4a..697dce0f4 100644 --- a/python/gigl/common/services/vertex_ai.py +++ b/python/gigl/common/services/vertex_ai.py @@ -208,6 +208,8 @@ def launch_job(self, job_config: VertexAiJobConfig) -> None: replica_count=job_config.replica_count - 1, ) worker_pool_specs.append(worker_spec) + # worker_pool_specs.append({}) + # worker_pool_specs.append(worker_spec) logger.info( f"Running Custom job {job_config.job_name} with worker_pool_specs {worker_pool_specs}, in project: {self._project}/{self._location} using staging bucket: {self._staging_bucket}, and attached labels: {job_config.labels}" diff --git a/python/gigl/distributed/distributed_neighborloader.py b/python/gigl/distributed/distributed_neighborloader.py index c2f6e474d..3d43ac83e 100644 --- a/python/gigl/distributed/distributed_neighborloader.py +++ b/python/gigl/distributed/distributed_neighborloader.py @@ -222,7 +222,7 @@ def __init__( remote_node_info = RemoteNodeInfo.load(uri) input_data = RemoteUriSamplerInput( - UriFactory.create_uri(remote_node_info.node_tensor_uri), + UriFactory.create_uri(remote_node_info.node_tensor_uris[rank]), remote_node_info.node_type or DEFAULT_HOMOGENEOUS_NODE_TYPE, ) self._node_feature_info = remote_node_info.node_feature_info @@ -230,7 +230,7 @@ def __init__( num_partitions = remote_node_info.num_partitions edge_dir = remote_node_info.edge_dir worker_options = RemoteDistSamplingWorkerOptions( - server_rank=0, + server_rank=list(range(remote_node_info.num_servers)), num_workers=num_workers, worker_devices=[torch.device("cpu") for i in range(num_workers)], master_addr=master_ip_address, diff --git a/python/gigl/distributed/sampler.py b/python/gigl/distributed/sampler.py index ff976b55f..f5874bfdf 100644 --- a/python/gigl/distributed/sampler.py +++ b/python/gigl/distributed/sampler.py @@ -19,12 +19,13 @@ class RemoteNodeInfo: node_type: Optional[NodeType] edge_types: Optional[list[tuple[NodeType, NodeType, NodeType]]] - node_tensor_uri: str + node_tensor_uris: list[str] node_feature_info: Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]] edge_feature_info: Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]] num_partitions: int edge_dir: str master_port: int + num_servers: int def serialize(self) -> str: """Serialize the RemoteNodeInfo to a JSON string.""" @@ -40,10 +41,11 @@ def serialize(self) -> str: out_dict["edge_types"] = None # Handle simple fields - out_dict["node_tensor_uri"] = self.node_tensor_uri + out_dict["node_tensor_uris"] = self.node_tensor_uris out_dict["num_partitions"] = self.num_partitions out_dict["edge_dir"] = self.edge_dir out_dict["master_port"] = self.master_port + out_dict["num_servers"] = self.num_servers def serialize_feature_info(feature_info: FeatureInfo) -> dict: """Serialize FeatureInfo with proper torch.dtype handling.""" @@ -98,10 +100,11 @@ def deserialize(cls, json_str: str) -> "RemoteNodeInfo": # Validate required fields required_fields = [ "edge_types", - "node_tensor_uri", + "node_tensor_uris", "num_partitions", "edge_dir", "master_port", + "num_servers", ] for field in required_fields: if field not in data: @@ -182,12 +185,13 @@ def deserialize_feature_info(feature_info_data: dict) -> FeatureInfo: if data["node_type"] is not None else None, edge_types=edge_types, - node_tensor_uri=data["node_tensor_uri"], + node_tensor_uris=data["node_tensor_uris"], node_feature_info=node_feature_info, edge_feature_info=edge_feature_info, num_partitions=data["num_partitions"], edge_dir=data["edge_dir"], master_port=data["master_port"], + num_servers=data["num_servers"], ) @classmethod diff --git a/python/tests/integration/common/services/vertex_ai_test.py b/python/tests/integration/common/services/vertex_ai_test.py index 30477a466..424ed431a 100644 --- a/python/tests/integration/common/services/vertex_ai_test.py +++ b/python/tests/integration/common/services/vertex_ai_test.py @@ -42,10 +42,10 @@ def test_launch_job(self): staging_bucket = resource_config.temp_assets_regional_bucket_path.uri job_name = f"GiGL-Integration-Test-{uuid.uuid4()}" container_uri = "condaforge/miniforge3:25.3.0-1" - command = ["python", "-c", "import logging; logging.info('Hello, World!')"] + command = ["printenv"] job_config = VertexAiJobConfig( - job_name=job_name, container_uri=container_uri, command=command + job_name=job_name, container_uri=container_uri, command=command, replica_count=2 ) vertex_ai_service = VertexAIService( @@ -57,7 +57,7 @@ def test_launch_job(self): vertex_ai_service.launch_job(job_config) - def test_run_pipeline(self): + def _test_run_pipeline(self): with tempfile.TemporaryDirectory() as tmpdir: pipeline_def = os.path.join(tmpdir, "pipeline.yaml") kfp.compiler.Compiler().compile(get_pipeline, pipeline_def) From 1f64a89e12a23cfd324f8ff04429a7036bb926de Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 3 Sep 2025 23:59:36 +0000 Subject: [PATCH 07/11] multi server works --- examples/server_client/client.py | 18 ++++++------ examples/server_client/same_machine.py | 1 - examples/server_client/server.py | 28 ++++++++++++++----- .../distributed/distributed_neighborloader.py | 26 +++++++++++++---- .../common/services/vertex_ai_test.py | 5 +++- 5 files changed, 54 insertions(+), 24 deletions(-) diff --git a/examples/server_client/client.py b/examples/server_client/client.py index 57717265e..7bbac0947 100644 --- a/examples/server_client/client.py +++ b/examples/server_client/client.py @@ -45,16 +45,19 @@ def run_client( # loader = glt.distributed.DistNeighborLoader( # data=None, # num_neighbors=[2, 2], - # input_nodes=f"{output_dir}/node_ids.pt", + # input_nodes=f"{output_dir}/node_ids_{client_rank}.pt", # worker_options=glt.distributed.RemoteDistSamplingWorkerOptions( - # server_rank=0, + # server_rank=[server_rank for server_rank in range(num_servers)], # num_workers=num_workers, # worker_devices=[torch.device("cpu") for i in range(num_workers)], # master_addr=host, - # master_port=get_free_port(), + # master_port=32421, # ), # to_device=current_device, # ) + + # for batch in loader: + # logger.info(f"Batch: {batch}") torch.distributed.init_process_group( backend="gloo", world_size=num_clients, @@ -71,12 +74,9 @@ def run_client( pin_memory_device=current_device, worker_concurrency=num_workers, ) - - # for batch in loader: - # logger.info(f"Batch: {batch}") - - for batch in gigl_loader: - logger.info(f"Gigl Batch: {batch}") + for i, batch in enumerate(gigl_loader): + if i % 100 == 0: + logger.info(f"Gigl Batch: {batch}") logger.info(f"Shutting down client") glt.distributed.shutdown_client() diff --git a/examples/server_client/same_machine.py b/examples/server_client/same_machine.py index 63c6a7575..ac1bbfe25 100644 --- a/examples/server_client/same_machine.py +++ b/examples/server_client/same_machine.py @@ -4,7 +4,6 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # isort: skip import argparse -import time import uuid from pathlib import Path diff --git a/examples/server_client/server.py b/examples/server_client/server.py index 522123e29..54fadfb77 100644 --- a/examples/server_client/server.py +++ b/examples/server_client/server.py @@ -29,8 +29,13 @@ def run_server( port: int, output_dir: str, ) -> None: - - torch.distributed.init_process_group(backend="gloo", world_size=num_servers, rank=server_rank, init_method=f"tcp://{host}:{43002}", group_name="gigl_server_comms") + torch.distributed.init_process_group( + backend="gloo", + world_size=num_servers, + rank=server_rank, + init_method=f"tcp://{host}:{43002}", + group_name="gigl_server_comms", + ) dataset = gd.build_dataset_from_task_config_uri( task_config_uri="gs://public-gigl/mocked_assets/2024-07-15--21-30-07-UTC/cora_homogeneous_node_anchor_edge_features_user_defined_labels/frozen_gbml_config.yaml", is_inference=True, @@ -47,16 +52,25 @@ def run_server( else: raise ValueError(f"Unsupported node partition book type: {type(node_pb)}") - for client_rank in range(num_clients): - node_tensor_uri = f"{output_dir}/node_ids_{client_rank}.pt" + num_shards = num_servers * num_clients + for shard_rank in range(num_shards): + node_tensor_uri = f"{output_dir}/node_ids_{shard_rank}.pt" node_tensor_uris.append(node_tensor_uri) logger.info( - f"Dumping {total_node_ids // num_clients} node_ids to {node_tensor_uri}" + f"Dumping {total_node_ids // num_shards} node_ids to {node_tensor_uri}" ) bytes_io = io.BytesIO() - torch.save(torch.arange(start=client_rank * (total_node_ids // num_clients), end=(client_rank + 1) * (total_node_ids // num_clients)), bytes_io) + torch.save( + torch.arange( + start=shard_rank * (total_node_ids // num_shards), + end=(shard_rank + 1) * (total_node_ids // num_shards), + ), + bytes_io, + ) bytes_io.seek(0) - FileLoader().load_from_filelike(UriFactory.create_uri(node_tensor_uri), bytes_io) + FileLoader().load_from_filelike( + UriFactory.create_uri(node_tensor_uri), bytes_io + ) bytes_io.close() remote_node_info = RemoteNodeInfo( diff --git a/python/gigl/distributed/distributed_neighborloader.py b/python/gigl/distributed/distributed_neighborloader.py index 3d43ac83e..83d01d2fa 100644 --- a/python/gigl/distributed/distributed_neighborloader.py +++ b/python/gigl/distributed/distributed_neighborloader.py @@ -221,20 +221,34 @@ def __init__( uri = input_nodes remote_node_info = RemoteNodeInfo.load(uri) - input_data = RemoteUriSamplerInput( - UriFactory.create_uri(remote_node_info.node_tensor_uris[rank]), - remote_node_info.node_type or DEFAULT_HOMOGENEOUS_NODE_TYPE, - ) + num_shards = remote_node_info.num_servers + input_data = [] + for shard_rank in range(num_shards): + input_data.append( + RemoteUriSamplerInput( + UriFactory.create_uri( + remote_node_info.node_tensor_uris[ + shard_rank + remote_node_info.num_servers * node_rank + ] + ), + remote_node_info.node_type or DEFAULT_HOMOGENEOUS_NODE_TYPE, + ) + ) self._node_feature_info = remote_node_info.node_feature_info self._edge_feature_info = remote_node_info.edge_feature_info num_partitions = remote_node_info.num_partitions edge_dir = remote_node_info.edge_dir worker_options = RemoteDistSamplingWorkerOptions( - server_rank=list(range(remote_node_info.num_servers)), + server_rank=[ + server_rank for server_rank in range(remote_node_info.num_servers) + ], num_workers=num_workers, worker_devices=[torch.device("cpu") for i in range(num_workers)], master_addr=master_ip_address, - master_port=remote_node_info.master_port, + master_port=42192, + ) + num_neighbors = patch_fanout_for_sampling( + remote_node_info.edge_types, num_neighbors ) else: if isinstance(input_nodes, torch.Tensor): diff --git a/python/tests/integration/common/services/vertex_ai_test.py b/python/tests/integration/common/services/vertex_ai_test.py index 424ed431a..15aa1e6ca 100644 --- a/python/tests/integration/common/services/vertex_ai_test.py +++ b/python/tests/integration/common/services/vertex_ai_test.py @@ -45,7 +45,10 @@ def test_launch_job(self): command = ["printenv"] job_config = VertexAiJobConfig( - job_name=job_name, container_uri=container_uri, command=command, replica_count=2 + job_name=job_name, + container_uri=container_uri, + command=command, + replica_count=2, ) vertex_ai_service = VertexAIService( From 3c18a1129f1ab2c85a00671deb0f892427b3864f Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 5 Sep 2025 17:05:31 +0000 Subject: [PATCH 08/11] need to get client cluster comms up and running --- examples/server_client/client.py | 46 +++++++++++++++++-- examples/server_client/same_machine.py | 57 ++++-------------------- examples/server_client/server.py | 41 +++++++++++++++-- python/gigl/common/services/vertex_ai.py | 31 ++++++++++++- 4 files changed, 118 insertions(+), 57 deletions(-) diff --git a/examples/server_client/client.py b/examples/server_client/client.py index 7bbac0947..ae687df4c 100644 --- a/examples/server_client/client.py +++ b/examples/server_client/client.py @@ -12,7 +12,11 @@ import gigl.distributed as gd from gigl.common import UriFactory from gigl.common.logger import Logger -from gigl.distributed.utils import get_free_port +from gigl.distributed.utils import ( + get_free_port, + get_free_ports_from_master_node, + get_internal_ip_from_master_node, +) logger = Logger() @@ -34,7 +38,10 @@ def run_client( ) current_ctx = glt.distributed.get_context() print("Current context: ", current_ctx) - current_device = torch.device(current_ctx.rank % torch.cuda.device_count()) + if torch.cuda.is_available(): + current_device = torch.device(current_ctx.rank % torch.cuda.device_count()) + else: + current_device = torch.device("cpu") logger.info(f"Client rank {client_rank} initialized on device {current_device}") # logger.info(f"Loading node_ids from {output_dir}/node_ids.pt") @@ -58,6 +65,9 @@ def run_client( # for batch in loader: # logger.info(f"Batch: {batch}") + if client_rank == 0: + for k, v in os.environ.items(): + logger.info(f"Environment variable: {k} = {v}") torch.distributed.init_process_group( backend="gloo", world_size=num_clients, @@ -83,6 +93,22 @@ def run_client( logger.info(f"Client rank {client_rank} exited") +def run_clients( + num_clients: int, num_servers: int, host: str, port: int, output_dir: str +) -> list: + client_processes = [] + mp_context = torch.multiprocessing.get_context("spawn") + for client_rank in range(num_clients): + client_process = mp_context.Process( + target=run_client, + args=(client_rank, num_clients, num_servers, host, port, output_dir), + ) + client_processes.append(client_process) + for client_process in client_processes: + client_process.start() + return client_processes + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") @@ -92,6 +118,20 @@ def run_client( type=str, default=f"/tmp/gigl/server_client/output/{uuid.uuid4()}", ) + parser.add_argument("--num_clients", type=int, default=1) + parser.add_argument("--num_servers", type=int, default=1) args = parser.parse_args() logger.info(f"Arguments: {args}") - run_client(0, 1, 1, args.host, args.port, args.output_dir) + if args.host == "FROM ENV" and args.port == -1: + logger.info(f"Using host and port from process group") + torch.distributed.init_process_group(backend="gloo") + args.host = get_internal_ip_from_master_node() + args.port = get_free_ports_from_master_node(num_ports=1)[0] + torch.distributed.destroy_process_group() + elif args.host == "FROM ENV" or args.port == -1: + raise ValueError("Either host or port must be provided") + logger.info(f"Using host: {args.host}") + logger.info(f"Using port: {args.port}") + run_clients( + args.num_clients, args.num_servers, args.host, args.port, args.output_dir + ) diff --git a/examples/server_client/same_machine.py b/examples/server_client/same_machine.py index ac1bbfe25..ff2eae624 100644 --- a/examples/server_client/same_machine.py +++ b/examples/server_client/same_machine.py @@ -7,9 +7,8 @@ import uuid from pathlib import Path -import torch -from examples.server_client.client import run_client -from examples.server_client.server import run_server +from examples.server_client.client import run_clients +from examples.server_client.server import run_servers from gigl.common.logger import Logger from gigl.distributed.utils import get_free_port @@ -37,52 +36,12 @@ def main(): output_dir = args.output_dir Path(output_dir).mkdir(parents=True, exist_ok=True) - server_processes = [] - mp_context = torch.multiprocessing.get_context("spawn") - - for server_rank in range(num_servers): - server_process = mp_context.Process( - target=run_server, - args=( - server_rank, - num_servers, - num_clients, - args.host, - args.port, - output_dir, - ), - ) - server_processes.append(server_process) - - for server_process in server_processes: - server_process.start() - - output_file = Path(f"{output_dir}/node_ids.pt") - - # while not output_file.exists(): - # time.sleep(5) - # logger.info( - # f"Waiting for server rank {server_rank} to dump node_ids to {output_dir}/node_ids.pt" - # ) - - client_processes = [] - - for client_rank in range(num_clients): - client_process = mp_context.Process( - target=run_client, - args=( - client_rank, - num_clients, - num_servers, - args.host, - args.port, - output_dir, - ), - ) - client_processes.append(client_process) - - for client_process in client_processes: - client_process.start() + server_processes = run_servers( + num_servers, num_clients, args.host, args.port, output_dir + ) + client_processes = run_clients( + num_clients, num_servers, args.host, args.port, output_dir + ) logger.info(f"Waiting for client processes to exit") for client_process in client_processes: diff --git a/examples/server_client/server.py b/examples/server_client/server.py index 54fadfb77..26e96f911 100644 --- a/examples/server_client/server.py +++ b/examples/server_client/server.py @@ -14,7 +14,11 @@ from gigl.common import UriFactory from gigl.common.logger import Logger from gigl.distributed.sampler import RemoteNodeInfo -from gigl.distributed.utils import get_free_port +from gigl.distributed.utils import ( + get_free_port, + get_free_ports_from_master_node, + get_internal_ip_from_master_node, +) from gigl.src.common.utils.file_loader import FileLoader from gigl.types.graph import to_homogeneous @@ -84,8 +88,7 @@ def run_server( master_port=get_free_port(), num_servers=num_servers, ) - with open(f"{output_dir}/remote_node_info.pyast", "w") as f: - f.write(remote_node_info.dump()) + FileLoader().load_from_filelike(UriFactory.create_uri(f"{output_dir}/remote_node_info.pyast"), io.BytesIO(remote_node_info.dump().encode())) print(f"Wrote remote node info to {output_dir}/remote_node_info.pyast") logger.info(f"Initializing serve {server_rank} / {num_servers}") glt.distributed.init_server( @@ -102,10 +105,28 @@ def run_server( logger.info(f"Server rank {server_rank} exited") +def run_servers( + num_servers: int, num_clients: int, host: str, port: int, output_dir: str +) -> list: + server_processes = [] + mp_context = torch.multiprocessing.get_context("spawn") + for server_rank in range(num_servers): + server_process = mp_context.Process( + target=run_server, + args=(server_rank, num_servers, num_clients, host, port, output_dir), + ) + server_processes.append(server_process) + for server_process in server_processes: + server_process.start() + return server_processes + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=get_free_port()) + parser.add_argument("--num_servers", type=int, default=1) + parser.add_argument("--num_clients", type=int, default=1) parser.add_argument( "--output_dir", type=str, @@ -113,4 +134,16 @@ def run_server( ) args = parser.parse_args() logger.info(f"Arguments: {args}") - run_server(0, 1, 1, args.host, args.port, args.output_dir) + if args.host == "FROM ENV" and args.port == -1: + logger.info(f"Using host and port from process group") + torch.distributed.init_process_group(backend="gloo") + args.host = get_internal_ip_from_master_node() + args.port = get_free_ports_from_master_node(num_ports=1)[0] + torch.distributed.destroy_process_group() + elif args.host == "FROM ENV" or args.port == -1: + raise ValueError("Either host or port must be provided") + logger.info(f"Using host: {args.host}") + logger.info(f"Using port: {args.port}") + run_servers( + args.num_servers, args.num_clients, args.host, args.port, args.output_dir + ) diff --git a/python/gigl/common/services/vertex_ai.py b/python/gigl/common/services/vertex_ai.py index 697dce0f4..737c5894a 100644 --- a/python/gigl/common/services/vertex_ai.py +++ b/python/gigl/common/services/vertex_ai.py @@ -140,7 +140,11 @@ def project(self) -> str: """The GCP project that is being used for this service.""" return self._project - def launch_job(self, job_config: VertexAiJobConfig) -> None: + def launch_job( + self, + job_config: VertexAiJobConfig, + worker_job_config: Optional[VertexAiJobConfig] = None, + ) -> None: """ Launch a Vertex AI CustomJob. See the docs for more info. @@ -211,6 +215,31 @@ def launch_job(self, job_config: VertexAiJobConfig) -> None: # worker_pool_specs.append({}) # worker_pool_specs.append(worker_spec) + if worker_job_config: + worker_machine_spec = MachineSpec( + machine_type=worker_job_config.machine_type, + accelerator_type=worker_job_config.accelerator_type, + accelerator_count=worker_job_config.accelerator_count, + ) + worker_container_spec = ContainerSpec( + image_uri=worker_job_config.container_uri, + command=worker_job_config.command, + args=worker_job_config.args, + env=env_vars, + ) + worker_disk_spec = DiskSpec( + boot_disk_type=worker_job_config.boot_disk_type, + boot_disk_size_gb=worker_job_config.boot_disk_size_gb, + ) + worker_spec = WorkerPoolSpec( + machine_spec=worker_machine_spec, + container_spec=worker_container_spec, + disk_spec=worker_disk_spec, + replica_count=worker_job_config.replica_count, + ) + worker_pool_specs.append({}) + worker_pool_specs.append(worker_spec) + logger.info( f"Running Custom job {job_config.job_name} with worker_pool_specs {worker_pool_specs}, in project: {self._project}/{self._location} using staging bucket: {self._staging_bucket}, and attached labels: {job_config.labels}" ) From da7c3284d1309a94fbe057483715a414bea61580 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 9 Sep 2025 02:08:12 +0000 Subject: [PATCH 09/11] vai[cpu] works --- examples/server_client/client.py | 80 ++++++++++++++++--- examples/server_client/same_machine.py | 10 ++- examples/server_client/server.py | 30 ++++++- python/gigl/common/services/vertex_ai.py | 3 +- .../distributed/distributed_neighborloader.py | 7 +- python/gigl/distributed/sampler.py | 4 + python/gigl/distributed/utils/networking.py | 16 +++- 7 files changed, 127 insertions(+), 23 deletions(-) diff --git a/examples/server_client/client.py b/examples/server_client/client.py index ae687df4c..10fce801a 100644 --- a/examples/server_client/client.py +++ b/examples/server_client/client.py @@ -4,6 +4,7 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # isort: skip import argparse +import json import uuid import graphlearn_torch as glt @@ -15,8 +16,10 @@ from gigl.distributed.utils import ( get_free_port, get_free_ports_from_master_node, + get_internal_ip_from_all_ranks, get_internal_ip_from_master_node, ) +from gigl.distributed.utils.networking import get_ports_for_server_client_clusters logger = Logger() @@ -27,8 +30,16 @@ def run_client( num_servers: int, host: str, port: int, + client_master_ip: str, + client_port: int, output_dir: str, ) -> None: + logger.info( + f"Running client with args: {client_rank=} {num_clients=} {num_servers=} {host=} {port=} {client_master_ip=} {client_port=} {output_dir=}" + ) + logger.info( + f"Initializing client {client_rank} / {num_clients} for {num_servers} servers on host {host} and port {port}" + ) glt.distributed.init_client( num_servers=num_servers, num_clients=num_clients, @@ -36,6 +47,7 @@ def run_client( master_addr=host, master_port=port, ) + logger.info(f"Client {client_rank} initialized") current_ctx = glt.distributed.get_context() print("Current context: ", current_ctx) if torch.cuda.is_available(): @@ -43,7 +55,11 @@ def run_client( else: current_device = torch.device("cpu") logger.info(f"Client rank {client_rank} initialized on device {current_device}") - + logger.info(f"Client rank {client_rank} requesting dataset metadata from server...") + metadata = glt.distributed.request_server( + 0, glt.distributed.DistServer.get_dataset_meta + ) + logger.info(f"Dataset metadata: {metadata}") # logger.info(f"Loading node_ids from {output_dir}/node_ids.pt") # node_ids = torch.load(f"{output_dir}/node_ids.pt") # logger.info(f"Loaded {node_ids.numel()} node_ids") @@ -65,15 +81,23 @@ def run_client( # for batch in loader: # logger.info(f"Batch: {batch}") - if client_rank == 0: - for k, v in os.environ.items(): - logger.info(f"Environment variable: {k} = {v}") + if os.environ.get("CLUSTER_SPEC"): + server_spec = json.loads(os.environ.get("CLUSTER_SPEC")) + else: + server_spec = None + logger.info(f"Server spec: {server_spec}") + # if client_rank == 0: + # for k, v in os.environ.items(): + # logger.info(f"Environment variable: {k} = {v}") + + init_method = f"tcp://{client_master_ip}:{client_port}" + logger.info(f"Init method: {init_method}") torch.distributed.init_process_group( backend="gloo", world_size=num_clients, rank=client_rank, group_name="gigl_loader_comms", - init_method=f"tcp://{host}:{42132}", + init_method=init_method, ) gigl_loader = gd.DistNeighborLoader( dataset=None, @@ -86,22 +110,38 @@ def run_client( ) for i, batch in enumerate(gigl_loader): if i % 100 == 0: - logger.info(f"Gigl Batch: {batch}") + logger.info(f"Client rank {client_rank} gigl batch {i}: {batch}") + logger.info(f"Client rank {client_rank} finished loading data for {i} batches") logger.info(f"Shutting down client") glt.distributed.shutdown_client() logger.info(f"Client rank {client_rank} exited") def run_clients( - num_clients: int, num_servers: int, host: str, port: int, output_dir: str + num_clients: int, + num_servers: int, + host: str, + port: int, + client_master_ip: str, + client_port: int, + output_dir: str, ) -> list: client_processes = [] mp_context = torch.multiprocessing.get_context("spawn") for client_rank in range(num_clients): client_process = mp_context.Process( target=run_client, - args=(client_rank, num_clients, num_servers, host, port, output_dir), + args=( + client_rank, + num_clients, + num_servers, + host, + port, + client_master_ip, + client_port, + output_dir, + ), ) client_processes.append(client_process) for client_process in client_processes: @@ -122,16 +162,36 @@ def run_clients( parser.add_argument("--num_servers", type=int, default=1) args = parser.parse_args() logger.info(f"Arguments: {args}") + client_port = None if args.host == "FROM ENV" and args.port == -1: logger.info(f"Using host and port from process group") torch.distributed.init_process_group(backend="gloo") args.host = get_internal_ip_from_master_node() args.port = get_free_ports_from_master_node(num_ports=1)[0] + server_port, client_port = get_ports_for_server_client_clusters( + args.num_servers, args.num_clients + ) + logger.info(f"Server port: {server_port}, client port: {client_port}") + ips = get_internal_ip_from_all_ranks() + logger.info(f"IPs: {ips}") + client_master_ip = ips[args.num_servers] + logger.info(f"Client master IP: {client_master_ip}") torch.distributed.destroy_process_group() elif args.host == "FROM ENV" or args.port == -1: raise ValueError("Either host or port must be provided") logger.info(f"Using host: {args.host}") logger.info(f"Using port: {args.port}") - run_clients( - args.num_clients, args.num_servers, args.host, args.port, args.output_dir + client_rank = int(os.environ.get("RANK")) - args.num_servers + run_client( + client_rank=client_rank, + num_clients=args.num_clients, + num_servers=args.num_servers, + host=args.host, + port=args.port, + client_master_ip=client_master_ip, + client_port=client_port, + output_dir=args.output_dir, ) + # run_clients( + # args.num_clients, args.num_servers, args.host, args.port, client_port, args.output_dir + # ) diff --git a/examples/server_client/same_machine.py b/examples/server_client/same_machine.py index ff2eae624..39ff4563f 100644 --- a/examples/server_client/same_machine.py +++ b/examples/server_client/same_machine.py @@ -37,10 +37,16 @@ def main(): Path(output_dir).mkdir(parents=True, exist_ok=True) server_processes = run_servers( - num_servers, num_clients, args.host, args.port, output_dir + server_rank=0, num_servers=num_servers, num_clients=num_clients, host=args.host, port=args.port, output_dir=output_dir ) client_processes = run_clients( - num_clients, num_servers, args.host, args.port, output_dir + num_clients, + num_servers, + args.host, + args.port, + "localhost", + get_free_port(), + output_dir, ) logger.info(f"Waiting for client processes to exit") diff --git a/examples/server_client/server.py b/examples/server_client/server.py index 26e96f911..5fa3b40c4 100644 --- a/examples/server_client/server.py +++ b/examples/server_client/server.py @@ -17,8 +17,10 @@ from gigl.distributed.utils import ( get_free_port, get_free_ports_from_master_node, + get_internal_ip_from_all_ranks, get_internal_ip_from_master_node, ) +from gigl.distributed.utils.networking import get_ports_for_server_client_clusters from gigl.src.common.utils.file_loader import FileLoader from gigl.types.graph import to_homogeneous @@ -85,10 +87,14 @@ def run_server( edge_feature_info=dataset.edge_feature_info, num_partitions=dataset.num_partitions, edge_dir=dataset.edge_dir, + master_addr=host, master_port=get_free_port(), num_servers=num_servers, ) - FileLoader().load_from_filelike(UriFactory.create_uri(f"{output_dir}/remote_node_info.pyast"), io.BytesIO(remote_node_info.dump().encode())) + FileLoader().load_from_filelike( + UriFactory.create_uri(f"{output_dir}/remote_node_info.pyast"), + io.BytesIO(remote_node_info.dump().encode()), + ) print(f"Wrote remote node info to {output_dir}/remote_node_info.pyast") logger.info(f"Initializing serve {server_rank} / {num_servers}") glt.distributed.init_server( @@ -106,11 +112,16 @@ def run_server( def run_servers( - num_servers: int, num_clients: int, host: str, port: int, output_dir: str + server_rank: int, + num_servers: int, + num_clients: int, + host: str, + port: int, + output_dir: str, ) -> list: server_processes = [] mp_context = torch.multiprocessing.get_context("spawn") - for server_rank in range(num_servers): + for i in range(num_servers): server_process = mp_context.Process( target=run_server, args=(server_rank, num_servers, num_clients, host, port, output_dir), @@ -139,11 +150,22 @@ def run_servers( torch.distributed.init_process_group(backend="gloo") args.host = get_internal_ip_from_master_node() args.port = get_free_ports_from_master_node(num_ports=1)[0] + server_port, client_port = get_ports_for_server_client_clusters( + args.num_servers, args.num_clients + ) + logger.info(f"Server port: {server_port}, client port: {client_port}") + ips = get_internal_ip_from_all_ranks() + logger.info(f"IPs: {ips}") torch.distributed.destroy_process_group() elif args.host == "FROM ENV" or args.port == -1: raise ValueError("Either host or port must be provided") logger.info(f"Using host: {args.host}") logger.info(f"Using port: {args.port}") run_servers( - args.num_servers, args.num_clients, args.host, args.port, args.output_dir + int(os.environ.get("RANK")), + args.num_servers, + args.num_clients, + args.host, + args.port, + args.output_dir, ) diff --git a/python/gigl/common/services/vertex_ai.py b/python/gigl/common/services/vertex_ai.py index 737c5894a..d258600b3 100644 --- a/python/gigl/common/services/vertex_ai.py +++ b/python/gigl/common/services/vertex_ai.py @@ -237,7 +237,8 @@ def launch_job( disk_spec=worker_disk_spec, replica_count=worker_job_config.replica_count, ) - worker_pool_specs.append({}) + if job_config.replica_count == 1: + worker_pool_specs.append({}) worker_pool_specs.append(worker_spec) logger.info( diff --git a/python/gigl/distributed/distributed_neighborloader.py b/python/gigl/distributed/distributed_neighborloader.py index 83d01d2fa..3f9016662 100644 --- a/python/gigl/distributed/distributed_neighborloader.py +++ b/python/gigl/distributed/distributed_neighborloader.py @@ -238,14 +238,17 @@ def __init__( self._edge_feature_info = remote_node_info.edge_feature_info num_partitions = remote_node_info.num_partitions edge_dir = remote_node_info.edge_dir + logger.info( + f"Using master ip address: {remote_node_info.master_addr}, master port: {remote_node_info.master_port}" + ) worker_options = RemoteDistSamplingWorkerOptions( server_rank=[ server_rank for server_rank in range(remote_node_info.num_servers) ], num_workers=num_workers, worker_devices=[torch.device("cpu") for i in range(num_workers)], - master_addr=master_ip_address, - master_port=42192, + master_addr=remote_node_info.master_addr, + master_port=remote_node_info.master_port, ) num_neighbors = patch_fanout_for_sampling( remote_node_info.edge_types, num_neighbors diff --git a/python/gigl/distributed/sampler.py b/python/gigl/distributed/sampler.py index f5874bfdf..9da3c3ed1 100644 --- a/python/gigl/distributed/sampler.py +++ b/python/gigl/distributed/sampler.py @@ -24,6 +24,7 @@ class RemoteNodeInfo: edge_feature_info: Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]] num_partitions: int edge_dir: str + master_addr: str master_port: int num_servers: int @@ -44,6 +45,7 @@ def serialize(self) -> str: out_dict["node_tensor_uris"] = self.node_tensor_uris out_dict["num_partitions"] = self.num_partitions out_dict["edge_dir"] = self.edge_dir + out_dict["master_addr"] = self.master_addr out_dict["master_port"] = self.master_port out_dict["num_servers"] = self.num_servers @@ -103,6 +105,7 @@ def deserialize(cls, json_str: str) -> "RemoteNodeInfo": "node_tensor_uris", "num_partitions", "edge_dir", + "master_addr", "master_port", "num_servers", ] @@ -190,6 +193,7 @@ def deserialize_feature_info(feature_info_data: dict) -> FeatureInfo: edge_feature_info=edge_feature_info, num_partitions=data["num_partitions"], edge_dir=data["edge_dir"], + master_addr=data["master_addr"], master_port=data["master_port"], num_servers=data["num_servers"], ) diff --git a/python/gigl/distributed/utils/networking.py b/python/gigl/distributed/utils/networking.py index 23c86c383..70b62f715 100644 --- a/python/gigl/distributed/utils/networking.py +++ b/python/gigl/distributed/utils/networking.py @@ -44,7 +44,7 @@ def get_free_ports(num_ports: int) -> list[int]: def get_free_ports_from_master_node( - num_ports=1, _global_rank_override: Optional[int] = None + num_ports=1, master_node_id: int = 0, _global_rank_override: Optional[int] = None ) -> list[int]: """ Get free ports from master node, that can be used for communication between workers. @@ -67,17 +67,17 @@ def get_free_ports_from_master_node( else _global_rank_override ) logger.info( - f"Rank {rank} is requesting {num_ports} free ports from rank 0 (master)" + f"Rank {rank} is requesting {num_ports} free ports from rank {master_node_id} (master)" ) ports: list[int] - if rank == 0: + if rank == master_node_id: ports = get_free_ports(num_ports) logger.info(f"Rank {rank} found free ports: {ports}") else: ports = [0] * num_ports # Broadcast from master from rank 0 to all other ranks - torch.distributed.broadcast_object_list(ports, src=0) + torch.distributed.broadcast_object_list(ports, src=master_node_id) logger.info(f"Rank {rank} received ports: {ports}") return ports @@ -150,3 +150,11 @@ def get_internal_ip_from_all_ranks() -> list[str]: assert all(ip for ip in ip_list), "Could not retrieve all ranks' internal IPs" return ip_list + + +def get_ports_for_server_client_clusters( + num_servers: int, num_clients: int +) -> tuple[int, int]: + server_port = get_free_ports_from_master_node(1)[0] + client_port = get_free_ports_from_master_node(1, master_node_id=num_servers)[0] + return server_port, client_port From 99c1b70be6af953f01e307881f2375e3ac40ede6 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 9 Sep 2025 16:46:18 +0000 Subject: [PATCH 10/11] works --- examples/server_client/client.py | 2 +- examples/server_client/server.py | 5 +++-- python/gigl/common/services/vertex_ai.py | 3 ++- python/gigl/distributed/utils/networking.py | 5 ++++- python/tests/integration/common/services/vertex_ai_test.py | 7 ++++++- 5 files changed, 16 insertions(+), 6 deletions(-) diff --git a/examples/server_client/client.py b/examples/server_client/client.py index 10fce801a..e31f8aa3b 100644 --- a/examples/server_client/client.py +++ b/examples/server_client/client.py @@ -1,7 +1,7 @@ import os # Suppress TensorFlow logs -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # isort: skip +#os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # isort: skip import argparse import json diff --git a/examples/server_client/server.py b/examples/server_client/server.py index 5fa3b40c4..8c2beaf57 100644 --- a/examples/server_client/server.py +++ b/examples/server_client/server.py @@ -1,7 +1,7 @@ import os # Suppress TensorFlow logs -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # isort: skip +#os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # isort: skip import argparse import io @@ -35,6 +35,7 @@ def run_server( port: int, output_dir: str, ) -> None: + logger.info(f"Initializing server {server_rank} / {num_servers}. Cluster rank: {os.environ.get('RANK')}") torch.distributed.init_process_group( backend="gloo", world_size=num_servers, @@ -121,7 +122,7 @@ def run_servers( ) -> list: server_processes = [] mp_context = torch.multiprocessing.get_context("spawn") - for i in range(num_servers): + for i in range(1): server_process = mp_context.Process( target=run_server, args=(server_rank, num_servers, num_clients, host, port, output_dir), diff --git a/python/gigl/common/services/vertex_ai.py b/python/gigl/common/services/vertex_ai.py index d258600b3..f97f7fcbe 100644 --- a/python/gigl/common/services/vertex_ai.py +++ b/python/gigl/common/services/vertex_ai.py @@ -239,8 +239,9 @@ def launch_job( ) if job_config.replica_count == 1: worker_pool_specs.append({}) + # worker_pool_specs.append({}) worker_pool_specs.append(worker_spec) - + print(f"worker_pool_specs: {worker_pool_specs}") logger.info( f"Running Custom job {job_config.job_name} with worker_pool_specs {worker_pool_specs}, in project: {self._project}/{self._location} using staging bucket: {self._staging_bucket}, and attached labels: {job_config.labels}" ) diff --git a/python/gigl/distributed/utils/networking.py b/python/gigl/distributed/utils/networking.py index 70b62f715..31cc05b4d 100644 --- a/python/gigl/distributed/utils/networking.py +++ b/python/gigl/distributed/utils/networking.py @@ -1,11 +1,14 @@ import socket from typing import Optional +import logging +import os import torch from gigl.common.logger import Logger logger = Logger() +logging.info(f"Vanilla logger on {os.environ.get('HOSTNAME')}:{os.environ.get('RANK')}") def get_free_port() -> int: @@ -97,7 +100,7 @@ def get_internal_ip_from_master_node( assert ( torch.distributed.is_initialized() ), "Distributed environment must be initialized" - + print(f"torch.distributed.is_initialized(): {torch.distributed.is_initialized()}") rank = ( torch.distributed.get_rank() if _global_rank_override is None diff --git a/python/tests/integration/common/services/vertex_ai_test.py b/python/tests/integration/common/services/vertex_ai_test.py index 15aa1e6ca..29566ca2d 100644 --- a/python/tests/integration/common/services/vertex_ai_test.py +++ b/python/tests/integration/common/services/vertex_ai_test.py @@ -41,13 +41,18 @@ def test_launch_job(self): service_account = resource_config.service_account_email staging_bucket = resource_config.temp_assets_regional_bucket_path.uri job_name = f"GiGL-Integration-Test-{uuid.uuid4()}" - container_uri = "condaforge/miniforge3:25.3.0-1" + # container_uri = "condaforge/miniforge3:25.3.0-1" + # container_uri = "us-docker.pkg.dev/vertex-ai/training/tf-gpu.2-17.py310:latest" + container_uri = "gcr.io/snap-umap-dev/gbml_cuda:20250909-152833" command = ["printenv"] job_config = VertexAiJobConfig( job_name=job_name, container_uri=container_uri, command=command, + machine_type="n1-standard-32", + accelerator_type="NVIDIA_TESLA_T4", + accelerator_count=2, replica_count=2, ) From acf0dba3b8edd712bb33006a5cf8f42efd97a275 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 9 Sep 2025 16:47:59 +0000 Subject: [PATCH 11/11] Revert --- .../common/services/vertex_ai_test.py | 16 ++--- .../dataset_input_metadata_translator_test.py | 2 +- .../pb_wrappers/preprocessed_metadata_test.py | 59 +++++++++++++++++++ 3 files changed, 64 insertions(+), 13 deletions(-) create mode 100644 python/tests/unit/src/common/types/pb_wrappers/preprocessed_metadata_test.py diff --git a/python/tests/integration/common/services/vertex_ai_test.py b/python/tests/integration/common/services/vertex_ai_test.py index 29566ca2d..30477a466 100644 --- a/python/tests/integration/common/services/vertex_ai_test.py +++ b/python/tests/integration/common/services/vertex_ai_test.py @@ -41,19 +41,11 @@ def test_launch_job(self): service_account = resource_config.service_account_email staging_bucket = resource_config.temp_assets_regional_bucket_path.uri job_name = f"GiGL-Integration-Test-{uuid.uuid4()}" - # container_uri = "condaforge/miniforge3:25.3.0-1" - # container_uri = "us-docker.pkg.dev/vertex-ai/training/tf-gpu.2-17.py310:latest" - container_uri = "gcr.io/snap-umap-dev/gbml_cuda:20250909-152833" - command = ["printenv"] + container_uri = "condaforge/miniforge3:25.3.0-1" + command = ["python", "-c", "import logging; logging.info('Hello, World!')"] job_config = VertexAiJobConfig( - job_name=job_name, - container_uri=container_uri, - command=command, - machine_type="n1-standard-32", - accelerator_type="NVIDIA_TESLA_T4", - accelerator_count=2, - replica_count=2, + job_name=job_name, container_uri=container_uri, command=command ) vertex_ai_service = VertexAIService( @@ -65,7 +57,7 @@ def test_launch_job(self): vertex_ai_service.launch_job(job_config) - def _test_run_pipeline(self): + def test_run_pipeline(self): with tempfile.TemporaryDirectory() as tmpdir: pipeline_def = os.path.join(tmpdir, "pipeline.yaml") kfp.compiler.Compiler().compile(get_pipeline, pipeline_def) diff --git a/python/tests/unit/distributed/dataset_input_metadata_translator_test.py b/python/tests/unit/distributed/dataset_input_metadata_translator_test.py index 75df1232e..4b24fa2a3 100644 --- a/python/tests/unit/distributed/dataset_input_metadata_translator_test.py +++ b/python/tests/unit/distributed/dataset_input_metadata_translator_test.py @@ -157,7 +157,7 @@ def test_translator_correctness(self, _, mocked_dataset_info: MockedDatasetInfo) ].tfrecord_uri_prefix, ) self.assertEqual( - sorted(seralized_node_info.feature_keys), + seralized_node_info.feature_keys, ( preprocessed_metadata_pb_wrapper.condensed_node_type_to_feature_keys_map[ condensed_node_type diff --git a/python/tests/unit/src/common/types/pb_wrappers/preprocessed_metadata_test.py b/python/tests/unit/src/common/types/pb_wrappers/preprocessed_metadata_test.py new file mode 100644 index 000000000..bde5566a2 --- /dev/null +++ b/python/tests/unit/src/common/types/pb_wrappers/preprocessed_metadata_test.py @@ -0,0 +1,59 @@ +import unittest + +from gigl.src.common.constants.graph_metadata import DEFAULT_CONDENSED_NODE_TYPE +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.src.mocking.lib.versioning import get_mocked_dataset_artifact_metadata +from gigl.src.mocking.mocking_assets.mocked_datasets_for_pipeline_tests import ( + CORA_USER_DEFINED_NODE_ANCHOR_MOCKED_DATASET_INFO, +) + + +class PreprocessedMetadataTest(unittest.TestCase): + def test_feature_schema_keys_match_original_keys(self): + """ + We currently observe a bug in the FeatureEmbeddingLayer which occurs if we sort the feature keys, leading + to training failures that use DDP. This test ensures that we don't sort the feature keys as part of the + preprocessed metadata pb wrapper, and that they are in the same order as the original feature keys. + + TODO (mkolodner-sc): Once the reason for why sorting the feature keys breaks training is understood and fixed, this test should be removed. + """ + + cora_dataset_info = get_mocked_dataset_artifact_metadata()[ + CORA_USER_DEFINED_NODE_ANCHOR_MOCKED_DATASET_INFO.name + ] + + gbml_config_pb_wrapper = ( + GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( + gbml_config_uri=cora_dataset_info.frozen_gbml_config_uri + ) + ) + + preprocessed_metadata_wrapper = ( + gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper + ) + feature_schema = ( + preprocessed_metadata_wrapper.condensed_node_type_to_feature_schema_map[ + DEFAULT_CONDENSED_NODE_TYPE + ] + ) + + # Get the original feature keys from the base proto + feature_spec_keys = list(feature_schema.feature_spec.keys()) + feature_index_keys = list(feature_schema.feature_index.keys()) + feature_schema_keys = list(feature_schema.schema.keys()) + + # The feature vocab keys are not required to match the feature keys, so we shouldn't check them. + + original_feature_keys = preprocessed_metadata_wrapper.preprocessed_metadata_pb.condensed_node_type_to_preprocessed_metadata[ + DEFAULT_CONDENSED_NODE_TYPE + ].feature_keys + + # Assert that all the keys in the feature schema match the original feature keys from the base proto + + self.assertEqual(feature_spec_keys, original_feature_keys) + self.assertEqual(feature_index_keys, original_feature_keys) + self.assertEqual(feature_schema_keys, original_feature_keys) + + +if __name__ == "__main__": + unittest.main()