From aa4621122d393430cc4ba37363b73b1817af9979 Mon Sep 17 00:00:00 2001 From: hgt312 Date: Tue, 31 Mar 2026 12:12:13 +0000 Subject: [PATCH] fix(trace): always rename non-aliased outputs to output{idx} Non-aliased outputs that already had a name from tracing kept that name instead of the canonical "output{idx}", breaking NEFF I/O name matching. Also adds alias input auto-resolution and better _validate_io errors in Spike, plus tests for all alias naming patterns. --- nkipy/src/nkipy/core/trace.py | 16 ++- nkipy/src/nkipy/runtime/device_kernel.py | 12 +- spike/src/spike/spike_model.py | 46 +++++++ tests/unit/test_alias.py | 145 +++++++++++++++++++++++ tests/unit/test_device_kernel_cc.py | 38 +++--- 5 files changed, 225 insertions(+), 32 deletions(-) diff --git a/nkipy/src/nkipy/core/trace.py b/nkipy/src/nkipy/core/trace.py index 88b908c..3533c70 100644 --- a/nkipy/src/nkipy/core/trace.py +++ b/nkipy/src/nkipy/core/trace.py @@ -217,11 +217,20 @@ def _mark_hlo_outputs(self, code: HLOModule, ret, param_tensor_refs): copy_tensor = ctx.build_op("copy", [bt], bt.shape, bt.dtype) ret[i] = NKIPyTensorRef(copy_tensor, name="") - # Step 3: Assign output names and build AliasInfo list + # Step 3: Assign output names and build AliasInfo list. + # + # Aliased outputs keep their parameter name (e.g. "kv_cache") so that + # the Neuron compiler + Spike runtime can bind the input and output to + # the same device buffer via the ".must_alias_input" convention. + # + # Non-aliased outputs are *always* renamed to "output{idx}" — even if + # they already carry a name from tracing — to prevent the Neuron + # compiler from folding (optimizing away) the output variable. for idx, r in enumerate(ret): if not isinstance(r, NKIPyTensorRef): raise RuntimeError(f"Unexpected return value type: {type(r)}") + is_alias_output = False if idx in aliased_return_positions: param_name, param_index = aliased_return_positions[idx] code.aliases.append( @@ -233,10 +242,9 @@ def _mark_hlo_outputs(self, code: HLOModule, ret, param_tensor_refs): ) ) r.backend_tensor.name = param_name + is_alias_output = True - # N.B.: the name "output{idx}" is specific - # it avoids variable folding in HLO lowering in Neuron Compiler - if not r.backend_tensor.name: + if not is_alias_output: r.backend_tensor.name = f"output{idx}" result_tensors = [r.backend_tensor for r in ret] diff --git a/nkipy/src/nkipy/runtime/device_kernel.py b/nkipy/src/nkipy/runtime/device_kernel.py index b3b13fb..2960fe9 100644 --- a/nkipy/src/nkipy/runtime/device_kernel.py +++ b/nkipy/src/nkipy/runtime/device_kernel.py @@ -124,8 +124,10 @@ def compile_and_load( # In MPMD mode, namespace build dir by rank to avoid concurrent writes # when different ranks produce the same content hash. if not is_spmd: - effective_rank = rank_id if rank_id is not None else ( - dist.get_rank() if distributed else None + effective_rank = ( + rank_id + if rank_id is not None + else (dist.get_rank() if distributed else None) ) if effective_rank is not None: compile_build_dir = os.path.join( @@ -174,11 +176,13 @@ def compile_and_load( # --- 2. Resolve CC parameters for loading --- resolved_cc = cc_enabled if cc_enabled is not None else distributed resolved_rank = ( - rank_id if rank_id is not None + rank_id + if rank_id is not None else (dist.get_rank() if distributed else None) ) resolved_world = ( - world_size if world_size is not None + world_size + if world_size is not None else (dist.get_world_size() if distributed else None) ) diff --git a/spike/src/spike/spike_model.py b/spike/src/spike/spike_model.py index b9aa42f..e700a95 100644 --- a/spike/src/spike/spike_model.py +++ b/spike/src/spike/spike_model.py @@ -195,8 +195,49 @@ def _check_dtype_compatibility( f"got {actual_dtype}" ) + _ALIAS_SUFFIX = ".must_alias_input" + + def _resolve_alias_inputs(self, inputs): + """Auto-remap aliased input names so callers can use original param names. + + NKIPy's tracer renames mutated (aliased) parameters from ``"X"`` to + ``"X.must_alias_input"`` in the compiled NEFF. This method lets + callers pass the natural name ``"X"`` and transparently appends the + suffix when the NEFF expects it. Inputs that already carry the + suffix or are not aliased are passed through unchanged. + """ + resolved = {} + for k, v in inputs.items(): + if k not in self.input_tensors_info: + alias_key = k + self._ALIAS_SUFFIX + if alias_key in self.input_tensors_info: + resolved[alias_key] = v + continue + resolved[k] = v + return resolved + def _validate_io(self, inputs, outputs): + """Validate that caller-supplied I/O dicts match the compiled NEFF. + + Checks tensor names, shapes, dtypes, and core placement. Raises + ``ValueError`` with the expected NEFF names on any name mismatch, + so callers get actionable diagnostics instead of a bare ``KeyError``. + """ model_core_id = self.model_ref.core_id + + unknown_inputs = set(inputs) - set(self.input_tensors_info) + if unknown_inputs: + raise ValueError( + f"Unknown input(s) {unknown_inputs} for model '{self.name}'. " + f"Expected inputs: {list(self.input_tensors_info.keys())}" + ) + unknown_outputs = set(outputs) - set(self.output_tensors_info) + if unknown_outputs: + raise ValueError( + f"Unknown output(s) {unknown_outputs} for model '{self.name}'. " + f"Expected outputs: {list(self.output_tensors_info.keys())}" + ) + for k, v in inputs.items(): tensor_core_id = v.tensor_ref.core_id assert tensor_core_id == model_core_id, ( @@ -246,6 +287,11 @@ def __call__( outputs = {tensor.name: tensor for tensor in output_tensors} auto_allocated = True + # Auto-resolve alias input naming: if caller passes "X" but the NEFF + # expects "X.must_alias_input", remap transparently so callers don't + # need to know about the alias suffix convention. + inputs = self._resolve_alias_inputs(inputs) + self._validate_io(inputs, outputs) input_refs = {k: v.tensor_ref for k, v in inputs.items()} diff --git a/tests/unit/test_alias.py b/tests/unit/test_alias.py index ebc5f47..6d4b49c 100644 --- a/tests/unit/test_alias.py +++ b/tests/unit/test_alias.py @@ -31,6 +31,13 @@ def nkipy_kernel_multi_alias(a_input, b_input, c_input): return a_input, c_input +def nkipy_kernel_named_intermediate(a_input, b_input): + """Kernel that returns an op result with an existing intermediate name.""" + out = np.add(a_input, b_input) + out.backend_tensor.name = "intermediate0" + return out + + def nkipy_kernel_no_return(a_input, b_input): """Kernel that mutates a_input but does not return anything.""" a_input[0, :] = b_input[1, :] @@ -204,5 +211,143 @@ def test_mixed_return_alias(trace_mode): trace_and_compile(nkipy_kernel_mixed_return, trace_mode, A.copy(), B) +def test_non_alias_outputs_are_renamed_to_output_names(): + """Non-aliased outputs must be renamed even if tracing assigned a temp name.""" + + a = ((np.random.rand(128, 512) - 0.5) * 2).astype(np.float16) + b = ((np.random.rand(128, 512) - 0.5) * 2).astype(np.float16) + + traced = NKIPyKernel.trace(nkipy_kernel_named_intermediate, backend="hlo") + hlo = traced.specialize(a, b) + + assert len(hlo.outputs) == 1 + assert hlo.outputs[0].name == "output0" + + +# ------------------------------------------------------------------ # +# Output naming contract tests +# +# These verify the exact I/O names that end up in compiled NEFFs, +# which callers must match when invoking kernel(inputs={...}, outputs={...}). +# The patterns below were discovered while integrating with sglang-nkipy. +# ------------------------------------------------------------------ # + + +def test_alias_output_naming_simple(): + """Direct alias: mutated param returned by identity keeps param name. + + Pattern: update_kv_cache(kv_cache: mutable) -> kv_cache + Expected NEFF: + input = "a_input.must_alias_input" + output = "a_input" + """ + a = ((np.random.rand(128, 512) - 0.5) * 2).astype(np.float16) + b = ((np.random.rand(128, 512) - 0.5) * 2).astype(np.float16) + + traced = NKIPyKernel.trace(nkipy_kernel_single_alias, backend="hlo") + hlo = traced.specialize(a, b) + + # 1 output: the aliased param + assert len(hlo.outputs) == 1 + assert hlo.outputs[0].name == "a_input" + + # Input param renamed with alias suffix + param_names = [p.name for p in hlo.parameters] + assert "a_input.must_alias_input" in param_names + assert "b_input" in param_names + + +def test_alias_auto_appended_output_naming(): + """Broken-identity alias: mutated param NOT returned → auto-appended. + + Pattern: prefill_post_moe_fn(output: mutable, ...) where output is mutated + but the function returns a *different* computed value. The tracer + auto-appends the original mutated param as an extra output. + + Expected NEFF: + inputs = "a_input.must_alias_input", "b_input" + outputs = "output0" (the sum), "a_input" (auto-appended alias) + """ + a = ((np.random.rand(128, 512) - 0.5) * 2).astype(np.float16) + b = ((np.random.rand(128, 512) - 0.5) * 2).astype(np.float16) + + traced = NKIPyKernel.trace(nkipy_kernel_mixed_return, backend="hlo") + hlo = traced.specialize(a, b) + + # 2 outputs: user-returned sum (output0) + auto-appended alias (a_input) + assert len(hlo.outputs) == 2 + assert hlo.outputs[0].name == "output0" + assert hlo.outputs[1].name == "a_input" + + # Alias metadata + assert len(hlo.aliases) == 1 + assert hlo.aliases[0].param_name == "a_input" + assert hlo.aliases[0].output_index == 1 + assert hlo.aliases[0].is_user_returned is False + + # Input param renamed + param_names = [p.name for p in hlo.parameters] + assert "a_input.must_alias_input" in param_names + + +def nkipy_kernel_alias_with_multiple_outputs(a_input, b_input, c_input): + """Kernel that aliases a_input and c_input, returns them plus a computed value. + + Pattern: fused pre_moe graph returning (kv_cache, hidden, topk, ...) + where kv_cache is aliased but hidden/topk/... are not. + """ + a_input[0:1, :] = b_input[0:1, :] + c_input[2:3, :] = b_input[2:3, :] + computed = np.add(a_input, b_input) + return a_input, computed, c_input + + +def test_alias_mixed_with_non_alias_outputs(): + """Multiple outputs where some are aliased and some are not. + + Expected NEFF: + inputs = "a_input.must_alias_input", "b_input", "c_input.must_alias_input" + outputs = "a_input" (alias), "output1" (computed), "c_input" (alias) + """ + a = ((np.random.rand(128, 512) - 0.5) * 2).astype(np.float16) + b = ((np.random.rand(128, 512) - 0.5) * 2).astype(np.float16) + c = ((np.random.rand(128, 512) - 0.5) * 2).astype(np.float16) + + traced = NKIPyKernel.trace(nkipy_kernel_alias_with_multiple_outputs, backend="hlo") + hlo = traced.specialize(a, b, c) + + assert len(hlo.outputs) == 3 + # Aliased outputs keep param names; non-aliased get output{idx} + assert hlo.outputs[0].name == "a_input" + assert hlo.outputs[1].name == "output1" + assert hlo.outputs[2].name == "c_input" + + # 2 aliases + assert len(hlo.aliases) == 2 + alias_names = {a.param_name for a in hlo.aliases} + assert alias_names == {"a_input", "c_input"} + + +def test_no_return_alias_output_naming(): + """Mutation-only kernel: auto-appended alias is the sole output. + + Expected NEFF: + input = "a_input.must_alias_input", "b_input" + output = "a_input" + """ + a = ((np.random.rand(128, 512) - 0.5) * 2).astype(np.float16) + b = ((np.random.rand(128, 512) - 0.5) * 2).astype(np.float16) + + traced = NKIPyKernel.trace(nkipy_kernel_no_return, backend="hlo") + hlo = traced.specialize(a, b) + + assert len(hlo.outputs) == 1 + assert hlo.outputs[0].name == "a_input" + + param_names = [p.name for p in hlo.parameters] + assert "a_input.must_alias_input" in param_names + assert "b_input" in param_names + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/unit/test_device_kernel_cc.py b/tests/unit/test_device_kernel_cc.py index a7261b7..af9458c 100644 --- a/tests/unit/test_device_kernel_cc.py +++ b/tests/unit/test_device_kernel_cc.py @@ -41,9 +41,10 @@ def mock_load_from_neff(): @pytest.fixture def mock_dist(): """Mock torch.distributed as initialized with world_size=2, rank=0.""" - with patch("nkipy.runtime.device_kernel._is_distributed", return_value=True), patch( - "nkipy.runtime.device_kernel.dist", create=True - ) as mock_d: + with ( + patch("nkipy.runtime.device_kernel._is_distributed", return_value=True), + patch("nkipy.runtime.device_kernel.dist", create=True) as mock_d, + ): mock_d.get_rank.return_value = 0 mock_d.get_world_size.return_value = 2 yield mock_d @@ -52,9 +53,10 @@ def mock_dist(): @pytest.fixture def mock_dist_rank1(): """Mock torch.distributed as initialized with world_size=2, rank=1.""" - with patch("nkipy.runtime.device_kernel._is_distributed", return_value=True), patch( - "nkipy.runtime.device_kernel.dist", create=True - ) as mock_d: + with ( + patch("nkipy.runtime.device_kernel._is_distributed", return_value=True), + patch("nkipy.runtime.device_kernel.dist", create=True) as mock_d, + ): mock_d.get_rank.return_value = 1 mock_d.get_world_size.return_value = 2 yield mock_d @@ -243,9 +245,7 @@ def test_cc_enabled_without_rank_raises( self, mock_trace_and_compile, mock_load_from_neff ): """cc_enabled=True without rank_id/world_size and no dist raises ValueError.""" - with patch( - "nkipy.runtime.device_kernel._is_distributed", return_value=False - ): + with patch("nkipy.runtime.device_kernel._is_distributed", return_value=False): with pytest.raises(ValueError, match="rank_id and world_size are required"): DeviceKernel.compile_and_load(_dummy_kernel, cc_enabled=True) @@ -253,13 +253,9 @@ def test_cc_enabled_without_world_size_raises( self, mock_trace_and_compile, mock_load_from_neff ): """cc_enabled=True with rank_id but no world_size and no dist raises.""" - with patch( - "nkipy.runtime.device_kernel._is_distributed", return_value=False - ): + with patch("nkipy.runtime.device_kernel._is_distributed", return_value=False): with pytest.raises(ValueError, match="rank_id and world_size are required"): - DeviceKernel.compile_and_load( - _dummy_kernel, cc_enabled=True, rank_id=0 - ) + DeviceKernel.compile_and_load(_dummy_kernel, cc_enabled=True, rank_id=0) class TestNonDistributed: @@ -267,9 +263,7 @@ class TestNonDistributed: def test_no_dist_no_cc(self, mock_trace_and_compile, mock_load_from_neff): """Without distributed, loads without CC by default.""" - with patch( - "nkipy.runtime.device_kernel._is_distributed", return_value=False - ): + with patch("nkipy.runtime.device_kernel._is_distributed", return_value=False): DeviceKernel.compile_and_load(_dummy_kernel) mock_load_from_neff.assert_called_once_with( @@ -278,9 +272,7 @@ def test_no_dist_no_cc(self, mock_trace_and_compile, mock_load_from_neff): def test_no_dist_explicit_cc(self, mock_trace_and_compile, mock_load_from_neff): """Without torch.distributed, explicit CC params still work.""" - with patch( - "nkipy.runtime.device_kernel._is_distributed", return_value=False - ): + with patch("nkipy.runtime.device_kernel._is_distributed", return_value=False): DeviceKernel.compile_and_load( _dummy_kernel, cc_enabled=True, rank_id=0, world_size=2 ) @@ -295,9 +287,7 @@ def test_no_dist_explicit_cc(self, mock_trace_and_compile, mock_load_from_neff): def test_no_dist_mpmd(self, mock_trace_and_compile, mock_load_from_neff): """MPMD without torch.distributed works with explicit CC.""" - with patch( - "nkipy.runtime.device_kernel._is_distributed", return_value=False - ): + with patch("nkipy.runtime.device_kernel._is_distributed", return_value=False): DeviceKernel.compile_and_load( _dummy_kernel, is_spmd=False,