From bcc90b1a5c9534e158c82084788f3d3dfa00c75a Mon Sep 17 00:00:00 2001 From: JianboDong Date: Thu, 4 Sep 2025 19:36:28 +0800 Subject: [PATCH 1/4] Update test_low_latency.py --- tests/test_low_latency.py | 229 +++++++++++++++++++++++++++++++++++--- 1 file changed, 214 insertions(+), 15 deletions(-) diff --git a/tests/test_low_latency.py b/tests/test_low_latency.py index aa928aab..ab644778 100644 --- a/tests/test_low_latency.py +++ b/tests/test_low_latency.py @@ -7,13 +7,14 @@ import numpy as np from functools import partial from typing import Optional - +import pandas as pd import deep_ep -from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back +from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back, get_global_token_indices def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer, + imbalance_factor: float = 1.0, distribution: str = 'lognormal', print_res: bool = True, use_logfmt: bool = False, seed: int = 0): torch.manual_seed(seed + rank) random.seed(seed + rank) @@ -34,9 +35,21 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, # NOTES: the last one is for performance testing # Most of the values in the perf case is lower than the threshold, casting most channels x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1) + + scatter_list = None + if rank == 0: + global_topk_idx = get_global_token_indices( + distribution, num_experts, num_tokens, num_ranks, num_topk, imbalance_factor, seed + ) + scatter_list = [ + chunk.contiguous() for chunk in torch.chunk(global_topk_idx, num_ranks, dim=0) + ] + topk_idx = torch.empty(num_tokens, num_topk, dtype=torch.long, device='cuda') + dist.scatter(tensor=topk_idx, scatter_list=scatter_list, src=0, group=group) - scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 - topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] + results = {} + results['topk_idx'] = topk_idx.cpu() + topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs() # Randomly mask some positions @@ -142,11 +155,17 @@ def test_func(return_recv_hook: bool): num_selections = (topk_idx[i] != -1).sum().item() num_dispatch_comm_bytes += num_fp8_bytes * num_selections num_combine_comm_bytes += (num_logfmt10_bytes if use_logfmt else num_bf16_bytes) * num_selections + results['dispatch_comm_bytes'] = num_dispatch_comm_bytes + results['combine_comm_bytes'] = num_combine_comm_bytes # Dispatch + combine testing avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False)) - print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, ' - f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True) + if print_res: + print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, ' + f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True) + results['total_time_avg'] = avg_t * 1e6 + results['total_time_min'] = min_t * 1e6 + results['total_time_max'] = max_t * 1e6 # Separate profiling for return_recv_hook in (False, True): @@ -155,14 +174,161 @@ def test_func(return_recv_hook: bool): kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True, suppress_kineto_output=True, num_kernels_per_period=2 if return_recv_hook else 1) if not return_recv_hook: - print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | ' - f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us', flush=True) + if print_res: + print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | ' + f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us', flush=True) + results['dispatch_time'] = dispatch_t * 1e6 + results['combine_time'] = combine_t * 1e6 else: - print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | ' - f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us', flush=True) - return hash_value + if print_res: + print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | ' + f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us', flush=True) + results['dispatch_send_time'] = dispatch_t[0] * 1e6 + results['dispatch_recv_time'] = dispatch_t[1] * 1e6 + results['combine_send_time'] = combine_t[0] * 1e6 + results['combine_recv_time'] = combine_t[1] * 1e6 + return results, hash_value + +def process_and_display_results(all_results, num_experts, num_ranks, imbalance_factor): + all_topk_idx_list = [result['topk_idx'] for result in all_results] + global_topk_idx = torch.cat(all_topk_idx_list, dim=0) + + num_local_experts = num_experts // num_ranks + rank_counts = torch.zeros(num_ranks, dtype=torch.int64) + + valid_indices = global_topk_idx[global_topk_idx >= 0] + for rank in range(num_ranks): + start_expert = rank * num_local_experts + end_expert = (rank + 1) * num_local_experts + + mask = (valid_indices >= start_expert) & (valid_indices < end_expert) + rank_counts[rank] = mask.sum().item() + + max_count = rank_counts.max().item() + min_count = rank_counts.min().item() + avg_count = rank_counts.float().mean().item() + median_count = rank_counts.float().median().item() + std_count = rank_counts.float().std().item() + actual_max_avg = max_count / avg_count if avg_count > 0 else 0 + + avg_dispatch_bytes = sum(result['dispatch_comm_bytes'] for result in all_results) / len(all_results) + avg_combine_bytes = sum(result['combine_comm_bytes'] for result in all_results) / len(all_results) + + avg_total_time = sum(result['total_time_avg'] for result in all_results) / len(all_results) + avg_dispatch_time = sum(result['dispatch_time'] for result in all_results) / len(all_results) + avg_combine_time = sum(result['combine_time'] for result in all_results) / len(all_results) + + total_bw = (avg_dispatch_bytes + avg_combine_bytes) / 1e9 / (avg_total_time / 1e6) + dispatch_bw = avg_dispatch_bytes / 1e9 / (avg_dispatch_time / 1e6) + combine_bw = avg_combine_bytes / 1e9 / (avg_combine_time / 1e6) + + for result in all_results: + if 'topk_idx' in result: + del result['topk_idx'] + for key in ['dispatch_comm_bytes', 'combine_comm_bytes']: + if key in result: + del result[key] + + df = pd.DataFrame(all_results) + mean_series = df.mean() + + mean_series['total_bw'] = total_bw + mean_series['dispatch_bw'] = dispatch_bw + mean_series['combine_bw'] = combine_bw + mean_series['imbalance_factor'] = imbalance_factor + mean_series['max_count'] = float(max_count) + mean_series['min_count'] = float(min_count) + mean_series['avg_count'] = float(avg_count) + mean_series['median_count'] = float(median_count) + mean_series['std_count'] = float(std_count) + mean_series['actual_max_avg'] = float(actual_max_avg) + + return mean_series +def print_summary_tables(final_df): + print("\n" + "="*120) + print(" PERFORMANCE SUMMARY (Statistics across all ranks)") + print("="*120) + + # Table 1: Token Distribution Statistics + print("\n--- Token Distribution per Rank ---") + imbalance_df = final_df[['max_count', 'min_count', 'avg_count', 'median_count', 'std_count']].copy() + imbalance_df.columns = ['Max', 'Min', 'Avg', 'Median', 'Std Dev'] + + formatters = { + 'Max': lambda x: f"{x:4.0f}", + 'Min': lambda x: f"{x:3.0f}", + 'Avg': lambda x: f"{x:5.1f}", + 'Median': lambda x: f"{x:6.1f}", + 'Std Dev': lambda x: f"{x:6.1f}" + } + + for col, formatter in formatters.items(): + imbalance_df[col] = imbalance_df[col].apply(formatter) + + print(imbalance_df.to_string()) + + # Table 2: Total Performance + print("\n--- Total Performance (Dispatch + Combine) ---") + total_perf_df = final_df[['total_bw', 'total_time_avg', 'total_time_min', 'total_time_max']].copy() + total_perf_df.columns = ['Total BW', 'Avg Time', 'Min Time', 'Max Time'] + + formatters = { + 'Total BW': lambda x: f"{x:.2f} GB/s", + 'Avg Time': lambda x: f"{x:.2f} us", + 'Min Time': lambda x: f"{x:.2f} us", + 'Max Time': lambda x: f"{x:.2f} us" + } + + for col, formatter in formatters.items(): + total_perf_df[col] = total_perf_df[col].apply(formatter) + + print(total_perf_df.to_string()) + + # Table 3: Separate Performance + print("\n--- Separate Dispatch & Combine Performance ---") + separate_perf_df = final_df[['dispatch_bw', 'dispatch_time', 'combine_bw', 'combine_time']].copy() + separate_perf_df.columns = ['Dispatch BW', 'Dispatch Time', 'Combine BW', 'Combine Time'] + + formatters = { + 'Dispatch BW': lambda x: f"{x:.2f} GB/s", + 'Dispatch Time': lambda x: f"{x:.2f} us", + 'Combine BW': lambda x: f"{x:.2f} GB/s", + 'Combine Time': lambda x: f"{x:.2f} us" + } + + for col, formatter in formatters.items(): + separate_perf_df[col] = separate_perf_df[col].apply(formatter) + + print(separate_perf_df.to_string()) + + # Table 4: Hook Performance + print("\n--- Send/Recv Timings (Hook=True) ---") + + hook_data = [] + for idx in final_df.index: + row = final_df.loc[idx] + hook_data.append([ + f"{row['dispatch_send_time']:>6.2f} us", + f"{row['dispatch_recv_time']:>6.2f} us", + f"{row['combine_send_time']:>6.2f} us", + f"{row['combine_recv_time']:>6.2f} us" + ]) + + columns = pd.MultiIndex.from_tuples([ + (' ', 'Send'), + ('Dispatch', 'Recv'), + (' ', 'Send'), + ('Combine', 'Recv') + ]) + + hook_df = pd.DataFrame(hook_data, index=final_df.index, columns=columns) + + print(hook_df.to_string()) + + print("\n" + "="*120) + # noinspection PyUnboundLocalVariable,PyShadowingNames def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): rank, num_ranks, group = init_dist(local_rank, num_local_ranks) @@ -178,16 +344,44 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): allow_mnnvl=args.allow_mnnvl) test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, use_logfmt=args.use_logfmt, seed=1) - + dist.barrier() + if rank == 0: + all_imbalance_summaries = [] + for imbalance_factor in args.imbalance_factors: + if rank == 0: + print(f"\n--> Running test for target imbalance factor: {imbalance_factor}", flush=True) + results_dict, _ = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, + imbalance_factor=imbalance_factor, distribution = args.distribution, + print_res=False, seed=1) + if rank == 0: + all_results = [None] * num_ranks + dist.gather_object(results_dict, all_results, dst=0, group=group) + + mean_series = process_and_display_results(all_results, num_experts, num_ranks, imbalance_factor) + all_imbalance_summaries.append(mean_series) + else: + dist.gather_object(results_dict, None, dst=0, group=group) + + if rank == 0 and all_imbalance_summaries: + df = pd.DataFrame(all_imbalance_summaries) + + df['display_index'] = df.apply( + lambda row: f"{row['imbalance_factor']:.1f} (Actual: {row['actual_max_avg']:.2f})", + axis=1 + ) + df.set_index('display_index', inplace=True) + df.index.name = "Max/Avg Ratio" + print_summary_tables(df) do_pressure_test = args.pressure_test for seed in range(int(1e9) if do_pressure_test else 0): if local_rank == 0: print(f'Testing with seed {seed} ...', flush=True) - ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, + _, ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, use_logfmt=args.use_logfmt, seed=seed) for i in range(20): - assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, - use_logfmt=args.use_logfmt, seed=seed) == ref_hash, f'Error: seed={seed}' + _, current_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, + use_logfmt=args.use_logfmt, seed=seed) + assert current_hash == ref_hash, f'Error: seed={seed}' # Destroy the buffer runtime and communication group buffer.destroy() @@ -217,6 +411,11 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): help='Whether to test LogFMT combine') parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test') + parser.add_argument('--imbalance-factors', type=float, nargs='+', default=[1.0, 2.0, 3.0], + help='A list of target max/avg ratios for per-rank expert load (tokens per expert). ' + 'Higher values create more load imbalance (e.g., 1.0, 2.0, 3.0).' + 'Note: actual ratios may be lower than targets due to token count constraints.') + parser.add_argument('--distribution', type=str, default='lognormal', choices=['lognormal','powerlaw', 'gamma']) args = parser.parse_args() num_processes = args.num_processes From 212ea443bab6888d94737fa7a0ed364915400149 Mon Sep 17 00:00:00 2001 From: JianboDong Date: Thu, 4 Sep 2025 19:38:39 +0800 Subject: [PATCH 2/4] Update utils.py --- tests/utils.py | 200 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 200 insertions(+) diff --git a/tests/utils.py b/tests/utils.py index a64cc0ae..482fdfcc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -37,11 +37,211 @@ def init_dist(local_rank: int, num_local_ranks: int): def calc_diff(x: torch.Tensor, y: torch.Tensor): + # avoid NaN + if x.numel() == 0 and y.numel() == 0: + return 0.0 x, y = x.double() + 1, y.double() + 1 denominator = (x * x + y * y).sum() sim = 2 * (x * y).sum() / denominator return (1 - sim).item() +def get_global_token_indices(distribution_type, num_experts, num_tokens, num_ranks, num_topk, + imbalance_factor, simulation_seed=0): + """ + Generates and returns the global top-k indices for all tokens across all ranks, + matching the target imbalance factor. + """ + if imbalance_factor <= 1.0: + # use uniform + print(f"Distribution: uniform, Target ratio: {imbalance_factor}", flush=True) + expert_probs = torch.ones(num_experts, dtype=torch.float32, device='cuda') + expert_probs /= expert_probs.sum() + + total_tokens = num_tokens * num_ranks + token_probs_expanded = expert_probs.unsqueeze(0).expand(total_tokens, -1) + + torch.manual_seed(simulation_seed) + global_topk_idx = torch.multinomial(token_probs_expanded, num_samples=num_topk, replacement=False) + return global_topk_idx + else: + # Imbalanced case: find params and get indices directly + params, global_topk_idx = find_distribution_parameters_for_target_ratio( + distribution_type=distribution_type, + target_imbalance_ratio=imbalance_factor, + num_experts=num_experts, + num_tokens_per_rank=num_tokens, + num_ranks=num_ranks, + num_topk=num_topk, + simulation_seed=simulation_seed + ) + + print(f'Distribution: {distribution_type}, Target ratio: {imbalance_factor}, ' + f'Found parameters: {params}', flush=True) + return global_topk_idx.contiguous() + +def _simulate_global_sampling( + distribution_type: str, + params_dict: dict, + num_experts: int, + total_tokens: int, + num_ranks: int, + num_topk: int, + simulation_seed: int = 0, + permutation: torch.Tensor = None +): + """ + Simulates global sampling and returns both the resulting imbalance ratio and the generated indices. + """ + num_local_experts = num_experts // num_ranks + expert_probs = _generate_expert_probs( + distribution_type, params_dict, num_experts, permutation=permutation + ) + token_probs_expanded = expert_probs.unsqueeze(0).expand(total_tokens, -1) + + torch.manual_seed(simulation_seed) + global_topk_idx = torch.multinomial( + token_probs_expanded, num_samples=num_topk, replacement=False + ) + + rank_counts = torch.zeros(num_ranks, dtype=torch.int64, device='cuda') + valid_indices = global_topk_idx.flatten() + valid_indices = valid_indices[valid_indices >= 0] + + for rank in range(num_ranks): + start_expert = rank * num_local_experts + end_expert = (rank + 1) * num_local_experts + mask = (valid_indices >= start_expert) & (valid_indices < end_expert) + rank_counts[rank] = mask.sum().item() + + max_count = rank_counts.max().item() + avg_count = rank_counts.float().mean().item() + + imbalance_ratio = max_count / avg_count if avg_count > 0 else 0.0 + + return imbalance_ratio, global_topk_idx +def find_distribution_parameters_for_target_ratio( + distribution_type: str, + target_imbalance_ratio: float, + num_experts: int, + num_tokens_per_rank: int, + num_ranks: int, + num_topk: int, + max_iterations=20, + tolerance=0.02, + simulation_seed=0 +): + """ + Finds parameters and returns them along with the final sampled indices. + Returns: + tuple: (dict of parameters, torch.Tensor of global_topk_idx) + """ + total_tokens = num_tokens_per_rank * num_ranks + torch.manual_seed(simulation_seed) + permutation = torch.randperm(num_experts, device='cuda') + + def simulate_ratio_only(params_dict): + ratio, _ = _simulate_global_sampling( + distribution_type, params_dict, num_experts, total_tokens, + num_ranks, num_topk, simulation_seed, permutation=permutation + ) + return ratio + + search_config = _get_search_config(distribution_type, target_imbalance_ratio) + final_params = _binary_search_parameters( + simulate_ratio_only, + search_config, + target_imbalance_ratio, + max_iterations, + tolerance + ) + _, final_global_topk_idx = _simulate_global_sampling( + distribution_type, final_params, num_experts, total_tokens, + num_ranks, num_topk, simulation_seed, permutation=permutation + ) + return final_params, final_global_topk_idx +def _generate_expert_probs(distribution_type: str, params: dict, num_experts: int, permutation: torch.Tensor = None): + """Generate expert probabilities, with optional shuffling.""" + if distribution_type == 'powerlaw': + alpha = params['alpha'] + ranks = torch.arange(1, num_experts + 1, device='cuda', dtype=torch.float32) + popularity_values = ranks ** (-alpha if alpha > 0 else 0) + + elif distribution_type == 'lognormal': + sigma = params['sigma'] + log_normal_dist = torch.distributions.LogNormal(loc=0.0, scale=sigma) + popularity_values = log_normal_dist.sample((num_experts,)).to('cuda') + popularity_values, _ = torch.sort(popularity_values, descending=True) + popularity_values.clamp_(min=1e-9) + + elif distribution_type == 'gamma': + shape = params['shape'] + gamma_dist = torch.distributions.Gamma(concentration=shape, rate=2.0) + popularity_values = gamma_dist.sample((num_experts,)).to('cuda') + popularity_values, _ = torch.sort(popularity_values, descending=True) + popularity_values.clamp_(min=1e-9) + + else: + raise ValueError(f"Unsupported distribution: {distribution_type}") + + if permutation is not None: + shuffled_values = torch.zeros_like(popularity_values) + shuffled_values.scatter_(0, permutation, popularity_values) + popularity_values = shuffled_values + + return popularity_values / popularity_values.sum() +def _get_search_config(distribution_type: str, target_ratio: float): + """Get search bounds and parameter names for each distribution""" + + if distribution_type == 'powerlaw': + return { + 'param_name': 'alpha', + 'low': 0.0, + 'high': 5.0 + } + + elif distribution_type == 'lognormal': + return { + 'param_name': 'sigma', + 'low': 0.1, + 'high': 5.0 + } + + elif distribution_type == 'gamma': + return { + 'param_name': 'shape', + 'low': 0.001, # Shape must be > 0 + 'high': 2.0, + 'inverse_relationship': True # reverse binary search + } + + else: + raise ValueError(f"No search config for distribution: {distribution_type}") +def _binary_search_parameters(simulate_fn, search_config, target_ratio, max_iterations, tolerance): + """Generic binary search for distribution parameters""" + + param_name = search_config['param_name'] + param_low = search_config['low'] + param_high = search_config['high'] + inverse_relationship = search_config.get('inverse_relationship', False) + + for i in range(max_iterations): + param_mid = (param_low + param_high) / 2 + actual_ratio = simulate_fn({param_name: param_mid}) + + if abs(actual_ratio - target_ratio) / target_ratio < tolerance: + return {param_name: param_mid} + + condition_for_increasing_param = actual_ratio < target_ratio + if inverse_relationship: + condition_for_increasing_param = not condition_for_increasing_param + + if condition_for_increasing_param: + param_low = param_mid + else: + param_high = param_mid + + final_param = (param_low + param_high) / 2 + return {param_name: final_param} def per_token_cast_to_fp8(x: torch.Tensor): assert x.dim() == 2 and x.size(1) % 128 == 0 From 1ae1d7f8304398a8e9af3f1f66a0c9ea80f9138b Mon Sep 17 00:00:00 2001 From: JianboDong Date: Thu, 30 Oct 2025 11:11:33 +0800 Subject: [PATCH 3/4] Update test_low_latency.py to solve conflicts --- tests/test_low_latency.py | 360 +++++++++++++++++++++++++------------- 1 file changed, 235 insertions(+), 125 deletions(-) diff --git a/tests/test_low_latency.py b/tests/test_low_latency.py index ab644778..6bcfbdeb 100644 --- a/tests/test_low_latency.py +++ b/tests/test_low_latency.py @@ -1,21 +1,53 @@ import argparse import random -import time -import os import torch import torch.distributed as dist -import numpy as np from functools import partial -from typing import Optional +from typing import Literal, Set import pandas as pd import deep_ep from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back, get_global_token_indices -def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, - rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer, - imbalance_factor: float = 1.0, distribution: str = 'lognormal', print_res: bool = True, - use_logfmt: bool = False, seed: int = 0): +def simulate_failure_and_skip(rank: int, api: Literal["dispatch", "combine", "clean"], expected_masked_ranks: Set[int]): + # Simulates rank failure when the rank first calls the corresponding communication API + failed_api_ranks = { + # API -> rank to fail (rank fails when it first calls the corresponding communication API) + 'dispatch': 1, + 'combine': 3, + 'clean': 5 + } + if rank in expected_masked_ranks: + # Rank already failed + return True + if api in failed_api_ranks.keys(): + expected_masked_ranks.add(failed_api_ranks[api]) + if failed_api_ranks[api] == rank: + print(f"Rank {rank} failed when first calling {api} communication API, exit...", flush=True) + return True + return False + + +def query_mask_buffer_and_check(api: Literal["dispatch", "combine", "clean"], buffer: deep_ep.Buffer, mask_status: torch.Tensor, + expected_masked_ranks: Set[int]): + buffer.low_latency_query_mask_buffer(mask_status) + assert set(mask_status.nonzero().squeeze(-1).tolist()) == expected_masked_ranks + + +def test_main(num_tokens: int, + hidden: int, + num_experts: int, + num_topk: int, + rank: int, + num_ranks: int, + group: dist.ProcessGroup, + buffer: deep_ep.Buffer, + imbalance_factor: float = 1.0, + distribution: str = 'lognormal', + print_res: bool = True, + use_logfmt: bool = False, + shrink_test: bool = False, + seed: int = 0): torch.manual_seed(seed + rank) random.seed(seed + rank) @@ -29,33 +61,37 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset) x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1) x_list = [x] - for i in range(4 if use_logfmt else 0): + for _ in range(4 if use_logfmt else 0): # NOTES: make more LogFMT casts and also with some BF16 x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.5 * random.random()) # NOTES: the last one is for performance testing # Most of the values in the perf case is lower than the threshold, casting most channels x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1) - + scatter_list = None if rank == 0: - global_topk_idx = get_global_token_indices( - distribution, num_experts, num_tokens, num_ranks, num_topk, imbalance_factor, seed - ) - scatter_list = [ - chunk.contiguous() for chunk in torch.chunk(global_topk_idx, num_ranks, dim=0) - ] + global_topk_idx = get_global_token_indices(distribution, num_experts, num_tokens, num_ranks, num_topk, imbalance_factor, seed) + scatter_list = [chunk.contiguous() for chunk in torch.chunk(global_topk_idx, num_ranks, dim=0)] topk_idx = torch.empty(num_tokens, num_topk, dtype=torch.long, device='cuda') dist.scatter(tensor=topk_idx, scatter_list=scatter_list, src=0, group=group) results = {} results['topk_idx'] = topk_idx.cpu() - + + topk_idx = topk_idx.to(deep_ep.topk_idx_t) topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs() # Randomly mask some positions - for i in range(10): + for _ in range(10): topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1 + all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda') + dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group) + + # For failure simulation and shrink testing + mask_status = torch.zeros((num_ranks, ), dtype=torch.int, device='cuda') + expected_masked_ranks = set() + # Check dispatch correctness do_check = True hash_value, num_times = 0, 0 @@ -64,8 +100,10 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, for dispatch_use_fp8 in (False, True): for round_scale in (False, True) if dispatch_use_fp8 else (False, ): for use_ue8m0 in (False, True) if round_scale else (False, ): + if shrink_test and simulate_failure_and_skip(rank, "dispatch", expected_masked_ranks): + break num_times += 1 - for i in range((num_times % 2) + 1): + for _ in range((num_times % 2) + 1): cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda') packed_recv_x, packed_recv_count, handle, event, hook = \ buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts, @@ -73,22 +111,26 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) hook() if return_recv_hook else event.current_stream_wait() + if shrink_test: + query_mask_buffer_and_check("dispatch", buffer, mask_status, expected_masked_ranks) packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \ if dispatch_use_fp8 else packed_recv_x.clone() - all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda') - dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group) for i in range(num_local_experts if do_check else 0): expert_id = rank * num_local_experts + i recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i] recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i] # Check expert indices - int_mask = (2 ** 32) - 1 + int_mask = (2**32) - 1 num_valid_tokens = recv_count.item() - assert cumulative_local_expert_recv_stats[i].item() == num_valid_tokens, f'{cumulative_local_expert_recv_stats[i].item()} != {num_valid_tokens}' - assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()' - assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}' + assert cumulative_local_expert_recv_stats[i].item( + ) == num_valid_tokens, f'{cumulative_local_expert_recv_stats[i].item()} != {num_valid_tokens}' + assert num_valid_tokens == ( + recv_layout_range + & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()' + assert num_valid_tokens == (all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status == 0].sum().item( + ), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status==0].sum().item()}' if num_valid_tokens == 0: continue @@ -103,6 +145,8 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, else: assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0 for j in range(num_ranks): + if shrink_test and mask_status[j]: + continue begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item() if not round_scale: assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item() @@ -114,21 +158,49 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens]) # Check combine correctness + if shrink_test and simulate_failure_and_skip(rank, "combine", expected_masked_ranks): + break for zero_copy in (False, ) if use_logfmt else (False, True): if zero_copy: buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') - combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, - use_logfmt=use_logfmt, - async_finish=not return_recv_hook, zero_copy=zero_copy, - return_recv_hook=return_recv_hook, out=out) + combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, + topk_idx, + topk_weights, + handle, + use_logfmt=use_logfmt, + async_finish=not return_recv_hook, + zero_copy=zero_copy, + return_recv_hook=return_recv_hook, + out=out) hook() if return_recv_hook else event.current_stream_wait() + if shrink_test: + query_mask_buffer_and_check("combine", buffer, mask_status, expected_masked_ranks) if do_check: + if shrink_test: + owner_by_expert = (torch.arange(num_experts, device='cuda') // num_local_experts) + fail_owner_mask = (mask_status == 1).index_select(0, owner_by_expert) + valid_topk_idx = topk_idx >= 0 + failed_topk_idx = torch.zeros_like(topk_idx, device='cuda', dtype=torch.bool) + failed_topk_idx[valid_topk_idx] = fail_owner_mask.index_select(0, topk_idx[valid_topk_idx]) + topk_idx[failed_topk_idx] = -1 diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) assert torch.isnan(combined_x).sum().item() == 0 - assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: {diff=}, {dispatch_use_fp8=}, {zero_copy=}' + if not round_scale: + assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: {diff=}, {dispatch_use_fp8=}, {zero_copy=}' hash_value ^= hash_tensor(combined_x) + # Clean buffer API + if shrink_test: + if simulate_failure_and_skip(rank, "clean", expected_masked_ranks): + break + + buffer.clean_low_latency_buffer(num_tokens, hidden, num_experts) + query_mask_buffer_and_check("clean", buffer, mask_status, expected_masked_ranks) + + if shrink_test: + return + # noinspection PyShadowingNames def large_gemm_with_hook(hook): mat_0 = torch.randn((8192, 8192), dtype=torch.float) @@ -143,8 +215,12 @@ def test_func(return_recv_hook: bool): cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook) large_gemm_with_hook(hook) if return_recv_hook else None - combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, - use_logfmt=use_logfmt, return_recv_hook=return_recv_hook) + combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, + topk_idx, + topk_weights, + handle, + use_logfmt=use_logfmt, + return_recv_hook=return_recv_hook) large_gemm_with_hook(hook) if return_recv_hook else None # Calculate bandwidth @@ -161,8 +237,10 @@ def test_func(return_recv_hook: bool): # Dispatch + combine testing avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False)) if print_res: - print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, ' - f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True) + print( + f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, ' + f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', + flush=True) results['total_time_avg'] = avg_t * 1e6 results['total_time_min'] = min_t * 1e6 results['total_time_max'] = max_t * 1e6 @@ -171,68 +249,75 @@ def test_func(return_recv_hook: bool): for return_recv_hook in (False, True): group.barrier() dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook), - kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True, - suppress_kineto_output=True, num_kernels_per_period=2 if return_recv_hook else 1) + kernel_names=('dispatch', 'combine'), + barrier_comm_profiling=True, + suppress_kineto_output=True, + num_kernels_per_period=2 if return_recv_hook else 1) if not return_recv_hook: if print_res: - print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | ' - f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us', flush=True) + print( + f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | ' + f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us', + flush=True) results['dispatch_time'] = dispatch_t * 1e6 - results['combine_time'] = combine_t * 1e6 + results['combine_time'] = combine_t * 1e6 else: if print_res: - print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | ' - f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us', flush=True) + print( + f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | ' + f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us', + flush=True) results['dispatch_send_time'] = dispatch_t[0] * 1e6 results['dispatch_recv_time'] = dispatch_t[1] * 1e6 results['combine_send_time'] = combine_t[0] * 1e6 results['combine_recv_time'] = combine_t[1] * 1e6 return results, hash_value + def process_and_display_results(all_results, num_experts, num_ranks, imbalance_factor): all_topk_idx_list = [result['topk_idx'] for result in all_results] global_topk_idx = torch.cat(all_topk_idx_list, dim=0) - + num_local_experts = num_experts // num_ranks rank_counts = torch.zeros(num_ranks, dtype=torch.int64) - + valid_indices = global_topk_idx[global_topk_idx >= 0] for rank in range(num_ranks): start_expert = rank * num_local_experts end_expert = (rank + 1) * num_local_experts - + mask = (valid_indices >= start_expert) & (valid_indices < end_expert) rank_counts[rank] = mask.sum().item() - + max_count = rank_counts.max().item() min_count = rank_counts.min().item() avg_count = rank_counts.float().mean().item() median_count = rank_counts.float().median().item() std_count = rank_counts.float().std().item() actual_max_avg = max_count / avg_count if avg_count > 0 else 0 - + avg_dispatch_bytes = sum(result['dispatch_comm_bytes'] for result in all_results) / len(all_results) avg_combine_bytes = sum(result['combine_comm_bytes'] for result in all_results) / len(all_results) - + avg_total_time = sum(result['total_time_avg'] for result in all_results) / len(all_results) avg_dispatch_time = sum(result['dispatch_time'] for result in all_results) / len(all_results) avg_combine_time = sum(result['combine_time'] for result in all_results) / len(all_results) - + total_bw = (avg_dispatch_bytes + avg_combine_bytes) / 1e9 / (avg_total_time / 1e6) dispatch_bw = avg_dispatch_bytes / 1e9 / (avg_dispatch_time / 1e6) combine_bw = avg_combine_bytes / 1e9 / (avg_combine_time / 1e6) - + for result in all_results: if 'topk_idx' in result: del result['topk_idx'] for key in ['dispatch_comm_bytes', 'combine_comm_bytes']: if key in result: del result[key] - + df = pd.DataFrame(all_results) mean_series = df.mean() - + mean_series['total_bw'] = total_bw mean_series['dispatch_bw'] = dispatch_bw mean_series['combine_bw'] = combine_bw @@ -243,19 +328,20 @@ def process_and_display_results(all_results, num_experts, num_ranks, imbalance_f mean_series['median_count'] = float(median_count) mean_series['std_count'] = float(std_count) mean_series['actual_max_avg'] = float(actual_max_avg) - + return mean_series + def print_summary_tables(final_df): - print("\n" + "="*120) + print("\n" + "=" * 120) print(" PERFORMANCE SUMMARY (Statistics across all ranks)") - print("="*120) - + print("=" * 120) + # Table 1: Token Distribution Statistics print("\n--- Token Distribution per Rank ---") imbalance_df = final_df[['max_count', 'min_count', 'avg_count', 'median_count', 'std_count']].copy() imbalance_df.columns = ['Max', 'Min', 'Avg', 'Median', 'Std Dev'] - + formatters = { 'Max': lambda x: f"{x:4.0f}", 'Min': lambda x: f"{x:3.0f}", @@ -263,72 +349,66 @@ def print_summary_tables(final_df): 'Median': lambda x: f"{x:6.1f}", 'Std Dev': lambda x: f"{x:6.1f}" } - + for col, formatter in formatters.items(): imbalance_df[col] = imbalance_df[col].apply(formatter) - + print(imbalance_df.to_string()) - + # Table 2: Total Performance print("\n--- Total Performance (Dispatch + Combine) ---") total_perf_df = final_df[['total_bw', 'total_time_avg', 'total_time_min', 'total_time_max']].copy() total_perf_df.columns = ['Total BW', 'Avg Time', 'Min Time', 'Max Time'] - + formatters = { 'Total BW': lambda x: f"{x:.2f} GB/s", 'Avg Time': lambda x: f"{x:.2f} us", 'Min Time': lambda x: f"{x:.2f} us", 'Max Time': lambda x: f"{x:.2f} us" } - + for col, formatter in formatters.items(): total_perf_df[col] = total_perf_df[col].apply(formatter) - + print(total_perf_df.to_string()) - + # Table 3: Separate Performance print("\n--- Separate Dispatch & Combine Performance ---") separate_perf_df = final_df[['dispatch_bw', 'dispatch_time', 'combine_bw', 'combine_time']].copy() separate_perf_df.columns = ['Dispatch BW', 'Dispatch Time', 'Combine BW', 'Combine Time'] - + formatters = { 'Dispatch BW': lambda x: f"{x:.2f} GB/s", 'Dispatch Time': lambda x: f"{x:.2f} us", 'Combine BW': lambda x: f"{x:.2f} GB/s", 'Combine Time': lambda x: f"{x:.2f} us" } - + for col, formatter in formatters.items(): separate_perf_df[col] = separate_perf_df[col].apply(formatter) - + print(separate_perf_df.to_string()) - + # Table 4: Hook Performance print("\n--- Send/Recv Timings (Hook=True) ---") - + hook_data = [] for idx in final_df.index: row = final_df.loc[idx] hook_data.append([ - f"{row['dispatch_send_time']:>6.2f} us", - f"{row['dispatch_recv_time']:>6.2f} us", - f"{row['combine_send_time']:>6.2f} us", + f"{row['dispatch_send_time']:>6.2f} us", f"{row['dispatch_recv_time']:>6.2f} us", f"{row['combine_send_time']:>6.2f} us", f"{row['combine_recv_time']:>6.2f} us" ]) - - columns = pd.MultiIndex.from_tuples([ - (' ', 'Send'), - ('Dispatch', 'Recv'), - (' ', 'Send'), - ('Combine', 'Recv') - ]) - + + columns = pd.MultiIndex.from_tuples([(' ', 'Send'), ('Dispatch', 'Recv'), (' ', 'Send'), ('Combine', 'Recv')]) + hook_df = pd.DataFrame(hook_data, index=final_df.index, columns=columns) - + print(hook_df.to_string()) - - print("\n" + "="*120) - + + print("\n" + "=" * 120) + + # noinspection PyUnboundLocalVariable,PyShadowingNames def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): rank, num_ranks, group = init_dist(local_rank, num_local_ranks) @@ -338,37 +418,56 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts) if local_rank == 0: print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True) - buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True, + buffer = deep_ep.Buffer(group, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=True, num_qps_per_rank=num_experts // num_ranks, - allow_nvlink_for_low_latency_mode=not args.disable_nvlink, explicitly_destroy=True, - allow_mnnvl=args.allow_mnnvl) - test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, - use_logfmt=args.use_logfmt, seed=1) + allow_nvlink_for_low_latency_mode=not args.disable_nvlink, + explicitly_destroy=True, + allow_mnnvl=args.allow_mnnvl, + enable_shrink=args.shrink_test) + test_main(num_tokens, + hidden, + num_experts, + num_topk, + rank, + num_ranks, + group, + buffer, + use_logfmt=args.use_logfmt, + shrink_test=args.shrink_test, + seed=1) dist.barrier() if rank == 0: all_imbalance_summaries = [] for imbalance_factor in args.imbalance_factors: if rank == 0: print(f"\n--> Running test for target imbalance factor: {imbalance_factor}", flush=True) - results_dict, _ = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, - imbalance_factor=imbalance_factor, distribution = args.distribution, - print_res=False, seed=1) + results_dict, _ = test_main(num_tokens, + hidden, + num_experts, + num_topk, + rank, + num_ranks, + group, + buffer, + imbalance_factor=imbalance_factor, + distribution=args.distribution, + print_res=False, + seed=1) if rank == 0: all_results = [None] * num_ranks dist.gather_object(results_dict, all_results, dst=0, group=group) - + mean_series = process_and_display_results(all_results, num_experts, num_ranks, imbalance_factor) all_imbalance_summaries.append(mean_series) else: dist.gather_object(results_dict, None, dst=0, group=group) - + if rank == 0 and all_imbalance_summaries: df = pd.DataFrame(all_imbalance_summaries) - - df['display_index'] = df.apply( - lambda row: f"{row['imbalance_factor']:.1f} (Actual: {row['actual_max_avg']:.2f})", - axis=1 - ) + + df['display_index'] = df.apply(lambda row: f"{row['imbalance_factor']:.1f} (Actual: {row['actual_max_avg']:.2f})", axis=1) df.set_index('display_index', inplace=True) df.index.name = "Max/Avg Ratio" print_summary_tables(df) @@ -376,11 +475,27 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): for seed in range(int(1e9) if do_pressure_test else 0): if local_rank == 0: print(f'Testing with seed {seed} ...', flush=True) - _, ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, - use_logfmt=args.use_logfmt, seed=seed) - for i in range(20): - _, current_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, - use_logfmt=args.use_logfmt, seed=seed) + __, ref_hash = test_main(num_tokens, + hidden, + num_experts, + num_topk, + rank, + num_ranks, + group, + buffer, + use_logfmt=args.use_logfmt, + seed=seed) + for _ in range(20): + __, current_hash = test_main(num_tokens, + hidden, + num_experts, + num_topk, + rank, + num_ranks, + group, + buffer, + use_logfmt=args.use_logfmt, + seed=seed) assert current_hash == ref_hash, f'Error: seed={seed}' # Destroy the buffer runtime and communication group @@ -393,29 +508,24 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): # TODO: you may modify NUMA binding for less CPU overhead # TODO: buggy with `num_tokens=512` parser = argparse.ArgumentParser(description='Test low-latency EP kernels') - parser.add_argument('--num-processes', type=int, default=8, - help='Number of processes to spawn (default: 8)') - parser.add_argument('--num-tokens', type=int, default=128, - help='Number of tokens (default: 128)') - parser.add_argument('--hidden', type=int, default=7168, - help='Hidden dimension size (default: 7168)') - parser.add_argument('--num-topk', type=int, default=8, - help='Number of top-k experts (default: 8)') - parser.add_argument('--num-experts', type=int, default=288, - help='Number of experts (default: 288)') - parser.add_argument('--allow-mnnvl', action="store_true", - help='Allow MNNVL for communication') - parser.add_argument('--disable-nvlink', action='store_true', - help='Whether to disable NVLink for testing') - parser.add_argument('--use-logfmt', action='store_true', - help='Whether to test LogFMT combine') - parser.add_argument("--pressure-test", action='store_true', - help='Whether to do pressure test') - parser.add_argument('--imbalance-factors', type=float, nargs='+', default=[1.0, 2.0, 3.0], + parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)') + parser.add_argument('--num-tokens', type=int, default=128, help='Number of tokens (default: 128)') + parser.add_argument('--hidden', type=int, default=7168, help='Hidden dimension size (default: 7168)') + parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)') + parser.add_argument('--num-experts', type=int, default=288, help='Number of experts (default: 288)') + parser.add_argument('--allow-mnnvl', action="store_true", help='Allow MNNVL for communication') + parser.add_argument('--disable-nvlink', action='store_true', help='Whether to disable NVLink for testing') + parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine') + parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test') + parser.add_argument("--shrink-test", action='store_true', help='Whether to simulate failure and test shrink mode') + parser.add_argument('--imbalance-factors', + type=float, + nargs='+', + default=[1.0, 2.0, 3.0], help='A list of target max/avg ratios for per-rank expert load (tokens per expert). ' - 'Higher values create more load imbalance (e.g., 1.0, 2.0, 3.0).' - 'Note: actual ratios may be lower than targets due to token count constraints.') - parser.add_argument('--distribution', type=str, default='lognormal', choices=['lognormal','powerlaw', 'gamma']) + 'Higher values create more load imbalance (e.g., 1.0, 2.0, 3.0).' + 'Note: actual ratios may be lower than targets due to token count constraints.') + parser.add_argument('--distribution', type=str, default='lognormal', choices=['lognormal', 'powerlaw', 'gamma']) args = parser.parse_args() num_processes = args.num_processes From c0c1f7626bd65c6ce94fe3ea9b736548dab3cda5 Mon Sep 17 00:00:00 2001 From: JianboDong Date: Thu, 30 Oct 2025 11:12:03 +0800 Subject: [PATCH 4/4] Update utils.py to solve conflicts --- tests/utils.py | 229 ++++++++++++++++++++++++++----------------------- 1 file changed, 120 insertions(+), 109 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 482fdfcc..6cd47520 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -45,8 +45,8 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): sim = 2 * (x * y).sum() / denominator return (1 - sim).item() -def get_global_token_indices(distribution_type, num_experts, num_tokens, num_ranks, num_topk, - imbalance_factor, simulation_seed=0): + +def get_global_token_indices(distribution_type, num_experts, num_tokens, num_ranks, num_topk, imbalance_factor, simulation_seed=0): """ Generates and returns the global top-k indices for all tokens across all ranks, matching the target imbalance factor. @@ -56,80 +56,73 @@ def get_global_token_indices(distribution_type, num_experts, num_tokens, num_ran print(f"Distribution: uniform, Target ratio: {imbalance_factor}", flush=True) expert_probs = torch.ones(num_experts, dtype=torch.float32, device='cuda') expert_probs /= expert_probs.sum() - + total_tokens = num_tokens * num_ranks token_probs_expanded = expert_probs.unsqueeze(0).expand(total_tokens, -1) - + torch.manual_seed(simulation_seed) global_topk_idx = torch.multinomial(token_probs_expanded, num_samples=num_topk, replacement=False) return global_topk_idx else: # Imbalanced case: find params and get indices directly - params, global_topk_idx = find_distribution_parameters_for_target_ratio( - distribution_type=distribution_type, - target_imbalance_ratio=imbalance_factor, - num_experts=num_experts, - num_tokens_per_rank=num_tokens, - num_ranks=num_ranks, - num_topk=num_topk, - simulation_seed=simulation_seed - ) - + params, global_topk_idx = find_distribution_parameters_for_target_ratio(distribution_type=distribution_type, + target_imbalance_ratio=imbalance_factor, + num_experts=num_experts, + num_tokens_per_rank=num_tokens, + num_ranks=num_ranks, + num_topk=num_topk, + simulation_seed=simulation_seed) + print(f'Distribution: {distribution_type}, Target ratio: {imbalance_factor}, ' f'Found parameters: {params}', flush=True) return global_topk_idx.contiguous() - -def _simulate_global_sampling( - distribution_type: str, - params_dict: dict, - num_experts: int, - total_tokens: int, - num_ranks: int, - num_topk: int, - simulation_seed: int = 0, - permutation: torch.Tensor = None -): + + +def _simulate_global_sampling(distribution_type: str, + params_dict: dict, + num_experts: int, + total_tokens: int, + num_ranks: int, + num_topk: int, + simulation_seed: int = 0, + permutation: torch.Tensor = None): """ Simulates global sampling and returns both the resulting imbalance ratio and the generated indices. """ num_local_experts = num_experts // num_ranks - expert_probs = _generate_expert_probs( - distribution_type, params_dict, num_experts, permutation=permutation - ) + expert_probs = _generate_expert_probs(distribution_type, params_dict, num_experts, permutation=permutation) token_probs_expanded = expert_probs.unsqueeze(0).expand(total_tokens, -1) - + torch.manual_seed(simulation_seed) - global_topk_idx = torch.multinomial( - token_probs_expanded, num_samples=num_topk, replacement=False - ) - + global_topk_idx = torch.multinomial(token_probs_expanded, num_samples=num_topk, replacement=False) + rank_counts = torch.zeros(num_ranks, dtype=torch.int64, device='cuda') valid_indices = global_topk_idx.flatten() valid_indices = valid_indices[valid_indices >= 0] - + for rank in range(num_ranks): start_expert = rank * num_local_experts end_expert = (rank + 1) * num_local_experts mask = (valid_indices >= start_expert) & (valid_indices < end_expert) rank_counts[rank] = mask.sum().item() - + max_count = rank_counts.max().item() avg_count = rank_counts.float().mean().item() - + imbalance_ratio = max_count / avg_count if avg_count > 0 else 0.0 - + return imbalance_ratio, global_topk_idx -def find_distribution_parameters_for_target_ratio( - distribution_type: str, - target_imbalance_ratio: float, - num_experts: int, - num_tokens_per_rank: int, - num_ranks: int, - num_topk: int, - max_iterations=20, - tolerance=0.02, - simulation_seed=0 -): + + +def find_distribution_parameters_for_target_ratio(distribution_type: str, + target_imbalance_ratio: float, + num_experts: int, + num_tokens_per_rank: int, + num_ranks: int, + num_topk: int, + max_iterations=20, + tolerance=0.02, + simulation_seed=0): """ Finds parameters and returns them along with the final sampled indices. Returns: @@ -138,128 +131,141 @@ def find_distribution_parameters_for_target_ratio( total_tokens = num_tokens_per_rank * num_ranks torch.manual_seed(simulation_seed) permutation = torch.randperm(num_experts, device='cuda') - + def simulate_ratio_only(params_dict): - ratio, _ = _simulate_global_sampling( - distribution_type, params_dict, num_experts, total_tokens, - num_ranks, num_topk, simulation_seed, permutation=permutation - ) + ratio, _ = _simulate_global_sampling(distribution_type, + params_dict, + num_experts, + total_tokens, + num_ranks, + num_topk, + simulation_seed, + permutation=permutation) return ratio - + search_config = _get_search_config(distribution_type, target_imbalance_ratio) - final_params = _binary_search_parameters( - simulate_ratio_only, - search_config, - target_imbalance_ratio, - max_iterations, - tolerance - ) - _, final_global_topk_idx = _simulate_global_sampling( - distribution_type, final_params, num_experts, total_tokens, - num_ranks, num_topk, simulation_seed, permutation=permutation - ) + final_params = _binary_search_parameters(simulate_ratio_only, search_config, target_imbalance_ratio, max_iterations, tolerance) + _, final_global_topk_idx = _simulate_global_sampling(distribution_type, + final_params, + num_experts, + total_tokens, + num_ranks, + num_topk, + simulation_seed, + permutation=permutation) return final_params, final_global_topk_idx + + def _generate_expert_probs(distribution_type: str, params: dict, num_experts: int, permutation: torch.Tensor = None): """Generate expert probabilities, with optional shuffling.""" if distribution_type == 'powerlaw': alpha = params['alpha'] ranks = torch.arange(1, num_experts + 1, device='cuda', dtype=torch.float32) - popularity_values = ranks ** (-alpha if alpha > 0 else 0) - + popularity_values = ranks**(-alpha if alpha > 0 else 0) + elif distribution_type == 'lognormal': sigma = params['sigma'] log_normal_dist = torch.distributions.LogNormal(loc=0.0, scale=sigma) - popularity_values = log_normal_dist.sample((num_experts,)).to('cuda') + popularity_values = log_normal_dist.sample((num_experts, )).to('cuda') popularity_values, _ = torch.sort(popularity_values, descending=True) popularity_values.clamp_(min=1e-9) - + elif distribution_type == 'gamma': shape = params['shape'] gamma_dist = torch.distributions.Gamma(concentration=shape, rate=2.0) - popularity_values = gamma_dist.sample((num_experts,)).to('cuda') + popularity_values = gamma_dist.sample((num_experts, )).to('cuda') popularity_values, _ = torch.sort(popularity_values, descending=True) popularity_values.clamp_(min=1e-9) - + else: raise ValueError(f"Unsupported distribution: {distribution_type}") - + if permutation is not None: shuffled_values = torch.zeros_like(popularity_values) shuffled_values.scatter_(0, permutation, popularity_values) popularity_values = shuffled_values - + return popularity_values / popularity_values.sum() + + def _get_search_config(distribution_type: str, target_ratio: float): """Get search bounds and parameter names for each distribution""" - + if distribution_type == 'powerlaw': - return { - 'param_name': 'alpha', - 'low': 0.0, - 'high': 5.0 - } - + return {'param_name': 'alpha', 'low': 0.0, 'high': 5.0} + elif distribution_type == 'lognormal': - return { - 'param_name': 'sigma', - 'low': 0.1, - 'high': 5.0 - } - + return {'param_name': 'sigma', 'low': 0.1, 'high': 5.0} + elif distribution_type == 'gamma': return { 'param_name': 'shape', 'low': 0.001, # Shape must be > 0 'high': 2.0, - 'inverse_relationship': True # reverse binary search + 'inverse_relationship': True # reverse binary search } - + else: raise ValueError(f"No search config for distribution: {distribution_type}") + + def _binary_search_parameters(simulate_fn, search_config, target_ratio, max_iterations, tolerance): """Generic binary search for distribution parameters""" - + param_name = search_config['param_name'] param_low = search_config['low'] param_high = search_config['high'] - inverse_relationship = search_config.get('inverse_relationship', False) - - for i in range(max_iterations): + inverse_relationship = search_config.get('inverse_relationship', False) + + for _ in range(max_iterations): param_mid = (param_low + param_high) / 2 actual_ratio = simulate_fn({param_name: param_mid}) - + if abs(actual_ratio - target_ratio) / target_ratio < tolerance: return {param_name: param_mid} - + condition_for_increasing_param = actual_ratio < target_ratio if inverse_relationship: condition_for_increasing_param = not condition_for_increasing_param - + if condition_for_increasing_param: param_low = param_mid else: param_high = param_mid - + final_param = (param_low + param_high) / 2 return {param_name: final_param} + +def align_up(x, y): + return (x + y - 1) // y * y + + def per_token_cast_to_fp8(x: torch.Tensor): - assert x.dim() == 2 and x.size(1) % 128 == 0 + assert x.dim() == 2 m, n = x.shape - x_view = x.view(m, -1, 128) - x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) + aligned_n = align_up(n, 128) + x_padded = torch.nn.functional.pad(x, (0, aligned_n - n), mode='constant', value=0) + x_padded_view = x_padded.view(m, -1, 128) + x_amax = x_padded_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + return (x_padded_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view( + m, aligned_n)[:, :n].contiguous(), (x_amax / 448.0).view(m, -1) def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor): if x_fp8.numel() == 0: return x_fp8.to(torch.bfloat16) + + assert x_fp8.dim() == 2 + m, n = x_fp8.shape + aligned_n = align_up(n, 128) + x_fp8_padded = torch.nn.functional.pad(x_fp8, (0, aligned_n - n), mode='constant', value=0) if x_scales.dtype == torch.int: x_scales = x_scales.view(dtype=torch.uint8).to(torch.int) << 23 x_scales = x_scales.view(dtype=torch.float) - x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128) + x_fp32_padded = x_fp8_padded.to(torch.float32).view(x_fp8.size(0), -1, 128) x_scales = x_scales.view(x_fp8.size(0), -1, 1) - return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16) + return (x_fp32_padded * x_scales).view(x_fp8_padded.shape).to(torch.bfloat16)[:, :n].contiguous() def inplace_unique(x: torch.Tensor, num_slots: int): @@ -314,6 +320,7 @@ def bench(fn, num_warmups: int = 50, num_tests: int = 50, post_fn=None): class empty_suppress: + def __enter__(self): return self @@ -322,6 +329,7 @@ def __exit__(self, *_): class suppress_stdout_stderr: + def __enter__(self): self.outnull_file = open(os.devnull, 'w') self.errnull_file = open(os.devnull, 'w') @@ -356,15 +364,19 @@ def __exit__(self, *_): self.errnull_file.close() -def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppress_kineto_output: bool = False, - trace_path: Optional[str] = None, barrier_comm_profiling: bool = False, +def bench_kineto(fn, + kernel_names: Union[str, tuple], + num_tests: int = 30, + suppress_kineto_output: bool = False, + trace_path: Optional[str] = None, + barrier_comm_profiling: bool = False, num_kernels_per_period: int = 1): # Profile suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress with suppress(): - schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) + schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1) with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) as prof: - for i in range(2): + for _ in range(2): # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead if barrier_comm_profiling: lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') @@ -377,7 +389,7 @@ def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppr prof.step() # Parse the profiling table - assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) + assert isinstance(kernel_names, (str, tuple)) is_tuple = isinstance(kernel_names, tuple) prof_lines = prof.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names @@ -414,8 +426,7 @@ def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppr durations = [event['dur'] / 1e6 for event in events] assert len(durations) % num_kernels_per_period == 0 num_kernel_patterns = len(durations) // num_kernels_per_period - kernel_durations[i] = [sum(durations[j::num_kernels_per_period]) / num_kernel_patterns - for j in range(num_kernels_per_period)] + kernel_durations[i] = [sum(durations[j::num_kernels_per_period]) / num_kernel_patterns for j in range(num_kernels_per_period)] # Return execution durations return kernel_durations if is_tuple else kernel_durations[0]