From 59670d0b032cc35b463209e5accc037b37cf6109 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 11 Nov 2025 14:42:14 -0800 Subject: [PATCH] Custom opify triton kernel until local_map functionalization is fixed stack-info: PR: https://github.com/meta-pytorch/autoparallel/pull/245, branch: xmfan/stack/19 --- autoparallel/_testing/models/dsv3.py | 25 +++++++++++++++++++++---- examples/example_ds3_local_map.py | 6 +++++- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/autoparallel/_testing/models/dsv3.py b/autoparallel/_testing/models/dsv3.py index 694b866..2e038d6 100644 --- a/autoparallel/_testing/models/dsv3.py +++ b/autoparallel/_testing/models/dsv3.py @@ -75,7 +75,9 @@ def _fill_indices_kernel( # ============== -def fill_indices_wrapper( +# workaround until local_map functionalization is fixed: https://github.com/pytorch/pytorch/issues/167568 +@torch.library.custom_op("autoparallel::fill_indices_functional", mutates_args=()) +def fill_indices_functional( tokens_per_expert_group: torch.Tensor, start_index_values: torch.Tensor, write_offsets: torch.Tensor, @@ -84,7 +86,7 @@ def fill_indices_wrapper( max_len: int, block_size: int = 128, max_blocks: int = 1024, # cap on total number of blocks to launch -): +) -> torch.Tensor: # preallocate output permuted_indices = torch.full( (max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device @@ -108,6 +110,22 @@ def fill_indices_wrapper( return permuted_indices +@fill_indices_functional.register_fake +def _( + tokens_per_expert_group: torch.Tensor, + start_index_values: torch.Tensor, + write_offsets: torch.Tensor, + experts_per_rank: int, + num_ranks: int, + max_len: int, + block_size: int = 128, + max_blocks: int = 1024, # cap on total number of blocks to launch +) -> torch.Tensor: + return torch.full( + (max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device + ) + + # reference def fill_indices_cpu( tokens_per_expert_group: torch.Tensor, @@ -143,7 +161,6 @@ def fill_indices_cpu( start_index, start_index + (end_idx - write_start), dtype=torch.int32, - # device=device, ) write_start += length return permuted_indices @@ -213,7 +230,7 @@ def generate_permute_indices( max_len, ) else: - permuted_indices = fill_indices_wrapper( + permuted_indices = fill_indices_functional( tokens_per_expert_group, start_index_values, write_offsets, diff --git a/examples/example_ds3_local_map.py b/examples/example_ds3_local_map.py index d5958fe..d4e6b91 100644 --- a/examples/example_ds3_local_map.py +++ b/examples/example_ds3_local_map.py @@ -153,7 +153,8 @@ def input_fn(): # ) # maybe not correct value parallel_mod.init_weights(buffer_device=device, seed=rng_seed) if rng_seed is not None: - NumericsLogger(logs_dir).log_model_weights(parallel_mod) + numerics_logger = NumericsLogger(logs_dir) + numerics_logger.log_model_weights(parallel_mod) x = ( torch.randint( @@ -177,6 +178,9 @@ def input_fn(): out.backward(torch.randn_like(out)) else: out = parallel_mod(*x) + assert not torch.any(torch.isnan(out)), "Found NaNs in forward output" + if rng_seed is not None: + numerics_logger.log_forward_output(out) out.backward(torch.randn_like(out)) print("All good!")