diff --git a/Ironwood/src/benchmark_collectives.py b/Ironwood/src/benchmark_collectives.py index 30e053e..69b4d21 100644 --- a/Ironwood/src/benchmark_collectives.py +++ b/Ironwood/src/benchmark_collectives.py @@ -5,6 +5,7 @@ import os from typing import Any, Dict +from benchmark_utils import find_sparsecore_usage_from_xplane from benchmark_utils import get_lhs_named_shading from benchmark_utils import get_out_sharding from benchmark_utils import MetricsStatistics @@ -26,7 +27,7 @@ SEED = 0 GLOBAL_SHARDING_STRATEGY = ShardingStrategy.NO_SHARDING GLOBAL_PSTATE = 7 - +LOG_SPARSECORE_USAGE = False def create_mesh(ici_size: int, mesh_shape: str) -> Mesh: """Creates a mesh with the given ICI size.""" @@ -85,6 +86,7 @@ def unified_ici_collectives_metrics( ici_average_time_ms_list: list[float], iteration: int, op_type: str, + trace_dir: str = None, ) -> Dict[str, Any]: """Calculates the metrics for the ICI collectives benchmark.""" @@ -147,8 +149,15 @@ def unified_ici_collectives_metrics( / rank ) - + + sparsecore_used = "NA" + if LOG_SPARSECORE_USAGE: + print("trace_dir: ", trace_dir) + if trace_dir: + sparsecore_used = find_sparsecore_usage_from_xplane(trace_dir) + print("sparsecore_used: ", sparsecore_used) print("hlo first replica group: ", hlo_first_replica_group) + metadata = { "iteration": iteration, "op_type": op_type, @@ -164,6 +173,7 @@ def unified_ici_collectives_metrics( "hlo_input_shape": json.dumps(hlo_input_shape), "hlo_output_shape": json.dumps(hlo_output_shape), "hlo_replica_groups": json.dumps(hlo_replica_groups), + "sparsecore_used": sparsecore_used, } achieved_bw = [transferred_data*1000/my_time for my_time in ici_average_time_ms_list] achieved_bw_statistics = MetricsStatistics( @@ -294,6 +304,7 @@ def data_generator(): "ici_average_time_ms_list": time_ms_list, "matrix_shape": (m, n, k), "op_type": "AR", + "trace_dir": trace_dir, } @@ -308,6 +319,7 @@ def psum_benchmark_calculate_metrics( matrix_shape: tuple[int, int, int], xla_output: str, op_type: str, + trace_dir: str, ) -> Dict[str, Any]: """Calculates the metrics for the psum benchmark.""" # Build dictionary of all the parameters in the function @@ -322,8 +334,10 @@ def psum_benchmark_calculate_metrics( ici_average_time_ms_list, matrix_dim, op_type, + trace_dir, ) + def psum_scatter_benchmark( matrix_dim: int, dtype: jnp.dtype, @@ -382,7 +396,7 @@ def f(x): ) sharding_strategy_tuple = tuple(map(int, sharding_strategy.split("x"))) op_dimension_tuple_multiplier = math.prod(sharding_strategy_tuple) - m = op_dimension_tuple_multiplier * 2 + m = op_dimension_tuple_multiplier n = matrix_dim k = 256 @@ -405,6 +419,7 @@ def data_generator(): "ici_average_time_ms_list": time_ms_list, "matrix_shape": (m, n, k), "op_type": "RS", + "trace_dir": trace_dir, } @@ -419,6 +434,7 @@ def psum_scatter_benchmark_calculate_metrics( matrix_shape: tuple[int, int, int], xla_output: str, op_type: str, + trace_dir: str, ) -> Dict[str, Any]: """Calculates the metrics for the psum_scatter benchmark.""" # Build dictionary of all the parameters in the function @@ -433,6 +449,7 @@ def psum_scatter_benchmark_calculate_metrics( ici_average_time_ms_list, matrix_dim, op_type, + trace_dir, ) def all_gather_benchmark( @@ -473,7 +490,9 @@ def all_gather_benchmark( "--xla_tpu_use_single_sparse_core_for_all_gather_offload=true", "--xla_tpu_use_tc_device_shape_on_sc=true", f"--xla_tpu_dvfs_p_state={GLOBAL_PSTATE}", + "--xla_tpu_scoped_vmem_limit_kib=65536", ] + # libtpu_init_args=[ ] os.environ["LIBTPU_INIT_ARGS"] = " ".join(libtpu_init_args) mesh = create_mesh(ici_size, mesh_shape) @@ -513,7 +532,8 @@ def data_generator(): return { "ici_average_time_ms_list": time_ms_list, "matrix_shape": (m, n, k), - "op_type": "AG" + "op_type": "AG", + "trace_dir": trace_dir, } @@ -528,6 +548,7 @@ def all_gather_benchmark_calculate_metrics( matrix_shape: tuple[int, int, int], xla_output: str, op_type: str, + trace_dir: str, ) -> Dict[str, Any]: """Calculates the metrics for the all_gather benchmark.""" # Build dictionary of all the parameters in the function @@ -542,6 +563,7 @@ def all_gather_benchmark_calculate_metrics( ici_average_time_ms_list, matrix_dim, op_type, + trace_dir, ) @@ -621,6 +643,7 @@ def data_generator(): "ici_average_time_ms_list": time_ms_list, "matrix_shape": (m, n, k), "op_type": "A2A", + "trace_dir": trace_dir, } @@ -635,6 +658,7 @@ def all_to_all_benchmark_calculate_metrics( matrix_shape: tuple[int, int, int], xla_output: str, op_type: str, + trace_dir: str, ) -> Dict[str, Any]: """Calculates the metrics for the all_to_all benchmark.""" # Build dictionary of all the parameters in the function @@ -649,5 +673,6 @@ def all_to_all_benchmark_calculate_metrics( ici_average_time_ms_list, matrix_dim, op_type, + trace_dir, ) diff --git a/Ironwood/src/benchmark_send_recv.py b/Ironwood/src/benchmark_send_recv.py index 72eb2f1..9095000 100644 --- a/Ironwood/src/benchmark_send_recv.py +++ b/Ironwood/src/benchmark_send_recv.py @@ -1,13 +1,12 @@ """Benchmarking p2p source target transfer.""" import os -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, Tuple import jax from jax.experimental import mesh_utils import jax.numpy as jnp import jax.sharding from benchmark_utils import ( - MetricsStatistics, get_trace, ) from common import MARKER diff --git a/Ironwood/src/benchmark_utils.py b/Ironwood/src/benchmark_utils.py index 475d73b..19a6642 100644 --- a/Ironwood/src/benchmark_utils.py +++ b/Ironwood/src/benchmark_utils.py @@ -24,6 +24,9 @@ from jax.sharding import Mesh from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P +import gc +import jax.extend +from tensorflow.tsl.profiler.protobuf import xplane_pb2 # The dictionary to map a JAX (collective) function to its main HLO. TARGET_TASK_NAME_COLLECTIVES_MAP = { @@ -121,6 +124,13 @@ def multiple_iteration_timeit_from_trace_throttling( return multiple_iteration_get_metrics_from_trace(trace) +def clear_jax_memory(): + backend = jax.extend.backend.get_backend() + for buf in backend.live_buffers(): + buf.delete() + gc.collect() + + def multiple_iteration_timeit_from_trace( compute_func: Callable, data_generator: Callable, @@ -153,11 +163,13 @@ def multiple_iteration_timeit_from_trace( print(f"[{task}] Running iteration {i} of {tries} with {matrix_dim}...") data_args = data_generator() jax.devices() + with jax.profiler.StepTraceAnnotation(task, step_num=i): with jax.named_scope(f"{MARKER}_{i}"): + result = compute_func(*data_args) jax.block_until_ready(result) - + clear_jax_memory() trace = get_trace(tmp_trace_dir) if trace_full_dir != tmp_trace_dir: @@ -502,6 +514,65 @@ def get_trace(log_dir: str) -> dict[str, Any]: return trace +def find_sparsecore_usage_from_xplane(log_dir: str) -> xplane_pb2.XSpace: + """Extract the XSpace object from the log directory. + + Returns: + An XSpace protobuf object. + """ + print("find_sparsecore_usage_from_xplane: ", log_dir) + + # Handle partial log_dir + if not (pathlib.Path(log_dir) / "plugins" / "profile").exists(): + potential_dirs = glob.glob(f"{log_dir}*") + potential_dirs = [d for d in potential_dirs if os.path.isdir(d)] + potential_dirs.sort(key=os.path.getmtime, reverse=True) + + for d in potential_dirs: + d_path = pathlib.Path(d) + if (d_path / "plugins" / "profile").exists(): + log_dir = d + print(f"Updated log_dir to match partial path: {log_dir}") + break + + # Check subdirectories recursively + candidates = list(d_path.glob("**/plugins/profile")) + if candidates: + latest = max(candidates, key=lambda p: p.stat().st_mtime) + log_dir = str(latest.parent.parent) + print(f"Updated log_dir via recursive search: {log_dir}") + break + + trace_folders = ( + pathlib.Path(log_dir).absolute() / "plugins" / "profile" + ).iterdir() + latest_trace_folder = max(trace_folders, key=os.path.getmtime) + + # XPlane files usually end with .xplane.pb + xplane_files = list(latest_trace_folder.glob("*.xplane.pb")) + try: + (xplane_file,) = xplane_files + except ValueError as value_error: + raise ValueError( + f"Invalid trace folder: {latest_trace_folder}. Expected 1" + f" '*.xplane.pb' file, but found {len(xplane_files)}." + ) from value_error + + with open(xplane_file, "rb") as f: + serialized_space = f.read() + + space = xplane_pb2.XSpace() + space.ParseFromString(serialized_space) + # print("space: ", space) + sparsecore_found = False + for _, plane in enumerate(space.planes): + print("plane: ", plane.name) + if "SparseCore" in plane.name: + sparsecore_found = True + break + return sparsecore_found + + def get_metrics_from_trace(trace: dict[str, Any], task: str) -> list[float]: # Check if the given task name is a collective with corresponding TPU opertion. # This is a workaround and should be reverted or refactored in future.