diff --git a/python/gigl/distributed/dist_ablp_neighborloader.py b/python/gigl/distributed/dist_ablp_neighborloader.py index 5ae260c39..24785f814 100644 --- a/python/gigl/distributed/dist_ablp_neighborloader.py +++ b/python/gigl/distributed/dist_ablp_neighborloader.py @@ -280,6 +280,8 @@ def __init__( anchor_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE supervision_edge_type = DEFAULT_HOMOGENEOUS_EDGE_TYPE supervision_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE + logger.info(f"local rank: {local_rank}, node rank: {node_rank}, anchor node type: {anchor_node_type}, " + f"supervision edge type: {supervision_edge_type}, supervision node type: {supervision_node_type}") missing_edge_types = set([supervision_edge_type]) - set(dataset.graph.keys()) if missing_edge_types: @@ -296,6 +298,7 @@ def __init__( self._negative_label_edge_type, ) = select_label_edge_types(supervision_edge_type, dataset.graph.keys()) self._supervision_edge_type = supervision_edge_type + logger.info(f"Local rank: {local_rank}, node rank: {node_rank}, supervision edge type: {supervision_edge_type}") positive_labels, negative_labels = get_labels_for_anchor_nodes( dataset=dataset, @@ -303,6 +306,7 @@ def __init__( positive_label_edge_type=self._positive_label_edge_type, negative_label_edge_type=self._negative_label_edge_type, ) + logger.info(f"Local rank: {local_rank}, node rank: {node_rank}, Got labels for anchor nodes") self.to_device = ( pin_memory_device @@ -315,12 +319,14 @@ def __init__( num_neighbors = patch_fanout_for_sampling( dataset.get_edge_types(), num_neighbors ) + logger.info(f"Local rank: {local_rank}, node rank: {node_rank}, Number of neighbors: {num_neighbors}") curr_process_nodes = shard_nodes_by_process( input_nodes=anchor_node_ids, local_process_rank=local_rank, local_process_world_size=local_world_size, ) + logger.info(f"local rank: {local_rank}, node rank: {node_rank}, current process nodes: {curr_process_nodes}`") self._node_feature_info = dataset.node_feature_info self._edge_feature_info = dataset.edge_feature_info diff --git a/python/gigl/utils/data_splitters.py b/python/gigl/utils/data_splitters.py index 00ad8c6e4..c02b9d3a5 100644 --- a/python/gigl/utils/data_splitters.py +++ b/python/gigl/utils/data_splitters.py @@ -26,11 +26,18 @@ message_passing_to_positive_label, reverse_edge_type, ) +import psutil +import os logger = Logger() PADDING_NODE: Final[torch.Tensor] = torch.tensor(-1, dtype=torch.int64) +def _debug_memory_usage(prefix: str): + process = psutil.Process(os.getpid()) + memory_info = process.memory_info() + logger.info(f"{prefix} Memory usage: {memory_info.rss / 1024 / 1024:.2f} MB (out of {psutil.virtual_memory().total / 1024 / 1024:.2f} MB)") + class NodeAnchorLinkSplitter(Protocol): """Protocol that should be satisfied for anything that is used to split on edges. @@ -648,6 +655,8 @@ def _get_padded_labels( # and indices is the COL_INDEX of a CSR matrix. # See https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format) # Note that GLT defaults to CSR under the hood, if this changes, we will need to update this. + _debug_memory_usage("Before indptr and indices") + indptr = topo.indptr # [N] indices = topo.indices # [M] extra_nodes_to_pad = 0 @@ -657,20 +666,42 @@ def _get_padded_labels( anchor_node_ids = anchor_node_ids[valid_ids] starts = indptr[anchor_node_ids] # [N] ends = indptr[anchor_node_ids + 1] # [N] + _debug_memory_usage("After starts and ends") max_range = int(torch.max(ends - starts).item()) + logger.info(f"max range {max_range}") + mask = torch.arange(max_range) >= (ends - starts).unsqueeze(1) + max_end_value = ends.max().item() + _debug_memory_usage("After max_end_value") + # del ends + gc.collect() + _debug_memory_usage("After ends gc") + + logger.info(f"Local Rank {torch.distributed.get_rank() % torch.distributed.get_world_size()}, " + f"Node Rank {torch.distributed.get_rank() // torch.distributed.get_world_size()}: " + f"Get padded labels") # Sample all labels based on the CSR start/stop indices. # Creates "indices" for us to us, e.g [[0, 1], [2, 3]] ranges = starts.unsqueeze(1) + torch.arange(max_range) # [N, max_range] + _debug_memory_usage("After ranges") + del starts + gc.collect() + _debug_memory_usage("After starts gc") + # Clamp the ranges to be valid indices into `indices`. ranges.clamp_(min=0, max=ends.max().item() - 1) + _debug_memory_usage("After clamp") # Mask out the parts of "ranges" that are not applicable to the current label # filling out the rest with `PADDING_NODE`. - mask = torch.arange(max_range) >= (ends - starts).unsqueeze(1) + # mask = torch.arange(max_range) >= (ends - starts).unsqueeze(1) labels = torch.where( mask, torch.full_like(ranges, PADDING_NODE.item()), indices[ranges] ) + _debug_memory_usage("After labels") + del ranges + gc.collect() + _debug_memory_usage("After ranges gc") labels = torch.cat( [ labels, @@ -678,6 +709,7 @@ def _get_padded_labels( ], dim=0, ) + _debug_memory_usage("After cat") return labels