Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions Ironwood/src/benchmark_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
from typing import Any, Dict

from benchmark_utils import find_sparsecore_usage_from_xplane
from benchmark_utils import get_lhs_named_shading
from benchmark_utils import get_out_sharding
from benchmark_utils import MetricsStatistics
Expand All @@ -26,7 +27,7 @@
SEED = 0
GLOBAL_SHARDING_STRATEGY = ShardingStrategy.NO_SHARDING
GLOBAL_PSTATE = 7

LOG_SPARSECORE_USAGE = False

def create_mesh(ici_size: int, mesh_shape: str) -> Mesh:
"""Creates a mesh with the given ICI size."""
Expand Down Expand Up @@ -85,6 +86,7 @@ def unified_ici_collectives_metrics(
ici_average_time_ms_list: list[float],
iteration: int,
op_type: str,
trace_dir: str = None,
) -> Dict[str, Any]:
"""Calculates the metrics for the ICI collectives benchmark."""

Expand Down Expand Up @@ -147,8 +149,15 @@ def unified_ici_collectives_metrics(
/ rank
)



sparsecore_used = "NA"
if LOG_SPARSECORE_USAGE:
print("trace_dir: ", trace_dir)
if trace_dir:
sparsecore_used = find_sparsecore_usage_from_xplane(trace_dir)
print("sparsecore_used: ", sparsecore_used)
print("hlo first replica group: ", hlo_first_replica_group)

metadata = {
"iteration": iteration,
"op_type": op_type,
Expand All @@ -164,6 +173,7 @@ def unified_ici_collectives_metrics(
"hlo_input_shape": json.dumps(hlo_input_shape),
"hlo_output_shape": json.dumps(hlo_output_shape),
"hlo_replica_groups": json.dumps(hlo_replica_groups),
"sparsecore_used": sparsecore_used,
}
achieved_bw = [transferred_data*1000/my_time for my_time in ici_average_time_ms_list]
achieved_bw_statistics = MetricsStatistics(
Expand Down Expand Up @@ -294,6 +304,7 @@ def data_generator():
"ici_average_time_ms_list": time_ms_list,
"matrix_shape": (m, n, k),
"op_type": "AR",
"trace_dir": trace_dir,
}


Expand All @@ -308,6 +319,7 @@ def psum_benchmark_calculate_metrics(
matrix_shape: tuple[int, int, int],
xla_output: str,
op_type: str,
trace_dir: str,
) -> Dict[str, Any]:
"""Calculates the metrics for the psum benchmark."""
# Build dictionary of all the parameters in the function
Expand All @@ -322,8 +334,10 @@ def psum_benchmark_calculate_metrics(
ici_average_time_ms_list,
matrix_dim,
op_type,
trace_dir,
)


def psum_scatter_benchmark(
matrix_dim: int,
dtype: jnp.dtype,
Expand Down Expand Up @@ -382,7 +396,7 @@ def f(x):
)
sharding_strategy_tuple = tuple(map(int, sharding_strategy.split("x")))
op_dimension_tuple_multiplier = math.prod(sharding_strategy_tuple)
m = op_dimension_tuple_multiplier * 2
m = op_dimension_tuple_multiplier
n = matrix_dim
k = 256

Expand All @@ -405,6 +419,7 @@ def data_generator():
"ici_average_time_ms_list": time_ms_list,
"matrix_shape": (m, n, k),
"op_type": "RS",
"trace_dir": trace_dir,
}


Expand All @@ -419,6 +434,7 @@ def psum_scatter_benchmark_calculate_metrics(
matrix_shape: tuple[int, int, int],
xla_output: str,
op_type: str,
trace_dir: str,
) -> Dict[str, Any]:
"""Calculates the metrics for the psum_scatter benchmark."""
# Build dictionary of all the parameters in the function
Expand All @@ -433,6 +449,7 @@ def psum_scatter_benchmark_calculate_metrics(
ici_average_time_ms_list,
matrix_dim,
op_type,
trace_dir,
)

def all_gather_benchmark(
Expand Down Expand Up @@ -473,7 +490,9 @@ def all_gather_benchmark(
"--xla_tpu_use_single_sparse_core_for_all_gather_offload=true",
"--xla_tpu_use_tc_device_shape_on_sc=true",
f"--xla_tpu_dvfs_p_state={GLOBAL_PSTATE}",
"--xla_tpu_scoped_vmem_limit_kib=65536",
]
# libtpu_init_args=[ ]
os.environ["LIBTPU_INIT_ARGS"] = " ".join(libtpu_init_args)
mesh = create_mesh(ici_size, mesh_shape)

Expand Down Expand Up @@ -513,7 +532,8 @@ def data_generator():
return {
"ici_average_time_ms_list": time_ms_list,
"matrix_shape": (m, n, k),
"op_type": "AG"
"op_type": "AG",
"trace_dir": trace_dir,
}


Expand All @@ -528,6 +548,7 @@ def all_gather_benchmark_calculate_metrics(
matrix_shape: tuple[int, int, int],
xla_output: str,
op_type: str,
trace_dir: str,
) -> Dict[str, Any]:
"""Calculates the metrics for the all_gather benchmark."""
# Build dictionary of all the parameters in the function
Expand All @@ -542,6 +563,7 @@ def all_gather_benchmark_calculate_metrics(
ici_average_time_ms_list,
matrix_dim,
op_type,
trace_dir,
)


Expand Down Expand Up @@ -621,6 +643,7 @@ def data_generator():
"ici_average_time_ms_list": time_ms_list,
"matrix_shape": (m, n, k),
"op_type": "A2A",
"trace_dir": trace_dir,
}


Expand All @@ -635,6 +658,7 @@ def all_to_all_benchmark_calculate_metrics(
matrix_shape: tuple[int, int, int],
xla_output: str,
op_type: str,
trace_dir: str,
) -> Dict[str, Any]:
"""Calculates the metrics for the all_to_all benchmark."""
# Build dictionary of all the parameters in the function
Expand All @@ -649,5 +673,6 @@ def all_to_all_benchmark_calculate_metrics(
ici_average_time_ms_list,
matrix_dim,
op_type,
trace_dir,
)

3 changes: 1 addition & 2 deletions Ironwood/src/benchmark_send_recv.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
"""Benchmarking p2p source target transfer."""

import os
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, Tuple
import jax
from jax.experimental import mesh_utils
import jax.numpy as jnp
import jax.sharding
from benchmark_utils import (
MetricsStatistics,
get_trace,
)
from common import MARKER
Expand Down
73 changes: 72 additions & 1 deletion Ironwood/src/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
import gc
import jax.extend
from tensorflow.tsl.profiler.protobuf import xplane_pb2

# The dictionary to map a JAX (collective) function to its main HLO.
TARGET_TASK_NAME_COLLECTIVES_MAP = {
Expand Down Expand Up @@ -121,6 +124,13 @@ def multiple_iteration_timeit_from_trace_throttling(
return multiple_iteration_get_metrics_from_trace(trace)


def clear_jax_memory():
backend = jax.extend.backend.get_backend()
for buf in backend.live_buffers():
buf.delete()
gc.collect()


def multiple_iteration_timeit_from_trace(
compute_func: Callable,
data_generator: Callable,
Expand Down Expand Up @@ -153,11 +163,13 @@ def multiple_iteration_timeit_from_trace(
print(f"[{task}] Running iteration {i} of {tries} with {matrix_dim}...")
data_args = data_generator()
jax.devices()

with jax.profiler.StepTraceAnnotation(task, step_num=i):
with jax.named_scope(f"{MARKER}_{i}"):

result = compute_func(*data_args)
jax.block_until_ready(result)

clear_jax_memory()
trace = get_trace(tmp_trace_dir)

if trace_full_dir != tmp_trace_dir:
Expand Down Expand Up @@ -502,6 +514,65 @@ def get_trace(log_dir: str) -> dict[str, Any]:
return trace


def find_sparsecore_usage_from_xplane(log_dir: str) -> xplane_pb2.XSpace:
"""Extract the XSpace object from the log directory.
Returns:
An XSpace protobuf object.
"""
print("find_sparsecore_usage_from_xplane: ", log_dir)

# Handle partial log_dir
if not (pathlib.Path(log_dir) / "plugins" / "profile").exists():
potential_dirs = glob.glob(f"{log_dir}*")
potential_dirs = [d for d in potential_dirs if os.path.isdir(d)]
potential_dirs.sort(key=os.path.getmtime, reverse=True)

for d in potential_dirs:
d_path = pathlib.Path(d)
if (d_path / "plugins" / "profile").exists():
log_dir = d
print(f"Updated log_dir to match partial path: {log_dir}")
break

# Check subdirectories recursively
candidates = list(d_path.glob("**/plugins/profile"))
if candidates:
latest = max(candidates, key=lambda p: p.stat().st_mtime)
log_dir = str(latest.parent.parent)
print(f"Updated log_dir via recursive search: {log_dir}")
break

trace_folders = (
pathlib.Path(log_dir).absolute() / "plugins" / "profile"
).iterdir()
latest_trace_folder = max(trace_folders, key=os.path.getmtime)

# XPlane files usually end with .xplane.pb
xplane_files = list(latest_trace_folder.glob("*.xplane.pb"))
try:
(xplane_file,) = xplane_files
except ValueError as value_error:
raise ValueError(
f"Invalid trace folder: {latest_trace_folder}. Expected 1"
f" '*.xplane.pb' file, but found {len(xplane_files)}."
) from value_error

with open(xplane_file, "rb") as f:
serialized_space = f.read()

space = xplane_pb2.XSpace()
space.ParseFromString(serialized_space)
# print("space: ", space)
sparsecore_found = False
for _, plane in enumerate(space.planes):
print("plane: ", plane.name)
if "SparseCore" in plane.name:
sparsecore_found = True
break
return sparsecore_found


def get_metrics_from_trace(trace: dict[str, Any], task: str) -> list[float]:
# Check if the given task name is a collective with corresponding TPU opertion.
# This is a workaround and should be reverted or refactored in future.
Expand Down