diff --git a/Makefile b/Makefile index da7ad72a2..b3951f1ba 100644 --- a/Makefile +++ b/Makefile @@ -154,7 +154,7 @@ assert_yaml_configs_parse: # Ex. `make unit_test_py PY_TEST_FILES="eval_metrics_test.py"` # By default, runs all tests under python/tests/unit. # See the help text for "--test_file_pattern" in python/tests/test_args.py for more details. -unit_test_py: clean_build_files_py type_check +unit_test_py: clean_build_files_py #type_check ( cd python ; \ python -m tests.unit.main \ --env=test \ diff --git a/proto/snapchat/research/gbml/gigl_resource_config.proto b/proto/snapchat/research/gbml/gigl_resource_config.proto index 9f08b03e7..aaf8cf253 100644 --- a/proto/snapchat/research/gbml/gigl_resource_config.proto +++ b/proto/snapchat/research/gbml/gigl_resource_config.proto @@ -132,6 +132,9 @@ message VertexAiResourceConfig { message VertexAiGraphStoreConfig { VertexAiResourceConfig graph_store_pool = 1; VertexAiResourceConfig compute_pool = 2; + + int32 num_processes_per_storage_machine = 3; + int32 num_processes_per_compute_machine = 4; } // (deprecated) // Configuration for distributed training resources diff --git a/python/gigl/distributed/dataset_factory.py b/python/gigl/distributed/dataset_factory.py index 077ef2a52..c60853891 100644 --- a/python/gigl/distributed/dataset_factory.py +++ b/python/gigl/distributed/dataset_factory.py @@ -520,6 +520,7 @@ def build_dataset_from_task_config_uri( ) # Read from GbmlConfig for preprocessed data metadata, GNN model uri, and bigquery embedding table path + logger.info(f"Reading GbmlConfig from URI: {task_config_uri}") gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( gbml_config_uri=UriFactory.create_uri(task_config_uri) ) diff --git a/python/gigl/distributed/distributed_neighborloader.py b/python/gigl/distributed/distributed_neighborloader.py index 3be9662f0..90a1e02ec 100644 --- a/python/gigl/distributed/distributed_neighborloader.py +++ b/python/gigl/distributed/distributed_neighborloader.py @@ -3,7 +3,13 @@ import torch from graphlearn_torch.channel import SampleMessage -from graphlearn_torch.distributed import DistLoader, MpDistSamplingWorkerOptions +from graphlearn_torch.distributed import ( + DistLoader, + DistServer, + MpDistSamplingWorkerOptions, + RemoteDistSamplingWorkerOptions, + request_server, +) from graphlearn_torch.sampler import NodeSamplerInput, SamplingConfig, SamplingType from torch_geometric.data import Data, HeteroData from torch_geometric.typing import EdgeType @@ -20,6 +26,7 @@ shard_nodes_by_process, strip_label_edges, ) +from gigl.env.distributed import GraphStoreInfo from gigl.src.common.types.graph_data import ( NodeType, # TODO (mkolodner-sc): Change to use torch_geometric.typing ) @@ -37,10 +44,10 @@ class DistNeighborLoader(DistLoader): def __init__( self, - dataset: DistDataset, + dataset: Optional[DistDataset], num_neighbors: Union[list[int], dict[EdgeType, list[int]]], input_nodes: Optional[ - Union[torch.Tensor, Tuple[NodeType, torch.Tensor]] + Union[torch.Tensor, Tuple[NodeType, torch.Tensor], list[torch.Tensor]] ] = None, num_workers: int = 1, batch_size: int = 1, @@ -54,6 +61,7 @@ def __init__( num_cpu_threads: Optional[int] = None, shuffle: bool = False, drop_last: bool = False, + graph_store_info: Optional[GraphStoreInfo] = None, ): """ Note: We try to adhere to pyg dataloader api as much as possible. @@ -193,6 +201,10 @@ def __init__( ) if input_nodes is None: + if dataset is None: + raise ValueError( + "Dataset must be provided if input_nodes are not provided." + ) if dataset.node_ids is None: raise ValueError( "Dataset must have node ids if input_nodes are not provided." @@ -205,44 +217,106 @@ def __init__( # Determines if the node ids passed in are heterogeneous or homogeneous. self._is_labeled_heterogeneous = False - if isinstance(input_nodes, torch.Tensor): - node_ids = input_nodes - - # If the dataset is heterogeneous, we may be in the "labeled homogeneous" setting, - # if so, then we should use DEFAULT_HOMOGENEOUS_NODE_TYPE. - if isinstance(dataset.node_ids, abc.Mapping): - if ( - len(dataset.node_ids) == 1 - and DEFAULT_HOMOGENEOUS_NODE_TYPE in dataset.node_ids - ): - node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE - self._is_labeled_heterogeneous = True - else: - raise ValueError( - f"For heterogeneous datasets, input_nodes must be a tuple of (node_type, node_ids) OR if it is a labeled homogeneous dataset, input_nodes may be a torch.Tensor. Received node types: {dataset.node_ids.keys()}" + if dataset is None: + if graph_store_info is None: + raise ValueError( + "graph_store_info must be provided if dataset is not provided." + ) + num_partitions, partition_idx, ntypes, etypes = request_server( + server_rank=0, + func=DistServer.get_dataset_meta, + ) + if not isinstance(input_nodes, list): + raise ValueError( + "input_nodes must be a list if dataset is not provided." + ) + if ( + len(input_nodes) + != graph_store_info.num_storage_nodes + * graph_store_info.num_processes_per_storage + ): + raise ValueError( + f"input_nodes must be a list of length {graph_store_info.num_storage_nodes * graph_store_info.num_processes_per_storage}, got {len(input_nodes)}. E.g. one entry per process in the storage cluster." + ) + worker_options = RemoteDistSamplingWorkerOptions( + server_rank=[ + server_rank + for server_rank in range( + graph_store_info.num_storage_nodes + * graph_store_info.num_processes_per_storage ) - else: - node_type = None + ], + num_workers=num_workers, + worker_devices=[torch.device("cpu") for i in range(num_workers)], + master_addr=graph_store_info.cluster_master_ip, + master_port=graph_store_info.cluster_master_port, + ) else: - node_type, node_ids = input_nodes - assert isinstance( - dataset.node_ids, abc.Mapping - ), "Dataset must be heterogeneous if provided input nodes are a tuple." - - num_neighbors = patch_fanout_for_sampling( - dataset.get_edge_types(), num_neighbors - ) + if isinstance(input_nodes, torch.Tensor): + node_ids = input_nodes + + # If the dataset is heterogeneous, we may be in the "labeled homogeneous" setting, + # if so, then we should use DEFAULT_HOMOGENEOUS_NODE_TYPE. + if isinstance(dataset.node_ids, abc.Mapping): + if ( + len(dataset.node_ids) == 1 + and DEFAULT_HOMOGENEOUS_NODE_TYPE in dataset.node_ids + ): + node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE + self._is_labeled_heterogeneous = True + else: + raise ValueError( + f"For heterogeneous datasets, input_nodes must be a tuple of (node_type, node_ids) OR if it is a labeled homogeneous dataset, input_nodes may be a torch.Tensor. Received node types: {dataset.node_ids.keys()}" + ) + else: + node_type = None + elif isinstance(input_nodes, tuple): + node_type, node_ids = input_nodes + assert isinstance( + dataset.node_ids, abc.Mapping + ), "Dataset must be heterogeneous if provided input nodes are a tuple." + else: + raise ValueError( + f"input_nodes must be a torch.Tensor or a tuple of (node_type, node_ids), got {type(input_nodes)}" + ) + etypes = dataset.get_edge_types() - curr_process_nodes = shard_nodes_by_process( - input_nodes=node_ids, - local_process_rank=local_rank, - local_process_world_size=local_world_size, - ) + curr_process_nodes = shard_nodes_by_process( + input_nodes=node_ids, + local_process_rank=local_rank, + local_process_world_size=local_world_size, + ) - self._node_feature_info = dataset.node_feature_info - self._edge_feature_info = dataset.edge_feature_info + self._node_feature_info = dataset.node_feature_info + self._edge_feature_info = dataset.edge_feature_info - input_data = NodeSamplerInput(node=curr_process_nodes, input_type=node_type) + input_data = NodeSamplerInput(node=curr_process_nodes, input_type=node_type) + dist_sampling_ports = ( + gigl.distributed.utils.get_free_ports_from_master_node( + num_ports=local_world_size + ) + ) + dist_sampling_port_for_current_rank = dist_sampling_ports[local_rank] + + worker_options = MpDistSamplingWorkerOptions( + num_workers=num_workers, + worker_devices=[torch.device("cpu") for _ in range(num_workers)], + worker_concurrency=worker_concurrency, + # Each worker will spawn several sampling workers, and all sampling workers spawned by workers in one group + # need to be connected. Thus, we need master ip address and master port to + # initate the connection. + # Note that different groups of workers are independent, and thus + # the sampling processes in different groups should be independent, and should + # use different master ports. + master_addr=master_ip_address, + master_port=dist_sampling_port_for_current_rank, + # Load testing show that when num_rpc_threads exceed 16, the performance + # will degrade. + num_rpc_threads=min(dataset.num_partitions, 16), + rpc_timeout=600, + channel_size=channel_size, + pin_memory=device.type == "cuda", + ) # Sets up processes and torch device for initializing the GLT DistNeighborLoader, setting up RPC and worker groups to minimize # the memory overhead and CPU contention. @@ -280,31 +354,14 @@ def __init__( ) # Sets up worker options for the dataloader - dist_sampling_ports = gigl.distributed.utils.get_free_ports_from_master_node( - num_ports=local_world_size - ) - dist_sampling_port_for_current_rank = dist_sampling_ports[local_rank] - - worker_options = MpDistSamplingWorkerOptions( - num_workers=num_workers, - worker_devices=[torch.device("cpu") for _ in range(num_workers)], - worker_concurrency=worker_concurrency, - # Each worker will spawn several sampling workers, and all sampling workers spawned by workers in one group - # need to be connected. Thus, we need master ip address and master port to - # initate the connection. - # Note that different groups of workers are independent, and thus - # the sampling processes in different groups should be independent, and should - # use different master ports. - master_addr=master_ip_address, - master_port=dist_sampling_port_for_current_rank, - # Load testing show that when num_rpc_threads exceed 16, the performance - # will degrade. - num_rpc_threads=min(dataset.num_partitions, 16), - rpc_timeout=600, - channel_size=channel_size, - pin_memory=device.type == "cuda", - ) + if should_cleanup_distributed_context and torch.distributed.is_initialized(): + logger.info( + f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__." + ) + torch.distributed.destroy_process_group() + + num_neighbors = patch_fanout_for_sampling(etypes, num_neighbors) sampling_config = SamplingConfig( sampling_type=SamplingType.NODE, num_neighbors=num_neighbors, @@ -315,16 +372,10 @@ def __init__( collect_features=True, with_neg=False, with_weight=False, - edge_dir=dataset.edge_dir, + edge_dir=dataset.edge_dir if dataset is not None else "out", seed=None, # it's actually optional - None means random. ) - if should_cleanup_distributed_context and torch.distributed.is_initialized(): - logger.info( - f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__." - ) - torch.distributed.destroy_process_group() - super().__init__(dataset, input_data, sampling_config, device, worker_options) def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: diff --git a/python/gigl/distributed/server_client/__init__.py b/python/gigl/distributed/server_client/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/gigl/distributed/server_client/remote_dataset.py b/python/gigl/distributed/server_client/remote_dataset.py new file mode 100644 index 000000000..19f263f2c --- /dev/null +++ b/python/gigl/distributed/server_client/remote_dataset.py @@ -0,0 +1,59 @@ +from typing import Optional, Union + +import torch + +from gigl.common.logger import Logger +from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.utils.neighborloader import shard_nodes_by_process +from gigl.src.common.types.graph_data import EdgeType, NodeType +from gigl.types.graph import DEFAULT_HOMOGENEOUS_NODE_TYPE, FeatureInfo + +logger = Logger() + +_dataset: Optional[DistDataset] = None + + +def register_dataset(dataset: DistDataset) -> None: + global _dataset + if _dataset is not None: + raise ValueError("Dataset already registered! Cannot register a new dataset.") + _dataset = dataset + + +def get_node_feature_info() -> Union[FeatureInfo, dict[NodeType, FeatureInfo], None]: + if _dataset is None: + raise ValueError( + "Dataset not registered! Register the dataset first with `gigl.distributed.server_client.register_dataset`" + ) + return _dataset.node_feature_info + + +def get_edge_feature_info() -> Union[FeatureInfo, dict[EdgeType, FeatureInfo], None]: + if _dataset is None: + raise ValueError( + "Dataset not registered! Register the dataset first with `gigl.distributed.server_client.register_dataset`" + ) + return _dataset.edge_feature_info + + +def get_node_ids_for_rank( + rank: int, world_size: int, node_type: NodeType = DEFAULT_HOMOGENEOUS_NODE_TYPE +) -> torch.Tensor: + logger.info( + f"Getting node ids for rank {rank} / {world_size} with node type {node_type}" + ) + if _dataset is None: + raise ValueError( + "Dataset not registered! Register the dataset first with `gigl.distributed.server_client.register_dataset`" + ) + if isinstance(_dataset.node_ids, torch.Tensor): + nodes = _dataset.node_ids + elif isinstance(_dataset.node_ids, dict): + nodes = _dataset.node_ids[node_type] + else: + raise ValueError( + f"Node ids must be a torch.Tensor or a dict[NodeType, torch.Tensor], got {type(_dataset.node_ids)}" + ) + logger.info(f"Sharding nodes {nodes.shape} for rank {rank} / {world_size}") + logger.info(f"Nodes: {nodes}") + return shard_nodes_by_process(nodes, rank, world_size) diff --git a/python/gigl/distributed/server_client/server_main.py b/python/gigl/distributed/server_client/server_main.py new file mode 100644 index 000000000..8b1207ef4 --- /dev/null +++ b/python/gigl/distributed/server_client/server_main.py @@ -0,0 +1,103 @@ +import argparse +import os + +# Suppress TensorFlow logs +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # isort: skip + +import graphlearn_torch as glt +import torch + +import gigl.distributed as gd +from gigl.common import Uri, UriFactory +from gigl.common.logger import Logger +from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.server_client.remote_dataset import register_dataset +from gigl.distributed.utils import get_graph_store_info +from gigl.env.distributed import GraphStoreInfo + +logger = Logger() + + +def run_server( + server_rank: int, + cluster_info: GraphStoreInfo, + dataset: DistDataset, +) -> None: + logger.info( + f"Initializing server {server_rank} / {cluster_info.num_storage_nodes * cluster_info.num_processes_per_storage}. on {cluster_info.storage_cluster_master_ip}:{cluster_info.storage_cluster_master_port}. Cluster rank: {os.environ.get('RANK')}, port: {cluster_info.cluster_master_port}" + ) + register_dataset(dataset) + logger.info("registered dataset") + glt.distributed.init_server( + num_servers=cluster_info.storage_world_size, + server_rank=server_rank, + dataset=dataset, + master_addr=cluster_info.cluster_master_ip, + master_port=cluster_info.cluster_master_port, + num_clients=cluster_info.compute_world_size, + ) + + logger.info( + f"Waiting for server rank {server_rank} / {cluster_info.num_storage_nodes} to exit" + ) + glt.distributed.wait_and_shutdown_server() + logger.info(f"Server rank {server_rank} exited") + + +def run_servers( + server_rank: int, + cluster_info: GraphStoreInfo, + task_config_uri: Uri, + is_inference: bool, +) -> list: + init_method = f"tcp://{cluster_info.storage_cluster_master_ip}:{cluster_info.storage_cluster_master_port}" + logger.info( + f"Initializing server {server_rank} / {cluster_info.num_storage_nodes}. OS rank: {os.environ['RANK']}, OS world size: {os.environ['WORLD_SIZE']} init method: {init_method}" + ) + torch.distributed.init_process_group( + backend="gloo", + world_size=cluster_info.num_storage_nodes, + rank=server_rank, + init_method=init_method, + group_name="gigl_server_comms", + ) + logger.info(f"Server {server_rank} / {cluster_info.num_storage_nodes} process group initialized") + dataset = gd.build_dataset_from_task_config_uri( + task_config_uri=task_config_uri, + is_inference=is_inference, + _tfrecord_uri_pattern=".*tfrecord", + ) + server_processes = [] + mp_context = torch.multiprocessing.get_context("spawn") + for i in range(cluster_info.num_processes_per_storage): + server_process = mp_context.Process( + target=run_server, + args=( + server_rank * cluster_info.num_processes_per_storage + i, # server_rank + cluster_info, # cluster_info + dataset, # dataset + ), + ) + server_processes.append(server_process) + for server_process in server_processes: + server_process.start() + return server_processes + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--task_config_uri", type=str, required=True) + parser.add_argument("--resource_config_uri", type=str, required=True) + parser.add_argument("--is_inference", action="store_true") + args = parser.parse_args() + logger.info(f"Arguments: {args}") + + is_inference = args.is_inference + torch.distributed.init_process_group() + cluster_info = get_graph_store_info() + run_servers( + server_rank=int(os.environ["RANK"]) - cluster_info.num_compute_nodes, + cluster_info=cluster_info, + task_config_uri=UriFactory.create_uri(args.task_config_uri), + is_inference=is_inference, + ) diff --git a/python/gigl/distributed/server_client/utils.py b/python/gigl/distributed/server_client/utils.py new file mode 100644 index 000000000..8a2e569cf --- /dev/null +++ b/python/gigl/distributed/server_client/utils.py @@ -0,0 +1,29 @@ +import torch +from graphlearn_torch.distributed import request_server + +from gigl.common.logger import Logger +from gigl.distributed.server_client.remote_dataset import get_node_ids_for_rank +from gigl.env.distributed import GraphStoreInfo +from gigl.src.common.types.graph_data import NodeType +from gigl.types.graph import DEFAULT_HOMOGENEOUS_NODE_TYPE + +logger = Logger() + + +def get_sampler_input_for_inference( + client_rank: int, + cluster_info: GraphStoreInfo, + node_type: NodeType = DEFAULT_HOMOGENEOUS_NODE_TYPE, +) -> list[torch.Tensor]: + sampler_input: list[torch.Tensor] = [] + for server_rank in range(cluster_info.storage_world_size): + world_size = cluster_info.compute_world_size + node_ids = request_server( + server_rank, + get_node_ids_for_rank, + client_rank, + world_size, + node_type, + ) + sampler_input.append(node_ids) + return sampler_input diff --git a/python/gigl/distributed/utils/networking.py b/python/gigl/distributed/utils/networking.py index ea38dbf07..f42f298dd 100644 --- a/python/gigl/distributed/utils/networking.py +++ b/python/gigl/distributed/utils/networking.py @@ -5,8 +5,12 @@ import torch from gigl.common.logger import Logger -from gigl.common.utils.vertex_ai_context import ClusterSpec, get_cluster_spec -from gigl.env.distributed import GraphStoreInfo +from gigl.common.utils.vertex_ai_context import get_cluster_spec +from gigl.env.distributed import ( + GRAPH_STORE_PROCESSES_PER_COMPUTE_VAR_NAME, + GRAPH_STORE_PROCESSES_PER_STORAGE_VAR_NAME, + GraphStoreInfo, +) logger = Logger() @@ -39,7 +43,8 @@ def get_free_ports(num_ports: int) -> list[int]: # OS assigns a free port; we want to keep it open until we have all ports so we only return unique ports s.bind(("", 0)) open_sockets.append(s) - ports.append(s.getsockname()[1]) + port = s.getsockname()[1] + ports.append(port) # Free up ports by closing the sockets for s in open_sockets: s.close() @@ -147,10 +152,17 @@ def get_internal_ip_from_node( # Other nodes will receive the master's IP via broadcast ip_list = [None] - device = "cuda" if torch.cuda.is_available() else "cpu" + device = ( + "cuda" + if torch.distributed.get_backend() == torch.distributed.Backend.NCCL + else "cpu" + ) + logger.info( + f"Rank {rank} broadcasting internal IP list: {ip_list} to rank {node_rank}" + ) torch.distributed.broadcast_object_list(ip_list, src=node_rank, device=device) node_ip = ip_list[0] - logger.info(f"Rank {rank} received master internal IP: {node_ip}") + logger.info(f"Rank {rank} received master node's internal IP: {node_ip}") assert node_ip is not None, "Could not retrieve master node's internal IP" return node_ip @@ -187,11 +199,6 @@ def get_internal_ip_from_all_ranks() -> list[str]: def get_graph_store_info() -> GraphStoreInfo: """ Get the information about the graph store cluster. - MUST be called with a torch.distributed process group initialized, for the *entire* training cluster. - E.g. the process group *must* include both the compute and storage nodes. - - This function should only be called on clusters that are setup by GiGL. - E.g. when GiGLResourceConfig.trainer_resource_config.vertex_ai_graph_store_trainer_config is set. Returns: GraphStoreInfo: The information about the graph store cluster. @@ -203,8 +210,11 @@ def get_graph_store_info() -> GraphStoreInfo: # If we want to ever support other (non-VAI) environments, # we must switch here depending on the environment. cluster_spec = get_cluster_spec() - - _validate_cluster_spec(cluster_spec) + # We setup the VAI cluster such that the compute nodes come first, followed by the storage nodes. + if len(cluster_spec.cluster["workerpool0"]) != 1: + raise ValueError( + f"Expected exactly one machine in workerpool0, but got {len(cluster_spec.cluster['workerpool0'])}" + ) if "workerpool1" in cluster_spec.cluster: num_compute_nodes = len(cluster_spec.cluster["workerpool0"]) + len( @@ -227,10 +237,19 @@ def get_graph_store_info() -> GraphStoreInfo: num_ports=1, node_rank=num_compute_nodes )[0] + num_processes_per_storage = int( + os.environ[GRAPH_STORE_PROCESSES_PER_STORAGE_VAR_NAME] + ) + num_processes_per_compute = int( + os.environ[GRAPH_STORE_PROCESSES_PER_COMPUTE_VAR_NAME] + ) + return GraphStoreInfo( num_cluster_nodes=num_storage_nodes + num_compute_nodes, num_storage_nodes=num_storage_nodes, num_compute_nodes=num_compute_nodes, + num_processes_per_storage=num_processes_per_storage, + num_processes_per_compute=num_processes_per_compute, cluster_master_ip=cluster_master_ip, storage_cluster_master_ip=storage_cluster_master_ip, compute_cluster_master_ip=compute_cluster_master_ip, @@ -238,37 +257,3 @@ def get_graph_store_info() -> GraphStoreInfo: storage_cluster_master_port=storage_cluster_master_port, compute_cluster_master_port=compute_cluster_master_port, ) - - -def _validate_cluster_spec(cluster_spec: ClusterSpec) -> None: - """Validate the cluster spec is setup as we'd expect.""" - - if len(cluster_spec.cluster["workerpool0"]) != 1: - raise ValueError( - f"Expected exactly one machine in workerpool0, but got {len(cluster_spec.cluster['workerpool0'])}" - ) - - # We want to ensure that the cluster is setup as we'd expect. - # e.g. `[[compute0], [compute1, ..., computeN], [storage0, ..., storageN]]` - # So we do this by checking that the task index matches up with the rank. - env_rank = int(os.environ["RANK"]) - if cluster_spec.task.type == "workerpool0": - offset = 0 - elif cluster_spec.task.type == "workerpool1": - offset = len(cluster_spec.cluster["workerpool0"]) - elif cluster_spec.task.type == "workerpool2": - if "workerpool1" in cluster_spec.cluster: - offset = len(cluster_spec.cluster["workerpool0"]) + len( - cluster_spec.cluster["workerpool1"] - ) - else: - offset = len(cluster_spec.cluster["workerpool0"]) - else: - raise ValueError( - f"Expected task type to be workerpool0, workerpool1, or workerpool2, but got {cluster_spec.task.type}" - ) - - if cluster_spec.task.index + offset != env_rank: - raise ValueError( - f"Expected task index to be {env_rank}, but got {cluster_spec.task.index + offset}" - ) diff --git a/python/gigl/env/distributed.py b/python/gigl/env/distributed.py index e8999be67..7527c3585 100644 --- a/python/gigl/env/distributed.py +++ b/python/gigl/env/distributed.py @@ -1,6 +1,7 @@ """Information about distributed environments.""" from dataclasses import dataclass +from typing import Final @dataclass(frozen=True) @@ -21,6 +22,14 @@ class DistributedContext: global_world_size: int +GRAPH_STORE_PROCESSES_PER_STORAGE_VAR_NAME: Final[ + str +] = "GRAPH_STORE_PROCESSES_PER_SERVER" +GRAPH_STORE_PROCESSES_PER_COMPUTE_VAR_NAME: Final[ + str +] = "GRAPH_STORE_PROCESSES_PER_COMPUTE" + + @dataclass(frozen=True) class GraphStoreInfo: """Information about a graph store cluster.""" @@ -32,6 +41,11 @@ class GraphStoreInfo: # Number of nodes in the compute cluster num_compute_nodes: int + # Number of processes per storage machine + num_processes_per_storage: int + # Number of processes per compute machine + num_processes_per_compute: int + # IP address of the master node for the whole cluster cluster_master_ip: str # IP address of the master node for the storage cluster @@ -45,3 +59,19 @@ class GraphStoreInfo: storage_cluster_master_port: int # Port of the master node for the compute cluster compute_cluster_master_port: int + + @property + def storage_world_size(self) -> int: + return self.num_storage_nodes * self.num_processes_per_storage + + @property + def compute_world_size(self) -> int: + return self.num_compute_nodes * self.num_processes_per_compute + + @property + def cluster_world_size(self) -> int: + return ( + self.num_cluster_nodes + * self.num_processes_per_storage + * self.num_processes_per_compute + ) diff --git a/python/gigl/src/common/types/pb_wrappers/gigl_resource_config.py b/python/gigl/src/common/types/pb_wrappers/gigl_resource_config.py index 76c339e84..c7fc04dde 100644 --- a/python/gigl/src/common/types/pb_wrappers/gigl_resource_config.py +++ b/python/gigl/src/common/types/pb_wrappers/gigl_resource_config.py @@ -17,6 +17,7 @@ SharedResourceConfig, SparkResourceConfig, TrainerResourceConfig, + VertexAiGraphStoreConfig, VertexAiResourceConfig, ) @@ -35,11 +36,13 @@ _VERTEX_AI_TRAINER_CONFIG = "vertex_ai_trainer_config" _KFP_TRAINER_CONFIG = "kfp_trainer_config" _LOCAL_TRAINER_CONFIG = "local_trainer_config" +_VERTEX_AI_GRAPH_STORE_TRAINER_CONFIG = "vertex_ai_graph_store_trainer_config" _INFERENCER_CONFIG_FIELD = "inferencer_config" _VERTEX_AI_INFERENCER_CONFIG = "vertex_ai_inferencer_config" _DATAFLOW_INFERENCER_CONFIG = "dataflow_inferencer_config" _LOCAL_INFERENCER_CONFIG = "local_inferencer_config" +_VERTEX_AI_GRAPH_STORE_INFERENCER_CONFIG = "vertex_ai_graph_store_inferencer_config" @dataclass @@ -47,10 +50,20 @@ class GiglResourceConfigWrapper: resource_config: GiglResourceConfig _loaded_shared_resource_config: Optional[SharedResourceConfig] = None _trainer_config: Optional[ - Union[VertexAiResourceConfig, KFPResourceConfig, LocalResourceConfig] + Union[ + VertexAiResourceConfig, + KFPResourceConfig, + LocalResourceConfig, + VertexAiGraphStoreConfig, + ] ] = None _inference_config: Optional[ - Union[DataflowResourceConfig, VertexAiResourceConfig, LocalResourceConfig] + Union[ + DataflowResourceConfig, + VertexAiResourceConfig, + LocalResourceConfig, + VertexAiGraphStoreConfig, + ] ] = None _split_gen_config: Union[ @@ -269,7 +282,12 @@ def vertex_ai_inferencer_region(self) -> str: @property def trainer_config( self, - ) -> Union[VertexAiResourceConfig, KFPResourceConfig, LocalResourceConfig]: + ) -> Union[ + VertexAiResourceConfig, + KFPResourceConfig, + LocalResourceConfig, + VertexAiGraphStoreConfig, + ]: """ Returns the trainer config specified in the resource config. (e.g. Vertex AI, KFP, Local) """ @@ -287,7 +305,10 @@ def trainer_config( self.resource_config.trainer_config ) _trainer_config: Union[ - VertexAiResourceConfig, KFPResourceConfig, LocalResourceConfig + VertexAiResourceConfig, + KFPResourceConfig, + LocalResourceConfig, + VertexAiGraphStoreConfig, ] if deprecated_config.WhichOneof(_TRAINER_CONFIG_FIELD) == _VERTEX_AI_TRAINER_CONFIG: # type: ignore[arg-type] logger.info( @@ -331,18 +352,27 @@ def trainer_config( _trainer_config = config.kfp_trainer_config elif config.WhichOneof(_TRAINER_CONFIG_FIELD) == _LOCAL_TRAINER_CONFIG: # type: ignore[arg-type] _trainer_config = config.local_trainer_config + elif config.WhichOneof(_TRAINER_CONFIG_FIELD) == _VERTEX_AI_GRAPH_STORE_TRAINER_CONFIG: # type: ignore[arg-type] + _trainer_config = config.vertex_ai_graph_store_trainer_config else: raise ValueError(f"Invalid trainer_config type: {config}") else: raise ValueError( f"Trainer config not found in resource config; neither trainer_config nor trainer_resource_config is set: {self.resource_config}" ) - return _trainer_config + self._trainer_config = _trainer_config + + return self._trainer_config @property def inferencer_config( self, - ) -> Union[DataflowResourceConfig, VertexAiResourceConfig, LocalResourceConfig]: + ) -> Union[ + DataflowResourceConfig, + VertexAiResourceConfig, + LocalResourceConfig, + VertexAiGraphStoreConfig, + ]: """ Returns the inferencer config specified in the resource config. (Dataflow) """ @@ -364,6 +394,10 @@ def inferencer_config( self._inference_config = config.local_inferencer_config elif config.WhichOneof(_INFERENCER_CONFIG_FIELD) == _VERTEX_AI_INFERENCER_CONFIG: # type: ignore[arg-type] self._inference_config = config.vertex_ai_inferencer_config + elif config.WhichOneof(_INFERENCER_CONFIG_FIELD) == _VERTEX_AI_GRAPH_STORE_INFERENCER_CONFIG: # type: ignore[arg-type] + self._inference_config = ( + config.vertex_ai_graph_store_inferencer_config + ) else: raise ValueError("Invalid inferencer_config type") else: diff --git a/python/gigl/src/inference/v2/glt_inferencer.py b/python/gigl/src/inference/v2/glt_inferencer.py index 66483b0ff..d28c4fbff 100644 --- a/python/gigl/src/inference/v2/glt_inferencer.py +++ b/python/gigl/src/inference/v2/glt_inferencer.py @@ -1,4 +1,5 @@ import argparse +from collections.abc import Mapping from typing import Optional from google.cloud.aiplatform_v1.types import Scheduling, accelerator_type, env_var @@ -10,6 +11,10 @@ ) from gigl.common.logger import Logger from gigl.common.services.vertex_ai import VertexAiJobConfig, VertexAIService +from gigl.env.distributed import ( + GRAPH_STORE_PROCESSES_PER_COMPUTE_VAR_NAME, + GRAPH_STORE_PROCESSES_PER_STORAGE_VAR_NAME, +) from gigl.env.pipelines_config import get_resource_config from gigl.src.common.constants.components import GiGLComponents from gigl.src.common.types import AppliedTaskIdentifier @@ -20,6 +25,7 @@ from gigl.src.common.utils.metrics_service_provider import initialize_metrics from snapchat.research.gbml.gigl_resource_config_pb2 import ( LocalResourceConfig, + VertexAiGraphStoreConfig, VertexAiResourceConfig, ) @@ -49,44 +55,25 @@ class GLTInferencer: GiGL Component that runs a GLT Inference using a provided class path """ - def __execute_VAI_inference( + def _launch_single_pool( self, + vertex_ai_resource_config: VertexAiResourceConfig, applied_task_identifier: AppliedTaskIdentifier, task_config_uri: Uri, resource_config_uri: Uri, - cpu_docker_uri: Optional[str] = None, - cuda_docker_uri: Optional[str] = None, + inference_process_command: str, + inference_process_runtime_args: Mapping[str, str], + resource_config_wrapper: GiglResourceConfigWrapper, + cpu_docker_uri: Optional[str], + cuda_docker_uri: Optional[str], ) -> None: - resource_config_wrapper: GiglResourceConfigWrapper = get_resource_config( - resource_config_uri=resource_config_uri - ) - gbml_config_pb_wrapper = ( - GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( - gbml_config_uri=task_config_uri - ) - ) - inference_process_command = gbml_config_pb_wrapper.inferencer_config.command - if not inference_process_command: - raise ValueError( - "Currently, GLT Inferencer only supports inferencer process command which" - + f" was not provided in inferencer config: {gbml_config_pb_wrapper.inferencer_config}" - ) - inference_process_runtime_args = ( - gbml_config_pb_wrapper.inferencer_config.inferencer_args - ) - assert isinstance( - resource_config_wrapper.inferencer_config, VertexAiResourceConfig - ) - inferencer_resource_config: VertexAiResourceConfig = ( - resource_config_wrapper.inferencer_config - ) - - is_cpu_training = _determine_if_cpu_inference( - inferencer_resource_config=inferencer_resource_config + """Launch a single pool inference job on Vertex AI.""" + is_cpu_inference = _determine_if_cpu_inference( + inferencer_resource_config=vertex_ai_resource_config ) cpu_docker_uri = cpu_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU cuda_docker_uri = cuda_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA - container_uri = cpu_docker_uri if is_cpu_training else cuda_docker_uri + container_uri = cpu_docker_uri if is_cpu_inference else cuda_docker_uri job_args = ( [ @@ -94,7 +81,7 @@ def __execute_VAI_inference( f"--task_config_uri={task_config_uri}", f"--resource_config_uri={resource_config_uri}", ] - + ([] if is_cpu_training else ["--use_cuda"]) + + ([] if is_cpu_inference else ["--use_cuda"]) + ([f"--{k}={v}" for k, v in inference_process_runtime_args.items()]) ) @@ -110,17 +97,17 @@ def __execute_VAI_inference( command=command, args=job_args, environment_variables=environment_variables, - machine_type=inferencer_resource_config.machine_type, - accelerator_type=inferencer_resource_config.gpu_type.upper().replace( + machine_type=vertex_ai_resource_config.machine_type, + accelerator_type=vertex_ai_resource_config.gpu_type.upper().replace( "-", "_" ), - accelerator_count=inferencer_resource_config.gpu_limit, - replica_count=inferencer_resource_config.num_replicas, + accelerator_count=vertex_ai_resource_config.gpu_limit, + replica_count=vertex_ai_resource_config.num_replicas, labels=resource_config_wrapper.get_resource_labels( component=GiGLComponents.Inferencer ), - timeout_s=inferencer_resource_config.timeout - if inferencer_resource_config.timeout + timeout_s=vertex_ai_resource_config.timeout + if vertex_ai_resource_config.timeout else None, # This should be `aiplatform.gapic.Scheduling.Strategy[inferencer_resource_config.scheduling_strategy]` # But mypy complains otherwise... @@ -128,9 +115,9 @@ def __execute_VAI_inference( # TODO(kmonte): Fix this scheduling_strategy=getattr( Scheduling.Strategy, - inferencer_resource_config.scheduling_strategy, + vertex_ai_resource_config.scheduling_strategy, ) - if inferencer_resource_config.scheduling_strategy + if vertex_ai_resource_config.scheduling_strategy else None, ) @@ -142,6 +129,213 @@ def __execute_VAI_inference( ) vertex_ai_service.launch_job(job_config=job_config) + def _launch_server_client( + self, + vertex_ai_graph_store_config: VertexAiGraphStoreConfig, + applied_task_identifier: AppliedTaskIdentifier, + task_config_uri: Uri, + resource_config_uri: Uri, + inference_process_command: str, + inference_process_runtime_args: Mapping[str, str], + resource_config_wrapper: GiglResourceConfigWrapper, + cpu_docker_uri: Optional[str], + cuda_docker_uri: Optional[str], + ) -> None: + """Launch a server/client inference job on Vertex AI using graph store config.""" + storage_pool_config = vertex_ai_graph_store_config.graph_store_pool + compute_pool_config = vertex_ai_graph_store_config.compute_pool + + # Determine if CPU or GPU based on compute pool + is_cpu_inference = _determine_if_cpu_inference( + inferencer_resource_config=compute_pool_config + ) + cpu_docker_uri = cpu_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU + cuda_docker_uri = cuda_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA + container_uri = cpu_docker_uri if is_cpu_inference else cuda_docker_uri + + compute_job_args = ( + [ + f"--job_name={applied_task_identifier}", + f"--task_config_uri={task_config_uri}", + f"--resource_config_uri={resource_config_uri}", + ] + + ([] if is_cpu_inference else ["--use_cuda"]) + + ([f"--{k}={v}" for k, v in inference_process_runtime_args.items()]) + ) + + command = inference_process_command.strip().split(" ") + logger.info(f"Running inference with command: {command}") + vai_job_name = f"gigl_infer_{applied_task_identifier}" + num_storage_processes = ( + vertex_ai_graph_store_config.num_processes_per_storage_machine + ) + if not num_storage_processes: + num_storage_processes = 1 + num_compute_processes = ( + vertex_ai_graph_store_config.num_processes_per_compute_machine + ) + if not num_compute_processes: + if is_cpu_inference: + num_compute_processes = 1 + else: + num_compute_processes = ( + vertex_ai_graph_store_config.compute_pool.gpu_limit + ) + # Add server/client environment variables + environment_variables: list[env_var.EnvVar] = [ + env_var.EnvVar(name="TF_CPP_MIN_LOG_LEVEL", value="3"), + env_var.EnvVar( + name=GRAPH_STORE_PROCESSES_PER_STORAGE_VAR_NAME, + value=str(num_storage_processes), + ), + env_var.EnvVar( + name=GRAPH_STORE_PROCESSES_PER_COMPUTE_VAR_NAME, + value=str(num_compute_processes), + ), + ] + + # Create compute pool job config + compute_job_config = VertexAiJobConfig( + job_name=vai_job_name, + container_uri=container_uri, + command=command, + args=compute_job_args, + environment_variables=environment_variables, + machine_type=compute_pool_config.machine_type, + accelerator_type=compute_pool_config.gpu_type.upper().replace("-", "_"), + accelerator_count=compute_pool_config.gpu_limit, + replica_count=compute_pool_config.num_replicas, + labels=resource_config_wrapper.get_resource_labels( + component=GiGLComponents.Inferencer + ), + timeout_s=compute_pool_config.timeout + if compute_pool_config.timeout + else None, + # This should be `aiplatform.gapic.Scheduling.Strategy[inferencer_resource_config.scheduling_strategy]` + # But mypy complains otherwise... + # python/gigl/src/inference/v2/glt_inferencer.py:124: error: The type "type[Strategy]" is not generic and not indexable [misc] + # TODO(kmonte): Fix this + scheduling_strategy=getattr( + Scheduling.Strategy, + compute_pool_config.scheduling_strategy, + ) + if compute_pool_config.scheduling_strategy + else None, + ) + + # Create storage pool job config + storage_job_args = [ + f"--job_name={applied_task_identifier}", + f"--task_config_uri={task_config_uri}", + f"--resource_config_uri={resource_config_uri}", + ] + ([] if is_cpu_inference else ["--use_cuda"]) + storage_job_command = [ + "python", + "-m", + "gigl.distributed.server_client.server_main", + ] + storage_job_config = VertexAiJobConfig( + job_name=vai_job_name, # Will be ignored, using compute pool's job name + container_uri=container_uri, + command=storage_job_command, + args=storage_job_args, + environment_variables=environment_variables, + machine_type=storage_pool_config.machine_type, + accelerator_type=storage_pool_config.gpu_type.upper().replace("-", "_"), + accelerator_count=storage_pool_config.gpu_limit, + replica_count=storage_pool_config.num_replicas, + labels=resource_config_wrapper.get_resource_labels( + component=GiGLComponents.Inferencer + ), + # This should be `aiplatform.gapic.Scheduling.Strategy[inferencer_resource_config.scheduling_strategy]` + # But mypy complains otherwise... + # python/gigl/src/inference/v2/glt_inferencer.py:124: error: The type "type[Strategy]" is not generic and not indexable [misc] + # TODO(kmonte): Fix this + scheduling_strategy=getattr( + Scheduling.Strategy, + storage_pool_config.scheduling_strategy, + ) + if storage_pool_config.scheduling_strategy + else None, + ) + + # Determine region from compute pool or use default region + region = ( + compute_pool_config.gcp_region_override + if compute_pool_config.gcp_region_override + else resource_config_wrapper.region + ) + + vertex_ai_service = VertexAIService( + project=resource_config_wrapper.project, + location=region, + service_account=resource_config_wrapper.service_account_email, + staging_bucket=resource_config_wrapper.temp_assets_regional_bucket_path.uri, + ) + vertex_ai_service.launch_graph_store_job( + compute_pool_job_config=compute_job_config, + storage_pool_job_config=storage_job_config, + ) + + def __execute_VAI_inference( + self, + applied_task_identifier: AppliedTaskIdentifier, + task_config_uri: Uri, + resource_config_uri: Uri, + cpu_docker_uri: Optional[str] = None, + cuda_docker_uri: Optional[str] = None, + ) -> None: + resource_config_wrapper: GiglResourceConfigWrapper = get_resource_config( + resource_config_uri=resource_config_uri + ) + gbml_config_pb_wrapper = ( + GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( + gbml_config_uri=task_config_uri + ) + ) + inference_process_command = gbml_config_pb_wrapper.inferencer_config.command + if not inference_process_command: + raise ValueError( + "Currently, GLT Inferencer only supports inferencer process command which" + + f" was not provided in inferencer config: {gbml_config_pb_wrapper.inferencer_config}" + ) + inference_process_runtime_args = ( + gbml_config_pb_wrapper.inferencer_config.inferencer_args + ) + + if isinstance( + resource_config_wrapper.inferencer_config, VertexAiResourceConfig + ): + self._launch_single_pool( + vertex_ai_resource_config=resource_config_wrapper.inferencer_config, + applied_task_identifier=applied_task_identifier, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + inference_process_command=inference_process_command, + inference_process_runtime_args=inference_process_runtime_args, + resource_config_wrapper=resource_config_wrapper, + cpu_docker_uri=cpu_docker_uri, + cuda_docker_uri=cuda_docker_uri, + ) + elif isinstance( + resource_config_wrapper.inferencer_config, VertexAiGraphStoreConfig + ): + self._launch_server_client( + vertex_ai_graph_store_config=resource_config_wrapper.inferencer_config, + applied_task_identifier=applied_task_identifier, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + inference_process_command=inference_process_command, + inference_process_runtime_args=inference_process_runtime_args, + resource_config_wrapper=resource_config_wrapper, + cpu_docker_uri=cpu_docker_uri, + cuda_docker_uri=cuda_docker_uri, + ) + else: + raise NotImplementedError( + f"Unsupported resource config for glt inference: {type(resource_config_wrapper.inferencer_config).__name__}" + ) + def run( self, applied_task_identifier: AppliedTaskIdentifier, @@ -157,10 +351,12 @@ def run( if isinstance(resource_config_wrapper.inferencer_config, LocalResourceConfig): raise NotImplementedError( - f"Local GLT Inferencer is not yet supported, please specify a {VertexAiResourceConfig.__name__} resource config field." + f"Local GLT Inferencer is not yet supported, please specify a {VertexAiResourceConfig.__name__} or {VertexAiGraphStoreConfig.__name__} resource config field." ) elif isinstance( resource_config_wrapper.inferencer_config, VertexAiResourceConfig + ) or isinstance( + resource_config_wrapper.inferencer_config, VertexAiGraphStoreConfig ): self.__execute_VAI_inference( applied_task_identifier=applied_task_identifier, diff --git a/python/gigl/src/mocking/dataset_asset_mocking_suite.py b/python/gigl/src/mocking/dataset_asset_mocking_suite.py index 2d2054bb7..7320ca47e 100644 --- a/python/gigl/src/mocking/dataset_asset_mocking_suite.py +++ b/python/gigl/src/mocking/dataset_asset_mocking_suite.py @@ -570,8 +570,8 @@ def compute_datasets_to_mock( for attr in dir(self) if callable(getattr(self, attr)) and attr.startswith("mock") ] - print(f"All mocking functions: {all_mocking_func_names}") - print(f"Selected datasets: {selected_datasets}") + # print(f"All mocking functions: {all_mocking_func_names}") + # print(f"Selected datasets: {selected_datasets}") mocking_func_names: list[str] if selected_datasets: @@ -593,7 +593,9 @@ def compute_datasets_to_mock( mocked_dataset_info = mocking_func() mocked_datasets[mocked_dataset_info.name] = mocked_dataset_info - logger.info(f"Mocked datasets registered successfully: {list(mocked_datasets)}") + logger.debug( + f"Mocked datasets registered successfully: {list(mocked_datasets)}" + ) return mocked_datasets diff --git a/python/gigl/src/training/v2/glt_trainer.py b/python/gigl/src/training/v2/glt_trainer.py index 13f249f84..cf8a2f376 100644 --- a/python/gigl/src/training/v2/glt_trainer.py +++ b/python/gigl/src/training/v2/glt_trainer.py @@ -1,4 +1,5 @@ import argparse +from collections.abc import Mapping from typing import Optional from google.cloud.aiplatform_v1.types import Scheduling, accelerator_type, env_var @@ -52,37 +53,20 @@ class GLTTrainer: GiGL Component that runs a GLT Training using a provided class path """ - def __execute_VAI_training( + def _launch_single_pool_training( self, + vertex_ai_resource_config: VertexAiResourceConfig, applied_task_identifier: AppliedTaskIdentifier, task_config_uri: Uri, resource_config_uri: Uri, - cpu_docker_uri: Optional[str] = None, - cuda_docker_uri: Optional[str] = None, + training_process_command: str, + training_process_runtime_args: Mapping[str, str], + resource_config: GiglResourceConfigWrapper, + cpu_docker_uri: Optional[str], + cuda_docker_uri: Optional[str], ) -> None: - resource_config: GiglResourceConfigWrapper = get_resource_config( - resource_config_uri=resource_config_uri - ) - gbml_config_pb_wrapper = ( - GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( - gbml_config_uri=task_config_uri - ) - ) - training_process_command = gbml_config_pb_wrapper.trainer_config.command - if not training_process_command: - raise ValueError( - "Currently, GLT Trainer only supports training process command which" - + f" was not provided in trainer config: {gbml_config_pb_wrapper.trainer_config}" - ) - training_process_runtime_args = ( - gbml_config_pb_wrapper.trainer_config.trainer_args - ) - - assert isinstance(resource_config.trainer_config, VertexAiResourceConfig) - trainer_resource_config: VertexAiResourceConfig = resource_config.trainer_config - is_cpu_training = _determine_if_cpu_training( - trainer_resource_config=trainer_resource_config + trainer_resource_config=vertex_ai_resource_config ) cpu_docker_uri = cpu_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU cuda_docker_uri = cuda_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA @@ -110,15 +94,17 @@ def __execute_VAI_training( command=command, args=job_args, environment_variables=environment_variables, - machine_type=trainer_resource_config.machine_type, - accelerator_type=trainer_resource_config.gpu_type.upper().replace("-", "_"), - accelerator_count=trainer_resource_config.gpu_limit, - replica_count=trainer_resource_config.num_replicas, + machine_type=vertex_ai_resource_config.machine_type, + accelerator_type=vertex_ai_resource_config.gpu_type.upper().replace( + "-", "_" + ), + accelerator_count=vertex_ai_resource_config.gpu_limit, + replica_count=vertex_ai_resource_config.num_replicas, labels=resource_config.get_resource_labels( component=GiGLComponents.Trainer ), - timeout_s=trainer_resource_config.timeout - if trainer_resource_config.timeout + timeout_s=vertex_ai_resource_config.timeout + if vertex_ai_resource_config.timeout else None, # This should be `aiplatform.gapic.Scheduling.Strategy[trainer_resource_config.scheduling_strategy]` # But mypy complains otherwise... @@ -126,9 +112,9 @@ def __execute_VAI_training( # TODO(kmonte): Fix this scheduling_strategy=getattr( Scheduling.Strategy, - trainer_resource_config.scheduling_strategy, + vertex_ai_resource_config.scheduling_strategy, ) - if trainer_resource_config.scheduling_strategy + if vertex_ai_resource_config.scheduling_strategy else None, ) @@ -140,6 +126,49 @@ def __execute_VAI_training( ) vertex_ai_service.launch_job(job_config=job_config) + def __execute_VAI_training( + self, + applied_task_identifier: AppliedTaskIdentifier, + task_config_uri: Uri, + resource_config_uri: Uri, + cpu_docker_uri: Optional[str] = None, + cuda_docker_uri: Optional[str] = None, + ) -> None: + resource_config: GiglResourceConfigWrapper = get_resource_config( + resource_config_uri=resource_config_uri + ) + gbml_config_pb_wrapper = ( + GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( + gbml_config_uri=task_config_uri + ) + ) + training_process_command = gbml_config_pb_wrapper.trainer_config.command + if not training_process_command: + raise ValueError( + "Currently, GLT Trainer only supports training process command which" + + f" was not provided in trainer config: {gbml_config_pb_wrapper.trainer_config}" + ) + training_process_runtime_args = ( + gbml_config_pb_wrapper.trainer_config.trainer_args + ) + + if isinstance(resource_config.trainer_config, VertexAiResourceConfig): + self._launch_single_pool_training( + vertex_ai_resource_config=resource_config.trainer_config, + applied_task_identifier=applied_task_identifier, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + training_process_command=training_process_command, + training_process_runtime_args=training_process_runtime_args, + resource_config=resource_config, + cpu_docker_uri=cpu_docker_uri, + cuda_docker_uri=cuda_docker_uri, + ) + else: + raise NotImplementedError( + f"Unsupported resource config for glt training: {type(resource_config.trainer_config).__name__}" + ) + def run( self, applied_task_identifier: AppliedTaskIdentifier, diff --git a/python/gigl/src/validation_check/libs/resource_config_checks.py b/python/gigl/src/validation_check/libs/resource_config_checks.py index d688f0477..af297289f 100644 --- a/python/gigl/src/validation_check/libs/resource_config_checks.py +++ b/python/gigl/src/validation_check/libs/resource_config_checks.py @@ -143,6 +143,7 @@ def check_if_trainer_resource_config_valid( gigl_resource_config_pb2.LocalResourceConfig, gigl_resource_config_pb2.VertexAiResourceConfig, gigl_resource_config_pb2.KFPResourceConfig, + gigl_resource_config_pb2.VertexAiGraphStoreConfig, ] = wrapper.trainer_config if isinstance(trainer_config, gigl_resource_config_pb2.LocalResourceConfig): @@ -152,9 +153,20 @@ def check_if_trainer_resource_config_valid( else: # Case where trainer config is gigl_resource_config_pb2.VertexAiResourceConfig or gigl_resource_config_pb2.KFPResourceConfig if isinstance(trainer_config, gigl_resource_config_pb2.VertexAiResourceConfig): - assert_proto_field_value_is_truthy( - proto=trainer_config, field_name="machine_type" + _validate_vertex_ai_resource_config( + vertex_ai_resource_config_pb=trainer_config ) + elif isinstance( + trainer_config, gigl_resource_config_pb2.VertexAiGraphStoreConfig + ): + _validate_vertex_ai_resource_config( + vertex_ai_resource_config_pb=trainer_config.graph_store_pool + ) + _validate_accelerator_type(proto_config=trainer_config.graph_store_pool) + _validate_vertex_ai_resource_config( + vertex_ai_resource_config_pb=trainer_config.compute_pool + ) + _validate_accelerator_type(proto_config=trainer_config.compute_pool) elif isinstance(trainer_config, gigl_resource_config_pb2.KFPResourceConfig): for field in [ "cpu_request", @@ -165,9 +177,9 @@ def check_if_trainer_resource_config_valid( ) else: raise ValueError( - f"""Expected distributed trainer config to be one of {gigl_resource_config_pb2.LocalResourceConfig.__name__}, - {gigl_resource_config_pb2.VertexAiResourceConfig.__name__}, - or {gigl_resource_config_pb2.KFPResourceConfig.__name__}. + f"""Expected distributed trainer config to be one of {gigl_resource_config_pb2.LocalResourceConfig.__name__}, + {gigl_resource_config_pb2.VertexAiResourceConfig.__name__}, + or {gigl_resource_config_pb2.KFPResourceConfig.__name__}. Got {type(trainer_config)}""" ) @@ -177,16 +189,10 @@ def check_if_trainer_resource_config_valid( ]: assert_proto_field_value_is_truthy(proto=trainer_config, field_name=field) - if trainer_config.gpu_type == AcceleratorType.ACCELERATOR_TYPE_UNSPECIFIED.name: # type: ignore - assert ( - trainer_config.gpu_limit == 0 - ), f"""gpu_limit must be equal to 0 for cpu training, indicated by provided gpu_type {trainer_config.gpu_type}. - Got gpu_limit {trainer_config.gpu_limit}""" - else: - assert ( - trainer_config.gpu_limit > 0 - ), f"""gpu_limit must be greater than 0 for gpu training, indicated by provided gpu_type {trainer_config.gpu_type}. - Got gpu_limit {trainer_config.gpu_limit}. Use gpu_type {AcceleratorType.ACCELERATOR_TYPE_UNSPECIFIED.name} for cpu training.""" # type: ignore + if not isinstance( + trainer_config, gigl_resource_config_pb2.VertexAiGraphStoreConfig + ): + _validate_accelerator_type(proto_config=trainer_config) def check_if_inferencer_resource_config_valid( @@ -216,12 +222,12 @@ def check_if_inferencer_resource_config_valid( if inferencer_config.gpu_type == AcceleratorType.ACCELERATOR_TYPE_UNSPECIFIED.name: # type: ignore assert ( inferencer_config.gpu_limit == 0 - ), f"""gpu_limit must be equal to 0 for cpu training, indicated by provided gpu_type {inferencer_config.gpu_type}. + ), f"""gpu_limit must be equal to 0 for cpu training, indicated by provided gpu_type {inferencer_config.gpu_type}. Got gpu_limit {inferencer_config.gpu_limit}""" else: assert ( inferencer_config.gpu_limit > 0 - ), f"""gpu_limit must be greater than 0 for gpu training, indicated by provided gpu_type {inferencer_config.gpu_type}. + ), f"""gpu_limit must be greater than 0 for gpu training, indicated by provided gpu_type {inferencer_config.gpu_type}. Got gpu_limit {inferencer_config.gpu_limit}. Use gpu_type {AcceleratorType.ACCELERATOR_TYPE_UNSPECIFIED.name} for cpu training.""" # type: ignore elif isinstance(inferencer_config, gigl_resource_config_pb2.LocalResourceConfig): assert_proto_field_value_is_truthy( @@ -229,8 +235,40 @@ def check_if_inferencer_resource_config_valid( ) else: raise ValueError( - f"""Expected inferencer config to be one of {gigl_resource_config_pb2.DataflowResourceConfig.__name__}, - {gigl_resource_config_pb2.VertexAiResourceConfig.__name__}, - or {gigl_resource_config_pb2.LocalResourceConfig.__name__}. + f"""Expected inferencer config to be one of {gigl_resource_config_pb2.DataflowResourceConfig.__name__}, + {gigl_resource_config_pb2.VertexAiResourceConfig.__name__}, + or {gigl_resource_config_pb2.LocalResourceConfig.__name__}. Got {type(inferencer_config)}""" ) + + +def _validate_vertex_ai_resource_config( + vertex_ai_resource_config_pb: gigl_resource_config_pb2.VertexAiResourceConfig, +) -> None: + """ + Checks if the provided Vertex AI resource configuration is valid. + """ + assert_proto_field_value_is_truthy( + proto=vertex_ai_resource_config_pb, field_name="machine_type" + ) + + +def _validate_accelerator_type( + proto_config: Union[ + gigl_resource_config_pb2.VertexAiResourceConfig, + gigl_resource_config_pb2.KFPResourceConfig, + ], +) -> None: + """ + Checks if the provided accelerator type is valid. + """ + if proto_config.gpu_type == AcceleratorType.ACCELERATOR_TYPE_UNSPECIFIED.name: # type: ignore + assert ( + proto_config.gpu_limit == 0 + ), f"""gpu_limit must be equal to 0 for cpu training, indicated by provided gpu_type {proto_config.gpu_type}. + Got gpu_limit {proto_config.gpu_limit}""" + else: + assert ( + proto_config.gpu_limit > 0 + ), f"""gpu_limit must be greater than 0 for gpu training, indicated by provided gpu_type {proto_config.gpu_type}. + Got gpu_limit {proto_config.gpu_limit}. Use gpu_type {AcceleratorType.ACCELERATOR_TYPE_UNSPECIFIED.name} for cpu training.""" # type: ignore diff --git a/python/snapchat/research/gbml/gigl_resource_config_pb2.py b/python/snapchat/research/gbml/gigl_resource_config_pb2.py index 36be03847..61424eba8 100644 --- a/python/snapchat/research/gbml/gigl_resource_config_pb2.py +++ b/python/snapchat/research/gbml/gigl_resource_config_pb2.py @@ -15,7 +15,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n1snapchat/research/gbml/gigl_resource_config.proto\x12\x16snapchat.research.gbml\"Y\n\x13SparkResourceConfig\x12\x14\n\x0cmachine_type\x18\x01 \x01(\t\x12\x16\n\x0enum_local_ssds\x18\x02 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x03 \x01(\r\"r\n\x16\x44\x61taflowResourceConfig\x12\x13\n\x0bnum_workers\x18\x01 \x01(\r\x12\x17\n\x0fmax_num_workers\x18\x02 \x01(\r\x12\x14\n\x0cmachine_type\x18\x03 \x01(\t\x12\x14\n\x0c\x64isk_size_gb\x18\x04 \x01(\r\"\xbc\x01\n\x16\x44\x61taPreprocessorConfig\x12P\n\x18\x65\x64ge_preprocessor_config\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfig\x12P\n\x18node_preprocessor_config\x18\x02 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfig\"h\n\x15VertexAiTrainerConfig\x12\x14\n\x0cmachine_type\x18\x01 \x01(\t\x12\x10\n\x08gpu_type\x18\x02 \x01(\t\x12\x11\n\tgpu_limit\x18\x03 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x04 \x01(\r\"z\n\x10KFPTrainerConfig\x12\x13\n\x0b\x63pu_request\x18\x01 \x01(\t\x12\x16\n\x0ememory_request\x18\x02 \x01(\t\x12\x10\n\x08gpu_type\x18\x03 \x01(\t\x12\x11\n\tgpu_limit\x18\x04 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x05 \x01(\r\")\n\x12LocalTrainerConfig\x12\x13\n\x0bnum_workers\x18\x01 \x01(\r\"\xb4\x01\n\x16VertexAiResourceConfig\x12\x14\n\x0cmachine_type\x18\x01 \x01(\t\x12\x10\n\x08gpu_type\x18\x02 \x01(\t\x12\x11\n\tgpu_limit\x18\x03 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x04 \x01(\r\x12\x0f\n\x07timeout\x18\x05 \x01(\r\x12\x1b\n\x13gcp_region_override\x18\x06 \x01(\t\x12\x1b\n\x13scheduling_strategy\x18\x07 \x01(\t\"{\n\x11KFPResourceConfig\x12\x13\n\x0b\x63pu_request\x18\x01 \x01(\t\x12\x16\n\x0ememory_request\x18\x02 \x01(\t\x12\x10\n\x08gpu_type\x18\x03 \x01(\t\x12\x11\n\tgpu_limit\x18\x04 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x05 \x01(\r\"*\n\x13LocalResourceConfig\x12\x13\n\x0bnum_workers\x18\x01 \x01(\r\"\xaa\x01\n\x18VertexAiGraphStoreConfig\x12H\n\x10graph_store_pool\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfig\x12\x44\n\x0c\x63ompute_pool\x18\x02 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfig\"\x93\x02\n\x18\x44istributedTrainerConfig\x12Q\n\x18vertex_ai_trainer_config\x18\x01 \x01(\x0b\x32-.snapchat.research.gbml.VertexAiTrainerConfigH\x00\x12\x46\n\x12kfp_trainer_config\x18\x02 \x01(\x0b\x32(.snapchat.research.gbml.KFPTrainerConfigH\x00\x12J\n\x14local_trainer_config\x18\x03 \x01(\x0b\x32*.snapchat.research.gbml.LocalTrainerConfigH\x00\x42\x10\n\x0etrainer_config\"\xf5\x02\n\x15TrainerResourceConfig\x12R\n\x18vertex_ai_trainer_config\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfigH\x00\x12G\n\x12kfp_trainer_config\x18\x02 \x01(\x0b\x32).snapchat.research.gbml.KFPResourceConfigH\x00\x12K\n\x14local_trainer_config\x18\x03 \x01(\x0b\x32+.snapchat.research.gbml.LocalResourceConfigH\x00\x12`\n$vertex_ai_graph_store_trainer_config\x18\x04 \x01(\x0b\x32\x30.snapchat.research.gbml.VertexAiGraphStoreConfigH\x00\x42\x10\n\x0etrainer_config\"\x91\x03\n\x18InferencerResourceConfig\x12U\n\x1bvertex_ai_inferencer_config\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfigH\x00\x12T\n\x1a\x64\x61taflow_inferencer_config\x18\x02 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfigH\x00\x12N\n\x17local_inferencer_config\x18\x03 \x01(\x0b\x32+.snapchat.research.gbml.LocalResourceConfigH\x00\x12\x63\n\'vertex_ai_graph_store_inferencer_config\x18\x04 \x01(\x0b\x32\x30.snapchat.research.gbml.VertexAiGraphStoreConfigH\x00\x42\x13\n\x11inferencer_config\"\xa3\x04\n\x14SharedResourceConfig\x12Y\n\x0fresource_labels\x18\x01 \x03(\x0b\x32@.snapchat.research.gbml.SharedResourceConfig.ResourceLabelsEntry\x12_\n\x15\x63ommon_compute_config\x18\x02 \x01(\x0b\x32@.snapchat.research.gbml.SharedResourceConfig.CommonComputeConfig\x1a\x97\x02\n\x13\x43ommonComputeConfig\x12\x0f\n\x07project\x18\x01 \x01(\t\x12\x0e\n\x06region\x18\x02 \x01(\t\x12\x1a\n\x12temp_assets_bucket\x18\x03 \x01(\t\x12#\n\x1btemp_regional_assets_bucket\x18\x04 \x01(\t\x12\x1a\n\x12perm_assets_bucket\x18\x05 \x01(\t\x12#\n\x1btemp_assets_bq_dataset_name\x18\x06 \x01(\t\x12!\n\x19\x65mbedding_bq_dataset_name\x18\x07 \x01(\t\x12!\n\x19gcp_service_account_email\x18\x08 \x01(\t\x12\x17\n\x0f\x64\x61taflow_runner\x18\x0b \x01(\t\x1a\x35\n\x13ResourceLabelsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xc8\x05\n\x12GiglResourceConfig\x12$\n\x1ashared_resource_config_uri\x18\x01 \x01(\tH\x00\x12N\n\x16shared_resource_config\x18\x02 \x01(\x0b\x32,.snapchat.research.gbml.SharedResourceConfigH\x00\x12K\n\x13preprocessor_config\x18\x0c \x01(\x0b\x32..snapchat.research.gbml.DataPreprocessorConfig\x12L\n\x17subgraph_sampler_config\x18\r \x01(\x0b\x32+.snapchat.research.gbml.SparkResourceConfig\x12K\n\x16split_generator_config\x18\x0e \x01(\x0b\x32+.snapchat.research.gbml.SparkResourceConfig\x12L\n\x0etrainer_config\x18\x0f \x01(\x0b\x32\x30.snapchat.research.gbml.DistributedTrainerConfigB\x02\x18\x01\x12M\n\x11inferencer_config\x18\x10 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfigB\x02\x18\x01\x12N\n\x17trainer_resource_config\x18\x11 \x01(\x0b\x32-.snapchat.research.gbml.TrainerResourceConfig\x12T\n\x1ainferencer_resource_config\x18\x12 \x01(\x0b\x32\x30.snapchat.research.gbml.InferencerResourceConfigB\x11\n\x0fshared_resource*\xf3\x01\n\tComponent\x12\x15\n\x11\x43omponent_Unknown\x10\x00\x12\x1e\n\x1a\x43omponent_Config_Validator\x10\x01\x12\x1e\n\x1a\x43omponent_Config_Populator\x10\x02\x12\x1f\n\x1b\x43omponent_Data_Preprocessor\x10\x03\x12\x1e\n\x1a\x43omponent_Subgraph_Sampler\x10\x04\x12\x1d\n\x19\x43omponent_Split_Generator\x10\x05\x12\x15\n\x11\x43omponent_Trainer\x10\x06\x12\x18\n\x14\x43omponent_Inferencer\x10\x07\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n1snapchat/research/gbml/gigl_resource_config.proto\x12\x16snapchat.research.gbml\"Y\n\x13SparkResourceConfig\x12\x14\n\x0cmachine_type\x18\x01 \x01(\t\x12\x16\n\x0enum_local_ssds\x18\x02 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x03 \x01(\r\"r\n\x16\x44\x61taflowResourceConfig\x12\x13\n\x0bnum_workers\x18\x01 \x01(\r\x12\x17\n\x0fmax_num_workers\x18\x02 \x01(\r\x12\x14\n\x0cmachine_type\x18\x03 \x01(\t\x12\x14\n\x0c\x64isk_size_gb\x18\x04 \x01(\r\"\xbc\x01\n\x16\x44\x61taPreprocessorConfig\x12P\n\x18\x65\x64ge_preprocessor_config\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfig\x12P\n\x18node_preprocessor_config\x18\x02 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfig\"h\n\x15VertexAiTrainerConfig\x12\x14\n\x0cmachine_type\x18\x01 \x01(\t\x12\x10\n\x08gpu_type\x18\x02 \x01(\t\x12\x11\n\tgpu_limit\x18\x03 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x04 \x01(\r\"z\n\x10KFPTrainerConfig\x12\x13\n\x0b\x63pu_request\x18\x01 \x01(\t\x12\x16\n\x0ememory_request\x18\x02 \x01(\t\x12\x10\n\x08gpu_type\x18\x03 \x01(\t\x12\x11\n\tgpu_limit\x18\x04 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x05 \x01(\r\")\n\x12LocalTrainerConfig\x12\x13\n\x0bnum_workers\x18\x01 \x01(\r\"\xb4\x01\n\x16VertexAiResourceConfig\x12\x14\n\x0cmachine_type\x18\x01 \x01(\t\x12\x10\n\x08gpu_type\x18\x02 \x01(\t\x12\x11\n\tgpu_limit\x18\x03 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x04 \x01(\r\x12\x0f\n\x07timeout\x18\x05 \x01(\r\x12\x1b\n\x13gcp_region_override\x18\x06 \x01(\t\x12\x1b\n\x13scheduling_strategy\x18\x07 \x01(\t\"{\n\x11KFPResourceConfig\x12\x13\n\x0b\x63pu_request\x18\x01 \x01(\t\x12\x16\n\x0ememory_request\x18\x02 \x01(\t\x12\x10\n\x08gpu_type\x18\x03 \x01(\t\x12\x11\n\tgpu_limit\x18\x04 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x05 \x01(\r\"*\n\x13LocalResourceConfig\x12\x13\n\x0bnum_workers\x18\x01 \x01(\r\"\x80\x02\n\x18VertexAiGraphStoreConfig\x12H\n\x10graph_store_pool\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfig\x12\x44\n\x0c\x63ompute_pool\x18\x02 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfig\x12)\n!num_processes_per_storage_machine\x18\x03 \x01(\x05\x12)\n!num_processes_per_compute_machine\x18\x04 \x01(\x05\"\x93\x02\n\x18\x44istributedTrainerConfig\x12Q\n\x18vertex_ai_trainer_config\x18\x01 \x01(\x0b\x32-.snapchat.research.gbml.VertexAiTrainerConfigH\x00\x12\x46\n\x12kfp_trainer_config\x18\x02 \x01(\x0b\x32(.snapchat.research.gbml.KFPTrainerConfigH\x00\x12J\n\x14local_trainer_config\x18\x03 \x01(\x0b\x32*.snapchat.research.gbml.LocalTrainerConfigH\x00\x42\x10\n\x0etrainer_config\"\xf5\x02\n\x15TrainerResourceConfig\x12R\n\x18vertex_ai_trainer_config\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfigH\x00\x12G\n\x12kfp_trainer_config\x18\x02 \x01(\x0b\x32).snapchat.research.gbml.KFPResourceConfigH\x00\x12K\n\x14local_trainer_config\x18\x03 \x01(\x0b\x32+.snapchat.research.gbml.LocalResourceConfigH\x00\x12`\n$vertex_ai_graph_store_trainer_config\x18\x04 \x01(\x0b\x32\x30.snapchat.research.gbml.VertexAiGraphStoreConfigH\x00\x42\x10\n\x0etrainer_config\"\x91\x03\n\x18InferencerResourceConfig\x12U\n\x1bvertex_ai_inferencer_config\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfigH\x00\x12T\n\x1a\x64\x61taflow_inferencer_config\x18\x02 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfigH\x00\x12N\n\x17local_inferencer_config\x18\x03 \x01(\x0b\x32+.snapchat.research.gbml.LocalResourceConfigH\x00\x12\x63\n\'vertex_ai_graph_store_inferencer_config\x18\x04 \x01(\x0b\x32\x30.snapchat.research.gbml.VertexAiGraphStoreConfigH\x00\x42\x13\n\x11inferencer_config\"\xa3\x04\n\x14SharedResourceConfig\x12Y\n\x0fresource_labels\x18\x01 \x03(\x0b\x32@.snapchat.research.gbml.SharedResourceConfig.ResourceLabelsEntry\x12_\n\x15\x63ommon_compute_config\x18\x02 \x01(\x0b\x32@.snapchat.research.gbml.SharedResourceConfig.CommonComputeConfig\x1a\x97\x02\n\x13\x43ommonComputeConfig\x12\x0f\n\x07project\x18\x01 \x01(\t\x12\x0e\n\x06region\x18\x02 \x01(\t\x12\x1a\n\x12temp_assets_bucket\x18\x03 \x01(\t\x12#\n\x1btemp_regional_assets_bucket\x18\x04 \x01(\t\x12\x1a\n\x12perm_assets_bucket\x18\x05 \x01(\t\x12#\n\x1btemp_assets_bq_dataset_name\x18\x06 \x01(\t\x12!\n\x19\x65mbedding_bq_dataset_name\x18\x07 \x01(\t\x12!\n\x19gcp_service_account_email\x18\x08 \x01(\t\x12\x17\n\x0f\x64\x61taflow_runner\x18\x0b \x01(\t\x1a\x35\n\x13ResourceLabelsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xc8\x05\n\x12GiglResourceConfig\x12$\n\x1ashared_resource_config_uri\x18\x01 \x01(\tH\x00\x12N\n\x16shared_resource_config\x18\x02 \x01(\x0b\x32,.snapchat.research.gbml.SharedResourceConfigH\x00\x12K\n\x13preprocessor_config\x18\x0c \x01(\x0b\x32..snapchat.research.gbml.DataPreprocessorConfig\x12L\n\x17subgraph_sampler_config\x18\r \x01(\x0b\x32+.snapchat.research.gbml.SparkResourceConfig\x12K\n\x16split_generator_config\x18\x0e \x01(\x0b\x32+.snapchat.research.gbml.SparkResourceConfig\x12L\n\x0etrainer_config\x18\x0f \x01(\x0b\x32\x30.snapchat.research.gbml.DistributedTrainerConfigB\x02\x18\x01\x12M\n\x11inferencer_config\x18\x10 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfigB\x02\x18\x01\x12N\n\x17trainer_resource_config\x18\x11 \x01(\x0b\x32-.snapchat.research.gbml.TrainerResourceConfig\x12T\n\x1ainferencer_resource_config\x18\x12 \x01(\x0b\x32\x30.snapchat.research.gbml.InferencerResourceConfigB\x11\n\x0fshared_resource*\xf3\x01\n\tComponent\x12\x15\n\x11\x43omponent_Unknown\x10\x00\x12\x1e\n\x1a\x43omponent_Config_Validator\x10\x01\x12\x1e\n\x1a\x43omponent_Config_Populator\x10\x02\x12\x1f\n\x1b\x43omponent_Data_Preprocessor\x10\x03\x12\x1e\n\x1a\x43omponent_Subgraph_Sampler\x10\x04\x12\x1d\n\x19\x43omponent_Split_Generator\x10\x05\x12\x15\n\x11\x43omponent_Trainer\x10\x06\x12\x18\n\x14\x43omponent_Inferencer\x10\x07\x62\x06proto3') _COMPONENT = DESCRIPTOR.enum_types_by_name['Component'] Component = enum_type_wrapper.EnumTypeWrapper(_COMPONENT) @@ -176,8 +176,8 @@ _GIGLRESOURCECONFIG.fields_by_name['trainer_config']._serialized_options = b'\030\001' _GIGLRESOURCECONFIG.fields_by_name['inferencer_config']._options = None _GIGLRESOURCECONFIG.fields_by_name['inferencer_config']._serialized_options = b'\030\001' - _COMPONENT._serialized_start=3597 - _COMPONENT._serialized_end=3840 + _COMPONENT._serialized_start=3683 + _COMPONENT._serialized_end=3926 _SPARKRESOURCECONFIG._serialized_start=77 _SPARKRESOURCECONFIG._serialized_end=166 _DATAFLOWRESOURCECONFIG._serialized_start=168 @@ -197,19 +197,19 @@ _LOCALRESOURCECONFIG._serialized_start=1056 _LOCALRESOURCECONFIG._serialized_end=1098 _VERTEXAIGRAPHSTORECONFIG._serialized_start=1101 - _VERTEXAIGRAPHSTORECONFIG._serialized_end=1271 - _DISTRIBUTEDTRAINERCONFIG._serialized_start=1274 - _DISTRIBUTEDTRAINERCONFIG._serialized_end=1549 - _TRAINERRESOURCECONFIG._serialized_start=1552 - _TRAINERRESOURCECONFIG._serialized_end=1925 - _INFERENCERRESOURCECONFIG._serialized_start=1928 - _INFERENCERRESOURCECONFIG._serialized_end=2329 - _SHAREDRESOURCECONFIG._serialized_start=2332 - _SHAREDRESOURCECONFIG._serialized_end=2879 - _SHAREDRESOURCECONFIG_COMMONCOMPUTECONFIG._serialized_start=2545 - _SHAREDRESOURCECONFIG_COMMONCOMPUTECONFIG._serialized_end=2824 - _SHAREDRESOURCECONFIG_RESOURCELABELSENTRY._serialized_start=2826 - _SHAREDRESOURCECONFIG_RESOURCELABELSENTRY._serialized_end=2879 - _GIGLRESOURCECONFIG._serialized_start=2882 - _GIGLRESOURCECONFIG._serialized_end=3594 + _VERTEXAIGRAPHSTORECONFIG._serialized_end=1357 + _DISTRIBUTEDTRAINERCONFIG._serialized_start=1360 + _DISTRIBUTEDTRAINERCONFIG._serialized_end=1635 + _TRAINERRESOURCECONFIG._serialized_start=1638 + _TRAINERRESOURCECONFIG._serialized_end=2011 + _INFERENCERRESOURCECONFIG._serialized_start=2014 + _INFERENCERRESOURCECONFIG._serialized_end=2415 + _SHAREDRESOURCECONFIG._serialized_start=2418 + _SHAREDRESOURCECONFIG._serialized_end=2965 + _SHAREDRESOURCECONFIG_COMMONCOMPUTECONFIG._serialized_start=2631 + _SHAREDRESOURCECONFIG_COMMONCOMPUTECONFIG._serialized_end=2910 + _SHAREDRESOURCECONFIG_RESOURCELABELSENTRY._serialized_start=2912 + _SHAREDRESOURCECONFIG_RESOURCELABELSENTRY._serialized_end=2965 + _GIGLRESOURCECONFIG._serialized_start=2968 + _GIGLRESOURCECONFIG._serialized_end=3680 # @@protoc_insertion_point(module_scope) diff --git a/python/snapchat/research/gbml/gigl_resource_config_pb2.pyi b/python/snapchat/research/gbml/gigl_resource_config_pb2.pyi index 426e3db8d..fa3db2df4 100644 --- a/python/snapchat/research/gbml/gigl_resource_config_pb2.pyi +++ b/python/snapchat/research/gbml/gigl_resource_config_pb2.pyi @@ -321,18 +321,24 @@ class VertexAiGraphStoreConfig(google.protobuf.message.Message): GRAPH_STORE_POOL_FIELD_NUMBER: builtins.int COMPUTE_POOL_FIELD_NUMBER: builtins.int + NUM_PROCESSES_PER_STORAGE_MACHINE_FIELD_NUMBER: builtins.int + NUM_PROCESSES_PER_COMPUTE_MACHINE_FIELD_NUMBER: builtins.int @property def graph_store_pool(self) -> global___VertexAiResourceConfig: ... @property def compute_pool(self) -> global___VertexAiResourceConfig: ... + num_processes_per_storage_machine: builtins.int + num_processes_per_compute_machine: builtins.int def __init__( self, *, graph_store_pool: global___VertexAiResourceConfig | None = ..., compute_pool: global___VertexAiResourceConfig | None = ..., + num_processes_per_storage_machine: builtins.int = ..., + num_processes_per_compute_machine: builtins.int = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["compute_pool", b"compute_pool", "graph_store_pool", b"graph_store_pool"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["compute_pool", b"compute_pool", "graph_store_pool", b"graph_store_pool"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["compute_pool", b"compute_pool", "graph_store_pool", b"graph_store_pool", "num_processes_per_compute_machine", b"num_processes_per_compute_machine", "num_processes_per_storage_machine", b"num_processes_per_storage_machine"]) -> None: ... global___VertexAiGraphStoreConfig = VertexAiGraphStoreConfig diff --git a/python/tests/integration/distributed/server_client/__init__.py b/python/tests/integration/distributed/server_client/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/tests/integration/distributed/server_client/utils_test.py b/python/tests/integration/distributed/server_client/utils_test.py new file mode 100644 index 000000000..762c3c5d7 --- /dev/null +++ b/python/tests/integration/distributed/server_client/utils_test.py @@ -0,0 +1,248 @@ +import os + +# Suppress TensorFlow logs +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # isort: skip +import collections +import unittest +from unittest import mock + +import torch +import torch.multiprocessing as mp +from graphlearn_torch.distributed import init_client, shutdown_client + +from gigl.common import Uri, UriFactory +from gigl.common.logger import Logger +from gigl.distributed.server_client.server_main import run_servers +from gigl.distributed.server_client.utils import get_sampler_input_for_inference +from gigl.distributed.utils import get_free_port +from gigl.distributed.utils.neighborloader import shard_nodes_by_process +from gigl.env.distributed import ( + GRAPH_STORE_PROCESSES_PER_COMPUTE_VAR_NAME, + GRAPH_STORE_PROCESSES_PER_STORAGE_VAR_NAME, + GraphStoreInfo, +) +from gigl.src.common.types.graph_data import NodeType +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.src.mocking.lib.versioning import get_mocked_dataset_artifact_metadata +from gigl.src.mocking.mocking_assets.mocked_datasets_for_pipeline_tests import ( + CORA_USER_DEFINED_NODE_ANCHOR_MOCKED_DATASET_INFO, +) +from tests.test_assets.distributed.utils import assert_tensor_equality + +logger = Logger() + + +def _run_client_process( + client_rank: int, + cluster_info: GraphStoreInfo, + node_type: NodeType, + expected_sampler_input: dict[int, list[torch.Tensor]], +) -> None: + client_global_rank = ( + int(os.environ["RANK"]) * cluster_info.num_processes_per_compute + client_rank + ) + logger.info( + f"Initializing client {client_global_rank} / {cluster_info.compute_world_size}. on {cluster_info.cluster_master_ip}:{cluster_info.cluster_master_port}. OS rank: {os.environ['RANK']}, local client rank: {client_rank} on port: {cluster_info.cluster_master_port}" + ) + torch.distributed.init_process_group( + backend="gloo", + world_size=cluster_info.compute_world_size, + rank=client_global_rank, + init_method=f"tcp://{cluster_info.compute_cluster_master_ip}:{cluster_info.compute_cluster_master_port}", + group_name="gigl_client_comms", + ) + init_client( + num_servers=cluster_info.storage_world_size, + num_clients=cluster_info.compute_world_size, + client_rank=client_global_rank, + master_addr=cluster_info.cluster_master_ip, + master_port=cluster_info.cluster_master_port, + client_group_name="gigl_client_rpc", + ) + + sampler_input = get_sampler_input_for_inference( + client_global_rank, + cluster_info, + node_type, + ) + + rank_expected_sampler_input = expected_sampler_input[client_global_rank] + for i in range(cluster_info.compute_world_size): + if i == client_global_rank: + assert len(sampler_input) == len(rank_expected_sampler_input) + for j, expected in enumerate(rank_expected_sampler_input): + assert_tensor_equality(sampler_input[j], expected) + logger.info( + f"{client_global_rank} / {cluster_info.compute_world_size} Sampler input verified" + ) + torch.distributed.barrier() + + torch.distributed.barrier() + logger.info( + f"{client_global_rank} / {cluster_info.compute_world_size} Shutting down client" + ) + shutdown_client() + + +def _client_process( + client_rank: int, + cluster_info: GraphStoreInfo, + node_type: NodeType, + expected_sampler_input: dict[int, list[torch.Tensor]], +) -> None: + logger.info( + f"Initializing client {client_rank} / {cluster_info.compute_world_size}. OS rank: {os.environ['RANK']}, OS world size: {os.environ['WORLD_SIZE']}, local client rank: {client_rank}" + ) + # torch.distributed.init_process_group() + logger.info(f"Client {client_rank} / {cluster_info.compute_world_size} initialized") + + # cluster_info = get_graph_store_info() + mp_context = torch.multiprocessing.get_context("spawn") + client_processes = [] + for i in range(cluster_info.num_processes_per_compute): + client_process = mp_context.Process( + target=_run_client_process, + args=[ + i, # client_rank + cluster_info, # cluster_info + node_type, # node_type + expected_sampler_input, # expected_sampler_input + ], + ) + client_processes.append(client_process) + for client_process in client_processes: + client_process.start() + for client_process in client_processes: + client_process.join() + + +def _run_server_processes( + server_rank: int, + cluster_info: GraphStoreInfo, + task_config_uri: Uri, + is_inference: bool, +) -> None: + logger.info( + f"Initializing server processes. OS rank: {os.environ['RANK']}, OS world size: {os.environ['WORLD_SIZE']}" + ) + run_servers( + server_rank=int(os.environ["RANK"]) - cluster_info.num_compute_nodes, + cluster_info=cluster_info, + task_config_uri=task_config_uri, + is_inference=is_inference, + ) + + +class TestUtils(unittest.TestCase): + def test_get_sampler_input_for_inference(self): + # Simulating two server machine, two compute machines. + # Each machine has one process. + cora_supervised_info = get_mocked_dataset_artifact_metadata()[ + CORA_USER_DEFINED_NODE_ANCHOR_MOCKED_DATASET_INFO.name + ] + task_config_uri = cora_supervised_info.frozen_gbml_config_uri + task_config = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( + gbml_config_uri=task_config_uri + ) + cluster_info = GraphStoreInfo( + num_cluster_nodes=4, + num_storage_nodes=2, + num_compute_nodes=2, + num_processes_per_storage=1, + num_processes_per_compute=2, + cluster_master_ip="localhost", + storage_cluster_master_ip="localhost", + compute_cluster_master_ip="localhost", + cluster_master_port=get_free_port(), + storage_cluster_master_port=get_free_port(), + compute_cluster_master_port=get_free_port(), + ) + + expected_sampler_input = collections.defaultdict(list) + num_cora_nodes = 2708 + all_nodes = torch.arange(num_cora_nodes, dtype=torch.int64) + all_nodes_generated_nodes = [] + for server_rank in range(cluster_info.storage_world_size): + server_node_start = ( + server_rank * num_cora_nodes // cluster_info.storage_world_size + ) + server_node_end = ( + (server_rank + 1) * num_cora_nodes // cluster_info.storage_world_size + ) + server_nodes = all_nodes[server_node_start:server_node_end] + logger.info( + f"Server rank {server_rank} nodes: {server_node_start}-{server_node_end}" + ) + for compute_rank in range(cluster_info.compute_world_size): + generated_nodes = shard_nodes_by_process( + server_nodes, compute_rank, cluster_info.compute_world_size + ) + all_nodes_generated_nodes.append(generated_nodes) + expected_sampler_input[compute_rank].append(generated_nodes) + + master_port = get_free_port() + ctx = mp.get_context("spawn") + client_processes: list = [] + for i in range(cluster_info.num_compute_nodes): + with mock.patch.dict( + os.environ, + { + "MASTER_ADDR": "localhost", + "MASTER_PORT": str(master_port), + "RANK": str(i), + "WORLD_SIZE": str(cluster_info.cluster_world_size), + GRAPH_STORE_PROCESSES_PER_STORAGE_VAR_NAME: str( + cluster_info.num_processes_per_storage + ), + GRAPH_STORE_PROCESSES_PER_COMPUTE_VAR_NAME: str( + cluster_info.num_processes_per_compute + ), + }, + clear=False, + ): + client_process = ctx.Process( + target=_client_process, + args=[ + i, # client_rank + cluster_info, # cluster_info + task_config.graph_metadata_pb_wrapper.homogeneous_node_type, # node_type + expected_sampler_input, # expected_sampler_input + ], + ) + client_process.start() + client_processes.append(client_process) + # Start server process + server_processes = [] + for i in range(cluster_info.num_storage_nodes): + with mock.patch.dict( + os.environ, + { + "MASTER_ADDR": "localhost", + "MASTER_PORT": str(master_port), + "RANK": str(i + cluster_info.num_compute_nodes), + "WORLD_SIZE": str(cluster_info.cluster_world_size), + GRAPH_STORE_PROCESSES_PER_STORAGE_VAR_NAME: str( + cluster_info.num_processes_per_storage + ), + GRAPH_STORE_PROCESSES_PER_COMPUTE_VAR_NAME: str( + cluster_info.num_processes_per_compute + ), + }, + clear=False, + ): + server_process = ctx.Process( + target=_run_server_processes, + args=[ + i, # server_rank + cluster_info, # cluster_info + UriFactory.create_uri(task_config_uri), # task_config_uri + True, # is_inference + ], + ) + server_process.start() + server_processes.append(server_process) + + for client_process in client_processes: + client_process.join() + for server_process in server_processes: + server_process.join() diff --git a/python/tests/test_assets/distributed/run_distributed_dataset.py b/python/tests/test_assets/distributed/run_distributed_dataset.py index 9006ebce5..734a973d1 100644 --- a/python/tests/test_assets/distributed/run_distributed_dataset.py +++ b/python/tests/test_assets/distributed/run_distributed_dataset.py @@ -1,12 +1,15 @@ -from typing import MutableMapping, Optional, Type, Union +import copy +from typing import Literal, MutableMapping, Optional, Type, Union import torch.distributed as dist +from gigl.common import Uri from gigl.common.data.load_torch_tensors import SerializedGraphMetadata from gigl.common.utils.vertex_ai_context import DistributedContext from gigl.distributed.dataset_factory import build_dataset from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.dist_partitioner import DistPartitioner +from gigl.distributed.dist_range_partitioner import DistRangePartitioner from gigl.distributed.utils.serialized_graph_metadata_translator import ( convert_pb_to_serialized_graph_metadata, ) @@ -107,3 +110,39 @@ def run_distributed_dataset( if output_dict is not None: output_dict[rank] = dataset return dataset + + +_DATASET_CACHE: dict[Uri, tuple] = {} + + +def build_dataset_for_testing( + task_config_uri: Uri, + edge_dir: Literal["in", "out"] = "out", + tfrecord_uri_pattern: str = ".*.tfrecord(.gz)?$", + partitioner_class: Type[DistPartitioner] = DistRangePartitioner, + splitter: Optional[Union[NodeAnchorLinkSplitter, NodeSplitter]] = None, + should_load_tensors_in_parallel: bool = True, + ssl_positive_label_percentage: Optional[float] = None, +) -> DistDataset: + if task_config_uri in _DATASET_CACHE: + ipc_handle = copy.deepcopy(_DATASET_CACHE[task_config_uri]) + return DistDataset.from_ipc_handle(ipc_handle) + gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( + gbml_config_uri=task_config_uri + ) + + serialized_graph_metadata = convert_pb_to_serialized_graph_metadata( + preprocessed_metadata_pb_wrapper=gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper, + graph_metadata_pb_wrapper=gbml_config_pb_wrapper.graph_metadata_pb_wrapper, + tfrecord_uri_pattern=tfrecord_uri_pattern, + ) + dataset = build_dataset( + serialized_graph_metadata=serialized_graph_metadata, + sample_edge_direction=edge_dir, + should_load_tensors_in_parallel=should_load_tensors_in_parallel, + partitioner_class=partitioner_class, + splitter=splitter, + _ssl_positive_label_percentage=ssl_positive_label_percentage, + ) + _DATASET_CACHE[task_config_uri] = dataset.share_ipc() + return dataset diff --git a/python/tests/test_assets/distributed/utils.py b/python/tests/test_assets/distributed/utils.py index 315b7241c..15a6f3a95 100644 --- a/python/tests/test_assets/distributed/utils.py +++ b/python/tests/test_assets/distributed/utils.py @@ -54,3 +54,16 @@ def get_process_group_init_method( str: The initialization method for the process group. """ return f"tcp://{host}:{port_picker()}" + + +def create_test_process_group() -> None: + """ + Creates a single node process group for testing. + Uses the "gloo" backend. + """ + torch.distributed.init_process_group( + backend="gloo", + rank=0, + world_size=1, + init_method=get_process_group_init_method(), + ) diff --git a/python/tests/unit/distributed/dist_ablp_neighborloader_test.py b/python/tests/unit/distributed/dist_ablp_neighborloader_test.py index 822ab6554..5490367db 100644 --- a/python/tests/unit/distributed/dist_ablp_neighborloader_test.py +++ b/python/tests/unit/distributed/dist_ablp_neighborloader_test.py @@ -11,7 +11,6 @@ from gigl.distributed.dataset_factory import build_dataset from gigl.distributed.dist_ablp_neighborloader import DistABLPLoader -from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.dist_partitioner import DistPartitioner from gigl.distributed.dist_range_partitioner import DistRangePartitioner @@ -36,9 +35,12 @@ to_homogeneous, ) from gigl.utils.data_splitters import HashedNodeAnchorLinkSplitter +from tests.test_assets.distributed.run_distributed_dataset import ( + build_dataset_for_testing, +) from tests.test_assets.distributed.utils import ( assert_tensor_equality, - get_process_group_init_method, + create_test_process_group, ) _POSITIVE_EDGE_TYPE = message_passing_to_positive_label(DEFAULT_HOMOGENEOUS_EDGE_TYPE) @@ -60,7 +62,6 @@ _B_TO_A = EdgeType(_B, _TO, _A) _C_TO_A = EdgeType(_C, _TO, _A) -# TODO(svij) - swap the DistNeighborLoader tests to not user context/local_process_rank/local_process_world_size. # GLT requires subclasses of DistNeighborLoader to be run in a separate process. Otherwise, we may run into segmentation fault # or other memory issues. Calling these functions in separate proceses also allows us to use shutdown_rpc() to ensure cleanup of @@ -126,9 +127,7 @@ def _run_distributed_ablp_neighbor_loader( input_nodes = torch.tensor([10, 15]) batch_size = 2 - torch.distributed.init_process_group( - rank=0, world_size=1, init_method=get_process_group_init_method() - ) + create_test_process_group() loader = DistABLPLoader( dataset=dataset, num_neighbors=[2, 2], @@ -176,9 +175,7 @@ def _run_cora_supervised( dataset: DistDataset, expected_data_count: int, ): - torch.distributed.init_process_group( - rank=0, world_size=1, init_method=get_process_group_init_method() - ) + create_test_process_group() loader = DistABLPLoader( dataset=dataset, num_neighbors=[2, 2], @@ -202,7 +199,6 @@ def _run_cora_supervised( def _run_dblp_supervised( _, dataset: DistDataset, - context: DistributedContext, supervision_edge_types: list[EdgeType], ): assert ( @@ -215,13 +211,11 @@ def _run_dblp_supervised( assert isinstance(dataset.graph, dict) fanout = [2, 2] num_neighbors = {edge_type: fanout for edge_type in dataset.graph.keys()} + create_test_process_group() loader = DistABLPLoader( dataset=dataset, num_neighbors=num_neighbors, input_nodes=(anchor_node_type, dataset.train_node_ids[anchor_node_type]), - context=context, - local_process_rank=0, - local_process_world_size=1, supervision_edge_type=supervision_edge_type, pin_memory_device=torch.device("cpu"), ) @@ -245,7 +239,6 @@ def _run_dblp_supervised( def _run_toy_heterogeneous_ablp( _, dataset: DistDataset, - context: DistributedContext, supervision_edge_types: list[EdgeType], fanout: Union[list[int], dict[EdgeType, list[int]]], ): @@ -263,13 +256,11 @@ def _run_toy_heterogeneous_ablp( all_positive_supervision_nodes, all_anchor_nodes, _, _ = dataset.graph[ labeled_edge_type ].topo.to_coo() + create_test_process_group() loader = DistABLPLoader( dataset=dataset, num_neighbors=fanout, input_nodes=(anchor_node_type, dataset.train_node_ids[anchor_node_type]), - context=context, - local_process_rank=0, - local_process_world_size=1, supervision_edge_type=supervision_edge_type, # We set the batch size to the number of "user" nodes in the heterogeneous toy graph to guarantee that the dataloader completes an epoch in 1 batch batch_size=15, @@ -330,9 +321,7 @@ def _run_distributed_ablp_neighbor_loader_multiple_supervision_edge_types( ): batch_size = 1 - torch.distributed.init_process_group( - rank=0, world_size=1, init_method=get_process_group_init_method() - ) + create_test_process_group() loader = DistABLPLoader( dataset=dataset, num_neighbors=[2, 2], @@ -428,17 +417,6 @@ def _run_distributed_ablp_neighbor_loader_multiple_supervision_edge_types( class DistABLPLoaderTest(unittest.TestCase): - def setUp(self): - self._master_ip_address = "localhost" - self._world_size = 1 - self._num_rpc_threads = 4 - - self._context = DistributedContext( - main_worker_ip_address=self._master_ip_address, - global_rank=0, - global_world_size=self._world_size, - ) - def tearDown(self): if torch.distributed.is_initialized(): print("Destroying process group") @@ -550,31 +528,37 @@ def test_ablp_dataloader( ) def test_cora_supervised(self): + create_test_process_group() cora_supervised_info = get_mocked_dataset_artifact_metadata()[ CORA_USER_DEFINED_NODE_ANCHOR_MOCKED_DATASET_INFO.name ] - gbml_config_pb_wrapper = ( - GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( - gbml_config_uri=cora_supervised_info.frozen_gbml_config_uri - ) - ) + # gbml_config_pb_wrapper = ( + # GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( + # gbml_config_uri=cora_supervised_info.frozen_gbml_config_uri + # ) + # ) - serialized_graph_metadata = convert_pb_to_serialized_graph_metadata( - preprocessed_metadata_pb_wrapper=gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper, - graph_metadata_pb_wrapper=gbml_config_pb_wrapper.graph_metadata_pb_wrapper, - tfrecord_uri_pattern=".*.tfrecord(.gz)?$", - ) + # serialized_graph_metadata = convert_pb_to_serialized_graph_metadata( + # preprocessed_metadata_pb_wrapper=gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper, + # graph_metadata_pb_wrapper=gbml_config_pb_wrapper.graph_metadata_pb_wrapper, + # tfrecord_uri_pattern=".*.tfrecord(.gz)?$", + # ) splitter = HashedNodeAnchorLinkSplitter( sampling_direction="in", should_convert_labels_to_edges=True ) - dataset = build_dataset( - serialized_graph_metadata=serialized_graph_metadata, - distributed_context=self._context, - sample_edge_direction="in", + # dataset = build_dataset( + # serialized_graph_metadata=serialized_graph_metadata, + # sample_edge_direction="in", + # splitter=splitter, + # ) + dataset = build_dataset_for_testing( + task_config_uri=cora_supervised_info.frozen_gbml_config_uri, + edge_dir="in", splitter=splitter, + tfrecord_uri_pattern=".*.tfrecord(.gz)?$", ) assert dataset.train_node_ids is not None, "Train node ids must exist." @@ -592,6 +576,7 @@ def test_cora_supervised(self): # TODO: (mkolodner-sc) - Figure out why this test is failing on Google Cloud Build @unittest.skip("Failing on Google Cloud Build - skiping for now") def test_dblp_supervised(self): + create_test_process_group() dblp_supervised_info = get_mocked_dataset_artifact_metadata()[ DBLP_GRAPH_NODE_ANCHOR_MOCKED_DATASET_INFO.name ] @@ -620,7 +605,6 @@ def test_dblp_supervised(self): dataset = build_dataset( serialized_graph_metadata=serialized_graph_metadata, - distributed_context=self._context, sample_edge_direction="in", _ssl_positive_label_percentage=0.1, splitter=splitter, @@ -628,7 +612,7 @@ def test_dblp_supervised(self): mp.spawn( fn=_run_dblp_supervised, - args=(dataset, self._context, supervision_edge_types), + args=(dataset, supervision_edge_types), ) @parameterized.expand( @@ -665,6 +649,7 @@ def test_toy_heterogeneous_ablp( partitioner_class: type[DistPartitioner], fanout: Union[list[int], dict[EdgeType, list[int]]], ): + create_test_process_group() toy_heterogeneous_supervised_info = get_mocked_dataset_artifact_metadata()[ HETEROGENEOUS_TOY_GRAPH_NODE_ANCHOR_MOCKED_DATASET_INFO.name ] @@ -675,11 +660,11 @@ def test_toy_heterogeneous_ablp( ) ) - serialized_graph_metadata = convert_pb_to_serialized_graph_metadata( - preprocessed_metadata_pb_wrapper=gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper, - graph_metadata_pb_wrapper=gbml_config_pb_wrapper.graph_metadata_pb_wrapper, - tfrecord_uri_pattern=".*.tfrecord(.gz)?$", - ) + # serialized_graph_metadata = convert_pb_to_serialized_graph_metadata( + # preprocessed_metadata_pb_wrapper=gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper, + # graph_metadata_pb_wrapper=gbml_config_pb_wrapper.graph_metadata_pb_wrapper, + # tfrecord_uri_pattern=".*.tfrecord(.gz)?$", + # ) supervision_edge_types = ( gbml_config_pb_wrapper.task_metadata_pb_wrapper.get_supervision_edge_types() @@ -691,18 +676,25 @@ def test_toy_heterogeneous_ablp( should_convert_labels_to_edges=True, ) - dataset = build_dataset( - serialized_graph_metadata=serialized_graph_metadata, - distributed_context=self._context, - sample_edge_direction="in", - _ssl_positive_label_percentage=0.1, + # dataset = build_dataset( + # serialized_graph_metadata=serialized_graph_metadata, + # sample_edge_direction="in", + # _ssl_positive_label_percentage=0.1, + # splitter=splitter, + # partitioner_class=partitioner_class, + # ) + + dataset = build_dataset_for_testing( + task_config_uri=toy_heterogeneous_supervised_info.frozen_gbml_config_uri, + edge_dir="in", splitter=splitter, partitioner_class=partitioner_class, + ssl_positive_label_percentage=0.1, ) mp.spawn( fn=_run_toy_heterogeneous_ablp, - args=(dataset, self._context, supervision_edge_types, fanout), + args=(dataset, supervision_edge_types, fanout), ) @parameterized.expand( @@ -1007,9 +999,7 @@ def test_ablp_dataloader_invalid_inputs( expected_error_message: str, **kwargs, ): - torch.distributed.init_process_group( - rank=0, world_size=1, init_method=get_process_group_init_method() - ) + create_test_process_group() with self.assertRaises(expected_error, msg=expected_error_message): DistABLPLoader(**kwargs) diff --git a/python/tests/unit/distributed/distributed_neighborloader_test.py b/python/tests/unit/distributed/distributed_neighborloader_test.py index 8fa907d32..7896b7b39 100644 --- a/python/tests/unit/distributed/distributed_neighborloader_test.py +++ b/python/tests/unit/distributed/distributed_neighborloader_test.py @@ -6,16 +6,9 @@ from graphlearn_torch.distributed import shutdown_rpc from torch_geometric.data import Data, HeteroData -import gigl.distributed.utils -from gigl.distributed.dataset_factory import build_dataset -from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.distributed_neighborloader import DistNeighborLoader -from gigl.distributed.utils.serialized_graph_metadata_translator import ( - convert_pb_to_serialized_graph_metadata, -) from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation -from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.mocking.lib.versioning import get_mocked_dataset_artifact_metadata from gigl.src.mocking.mocking_assets.mocked_datasets_for_pipeline_tests import ( CORA_NODE_ANCHOR_MOCKED_DATASET_INFO, @@ -35,11 +28,11 @@ from gigl.utils.data_splitters import HashedNodeAnchorLinkSplitter, HashedNodeSplitter from gigl.utils.iterator import InfiniteIterator from tests.test_assets.distributed.run_distributed_dataset import ( - run_distributed_dataset, + build_dataset_for_testing, ) from tests.test_assets.distributed.utils import ( assert_tensor_equality, - get_process_group_init_method, + create_test_process_group, ) _POSITIVE_EDGE_TYPE = message_passing_to_positive_label(DEFAULT_HOMOGENEOUS_EDGE_TYPE) @@ -50,8 +43,6 @@ _USER_TO_STORY = EdgeType(_USER, Relation("to"), _STORY) _STORY_TO_USER = EdgeType(_STORY, Relation("to"), _USER) -# TODO(svij) - swap the DistNeighborLoader tests to not user context/local_process_rank/local_process_world_size. - # GLT requires subclasses of DistNeighborLoader to be run in a separate process. Otherwise, we may run into segmentation fault # or other memory issues. Calling these functions in separate proceses also allows us to use shutdown_rpc() to ensure cleanup of # ports, providing stronger guarantees of isolation between tests. @@ -61,15 +52,12 @@ def _run_distributed_neighbor_loader( _, dataset: DistDataset, - context: DistributedContext, expected_data_count: int, ): + create_test_process_group() loader = DistNeighborLoader( dataset=dataset, num_neighbors=[2, 2], - context=context, - local_process_rank=0, - local_process_world_size=1, pin_memory_device=torch.device("cpu"), ) @@ -88,17 +76,14 @@ def _run_distributed_neighbor_loader( def _run_distributed_neighbor_loader_labeled_homogeneous( _, dataset: DistDataset, - context: DistributedContext, expected_data_count: int, ): + create_test_process_group() assert isinstance(dataset.node_ids, Mapping) loader = DistNeighborLoader( dataset=dataset, input_nodes=to_homogeneous(dataset.node_ids), num_neighbors=[2, 2], - context=context, - local_process_rank=0, - local_process_world_size=1, pin_memory_device=torch.device("cpu"), ) @@ -117,15 +102,12 @@ def _run_distributed_neighbor_loader_labeled_homogeneous( def _run_infinite_distributed_neighbor_loader( _, dataset: DistDataset, - context: DistributedContext, max_num_batches: int, ): + create_test_process_group() loader = DistNeighborLoader( dataset=dataset, num_neighbors=[2, 2], - context=context, - local_process_rank=0, - local_process_world_size=1, pin_memory_device=torch.device("cpu"), ) @@ -147,17 +129,14 @@ def _run_infinite_distributed_neighbor_loader( def _run_distributed_heterogeneous_neighbor_loader( _, dataset: DistDataset, - context: DistributedContext, expected_data_count: int, ): + create_test_process_group() assert isinstance(dataset.node_ids, Mapping) loader = DistNeighborLoader( dataset=dataset, input_nodes=(NodeType("author"), dataset.node_ids[NodeType("author")]), num_neighbors=[2, 2], - context=context, - local_process_rank=0, - local_process_world_size=1, pin_memory_device=torch.device("cpu"), ) @@ -174,12 +153,9 @@ def _run_distributed_heterogeneous_neighbor_loader( def _run_multiple_neighbor_loader( _, dataset: DistDataset, - context: DistributedContext, expected_data_count: int, ): - torch.distributed.init_process_group( - rank=0, world_size=1, init_method=get_process_group_init_method() - ) + create_test_process_group() loader_one = DistNeighborLoader( dataset=dataset, num_neighbors=[2, 2], @@ -223,9 +199,7 @@ def _run_distributed_neighbor_loader_with_node_labels_homogeneous( dataset: DistDataset, batch_size: int, ): - torch.distributed.init_process_group( - rank=0, world_size=1, init_method=get_process_group_init_method() - ) + create_test_process_group() loader = DistNeighborLoader( dataset=dataset, @@ -250,9 +224,7 @@ def _run_distributed_neighbor_loader_with_node_labels_heterogeneous( dataset: DistDataset, batch_size: int, ): - torch.distributed.init_process_group( - rank=0, world_size=1, init_method=get_process_group_init_method() - ) + create_test_process_group() assert isinstance(dataset.node_ids, Mapping) @@ -299,9 +271,7 @@ def _run_cora_supervised_node_classification( batch_size: int, ): """Run CORA supervised node classification test using DistNeighborLoader.""" - torch.distributed.init_process_group( - rank=0, world_size=1, init_method=get_process_group_init_method() - ) + create_test_process_group() loader = DistNeighborLoader( dataset=dataset, @@ -326,17 +296,6 @@ def _run_cora_supervised_node_classification( class DistributedNeighborLoaderTest(unittest.TestCase): - def setUp(self): - self._master_ip_address = "localhost" - self._world_size = 1 - self._num_rpc_threads = 4 - - self._context = DistributedContext( - main_worker_ip_address=self._master_ip_address, - global_rank=0, - global_world_size=self._world_size, - ) - def tearDown(self): if torch.distributed.is_initialized(): print("Destroying process group") @@ -346,30 +305,32 @@ def tearDown(self): super().tearDown() def test_distributed_neighbor_loader(self): + create_test_process_group() expected_data_count = 2708 - port = gigl.distributed.utils.get_free_port() - - dataset = run_distributed_dataset( - rank=0, - world_size=self._world_size, - mocked_dataset_info=CORA_NODE_ANCHOR_MOCKED_DATASET_INFO, - should_load_tensors_in_parallel=True, - _port=port, + + cora_supervised_info = get_mocked_dataset_artifact_metadata()[ + CORA_NODE_ANCHOR_MOCKED_DATASET_INFO.name + ] + dataset = build_dataset_for_testing( + task_config_uri=cora_supervised_info.frozen_gbml_config_uri, + edge_dir="in", + tfrecord_uri_pattern=".*.tfrecord(.gz)?$", ) mp.spawn( fn=_run_distributed_neighbor_loader, - args=(dataset, self._context, expected_data_count), + args=(dataset, expected_data_count), ) def test_infinite_distributed_neighbor_loader(self): - port = gigl.distributed.utils.get_free_port() - dataset = run_distributed_dataset( - rank=0, - world_size=self._world_size, - mocked_dataset_info=CORA_NODE_ANCHOR_MOCKED_DATASET_INFO, - should_load_tensors_in_parallel=True, - _port=port, + create_test_process_group() + cora_supervised_info = get_mocked_dataset_artifact_metadata()[ + CORA_NODE_ANCHOR_MOCKED_DATASET_INFO.name + ] + dataset = build_dataset_for_testing( + task_config_uri=cora_supervised_info.frozen_gbml_config_uri, + edge_dir="in", + tfrecord_uri_pattern=".*.tfrecord(.gz)?$", ) assert isinstance(dataset.node_ids, torch.Tensor) @@ -381,74 +342,67 @@ def test_infinite_distributed_neighbor_loader(self): mp.spawn( fn=_run_infinite_distributed_neighbor_loader, - args=(dataset, self._context, max_num_batches), + args=(dataset, max_num_batches), ) # TODO: (svij) - Figure out why this test is failing on Google Cloud Build @unittest.skip("Failing on Google Cloud Build - skiping for now") def test_distributed_neighbor_loader_heterogeneous(self): + create_test_process_group() expected_data_count = 4057 - dataset = run_distributed_dataset( - rank=0, - world_size=self._world_size, - mocked_dataset_info=DBLP_GRAPH_NODE_ANCHOR_MOCKED_DATASET_INFO, - should_load_tensors_in_parallel=True, + dblp_supervised_info = get_mocked_dataset_artifact_metadata()[ + DBLP_GRAPH_NODE_ANCHOR_MOCKED_DATASET_INFO.name + ] + dataset = build_dataset_for_testing( + task_config_uri=dblp_supervised_info.frozen_gbml_config_uri, + edge_dir="in", + tfrecord_uri_pattern=".*.tfrecord(.gz)?$", ) mp.spawn( fn=_run_distributed_heterogeneous_neighbor_loader, - args=(dataset, self._context, expected_data_count), + args=(dataset, expected_data_count), ) def test_random_loading_labeled_homogeneous(self): + create_test_process_group() cora_supervised_info = get_mocked_dataset_artifact_metadata()[ CORA_USER_DEFINED_NODE_ANCHOR_MOCKED_DATASET_INFO.name ] - gbml_config_pb_wrapper = ( - GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( - gbml_config_uri=cora_supervised_info.frozen_gbml_config_uri - ) - ) - - serialized_graph_metadata = convert_pb_to_serialized_graph_metadata( - preprocessed_metadata_pb_wrapper=gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper, - graph_metadata_pb_wrapper=gbml_config_pb_wrapper.graph_metadata_pb_wrapper, - tfrecord_uri_pattern=".*.tfrecord(.gz)?$", - ) - splitter = HashedNodeAnchorLinkSplitter( sampling_direction="in", should_convert_labels_to_edges=True ) - dataset = build_dataset( - serialized_graph_metadata=serialized_graph_metadata, - distributed_context=self._context, - sample_edge_direction="in", + dataset = build_dataset_for_testing( + task_config_uri=cora_supervised_info.frozen_gbml_config_uri, + edge_dir="in", splitter=splitter, + tfrecord_uri_pattern=".*.tfrecord(.gz)?$", ) assert isinstance(dataset.node_ids, Mapping) mp.spawn( fn=_run_distributed_neighbor_loader_labeled_homogeneous, - args=(dataset, self._context, to_homogeneous(dataset.node_ids).size(0)), + args=(dataset, to_homogeneous(dataset.node_ids).size(0)), ) def test_multiple_neighbor_loader(self): - port = gigl.distributed.utils.get_free_port() + create_test_process_group() expected_data_count = 2708 - dataset = run_distributed_dataset( - rank=0, - world_size=self._world_size, - mocked_dataset_info=CORA_NODE_ANCHOR_MOCKED_DATASET_INFO, - should_load_tensors_in_parallel=True, - _port=port, + cora_supervised_info = get_mocked_dataset_artifact_metadata()[ + CORA_NODE_ANCHOR_MOCKED_DATASET_INFO.name + ] + dataset = build_dataset_for_testing( + task_config_uri=cora_supervised_info.frozen_gbml_config_uri, + edge_dir="in", + tfrecord_uri_pattern=".*.tfrecord(.gz)?$", ) mp.spawn( fn=_run_multiple_neighbor_loader, - args=(dataset, self._context, expected_data_count), + args=(dataset, expected_data_count), ) def test_distributed_neighbor_loader_with_node_labels_homogeneous(self): @@ -530,32 +484,18 @@ def test_distributed_neighbor_loader_with_node_labels_heterogeneous(self): def test_cora_supervised_node_classification(self): """Test CORA dataset for supervised node classification task.""" - - torch.distributed.init_process_group( - rank=0, world_size=1, init_method=get_process_group_init_method() - ) + create_test_process_group() cora_supervised_info = get_mocked_dataset_artifact_metadata()[ CORA_NODE_CLASSIFICATION_MOCKED_DATASET_INFO.name ] - gbml_config_pb_wrapper = ( - GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( - gbml_config_uri=cora_supervised_info.frozen_gbml_config_uri - ) - ) - - serialized_graph_metadata = convert_pb_to_serialized_graph_metadata( - preprocessed_metadata_pb_wrapper=gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper, - graph_metadata_pb_wrapper=gbml_config_pb_wrapper.graph_metadata_pb_wrapper, - tfrecord_uri_pattern=".*.tfrecord(.gz)?$", - ) - splitter = HashedNodeSplitter() - dataset = build_dataset( - serialized_graph_metadata=serialized_graph_metadata, - sample_edge_direction="in", + dataset = build_dataset_for_testing( + task_config_uri=cora_supervised_info.frozen_gbml_config_uri, + edge_dir="in", splitter=splitter, + tfrecord_uri_pattern=".*.tfrecord(.gz)?$", ) mp.spawn( @@ -590,7 +530,7 @@ def test_isolated_heterogeneous_neighbor_loader( mp.spawn( fn=_run_distributed_heterogeneous_neighbor_loader, - args=(dataset, self._context, 18), + args=(dataset, 18), ) def test_isolated_homogeneous_neighbor_loader( @@ -613,7 +553,7 @@ def test_isolated_homogeneous_neighbor_loader( mp.spawn( fn=_run_distributed_neighbor_loader, - args=(dataset, self._context, 18), + args=(dataset, 18), ) diff --git a/python/tests/unit/distributed/server_client/__init__.py b/python/tests/unit/distributed/server_client/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/tests/unit/distributed/server_client/remote_dataset_test.py b/python/tests/unit/distributed/server_client/remote_dataset_test.py new file mode 100644 index 000000000..28cdfa6bf --- /dev/null +++ b/python/tests/unit/distributed/server_client/remote_dataset_test.py @@ -0,0 +1,246 @@ +import unittest + +import torch + +from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.server_client import remote_dataset +from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation +from gigl.types.graph import ( + FeatureInfo, + FeaturePartitionData, + GraphPartitionData, + PartitionOutput, +) +from tests.test_assets.distributed.utils import assert_tensor_equality + +_USER = NodeType("user") +_STORY = NodeType("story") +_USER_TO_STORY = EdgeType(_USER, Relation("to"), _STORY) +_STORY_TO_USER = EdgeType(_STORY, Relation("to"), _USER) + + +class TestRemoteDataset(unittest.TestCase): + def setUp(self) -> None: + """Reset the global dataset before each test.""" + remote_dataset._dataset = None + + def tearDown(self) -> None: + """Clean up after each test.""" + remote_dataset._dataset = None + + def _create_heterogeneous_dataset(self) -> DistDataset: + """Helper method to create a heterogeneous test dataset.""" + partition_output = PartitionOutput( + node_partition_book={ + _USER: torch.zeros(5, dtype=torch.int64), + _STORY: torch.zeros(5, dtype=torch.int64), + }, + edge_partition_book={ + _USER_TO_STORY: torch.zeros(5, dtype=torch.int64), + _STORY_TO_USER: torch.zeros(5, dtype=torch.int64), + }, + partitioned_edge_index={ + _USER_TO_STORY: GraphPartitionData( + edge_index=torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]), + edge_ids=None, + ), + _STORY_TO_USER: GraphPartitionData( + edge_index=torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]), + edge_ids=None, + ), + }, + partitioned_node_features={ + _USER: FeaturePartitionData( + feats=torch.zeros(5, 2), ids=torch.arange(5) + ), + _STORY: FeaturePartitionData( + feats=torch.zeros(5, 2), ids=torch.arange(5) + ), + }, + partitioned_edge_features=None, + partitioned_positive_labels=None, + partitioned_negative_labels=None, + partitioned_node_labels={ + _USER: FeaturePartitionData( + feats=torch.arange(5).unsqueeze(1), ids=torch.arange(5) + ), + _STORY: FeaturePartitionData( + feats=torch.arange(5).unsqueeze(1), ids=torch.arange(5) + ), + }, + ) + dataset = DistDataset(rank=0, world_size=1, edge_dir="out") + dataset.build(partition_output=partition_output) + return dataset + + def _create_homogeneous_dataset(self) -> DistDataset: + """Helper method to create a homogeneous test dataset.""" + partition_output = PartitionOutput( + node_partition_book=torch.zeros(10, dtype=torch.int64), + edge_partition_book=torch.zeros(10, dtype=torch.int64), + partitioned_edge_index=GraphPartitionData( + edge_index=torch.tensor( + [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]] + ), + edge_ids=None, + ), + partitioned_node_features=FeaturePartitionData( + feats=torch.zeros(10, 3), ids=torch.arange(10) + ), + partitioned_edge_features=None, + partitioned_positive_labels=None, + partitioned_negative_labels=None, + partitioned_node_labels=None, + ) + dataset = DistDataset(rank=0, world_size=1, edge_dir="out") + dataset.build(partition_output=partition_output) + return dataset + + def test_register_dataset(self) -> None: + """Test that register_dataset correctly sets the global dataset.""" + dataset = self._create_heterogeneous_dataset() + remote_dataset.register_dataset(dataset) + + # Verify the dataset was registered + self.assertIsNotNone(remote_dataset._dataset) + self.assertEqual(remote_dataset._dataset, dataset) + + def test_reregister_dataset_raises_error(self) -> None: + """Test that reregistering a dataset raises an error.""" + dataset = self._create_heterogeneous_dataset() + remote_dataset.register_dataset(dataset) + with self.assertRaises(ValueError) as context: + remote_dataset.register_dataset(dataset) + self.assertIn("Dataset already registered!", str(context.exception)) + + def test_get_node_feature_info_with_heterogeneous_dataset(self) -> None: + """Test get_node_feature_info with a registered heterogeneous dataset.""" + dataset = self._create_heterogeneous_dataset() + remote_dataset.register_dataset(dataset) + + node_feature_info = remote_dataset.get_node_feature_info() + + # Verify it returns the correct feature info + self.assertIsNotNone(node_feature_info) + self.assertIsInstance(node_feature_info, dict) + self.assertIn(_USER, node_feature_info) + self.assertIn(_STORY, node_feature_info) + self.assertEqual( + node_feature_info[_USER], FeatureInfo(dim=2, dtype=torch.float32) + ) + self.assertEqual( + node_feature_info[_STORY], FeatureInfo(dim=2, dtype=torch.float32) + ) + + def test_get_node_feature_info_with_homogeneous_dataset(self) -> None: + """Test get_node_feature_info with a registered homogeneous dataset.""" + dataset = self._create_homogeneous_dataset() + remote_dataset.register_dataset(dataset) + + node_feature_info = remote_dataset.get_node_feature_info() + + # Verify it returns the correct feature info + self.assertIsNotNone(node_feature_info) + self.assertEqual(node_feature_info, FeatureInfo(dim=3, dtype=torch.float32)) + + def test_get_node_feature_info_without_registered_dataset(self) -> None: + """Test get_node_feature_info raises ValueError when no dataset is registered.""" + with self.assertRaises(ValueError) as context: + remote_dataset.get_node_feature_info() + + self.assertIn("Dataset not registered", str(context.exception)) + self.assertIn("register_dataset", str(context.exception)) + + def test_get_edge_feature_info_with_heterogeneous_dataset(self) -> None: + """Test get_edge_feature_info with a registered heterogeneous dataset.""" + dataset = self._create_heterogeneous_dataset() + remote_dataset.register_dataset(dataset) + + edge_feature_info = remote_dataset.get_edge_feature_info() + + # For this test dataset, edge features are None + self.assertIsNone(edge_feature_info) + + def test_get_edge_feature_info_with_homogeneous_dataset(self) -> None: + """Test get_edge_feature_info with a registered homogeneous dataset.""" + dataset = self._create_homogeneous_dataset() + remote_dataset.register_dataset(dataset) + + edge_feature_info = remote_dataset.get_edge_feature_info() + + # For this test dataset, edge features are None + self.assertIsNone(edge_feature_info) + + def test_get_edge_feature_info_without_registered_dataset(self) -> None: + """Test get_edge_feature_info raises ValueError when no dataset is registered.""" + with self.assertRaises(ValueError) as context: + remote_dataset.get_edge_feature_info() + + self.assertIn("Dataset not registered", str(context.exception)) + self.assertIn("register_dataset", str(context.exception)) + + def test_get_node_ids_for_rank_with_homogeneous_dataset(self) -> None: + """Test get_node_ids_for_rank with a homogeneous dataset.""" + dataset = self._create_homogeneous_dataset() + remote_dataset.register_dataset(dataset) + + # Test with world_size=1, rank=0 (should get all nodes) + node_ids = remote_dataset.get_node_ids_for_rank(rank=0, world_size=1) + self.assertIsInstance(node_ids, torch.Tensor) + self.assertEqual(node_ids.shape[0], 10) + assert_tensor_equality(node_ids, torch.arange(10)) + + def test_get_node_ids_for_rank_with_heterogeneous_dataset(self) -> None: + """Test get_node_ids_for_rank with a heterogeneous dataset.""" + dataset = self._create_heterogeneous_dataset() + remote_dataset.register_dataset(dataset) + + # Test with USER node type + user_node_ids = remote_dataset.get_node_ids_for_rank( + rank=0, world_size=1, node_type=_USER + ) + self.assertIsInstance(user_node_ids, torch.Tensor) + self.assertEqual(user_node_ids.shape[0], 5) + assert_tensor_equality(user_node_ids, torch.arange(5)) + + # Test with STORY node type + story_node_ids = remote_dataset.get_node_ids_for_rank( + rank=0, world_size=1, node_type=_STORY + ) + self.assertIsInstance(story_node_ids, torch.Tensor) + self.assertEqual(story_node_ids.shape[0], 5) + assert_tensor_equality(story_node_ids, torch.arange(5)) + + def test_get_node_ids_for_rank_with_multiple_ranks(self) -> None: + """Test get_node_ids_for_rank with multiple ranks to verify sharding.""" + dataset = self._create_homogeneous_dataset() + remote_dataset.register_dataset(dataset) + + # Test with world_size=2 + rank_0_nodes = remote_dataset.get_node_ids_for_rank(rank=0, world_size=2) + rank_1_nodes = remote_dataset.get_node_ids_for_rank(rank=1, world_size=2) + + # Verify each rank gets different nodes + assert_tensor_equality(rank_0_nodes, torch.arange(5)) + assert_tensor_equality(rank_1_nodes, torch.arange(5, 10)) + + # Test with world_size=3 (uneven split) + rank_0_nodes = remote_dataset.get_node_ids_for_rank(rank=0, world_size=3) + rank_1_nodes = remote_dataset.get_node_ids_for_rank(rank=1, world_size=3) + rank_2_nodes = remote_dataset.get_node_ids_for_rank(rank=2, world_size=3) + + assert_tensor_equality(rank_0_nodes, torch.arange(3)) + assert_tensor_equality(rank_1_nodes, torch.arange(3, 6)) + assert_tensor_equality(rank_2_nodes, torch.arange(6, 10)) + + def test_get_node_ids_for_rank_without_registered_dataset(self) -> None: + """Test get_node_ids_for_rank raises ValueError when no dataset is registered.""" + with self.assertRaises(ValueError) as context: + remote_dataset.get_node_ids_for_rank(rank=0, world_size=1) + + self.assertIn("Dataset not registered", str(context.exception)) + self.assertIn("register_dataset", str(context.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala index 2decc7545..084459025 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala @@ -55,68 +55,71 @@ object GiglResourceConfigProto extends _root_.scalapb.GeneratedFileObject { mVxdWVzdFINbWVtb3J5UmVxdWVzdBInCghncHVfdHlwZRgDIAEoCUIM4j8JEgdncHVUeXBlUgdncHVUeXBlEioKCWdwdV9saW1pd BgEIAEoDUIN4j8KEghncHVMaW1pdFIIZ3B1TGltaXQSMwoMbnVtX3JlcGxpY2FzGAUgASgNQhDiPw0SC251bVJlcGxpY2FzUgtud W1SZXBsaWNhcyJHChNMb2NhbFJlc291cmNlQ29uZmlnEjAKC251bV93b3JrZXJzGAEgASgNQg/iPwwSCm51bVdvcmtlcnNSCm51b - VdvcmtlcnMi7gEKGFZlcnRleEFpR3JhcGhTdG9yZUNvbmZpZxJtChBncmFwaF9zdG9yZV9wb29sGAEgASgLMi4uc25hcGNoYXQuc + VdvcmtlcnMiygMKGFZlcnRleEFpR3JhcGhTdG9yZUNvbmZpZxJtChBncmFwaF9zdG9yZV9wb29sGAEgASgLMi4uc25hcGNoYXQuc mVzZWFyY2guZ2JtbC5WZXJ0ZXhBaVJlc291cmNlQ29uZmlnQhPiPxASDmdyYXBoU3RvcmVQb29sUg5ncmFwaFN0b3JlUG9vbBJjC gxjb21wdXRlX3Bvb2wYAiABKAsyLi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlZlcnRleEFpUmVzb3VyY2VDb25maWdCEOI/DRILY - 29tcHV0ZVBvb2xSC2NvbXB1dGVQb29sIp0DChhEaXN0cmlidXRlZFRyYWluZXJDb25maWcShAEKGHZlcnRleF9haV90cmFpbmVyX - 2NvbmZpZxgBIAEoCzItLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuVmVydGV4QWlUcmFpbmVyQ29uZmlnQhriPxcSFXZlcnRleEFpV - HJhaW5lckNvbmZpZ0gAUhV2ZXJ0ZXhBaVRyYWluZXJDb25maWcSbwoSa2ZwX3RyYWluZXJfY29uZmlnGAIgASgLMiguc25hcGNoY - XQucmVzZWFyY2guZ2JtbC5LRlBUcmFpbmVyQ29uZmlnQhXiPxISEGtmcFRyYWluZXJDb25maWdIAFIQa2ZwVHJhaW5lckNvbmZpZ - xJ3ChRsb2NhbF90cmFpbmVyX2NvbmZpZxgDIAEoCzIqLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuTG9jYWxUcmFpbmVyQ29uZmlnQ - hfiPxQSEmxvY2FsVHJhaW5lckNvbmZpZ0gAUhJsb2NhbFRyYWluZXJDb25maWdCEAoOdHJhaW5lcl9jb25maWcixwQKFVRyYWluZ - XJSZXNvdXJjZUNvbmZpZxKFAQoYdmVydGV4X2FpX3RyYWluZXJfY29uZmlnGAEgASgLMi4uc25hcGNoYXQucmVzZWFyY2guZ2Jtb - C5WZXJ0ZXhBaVJlc291cmNlQ29uZmlnQhriPxcSFXZlcnRleEFpVHJhaW5lckNvbmZpZ0gAUhV2ZXJ0ZXhBaVRyYWluZXJDb25ma - WcScAoSa2ZwX3RyYWluZXJfY29uZmlnGAIgASgLMikuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5LRlBSZXNvdXJjZUNvbmZpZ0IV4 - j8SEhBrZnBUcmFpbmVyQ29uZmlnSABSEGtmcFRyYWluZXJDb25maWcSeAoUbG9jYWxfdHJhaW5lcl9jb25maWcYAyABKAsyKy5zb - mFwY2hhdC5yZXNlYXJjaC5nYm1sLkxvY2FsUmVzb3VyY2VDb25maWdCF+I/FBISbG9jYWxUcmFpbmVyQ29uZmlnSABSEmxvY2FsV - HJhaW5lckNvbmZpZxKnAQokdmVydGV4X2FpX2dyYXBoX3N0b3JlX3RyYWluZXJfY29uZmlnGAQgASgLMjAuc25hcGNoYXQucmVzZ - WFyY2guZ2JtbC5WZXJ0ZXhBaUdyYXBoU3RvcmVDb25maWdCJOI/IRIfdmVydGV4QWlHcmFwaFN0b3JlVHJhaW5lckNvbmZpZ0gAU - h92ZXJ0ZXhBaUdyYXBoU3RvcmVUcmFpbmVyQ29uZmlnQhAKDnRyYWluZXJfY29uZmlnIocFChhJbmZlcmVuY2VyUmVzb3VyY2VDb - 25maWcSjgEKG3ZlcnRleF9haV9pbmZlcmVuY2VyX2NvbmZpZxgBIAEoCzIuLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuVmVydGV4Q - WlSZXNvdXJjZUNvbmZpZ0Id4j8aEhh2ZXJ0ZXhBaUluZmVyZW5jZXJDb25maWdIAFIYdmVydGV4QWlJbmZlcmVuY2VyQ29uZmlnE - o0BChpkYXRhZmxvd19pbmZlcmVuY2VyX2NvbmZpZxgCIAEoCzIuLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuRGF0YWZsb3dSZXNvd - XJjZUNvbmZpZ0Id4j8aEhhkYXRhZmxvd0luZmVyZW5jZXJDb25maWdIAFIYZGF0YWZsb3dJbmZlcmVuY2VyQ29uZmlnEoEBChdsb - 2NhbF9pbmZlcmVuY2VyX2NvbmZpZxgDIAEoCzIrLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuTG9jYWxSZXNvdXJjZUNvbmZpZ0Ia4 - j8XEhVsb2NhbEluZmVyZW5jZXJDb25maWdIAFIVbG9jYWxJbmZlcmVuY2VyQ29uZmlnErABCid2ZXJ0ZXhfYWlfZ3JhcGhfc3Rvc - mVfaW5mZXJlbmNlcl9jb25maWcYBCABKAsyMC5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlZlcnRleEFpR3JhcGhTdG9yZUNvbmZpZ - 0In4j8kEiJ2ZXJ0ZXhBaUdyYXBoU3RvcmVJbmZlcmVuY2VyQ29uZmlnSABSInZlcnRleEFpR3JhcGhTdG9yZUluZmVyZW5jZXJDb - 25maWdCEwoRaW5mZXJlbmNlcl9jb25maWcilwgKFFNoYXJlZFJlc291cmNlQ29uZmlnEn4KD3Jlc291cmNlX2xhYmVscxgBIAMoC - zJALnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuU2hhcmVkUmVzb3VyY2VDb25maWcuUmVzb3VyY2VMYWJlbHNFbnRyeUIT4j8QEg5yZ - XNvdXJjZUxhYmVsc1IOcmVzb3VyY2VMYWJlbHMSjgEKFWNvbW1vbl9jb21wdXRlX2NvbmZpZxgCIAEoCzJALnNuYXBjaGF0LnJlc - 2VhcmNoLmdibWwuU2hhcmVkUmVzb3VyY2VDb25maWcuQ29tbW9uQ29tcHV0ZUNvbmZpZ0IY4j8VEhNjb21tb25Db21wdXRlQ29uZ - mlnUhNjb21tb25Db21wdXRlQ29uZmlnGpQFChNDb21tb25Db21wdXRlQ29uZmlnEiYKB3Byb2plY3QYASABKAlCDOI/CRIHcHJva - mVjdFIHcHJvamVjdBIjCgZyZWdpb24YAiABKAlCC+I/CBIGcmVnaW9uUgZyZWdpb24SQwoSdGVtcF9hc3NldHNfYnVja2V0GAMgA - SgJQhXiPxISEHRlbXBBc3NldHNCdWNrZXRSEHRlbXBBc3NldHNCdWNrZXQSXAobdGVtcF9yZWdpb25hbF9hc3NldHNfYnVja2V0G - AQgASgJQh3iPxoSGHRlbXBSZWdpb25hbEFzc2V0c0J1Y2tldFIYdGVtcFJlZ2lvbmFsQXNzZXRzQnVja2V0EkMKEnBlcm1fYXNzZ - XRzX2J1Y2tldBgFIAEoCUIV4j8SEhBwZXJtQXNzZXRzQnVja2V0UhBwZXJtQXNzZXRzQnVja2V0EloKG3RlbXBfYXNzZXRzX2JxX - 2RhdGFzZXRfbmFtZRgGIAEoCUIc4j8ZEhd0ZW1wQXNzZXRzQnFEYXRhc2V0TmFtZVIXdGVtcEFzc2V0c0JxRGF0YXNldE5hbWUSV - goZZW1iZWRkaW5nX2JxX2RhdGFzZXRfbmFtZRgHIAEoCUIb4j8YEhZlbWJlZGRpbmdCcURhdGFzZXROYW1lUhZlbWJlZGRpbmdCc - URhdGFzZXROYW1lElYKGWdjcF9zZXJ2aWNlX2FjY291bnRfZW1haWwYCCABKAlCG+I/GBIWZ2NwU2VydmljZUFjY291bnRFbWFpb - FIWZ2NwU2VydmljZUFjY291bnRFbWFpbBI8Cg9kYXRhZmxvd19ydW5uZXIYCyABKAlCE+I/EBIOZGF0YWZsb3dSdW5uZXJSDmRhd - GFmbG93UnVubmVyGlcKE1Jlc291cmNlTGFiZWxzRW50cnkSGgoDa2V5GAEgASgJQgjiPwUSA2tleVIDa2V5EiAKBXZhbHVlGAIgA - SgJQgriPwcSBXZhbHVlUgV2YWx1ZToCOAEi9wgKEkdpZ2xSZXNvdXJjZUNvbmZpZxJbChpzaGFyZWRfcmVzb3VyY2VfY29uZmlnX - 3VyaRgBIAEoCUIc4j8ZEhdzaGFyZWRSZXNvdXJjZUNvbmZpZ1VyaUgAUhdzaGFyZWRSZXNvdXJjZUNvbmZpZ1VyaRJ/ChZzaGFyZ - WRfcmVzb3VyY2VfY29uZmlnGAIgASgLMiwuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5TaGFyZWRSZXNvdXJjZUNvbmZpZ0IZ4j8WE - hRzaGFyZWRSZXNvdXJjZUNvbmZpZ0gAUhRzaGFyZWRSZXNvdXJjZUNvbmZpZxJ4ChNwcmVwcm9jZXNzb3JfY29uZmlnGAwgASgLM - i4uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5EYXRhUHJlcHJvY2Vzc29yQ29uZmlnQhfiPxQSEnByZXByb2Nlc3NvckNvbmZpZ1ISc - HJlcHJvY2Vzc29yQ29uZmlnEn8KF3N1YmdyYXBoX3NhbXBsZXJfY29uZmlnGA0gASgLMisuc25hcGNoYXQucmVzZWFyY2guZ2Jtb - C5TcGFya1Jlc291cmNlQ29uZmlnQhriPxcSFXN1YmdyYXBoU2FtcGxlckNvbmZpZ1IVc3ViZ3JhcGhTYW1wbGVyQ29uZmlnEnwKF - nNwbGl0X2dlbmVyYXRvcl9jb25maWcYDiABKAsyKy5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlNwYXJrUmVzb3VyY2VDb25maWdCG - eI/FhIUc3BsaXRHZW5lcmF0b3JDb25maWdSFHNwbGl0R2VuZXJhdG9yQ29uZmlnEm0KDnRyYWluZXJfY29uZmlnGA8gASgLMjAuc - 25hcGNoYXQucmVzZWFyY2guZ2JtbC5EaXN0cmlidXRlZFRyYWluZXJDb25maWdCFBgB4j8PEg10cmFpbmVyQ29uZmlnUg10cmFpb - mVyQ29uZmlnEnQKEWluZmVyZW5jZXJfY29uZmlnGBAgASgLMi4uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5EYXRhZmxvd1Jlc291c - mNlQ29uZmlnQhcYAeI/EhIQaW5mZXJlbmNlckNvbmZpZ1IQaW5mZXJlbmNlckNvbmZpZxKBAQoXdHJhaW5lcl9yZXNvdXJjZV9jb - 25maWcYESABKAsyLS5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlRyYWluZXJSZXNvdXJjZUNvbmZpZ0Ia4j8XEhV0cmFpbmVyUmVzb - 3VyY2VDb25maWdSFXRyYWluZXJSZXNvdXJjZUNvbmZpZxKNAQoaaW5mZXJlbmNlcl9yZXNvdXJjZV9jb25maWcYEiABKAsyMC5zb - mFwY2hhdC5yZXNlYXJjaC5nYm1sLkluZmVyZW5jZXJSZXNvdXJjZUNvbmZpZ0Id4j8aEhhpbmZlcmVuY2VyUmVzb3VyY2VDb25ma - WdSGGluZmVyZW5jZXJSZXNvdXJjZUNvbmZpZ0IRCg9zaGFyZWRfcmVzb3VyY2Uq4wMKCUNvbXBvbmVudBItChFDb21wb25lbnRfV - W5rbm93bhAAGhbiPxMSEUNvbXBvbmVudF9Vbmtub3duEj8KGkNvbXBvbmVudF9Db25maWdfVmFsaWRhdG9yEAEaH+I/HBIaQ29tc - G9uZW50X0NvbmZpZ19WYWxpZGF0b3ISPwoaQ29tcG9uZW50X0NvbmZpZ19Qb3B1bGF0b3IQAhof4j8cEhpDb21wb25lbnRfQ29uZ - mlnX1BvcHVsYXRvchJBChtDb21wb25lbnRfRGF0YV9QcmVwcm9jZXNzb3IQAxog4j8dEhtDb21wb25lbnRfRGF0YV9QcmVwcm9jZ - XNzb3ISPwoaQ29tcG9uZW50X1N1YmdyYXBoX1NhbXBsZXIQBBof4j8cEhpDb21wb25lbnRfU3ViZ3JhcGhfU2FtcGxlchI9ChlDb - 21wb25lbnRfU3BsaXRfR2VuZXJhdG9yEAUaHuI/GxIZQ29tcG9uZW50X1NwbGl0X0dlbmVyYXRvchItChFDb21wb25lbnRfVHJha - W5lchAGGhbiPxMSEUNvbXBvbmVudF9UcmFpbmVyEjMKFENvbXBvbmVudF9JbmZlcmVuY2VyEAcaGeI/FhIUQ29tcG9uZW50X0luZ - mVyZW5jZXJiBnByb3RvMw==""" + 29tcHV0ZVBvb2xSC2NvbXB1dGVQb29sEmwKIW51bV9wcm9jZXNzZXNfcGVyX3N0b3JhZ2VfbWFjaGluZRgDIAEoBUIi4j8fEh1ud + W1Qcm9jZXNzZXNQZXJTdG9yYWdlTWFjaGluZVIdbnVtUHJvY2Vzc2VzUGVyU3RvcmFnZU1hY2hpbmUSbAohbnVtX3Byb2Nlc3Nlc + 19wZXJfY29tcHV0ZV9tYWNoaW5lGAQgASgFQiLiPx8SHW51bVByb2Nlc3Nlc1BlckNvbXB1dGVNYWNoaW5lUh1udW1Qcm9jZXNzZ + XNQZXJDb21wdXRlTWFjaGluZSKdAwoYRGlzdHJpYnV0ZWRUcmFpbmVyQ29uZmlnEoQBChh2ZXJ0ZXhfYWlfdHJhaW5lcl9jb25ma + WcYASABKAsyLS5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlZlcnRleEFpVHJhaW5lckNvbmZpZ0Ia4j8XEhV2ZXJ0ZXhBaVRyYWluZ + XJDb25maWdIAFIVdmVydGV4QWlUcmFpbmVyQ29uZmlnEm8KEmtmcF90cmFpbmVyX2NvbmZpZxgCIAEoCzIoLnNuYXBjaGF0LnJlc + 2VhcmNoLmdibWwuS0ZQVHJhaW5lckNvbmZpZ0IV4j8SEhBrZnBUcmFpbmVyQ29uZmlnSABSEGtmcFRyYWluZXJDb25maWcSdwoUb + G9jYWxfdHJhaW5lcl9jb25maWcYAyABKAsyKi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkxvY2FsVHJhaW5lckNvbmZpZ0IX4j8UE + hJsb2NhbFRyYWluZXJDb25maWdIAFISbG9jYWxUcmFpbmVyQ29uZmlnQhAKDnRyYWluZXJfY29uZmlnIscEChVUcmFpbmVyUmVzb + 3VyY2VDb25maWcShQEKGHZlcnRleF9haV90cmFpbmVyX2NvbmZpZxgBIAEoCzIuLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuVmVyd + GV4QWlSZXNvdXJjZUNvbmZpZ0Ia4j8XEhV2ZXJ0ZXhBaVRyYWluZXJDb25maWdIAFIVdmVydGV4QWlUcmFpbmVyQ29uZmlnEnAKE + mtmcF90cmFpbmVyX2NvbmZpZxgCIAEoCzIpLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuS0ZQUmVzb3VyY2VDb25maWdCFeI/EhIQa + 2ZwVHJhaW5lckNvbmZpZ0gAUhBrZnBUcmFpbmVyQ29uZmlnEngKFGxvY2FsX3RyYWluZXJfY29uZmlnGAMgASgLMisuc25hcGNoY + XQucmVzZWFyY2guZ2JtbC5Mb2NhbFJlc291cmNlQ29uZmlnQhfiPxQSEmxvY2FsVHJhaW5lckNvbmZpZ0gAUhJsb2NhbFRyYWluZ + XJDb25maWcSpwEKJHZlcnRleF9haV9ncmFwaF9zdG9yZV90cmFpbmVyX2NvbmZpZxgEIAEoCzIwLnNuYXBjaGF0LnJlc2VhcmNoL + mdibWwuVmVydGV4QWlHcmFwaFN0b3JlQ29uZmlnQiTiPyESH3ZlcnRleEFpR3JhcGhTdG9yZVRyYWluZXJDb25maWdIAFIfdmVyd + GV4QWlHcmFwaFN0b3JlVHJhaW5lckNvbmZpZ0IQCg50cmFpbmVyX2NvbmZpZyKHBQoYSW5mZXJlbmNlclJlc291cmNlQ29uZmlnE + o4BCht2ZXJ0ZXhfYWlfaW5mZXJlbmNlcl9jb25maWcYASABKAsyLi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlZlcnRleEFpUmVzb + 3VyY2VDb25maWdCHeI/GhIYdmVydGV4QWlJbmZlcmVuY2VyQ29uZmlnSABSGHZlcnRleEFpSW5mZXJlbmNlckNvbmZpZxKNAQoaZ + GF0YWZsb3dfaW5mZXJlbmNlcl9jb25maWcYAiABKAsyLi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkRhdGFmbG93UmVzb3VyY2VDb + 25maWdCHeI/GhIYZGF0YWZsb3dJbmZlcmVuY2VyQ29uZmlnSABSGGRhdGFmbG93SW5mZXJlbmNlckNvbmZpZxKBAQoXbG9jYWxfa + W5mZXJlbmNlcl9jb25maWcYAyABKAsyKy5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkxvY2FsUmVzb3VyY2VDb25maWdCGuI/FxIVb + G9jYWxJbmZlcmVuY2VyQ29uZmlnSABSFWxvY2FsSW5mZXJlbmNlckNvbmZpZxKwAQondmVydGV4X2FpX2dyYXBoX3N0b3JlX2luZ + mVyZW5jZXJfY29uZmlnGAQgASgLMjAuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5WZXJ0ZXhBaUdyYXBoU3RvcmVDb25maWdCJ+I/J + BIidmVydGV4QWlHcmFwaFN0b3JlSW5mZXJlbmNlckNvbmZpZ0gAUiJ2ZXJ0ZXhBaUdyYXBoU3RvcmVJbmZlcmVuY2VyQ29uZmlnQ + hMKEWluZmVyZW5jZXJfY29uZmlnIpcIChRTaGFyZWRSZXNvdXJjZUNvbmZpZxJ+Cg9yZXNvdXJjZV9sYWJlbHMYASADKAsyQC5zb + mFwY2hhdC5yZXNlYXJjaC5nYm1sLlNoYXJlZFJlc291cmNlQ29uZmlnLlJlc291cmNlTGFiZWxzRW50cnlCE+I/EBIOcmVzb3VyY + 2VMYWJlbHNSDnJlc291cmNlTGFiZWxzEo4BChVjb21tb25fY29tcHV0ZV9jb25maWcYAiABKAsyQC5zbmFwY2hhdC5yZXNlYXJja + C5nYm1sLlNoYXJlZFJlc291cmNlQ29uZmlnLkNvbW1vbkNvbXB1dGVDb25maWdCGOI/FRITY29tbW9uQ29tcHV0ZUNvbmZpZ1ITY + 29tbW9uQ29tcHV0ZUNvbmZpZxqUBQoTQ29tbW9uQ29tcHV0ZUNvbmZpZxImCgdwcm9qZWN0GAEgASgJQgziPwkSB3Byb2plY3RSB + 3Byb2plY3QSIwoGcmVnaW9uGAIgASgJQgviPwgSBnJlZ2lvblIGcmVnaW9uEkMKEnRlbXBfYXNzZXRzX2J1Y2tldBgDIAEoCUIV4 + j8SEhB0ZW1wQXNzZXRzQnVja2V0UhB0ZW1wQXNzZXRzQnVja2V0ElwKG3RlbXBfcmVnaW9uYWxfYXNzZXRzX2J1Y2tldBgEIAEoC + UId4j8aEhh0ZW1wUmVnaW9uYWxBc3NldHNCdWNrZXRSGHRlbXBSZWdpb25hbEFzc2V0c0J1Y2tldBJDChJwZXJtX2Fzc2V0c19id + WNrZXQYBSABKAlCFeI/EhIQcGVybUFzc2V0c0J1Y2tldFIQcGVybUFzc2V0c0J1Y2tldBJaCht0ZW1wX2Fzc2V0c19icV9kYXRhc + 2V0X25hbWUYBiABKAlCHOI/GRIXdGVtcEFzc2V0c0JxRGF0YXNldE5hbWVSF3RlbXBBc3NldHNCcURhdGFzZXROYW1lElYKGWVtY + mVkZGluZ19icV9kYXRhc2V0X25hbWUYByABKAlCG+I/GBIWZW1iZWRkaW5nQnFEYXRhc2V0TmFtZVIWZW1iZWRkaW5nQnFEYXRhc + 2V0TmFtZRJWChlnY3Bfc2VydmljZV9hY2NvdW50X2VtYWlsGAggASgJQhviPxgSFmdjcFNlcnZpY2VBY2NvdW50RW1haWxSFmdjc + FNlcnZpY2VBY2NvdW50RW1haWwSPAoPZGF0YWZsb3dfcnVubmVyGAsgASgJQhPiPxASDmRhdGFmbG93UnVubmVyUg5kYXRhZmxvd + 1J1bm5lchpXChNSZXNvdXJjZUxhYmVsc0VudHJ5EhoKA2tleRgBIAEoCUII4j8FEgNrZXlSA2tleRIgCgV2YWx1ZRgCIAEoCUIK4 + j8HEgV2YWx1ZVIFdmFsdWU6AjgBIvcIChJHaWdsUmVzb3VyY2VDb25maWcSWwoac2hhcmVkX3Jlc291cmNlX2NvbmZpZ191cmkYA + SABKAlCHOI/GRIXc2hhcmVkUmVzb3VyY2VDb25maWdVcmlIAFIXc2hhcmVkUmVzb3VyY2VDb25maWdVcmkSfwoWc2hhcmVkX3Jlc + 291cmNlX2NvbmZpZxgCIAEoCzIsLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuU2hhcmVkUmVzb3VyY2VDb25maWdCGeI/FhIUc2hhc + mVkUmVzb3VyY2VDb25maWdIAFIUc2hhcmVkUmVzb3VyY2VDb25maWcSeAoTcHJlcHJvY2Vzc29yX2NvbmZpZxgMIAEoCzIuLnNuY + XBjaGF0LnJlc2VhcmNoLmdibWwuRGF0YVByZXByb2Nlc3NvckNvbmZpZ0IX4j8UEhJwcmVwcm9jZXNzb3JDb25maWdSEnByZXByb + 2Nlc3NvckNvbmZpZxJ/ChdzdWJncmFwaF9zYW1wbGVyX2NvbmZpZxgNIAEoCzIrLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuU3Bhc + mtSZXNvdXJjZUNvbmZpZ0Ia4j8XEhVzdWJncmFwaFNhbXBsZXJDb25maWdSFXN1YmdyYXBoU2FtcGxlckNvbmZpZxJ8ChZzcGxpd + F9nZW5lcmF0b3JfY29uZmlnGA4gASgLMisuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5TcGFya1Jlc291cmNlQ29uZmlnQhniPxYSF + HNwbGl0R2VuZXJhdG9yQ29uZmlnUhRzcGxpdEdlbmVyYXRvckNvbmZpZxJtCg50cmFpbmVyX2NvbmZpZxgPIAEoCzIwLnNuYXBja + GF0LnJlc2VhcmNoLmdibWwuRGlzdHJpYnV0ZWRUcmFpbmVyQ29uZmlnQhQYAeI/DxINdHJhaW5lckNvbmZpZ1INdHJhaW5lckNvb + mZpZxJ0ChFpbmZlcmVuY2VyX2NvbmZpZxgQIAEoCzIuLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuRGF0YWZsb3dSZXNvdXJjZUNvb + mZpZ0IXGAHiPxISEGluZmVyZW5jZXJDb25maWdSEGluZmVyZW5jZXJDb25maWcSgQEKF3RyYWluZXJfcmVzb3VyY2VfY29uZmlnG + BEgASgLMi0uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5UcmFpbmVyUmVzb3VyY2VDb25maWdCGuI/FxIVdHJhaW5lclJlc291cmNlQ + 29uZmlnUhV0cmFpbmVyUmVzb3VyY2VDb25maWcSjQEKGmluZmVyZW5jZXJfcmVzb3VyY2VfY29uZmlnGBIgASgLMjAuc25hcGNoY + XQucmVzZWFyY2guZ2JtbC5JbmZlcmVuY2VyUmVzb3VyY2VDb25maWdCHeI/GhIYaW5mZXJlbmNlclJlc291cmNlQ29uZmlnUhhpb + mZlcmVuY2VyUmVzb3VyY2VDb25maWdCEQoPc2hhcmVkX3Jlc291cmNlKuMDCglDb21wb25lbnQSLQoRQ29tcG9uZW50X1Vua25vd + 24QABoW4j8TEhFDb21wb25lbnRfVW5rbm93bhI/ChpDb21wb25lbnRfQ29uZmlnX1ZhbGlkYXRvchABGh/iPxwSGkNvbXBvbmVud + F9Db25maWdfVmFsaWRhdG9yEj8KGkNvbXBvbmVudF9Db25maWdfUG9wdWxhdG9yEAIaH+I/HBIaQ29tcG9uZW50X0NvbmZpZ19Qb + 3B1bGF0b3ISQQobQ29tcG9uZW50X0RhdGFfUHJlcHJvY2Vzc29yEAMaIOI/HRIbQ29tcG9uZW50X0RhdGFfUHJlcHJvY2Vzc29yE + j8KGkNvbXBvbmVudF9TdWJncmFwaF9TYW1wbGVyEAQaH+I/HBIaQ29tcG9uZW50X1N1YmdyYXBoX1NhbXBsZXISPQoZQ29tcG9uZ + W50X1NwbGl0X0dlbmVyYXRvchAFGh7iPxsSGUNvbXBvbmVudF9TcGxpdF9HZW5lcmF0b3ISLQoRQ29tcG9uZW50X1RyYWluZXIQB + hoW4j8TEhFDb21wb25lbnRfVHJhaW5lchIzChRDb21wb25lbnRfSW5mZXJlbmNlchAHGhniPxYSFENvbXBvbmVudF9JbmZlcmVuY + 2VyYgZwcm90bzM=""" ).mkString) lazy val scalaDescriptor: _root_.scalapb.descriptors.FileDescriptor = { val scalaProto = com.google.protobuf.descriptor.FileDescriptorProto.parseFrom(ProtoBytes) diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiGraphStoreConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiGraphStoreConfig.scala index c07f1a3cb..30ded01c0 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiGraphStoreConfig.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiGraphStoreConfig.scala @@ -16,6 +16,8 @@ package snapchat.research.gbml.gigl_resource_config final case class VertexAiGraphStoreConfig( graphStorePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None, computePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None, + numProcessesPerStorageMachine: _root_.scala.Int = 0, + numProcessesPerComputeMachine: _root_.scala.Int = 0, unknownFields: _root_.scalapb.UnknownFieldSet = _root_.scalapb.UnknownFieldSet.empty ) extends scalapb.GeneratedMessage with scalapb.lenses.Updatable[VertexAiGraphStoreConfig] { @transient @@ -30,6 +32,20 @@ final case class VertexAiGraphStoreConfig( val __value = computePool.get __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize }; + + { + val __value = numProcessesPerStorageMachine + if (__value != 0) { + __size += _root_.com.google.protobuf.CodedOutputStream.computeInt32Size(3, __value) + } + }; + + { + val __value = numProcessesPerComputeMachine + if (__value != 0) { + __size += _root_.com.google.protobuf.CodedOutputStream.computeInt32Size(4, __value) + } + }; __size += unknownFields.serializedSize __size } @@ -55,6 +71,18 @@ final case class VertexAiGraphStoreConfig( _output__.writeUInt32NoTag(__m.serializedSize) __m.writeTo(_output__) }; + { + val __v = numProcessesPerStorageMachine + if (__v != 0) { + _output__.writeInt32(3, __v) + } + }; + { + val __v = numProcessesPerComputeMachine + if (__v != 0) { + _output__.writeInt32(4, __v) + } + }; unknownFields.writeTo(_output__) } def getGraphStorePool: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig = graphStorePool.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig.defaultInstance) @@ -63,12 +91,22 @@ final case class VertexAiGraphStoreConfig( def getComputePool: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig = computePool.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig.defaultInstance) def clearComputePool: VertexAiGraphStoreConfig = copy(computePool = _root_.scala.None) def withComputePool(__v: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig): VertexAiGraphStoreConfig = copy(computePool = Option(__v)) + def withNumProcessesPerStorageMachine(__v: _root_.scala.Int): VertexAiGraphStoreConfig = copy(numProcessesPerStorageMachine = __v) + def withNumProcessesPerComputeMachine(__v: _root_.scala.Int): VertexAiGraphStoreConfig = copy(numProcessesPerComputeMachine = __v) def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) def discardUnknownFields = copy(unknownFields = _root_.scalapb.UnknownFieldSet.empty) def getFieldByNumber(__fieldNumber: _root_.scala.Int): _root_.scala.Any = { (__fieldNumber: @_root_.scala.unchecked) match { case 1 => graphStorePool.orNull case 2 => computePool.orNull + case 3 => { + val __t = numProcessesPerStorageMachine + if (__t != 0) __t else null + } + case 4 => { + val __t = numProcessesPerComputeMachine + if (__t != 0) __t else null + } } } def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { @@ -76,6 +114,8 @@ final case class VertexAiGraphStoreConfig( (__field.number: @_root_.scala.unchecked) match { case 1 => graphStorePool.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 2 => computePool.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) + case 3 => _root_.scalapb.descriptors.PInt(numProcessesPerStorageMachine) + case 4 => _root_.scalapb.descriptors.PInt(numProcessesPerComputeMachine) } } def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) @@ -88,6 +128,8 @@ object VertexAiGraphStoreConfig extends scalapb.GeneratedMessageCompanion[snapch def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig = { var __graphStorePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None var __computePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None + var __numProcessesPerStorageMachine: _root_.scala.Int = 0 + var __numProcessesPerComputeMachine: _root_.scala.Int = 0 var `_unknownFields__`: _root_.scalapb.UnknownFieldSet.Builder = null var _done__ = false while (!_done__) { @@ -98,6 +140,10 @@ object VertexAiGraphStoreConfig extends scalapb.GeneratedMessageCompanion[snapch __graphStorePool = Option(__graphStorePool.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case 18 => __computePool = Option(__computePool.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) + case 24 => + __numProcessesPerStorageMachine = _input__.readInt32() + case 32 => + __numProcessesPerComputeMachine = _input__.readInt32() case tag => if (_unknownFields__ == null) { _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() @@ -108,6 +154,8 @@ object VertexAiGraphStoreConfig extends scalapb.GeneratedMessageCompanion[snapch snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig( graphStorePool = __graphStorePool, computePool = __computePool, + numProcessesPerStorageMachine = __numProcessesPerStorageMachine, + numProcessesPerComputeMachine = __numProcessesPerComputeMachine, unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result() ) } @@ -116,7 +164,9 @@ object VertexAiGraphStoreConfig extends scalapb.GeneratedMessageCompanion[snapch _root_.scala.Predef.require(__fieldsMap.keys.forall(_.containingMessage eq scalaDescriptor), "FieldDescriptor does not match message type.") snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig( graphStorePool = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]]), - computePool = __fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]]) + computePool = __fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]]), + numProcessesPerStorageMachine = __fieldsMap.get(scalaDescriptor.findFieldByNumber(3).get).map(_.as[_root_.scala.Int]).getOrElse(0), + numProcessesPerComputeMachine = __fieldsMap.get(scalaDescriptor.findFieldByNumber(4).get).map(_.as[_root_.scala.Int]).getOrElse(0) ) case _ => throw new RuntimeException("Expected PMessage") } @@ -134,22 +184,32 @@ object VertexAiGraphStoreConfig extends scalapb.GeneratedMessageCompanion[snapch def enumCompanionForFieldNumber(__fieldNumber: _root_.scala.Int): _root_.scalapb.GeneratedEnumCompanion[_] = throw new MatchError(__fieldNumber) lazy val defaultInstance = snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig( graphStorePool = _root_.scala.None, - computePool = _root_.scala.None + computePool = _root_.scala.None, + numProcessesPerStorageMachine = 0, + numProcessesPerComputeMachine = 0 ) implicit class VertexAiGraphStoreConfigLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig](_l) { def graphStorePool: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = field(_.getGraphStorePool)((c_, f_) => c_.copy(graphStorePool = Option(f_))) def optionalGraphStorePool: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]] = field(_.graphStorePool)((c_, f_) => c_.copy(graphStorePool = f_)) def computePool: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = field(_.getComputePool)((c_, f_) => c_.copy(computePool = Option(f_))) def optionalComputePool: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]] = field(_.computePool)((c_, f_) => c_.copy(computePool = f_)) + def numProcessesPerStorageMachine: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Int] = field(_.numProcessesPerStorageMachine)((c_, f_) => c_.copy(numProcessesPerStorageMachine = f_)) + def numProcessesPerComputeMachine: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Int] = field(_.numProcessesPerComputeMachine)((c_, f_) => c_.copy(numProcessesPerComputeMachine = f_)) } final val GRAPH_STORE_POOL_FIELD_NUMBER = 1 final val COMPUTE_POOL_FIELD_NUMBER = 2 + final val NUM_PROCESSES_PER_STORAGE_MACHINE_FIELD_NUMBER = 3 + final val NUM_PROCESSES_PER_COMPUTE_MACHINE_FIELD_NUMBER = 4 def of( graphStorePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig], - computePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] + computePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig], + numProcessesPerStorageMachine: _root_.scala.Int, + numProcessesPerComputeMachine: _root_.scala.Int ): _root_.snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig = _root_.snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig( graphStorePool, - computePool + computePool, + numProcessesPerStorageMachine, + numProcessesPerComputeMachine ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.VertexAiGraphStoreConfig]) } diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala index 2decc7545..084459025 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala @@ -55,68 +55,71 @@ object GiglResourceConfigProto extends _root_.scalapb.GeneratedFileObject { mVxdWVzdFINbWVtb3J5UmVxdWVzdBInCghncHVfdHlwZRgDIAEoCUIM4j8JEgdncHVUeXBlUgdncHVUeXBlEioKCWdwdV9saW1pd BgEIAEoDUIN4j8KEghncHVMaW1pdFIIZ3B1TGltaXQSMwoMbnVtX3JlcGxpY2FzGAUgASgNQhDiPw0SC251bVJlcGxpY2FzUgtud W1SZXBsaWNhcyJHChNMb2NhbFJlc291cmNlQ29uZmlnEjAKC251bV93b3JrZXJzGAEgASgNQg/iPwwSCm51bVdvcmtlcnNSCm51b - VdvcmtlcnMi7gEKGFZlcnRleEFpR3JhcGhTdG9yZUNvbmZpZxJtChBncmFwaF9zdG9yZV9wb29sGAEgASgLMi4uc25hcGNoYXQuc + VdvcmtlcnMiygMKGFZlcnRleEFpR3JhcGhTdG9yZUNvbmZpZxJtChBncmFwaF9zdG9yZV9wb29sGAEgASgLMi4uc25hcGNoYXQuc mVzZWFyY2guZ2JtbC5WZXJ0ZXhBaVJlc291cmNlQ29uZmlnQhPiPxASDmdyYXBoU3RvcmVQb29sUg5ncmFwaFN0b3JlUG9vbBJjC gxjb21wdXRlX3Bvb2wYAiABKAsyLi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlZlcnRleEFpUmVzb3VyY2VDb25maWdCEOI/DRILY - 29tcHV0ZVBvb2xSC2NvbXB1dGVQb29sIp0DChhEaXN0cmlidXRlZFRyYWluZXJDb25maWcShAEKGHZlcnRleF9haV90cmFpbmVyX - 2NvbmZpZxgBIAEoCzItLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuVmVydGV4QWlUcmFpbmVyQ29uZmlnQhriPxcSFXZlcnRleEFpV - HJhaW5lckNvbmZpZ0gAUhV2ZXJ0ZXhBaVRyYWluZXJDb25maWcSbwoSa2ZwX3RyYWluZXJfY29uZmlnGAIgASgLMiguc25hcGNoY - XQucmVzZWFyY2guZ2JtbC5LRlBUcmFpbmVyQ29uZmlnQhXiPxISEGtmcFRyYWluZXJDb25maWdIAFIQa2ZwVHJhaW5lckNvbmZpZ - xJ3ChRsb2NhbF90cmFpbmVyX2NvbmZpZxgDIAEoCzIqLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuTG9jYWxUcmFpbmVyQ29uZmlnQ - hfiPxQSEmxvY2FsVHJhaW5lckNvbmZpZ0gAUhJsb2NhbFRyYWluZXJDb25maWdCEAoOdHJhaW5lcl9jb25maWcixwQKFVRyYWluZ - XJSZXNvdXJjZUNvbmZpZxKFAQoYdmVydGV4X2FpX3RyYWluZXJfY29uZmlnGAEgASgLMi4uc25hcGNoYXQucmVzZWFyY2guZ2Jtb - C5WZXJ0ZXhBaVJlc291cmNlQ29uZmlnQhriPxcSFXZlcnRleEFpVHJhaW5lckNvbmZpZ0gAUhV2ZXJ0ZXhBaVRyYWluZXJDb25ma - WcScAoSa2ZwX3RyYWluZXJfY29uZmlnGAIgASgLMikuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5LRlBSZXNvdXJjZUNvbmZpZ0IV4 - j8SEhBrZnBUcmFpbmVyQ29uZmlnSABSEGtmcFRyYWluZXJDb25maWcSeAoUbG9jYWxfdHJhaW5lcl9jb25maWcYAyABKAsyKy5zb - mFwY2hhdC5yZXNlYXJjaC5nYm1sLkxvY2FsUmVzb3VyY2VDb25maWdCF+I/FBISbG9jYWxUcmFpbmVyQ29uZmlnSABSEmxvY2FsV - HJhaW5lckNvbmZpZxKnAQokdmVydGV4X2FpX2dyYXBoX3N0b3JlX3RyYWluZXJfY29uZmlnGAQgASgLMjAuc25hcGNoYXQucmVzZ - WFyY2guZ2JtbC5WZXJ0ZXhBaUdyYXBoU3RvcmVDb25maWdCJOI/IRIfdmVydGV4QWlHcmFwaFN0b3JlVHJhaW5lckNvbmZpZ0gAU - h92ZXJ0ZXhBaUdyYXBoU3RvcmVUcmFpbmVyQ29uZmlnQhAKDnRyYWluZXJfY29uZmlnIocFChhJbmZlcmVuY2VyUmVzb3VyY2VDb - 25maWcSjgEKG3ZlcnRleF9haV9pbmZlcmVuY2VyX2NvbmZpZxgBIAEoCzIuLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuVmVydGV4Q - WlSZXNvdXJjZUNvbmZpZ0Id4j8aEhh2ZXJ0ZXhBaUluZmVyZW5jZXJDb25maWdIAFIYdmVydGV4QWlJbmZlcmVuY2VyQ29uZmlnE - o0BChpkYXRhZmxvd19pbmZlcmVuY2VyX2NvbmZpZxgCIAEoCzIuLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuRGF0YWZsb3dSZXNvd - XJjZUNvbmZpZ0Id4j8aEhhkYXRhZmxvd0luZmVyZW5jZXJDb25maWdIAFIYZGF0YWZsb3dJbmZlcmVuY2VyQ29uZmlnEoEBChdsb - 2NhbF9pbmZlcmVuY2VyX2NvbmZpZxgDIAEoCzIrLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuTG9jYWxSZXNvdXJjZUNvbmZpZ0Ia4 - j8XEhVsb2NhbEluZmVyZW5jZXJDb25maWdIAFIVbG9jYWxJbmZlcmVuY2VyQ29uZmlnErABCid2ZXJ0ZXhfYWlfZ3JhcGhfc3Rvc - mVfaW5mZXJlbmNlcl9jb25maWcYBCABKAsyMC5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlZlcnRleEFpR3JhcGhTdG9yZUNvbmZpZ - 0In4j8kEiJ2ZXJ0ZXhBaUdyYXBoU3RvcmVJbmZlcmVuY2VyQ29uZmlnSABSInZlcnRleEFpR3JhcGhTdG9yZUluZmVyZW5jZXJDb - 25maWdCEwoRaW5mZXJlbmNlcl9jb25maWcilwgKFFNoYXJlZFJlc291cmNlQ29uZmlnEn4KD3Jlc291cmNlX2xhYmVscxgBIAMoC - zJALnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuU2hhcmVkUmVzb3VyY2VDb25maWcuUmVzb3VyY2VMYWJlbHNFbnRyeUIT4j8QEg5yZ - XNvdXJjZUxhYmVsc1IOcmVzb3VyY2VMYWJlbHMSjgEKFWNvbW1vbl9jb21wdXRlX2NvbmZpZxgCIAEoCzJALnNuYXBjaGF0LnJlc - 2VhcmNoLmdibWwuU2hhcmVkUmVzb3VyY2VDb25maWcuQ29tbW9uQ29tcHV0ZUNvbmZpZ0IY4j8VEhNjb21tb25Db21wdXRlQ29uZ - mlnUhNjb21tb25Db21wdXRlQ29uZmlnGpQFChNDb21tb25Db21wdXRlQ29uZmlnEiYKB3Byb2plY3QYASABKAlCDOI/CRIHcHJva - mVjdFIHcHJvamVjdBIjCgZyZWdpb24YAiABKAlCC+I/CBIGcmVnaW9uUgZyZWdpb24SQwoSdGVtcF9hc3NldHNfYnVja2V0GAMgA - SgJQhXiPxISEHRlbXBBc3NldHNCdWNrZXRSEHRlbXBBc3NldHNCdWNrZXQSXAobdGVtcF9yZWdpb25hbF9hc3NldHNfYnVja2V0G - AQgASgJQh3iPxoSGHRlbXBSZWdpb25hbEFzc2V0c0J1Y2tldFIYdGVtcFJlZ2lvbmFsQXNzZXRzQnVja2V0EkMKEnBlcm1fYXNzZ - XRzX2J1Y2tldBgFIAEoCUIV4j8SEhBwZXJtQXNzZXRzQnVja2V0UhBwZXJtQXNzZXRzQnVja2V0EloKG3RlbXBfYXNzZXRzX2JxX - 2RhdGFzZXRfbmFtZRgGIAEoCUIc4j8ZEhd0ZW1wQXNzZXRzQnFEYXRhc2V0TmFtZVIXdGVtcEFzc2V0c0JxRGF0YXNldE5hbWUSV - goZZW1iZWRkaW5nX2JxX2RhdGFzZXRfbmFtZRgHIAEoCUIb4j8YEhZlbWJlZGRpbmdCcURhdGFzZXROYW1lUhZlbWJlZGRpbmdCc - URhdGFzZXROYW1lElYKGWdjcF9zZXJ2aWNlX2FjY291bnRfZW1haWwYCCABKAlCG+I/GBIWZ2NwU2VydmljZUFjY291bnRFbWFpb - FIWZ2NwU2VydmljZUFjY291bnRFbWFpbBI8Cg9kYXRhZmxvd19ydW5uZXIYCyABKAlCE+I/EBIOZGF0YWZsb3dSdW5uZXJSDmRhd - GFmbG93UnVubmVyGlcKE1Jlc291cmNlTGFiZWxzRW50cnkSGgoDa2V5GAEgASgJQgjiPwUSA2tleVIDa2V5EiAKBXZhbHVlGAIgA - SgJQgriPwcSBXZhbHVlUgV2YWx1ZToCOAEi9wgKEkdpZ2xSZXNvdXJjZUNvbmZpZxJbChpzaGFyZWRfcmVzb3VyY2VfY29uZmlnX - 3VyaRgBIAEoCUIc4j8ZEhdzaGFyZWRSZXNvdXJjZUNvbmZpZ1VyaUgAUhdzaGFyZWRSZXNvdXJjZUNvbmZpZ1VyaRJ/ChZzaGFyZ - WRfcmVzb3VyY2VfY29uZmlnGAIgASgLMiwuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5TaGFyZWRSZXNvdXJjZUNvbmZpZ0IZ4j8WE - hRzaGFyZWRSZXNvdXJjZUNvbmZpZ0gAUhRzaGFyZWRSZXNvdXJjZUNvbmZpZxJ4ChNwcmVwcm9jZXNzb3JfY29uZmlnGAwgASgLM - i4uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5EYXRhUHJlcHJvY2Vzc29yQ29uZmlnQhfiPxQSEnByZXByb2Nlc3NvckNvbmZpZ1ISc - HJlcHJvY2Vzc29yQ29uZmlnEn8KF3N1YmdyYXBoX3NhbXBsZXJfY29uZmlnGA0gASgLMisuc25hcGNoYXQucmVzZWFyY2guZ2Jtb - C5TcGFya1Jlc291cmNlQ29uZmlnQhriPxcSFXN1YmdyYXBoU2FtcGxlckNvbmZpZ1IVc3ViZ3JhcGhTYW1wbGVyQ29uZmlnEnwKF - nNwbGl0X2dlbmVyYXRvcl9jb25maWcYDiABKAsyKy5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlNwYXJrUmVzb3VyY2VDb25maWdCG - eI/FhIUc3BsaXRHZW5lcmF0b3JDb25maWdSFHNwbGl0R2VuZXJhdG9yQ29uZmlnEm0KDnRyYWluZXJfY29uZmlnGA8gASgLMjAuc - 25hcGNoYXQucmVzZWFyY2guZ2JtbC5EaXN0cmlidXRlZFRyYWluZXJDb25maWdCFBgB4j8PEg10cmFpbmVyQ29uZmlnUg10cmFpb - mVyQ29uZmlnEnQKEWluZmVyZW5jZXJfY29uZmlnGBAgASgLMi4uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5EYXRhZmxvd1Jlc291c - mNlQ29uZmlnQhcYAeI/EhIQaW5mZXJlbmNlckNvbmZpZ1IQaW5mZXJlbmNlckNvbmZpZxKBAQoXdHJhaW5lcl9yZXNvdXJjZV9jb - 25maWcYESABKAsyLS5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlRyYWluZXJSZXNvdXJjZUNvbmZpZ0Ia4j8XEhV0cmFpbmVyUmVzb - 3VyY2VDb25maWdSFXRyYWluZXJSZXNvdXJjZUNvbmZpZxKNAQoaaW5mZXJlbmNlcl9yZXNvdXJjZV9jb25maWcYEiABKAsyMC5zb - mFwY2hhdC5yZXNlYXJjaC5nYm1sLkluZmVyZW5jZXJSZXNvdXJjZUNvbmZpZ0Id4j8aEhhpbmZlcmVuY2VyUmVzb3VyY2VDb25ma - WdSGGluZmVyZW5jZXJSZXNvdXJjZUNvbmZpZ0IRCg9zaGFyZWRfcmVzb3VyY2Uq4wMKCUNvbXBvbmVudBItChFDb21wb25lbnRfV - W5rbm93bhAAGhbiPxMSEUNvbXBvbmVudF9Vbmtub3duEj8KGkNvbXBvbmVudF9Db25maWdfVmFsaWRhdG9yEAEaH+I/HBIaQ29tc - G9uZW50X0NvbmZpZ19WYWxpZGF0b3ISPwoaQ29tcG9uZW50X0NvbmZpZ19Qb3B1bGF0b3IQAhof4j8cEhpDb21wb25lbnRfQ29uZ - mlnX1BvcHVsYXRvchJBChtDb21wb25lbnRfRGF0YV9QcmVwcm9jZXNzb3IQAxog4j8dEhtDb21wb25lbnRfRGF0YV9QcmVwcm9jZ - XNzb3ISPwoaQ29tcG9uZW50X1N1YmdyYXBoX1NhbXBsZXIQBBof4j8cEhpDb21wb25lbnRfU3ViZ3JhcGhfU2FtcGxlchI9ChlDb - 21wb25lbnRfU3BsaXRfR2VuZXJhdG9yEAUaHuI/GxIZQ29tcG9uZW50X1NwbGl0X0dlbmVyYXRvchItChFDb21wb25lbnRfVHJha - W5lchAGGhbiPxMSEUNvbXBvbmVudF9UcmFpbmVyEjMKFENvbXBvbmVudF9JbmZlcmVuY2VyEAcaGeI/FhIUQ29tcG9uZW50X0luZ - mVyZW5jZXJiBnByb3RvMw==""" + 29tcHV0ZVBvb2xSC2NvbXB1dGVQb29sEmwKIW51bV9wcm9jZXNzZXNfcGVyX3N0b3JhZ2VfbWFjaGluZRgDIAEoBUIi4j8fEh1ud + W1Qcm9jZXNzZXNQZXJTdG9yYWdlTWFjaGluZVIdbnVtUHJvY2Vzc2VzUGVyU3RvcmFnZU1hY2hpbmUSbAohbnVtX3Byb2Nlc3Nlc + 19wZXJfY29tcHV0ZV9tYWNoaW5lGAQgASgFQiLiPx8SHW51bVByb2Nlc3Nlc1BlckNvbXB1dGVNYWNoaW5lUh1udW1Qcm9jZXNzZ + XNQZXJDb21wdXRlTWFjaGluZSKdAwoYRGlzdHJpYnV0ZWRUcmFpbmVyQ29uZmlnEoQBChh2ZXJ0ZXhfYWlfdHJhaW5lcl9jb25ma + WcYASABKAsyLS5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlZlcnRleEFpVHJhaW5lckNvbmZpZ0Ia4j8XEhV2ZXJ0ZXhBaVRyYWluZ + XJDb25maWdIAFIVdmVydGV4QWlUcmFpbmVyQ29uZmlnEm8KEmtmcF90cmFpbmVyX2NvbmZpZxgCIAEoCzIoLnNuYXBjaGF0LnJlc + 2VhcmNoLmdibWwuS0ZQVHJhaW5lckNvbmZpZ0IV4j8SEhBrZnBUcmFpbmVyQ29uZmlnSABSEGtmcFRyYWluZXJDb25maWcSdwoUb + G9jYWxfdHJhaW5lcl9jb25maWcYAyABKAsyKi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkxvY2FsVHJhaW5lckNvbmZpZ0IX4j8UE + hJsb2NhbFRyYWluZXJDb25maWdIAFISbG9jYWxUcmFpbmVyQ29uZmlnQhAKDnRyYWluZXJfY29uZmlnIscEChVUcmFpbmVyUmVzb + 3VyY2VDb25maWcShQEKGHZlcnRleF9haV90cmFpbmVyX2NvbmZpZxgBIAEoCzIuLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuVmVyd + GV4QWlSZXNvdXJjZUNvbmZpZ0Ia4j8XEhV2ZXJ0ZXhBaVRyYWluZXJDb25maWdIAFIVdmVydGV4QWlUcmFpbmVyQ29uZmlnEnAKE + mtmcF90cmFpbmVyX2NvbmZpZxgCIAEoCzIpLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuS0ZQUmVzb3VyY2VDb25maWdCFeI/EhIQa + 2ZwVHJhaW5lckNvbmZpZ0gAUhBrZnBUcmFpbmVyQ29uZmlnEngKFGxvY2FsX3RyYWluZXJfY29uZmlnGAMgASgLMisuc25hcGNoY + XQucmVzZWFyY2guZ2JtbC5Mb2NhbFJlc291cmNlQ29uZmlnQhfiPxQSEmxvY2FsVHJhaW5lckNvbmZpZ0gAUhJsb2NhbFRyYWluZ + XJDb25maWcSpwEKJHZlcnRleF9haV9ncmFwaF9zdG9yZV90cmFpbmVyX2NvbmZpZxgEIAEoCzIwLnNuYXBjaGF0LnJlc2VhcmNoL + mdibWwuVmVydGV4QWlHcmFwaFN0b3JlQ29uZmlnQiTiPyESH3ZlcnRleEFpR3JhcGhTdG9yZVRyYWluZXJDb25maWdIAFIfdmVyd + GV4QWlHcmFwaFN0b3JlVHJhaW5lckNvbmZpZ0IQCg50cmFpbmVyX2NvbmZpZyKHBQoYSW5mZXJlbmNlclJlc291cmNlQ29uZmlnE + o4BCht2ZXJ0ZXhfYWlfaW5mZXJlbmNlcl9jb25maWcYASABKAsyLi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlZlcnRleEFpUmVzb + 3VyY2VDb25maWdCHeI/GhIYdmVydGV4QWlJbmZlcmVuY2VyQ29uZmlnSABSGHZlcnRleEFpSW5mZXJlbmNlckNvbmZpZxKNAQoaZ + GF0YWZsb3dfaW5mZXJlbmNlcl9jb25maWcYAiABKAsyLi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkRhdGFmbG93UmVzb3VyY2VDb + 25maWdCHeI/GhIYZGF0YWZsb3dJbmZlcmVuY2VyQ29uZmlnSABSGGRhdGFmbG93SW5mZXJlbmNlckNvbmZpZxKBAQoXbG9jYWxfa + W5mZXJlbmNlcl9jb25maWcYAyABKAsyKy5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkxvY2FsUmVzb3VyY2VDb25maWdCGuI/FxIVb + G9jYWxJbmZlcmVuY2VyQ29uZmlnSABSFWxvY2FsSW5mZXJlbmNlckNvbmZpZxKwAQondmVydGV4X2FpX2dyYXBoX3N0b3JlX2luZ + mVyZW5jZXJfY29uZmlnGAQgASgLMjAuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5WZXJ0ZXhBaUdyYXBoU3RvcmVDb25maWdCJ+I/J + BIidmVydGV4QWlHcmFwaFN0b3JlSW5mZXJlbmNlckNvbmZpZ0gAUiJ2ZXJ0ZXhBaUdyYXBoU3RvcmVJbmZlcmVuY2VyQ29uZmlnQ + hMKEWluZmVyZW5jZXJfY29uZmlnIpcIChRTaGFyZWRSZXNvdXJjZUNvbmZpZxJ+Cg9yZXNvdXJjZV9sYWJlbHMYASADKAsyQC5zb + mFwY2hhdC5yZXNlYXJjaC5nYm1sLlNoYXJlZFJlc291cmNlQ29uZmlnLlJlc291cmNlTGFiZWxzRW50cnlCE+I/EBIOcmVzb3VyY + 2VMYWJlbHNSDnJlc291cmNlTGFiZWxzEo4BChVjb21tb25fY29tcHV0ZV9jb25maWcYAiABKAsyQC5zbmFwY2hhdC5yZXNlYXJja + C5nYm1sLlNoYXJlZFJlc291cmNlQ29uZmlnLkNvbW1vbkNvbXB1dGVDb25maWdCGOI/FRITY29tbW9uQ29tcHV0ZUNvbmZpZ1ITY + 29tbW9uQ29tcHV0ZUNvbmZpZxqUBQoTQ29tbW9uQ29tcHV0ZUNvbmZpZxImCgdwcm9qZWN0GAEgASgJQgziPwkSB3Byb2plY3RSB + 3Byb2plY3QSIwoGcmVnaW9uGAIgASgJQgviPwgSBnJlZ2lvblIGcmVnaW9uEkMKEnRlbXBfYXNzZXRzX2J1Y2tldBgDIAEoCUIV4 + j8SEhB0ZW1wQXNzZXRzQnVja2V0UhB0ZW1wQXNzZXRzQnVja2V0ElwKG3RlbXBfcmVnaW9uYWxfYXNzZXRzX2J1Y2tldBgEIAEoC + UId4j8aEhh0ZW1wUmVnaW9uYWxBc3NldHNCdWNrZXRSGHRlbXBSZWdpb25hbEFzc2V0c0J1Y2tldBJDChJwZXJtX2Fzc2V0c19id + WNrZXQYBSABKAlCFeI/EhIQcGVybUFzc2V0c0J1Y2tldFIQcGVybUFzc2V0c0J1Y2tldBJaCht0ZW1wX2Fzc2V0c19icV9kYXRhc + 2V0X25hbWUYBiABKAlCHOI/GRIXdGVtcEFzc2V0c0JxRGF0YXNldE5hbWVSF3RlbXBBc3NldHNCcURhdGFzZXROYW1lElYKGWVtY + mVkZGluZ19icV9kYXRhc2V0X25hbWUYByABKAlCG+I/GBIWZW1iZWRkaW5nQnFEYXRhc2V0TmFtZVIWZW1iZWRkaW5nQnFEYXRhc + 2V0TmFtZRJWChlnY3Bfc2VydmljZV9hY2NvdW50X2VtYWlsGAggASgJQhviPxgSFmdjcFNlcnZpY2VBY2NvdW50RW1haWxSFmdjc + FNlcnZpY2VBY2NvdW50RW1haWwSPAoPZGF0YWZsb3dfcnVubmVyGAsgASgJQhPiPxASDmRhdGFmbG93UnVubmVyUg5kYXRhZmxvd + 1J1bm5lchpXChNSZXNvdXJjZUxhYmVsc0VudHJ5EhoKA2tleRgBIAEoCUII4j8FEgNrZXlSA2tleRIgCgV2YWx1ZRgCIAEoCUIK4 + j8HEgV2YWx1ZVIFdmFsdWU6AjgBIvcIChJHaWdsUmVzb3VyY2VDb25maWcSWwoac2hhcmVkX3Jlc291cmNlX2NvbmZpZ191cmkYA + SABKAlCHOI/GRIXc2hhcmVkUmVzb3VyY2VDb25maWdVcmlIAFIXc2hhcmVkUmVzb3VyY2VDb25maWdVcmkSfwoWc2hhcmVkX3Jlc + 291cmNlX2NvbmZpZxgCIAEoCzIsLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuU2hhcmVkUmVzb3VyY2VDb25maWdCGeI/FhIUc2hhc + mVkUmVzb3VyY2VDb25maWdIAFIUc2hhcmVkUmVzb3VyY2VDb25maWcSeAoTcHJlcHJvY2Vzc29yX2NvbmZpZxgMIAEoCzIuLnNuY + XBjaGF0LnJlc2VhcmNoLmdibWwuRGF0YVByZXByb2Nlc3NvckNvbmZpZ0IX4j8UEhJwcmVwcm9jZXNzb3JDb25maWdSEnByZXByb + 2Nlc3NvckNvbmZpZxJ/ChdzdWJncmFwaF9zYW1wbGVyX2NvbmZpZxgNIAEoCzIrLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuU3Bhc + mtSZXNvdXJjZUNvbmZpZ0Ia4j8XEhVzdWJncmFwaFNhbXBsZXJDb25maWdSFXN1YmdyYXBoU2FtcGxlckNvbmZpZxJ8ChZzcGxpd + F9nZW5lcmF0b3JfY29uZmlnGA4gASgLMisuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5TcGFya1Jlc291cmNlQ29uZmlnQhniPxYSF + HNwbGl0R2VuZXJhdG9yQ29uZmlnUhRzcGxpdEdlbmVyYXRvckNvbmZpZxJtCg50cmFpbmVyX2NvbmZpZxgPIAEoCzIwLnNuYXBja + GF0LnJlc2VhcmNoLmdibWwuRGlzdHJpYnV0ZWRUcmFpbmVyQ29uZmlnQhQYAeI/DxINdHJhaW5lckNvbmZpZ1INdHJhaW5lckNvb + mZpZxJ0ChFpbmZlcmVuY2VyX2NvbmZpZxgQIAEoCzIuLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuRGF0YWZsb3dSZXNvdXJjZUNvb + mZpZ0IXGAHiPxISEGluZmVyZW5jZXJDb25maWdSEGluZmVyZW5jZXJDb25maWcSgQEKF3RyYWluZXJfcmVzb3VyY2VfY29uZmlnG + BEgASgLMi0uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5UcmFpbmVyUmVzb3VyY2VDb25maWdCGuI/FxIVdHJhaW5lclJlc291cmNlQ + 29uZmlnUhV0cmFpbmVyUmVzb3VyY2VDb25maWcSjQEKGmluZmVyZW5jZXJfcmVzb3VyY2VfY29uZmlnGBIgASgLMjAuc25hcGNoY + XQucmVzZWFyY2guZ2JtbC5JbmZlcmVuY2VyUmVzb3VyY2VDb25maWdCHeI/GhIYaW5mZXJlbmNlclJlc291cmNlQ29uZmlnUhhpb + mZlcmVuY2VyUmVzb3VyY2VDb25maWdCEQoPc2hhcmVkX3Jlc291cmNlKuMDCglDb21wb25lbnQSLQoRQ29tcG9uZW50X1Vua25vd + 24QABoW4j8TEhFDb21wb25lbnRfVW5rbm93bhI/ChpDb21wb25lbnRfQ29uZmlnX1ZhbGlkYXRvchABGh/iPxwSGkNvbXBvbmVud + F9Db25maWdfVmFsaWRhdG9yEj8KGkNvbXBvbmVudF9Db25maWdfUG9wdWxhdG9yEAIaH+I/HBIaQ29tcG9uZW50X0NvbmZpZ19Qb + 3B1bGF0b3ISQQobQ29tcG9uZW50X0RhdGFfUHJlcHJvY2Vzc29yEAMaIOI/HRIbQ29tcG9uZW50X0RhdGFfUHJlcHJvY2Vzc29yE + j8KGkNvbXBvbmVudF9TdWJncmFwaF9TYW1wbGVyEAQaH+I/HBIaQ29tcG9uZW50X1N1YmdyYXBoX1NhbXBsZXISPQoZQ29tcG9uZ + W50X1NwbGl0X0dlbmVyYXRvchAFGh7iPxsSGUNvbXBvbmVudF9TcGxpdF9HZW5lcmF0b3ISLQoRQ29tcG9uZW50X1RyYWluZXIQB + hoW4j8TEhFDb21wb25lbnRfVHJhaW5lchIzChRDb21wb25lbnRfSW5mZXJlbmNlchAHGhniPxYSFENvbXBvbmVudF9JbmZlcmVuY + 2VyYgZwcm90bzM=""" ).mkString) lazy val scalaDescriptor: _root_.scalapb.descriptors.FileDescriptor = { val scalaProto = com.google.protobuf.descriptor.FileDescriptorProto.parseFrom(ProtoBytes) diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiGraphStoreConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiGraphStoreConfig.scala index c07f1a3cb..30ded01c0 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiGraphStoreConfig.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiGraphStoreConfig.scala @@ -16,6 +16,8 @@ package snapchat.research.gbml.gigl_resource_config final case class VertexAiGraphStoreConfig( graphStorePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None, computePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None, + numProcessesPerStorageMachine: _root_.scala.Int = 0, + numProcessesPerComputeMachine: _root_.scala.Int = 0, unknownFields: _root_.scalapb.UnknownFieldSet = _root_.scalapb.UnknownFieldSet.empty ) extends scalapb.GeneratedMessage with scalapb.lenses.Updatable[VertexAiGraphStoreConfig] { @transient @@ -30,6 +32,20 @@ final case class VertexAiGraphStoreConfig( val __value = computePool.get __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize }; + + { + val __value = numProcessesPerStorageMachine + if (__value != 0) { + __size += _root_.com.google.protobuf.CodedOutputStream.computeInt32Size(3, __value) + } + }; + + { + val __value = numProcessesPerComputeMachine + if (__value != 0) { + __size += _root_.com.google.protobuf.CodedOutputStream.computeInt32Size(4, __value) + } + }; __size += unknownFields.serializedSize __size } @@ -55,6 +71,18 @@ final case class VertexAiGraphStoreConfig( _output__.writeUInt32NoTag(__m.serializedSize) __m.writeTo(_output__) }; + { + val __v = numProcessesPerStorageMachine + if (__v != 0) { + _output__.writeInt32(3, __v) + } + }; + { + val __v = numProcessesPerComputeMachine + if (__v != 0) { + _output__.writeInt32(4, __v) + } + }; unknownFields.writeTo(_output__) } def getGraphStorePool: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig = graphStorePool.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig.defaultInstance) @@ -63,12 +91,22 @@ final case class VertexAiGraphStoreConfig( def getComputePool: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig = computePool.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig.defaultInstance) def clearComputePool: VertexAiGraphStoreConfig = copy(computePool = _root_.scala.None) def withComputePool(__v: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig): VertexAiGraphStoreConfig = copy(computePool = Option(__v)) + def withNumProcessesPerStorageMachine(__v: _root_.scala.Int): VertexAiGraphStoreConfig = copy(numProcessesPerStorageMachine = __v) + def withNumProcessesPerComputeMachine(__v: _root_.scala.Int): VertexAiGraphStoreConfig = copy(numProcessesPerComputeMachine = __v) def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) def discardUnknownFields = copy(unknownFields = _root_.scalapb.UnknownFieldSet.empty) def getFieldByNumber(__fieldNumber: _root_.scala.Int): _root_.scala.Any = { (__fieldNumber: @_root_.scala.unchecked) match { case 1 => graphStorePool.orNull case 2 => computePool.orNull + case 3 => { + val __t = numProcessesPerStorageMachine + if (__t != 0) __t else null + } + case 4 => { + val __t = numProcessesPerComputeMachine + if (__t != 0) __t else null + } } } def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { @@ -76,6 +114,8 @@ final case class VertexAiGraphStoreConfig( (__field.number: @_root_.scala.unchecked) match { case 1 => graphStorePool.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 2 => computePool.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) + case 3 => _root_.scalapb.descriptors.PInt(numProcessesPerStorageMachine) + case 4 => _root_.scalapb.descriptors.PInt(numProcessesPerComputeMachine) } } def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) @@ -88,6 +128,8 @@ object VertexAiGraphStoreConfig extends scalapb.GeneratedMessageCompanion[snapch def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig = { var __graphStorePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None var __computePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None + var __numProcessesPerStorageMachine: _root_.scala.Int = 0 + var __numProcessesPerComputeMachine: _root_.scala.Int = 0 var `_unknownFields__`: _root_.scalapb.UnknownFieldSet.Builder = null var _done__ = false while (!_done__) { @@ -98,6 +140,10 @@ object VertexAiGraphStoreConfig extends scalapb.GeneratedMessageCompanion[snapch __graphStorePool = Option(__graphStorePool.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case 18 => __computePool = Option(__computePool.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) + case 24 => + __numProcessesPerStorageMachine = _input__.readInt32() + case 32 => + __numProcessesPerComputeMachine = _input__.readInt32() case tag => if (_unknownFields__ == null) { _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() @@ -108,6 +154,8 @@ object VertexAiGraphStoreConfig extends scalapb.GeneratedMessageCompanion[snapch snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig( graphStorePool = __graphStorePool, computePool = __computePool, + numProcessesPerStorageMachine = __numProcessesPerStorageMachine, + numProcessesPerComputeMachine = __numProcessesPerComputeMachine, unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result() ) } @@ -116,7 +164,9 @@ object VertexAiGraphStoreConfig extends scalapb.GeneratedMessageCompanion[snapch _root_.scala.Predef.require(__fieldsMap.keys.forall(_.containingMessage eq scalaDescriptor), "FieldDescriptor does not match message type.") snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig( graphStorePool = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]]), - computePool = __fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]]) + computePool = __fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]]), + numProcessesPerStorageMachine = __fieldsMap.get(scalaDescriptor.findFieldByNumber(3).get).map(_.as[_root_.scala.Int]).getOrElse(0), + numProcessesPerComputeMachine = __fieldsMap.get(scalaDescriptor.findFieldByNumber(4).get).map(_.as[_root_.scala.Int]).getOrElse(0) ) case _ => throw new RuntimeException("Expected PMessage") } @@ -134,22 +184,32 @@ object VertexAiGraphStoreConfig extends scalapb.GeneratedMessageCompanion[snapch def enumCompanionForFieldNumber(__fieldNumber: _root_.scala.Int): _root_.scalapb.GeneratedEnumCompanion[_] = throw new MatchError(__fieldNumber) lazy val defaultInstance = snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig( graphStorePool = _root_.scala.None, - computePool = _root_.scala.None + computePool = _root_.scala.None, + numProcessesPerStorageMachine = 0, + numProcessesPerComputeMachine = 0 ) implicit class VertexAiGraphStoreConfigLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig](_l) { def graphStorePool: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = field(_.getGraphStorePool)((c_, f_) => c_.copy(graphStorePool = Option(f_))) def optionalGraphStorePool: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]] = field(_.graphStorePool)((c_, f_) => c_.copy(graphStorePool = f_)) def computePool: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = field(_.getComputePool)((c_, f_) => c_.copy(computePool = Option(f_))) def optionalComputePool: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]] = field(_.computePool)((c_, f_) => c_.copy(computePool = f_)) + def numProcessesPerStorageMachine: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Int] = field(_.numProcessesPerStorageMachine)((c_, f_) => c_.copy(numProcessesPerStorageMachine = f_)) + def numProcessesPerComputeMachine: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Int] = field(_.numProcessesPerComputeMachine)((c_, f_) => c_.copy(numProcessesPerComputeMachine = f_)) } final val GRAPH_STORE_POOL_FIELD_NUMBER = 1 final val COMPUTE_POOL_FIELD_NUMBER = 2 + final val NUM_PROCESSES_PER_STORAGE_MACHINE_FIELD_NUMBER = 3 + final val NUM_PROCESSES_PER_COMPUTE_MACHINE_FIELD_NUMBER = 4 def of( graphStorePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig], - computePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] + computePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig], + numProcessesPerStorageMachine: _root_.scala.Int, + numProcessesPerComputeMachine: _root_.scala.Int ): _root_.snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig = _root_.snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig( graphStorePool, - computePool + computePool, + numProcessesPerStorageMachine, + numProcessesPerComputeMachine ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.VertexAiGraphStoreConfig]) }