Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions nkipy/src/nkipy/core/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,20 @@ def _mark_hlo_outputs(self, code: HLOModule, ret, param_tensor_refs):
copy_tensor = ctx.build_op("copy", [bt], bt.shape, bt.dtype)
ret[i] = NKIPyTensorRef(copy_tensor, name="")

# Step 3: Assign output names and build AliasInfo list
# Step 3: Assign output names and build AliasInfo list.
#
# Aliased outputs keep their parameter name (e.g. "kv_cache") so that
# the Neuron compiler + Spike runtime can bind the input and output to
# the same device buffer via the ".must_alias_input" convention.
#
# Non-aliased outputs are *always* renamed to "output{idx}" — even if
# they already carry a name from tracing — to prevent the Neuron
# compiler from folding (optimizing away) the output variable.
for idx, r in enumerate(ret):
if not isinstance(r, NKIPyTensorRef):
raise RuntimeError(f"Unexpected return value type: {type(r)}")

is_alias_output = False
if idx in aliased_return_positions:
param_name, param_index = aliased_return_positions[idx]
code.aliases.append(
Expand All @@ -233,10 +242,9 @@ def _mark_hlo_outputs(self, code: HLOModule, ret, param_tensor_refs):
)
)
r.backend_tensor.name = param_name
is_alias_output = True

# N.B.: the name "output{idx}" is specific
# it avoids variable folding in HLO lowering in Neuron Compiler
if not r.backend_tensor.name:
if not is_alias_output:
r.backend_tensor.name = f"output{idx}"

result_tensors = [r.backend_tensor for r in ret]
Expand Down
12 changes: 8 additions & 4 deletions nkipy/src/nkipy/runtime/device_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,10 @@ def compile_and_load(
# In MPMD mode, namespace build dir by rank to avoid concurrent writes
# when different ranks produce the same content hash.
if not is_spmd:
effective_rank = rank_id if rank_id is not None else (
dist.get_rank() if distributed else None
effective_rank = (
rank_id
if rank_id is not None
else (dist.get_rank() if distributed else None)
)
if effective_rank is not None:
compile_build_dir = os.path.join(
Expand Down Expand Up @@ -174,11 +176,13 @@ def compile_and_load(
# --- 2. Resolve CC parameters for loading ---
resolved_cc = cc_enabled if cc_enabled is not None else distributed
resolved_rank = (
rank_id if rank_id is not None
rank_id
if rank_id is not None
else (dist.get_rank() if distributed else None)
)
resolved_world = (
world_size if world_size is not None
world_size
if world_size is not None
else (dist.get_world_size() if distributed else None)
)

Expand Down
46 changes: 46 additions & 0 deletions spike/src/spike/spike_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,49 @@ def _check_dtype_compatibility(
f"got {actual_dtype}"
)

_ALIAS_SUFFIX = ".must_alias_input"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to complicate the handling of alias naming in spike.
Spike is a wrapper on runtime, it should not need to know how we lower the function into NEFFs. It should only deal with what's available in the NEFFs.

This specific problem can be addressed at the user level? The caller can pass .must_alias_input in the input tensor list.

A proper solution can be in the NEFF lowering in NKIPy (but we want to make sure we are aligned with NKI)


def _resolve_alias_inputs(self, inputs):
"""Auto-remap aliased input names so callers can use original param names.

NKIPy's tracer renames mutated (aliased) parameters from ``"X"`` to
``"X.must_alias_input"`` in the compiled NEFF. This method lets
callers pass the natural name ``"X"`` and transparently appends the
suffix when the NEFF expects it. Inputs that already carry the
suffix or are not aliased are passed through unchanged.
"""
resolved = {}
for k, v in inputs.items():
if k not in self.input_tensors_info:
alias_key = k + self._ALIAS_SUFFIX
if alias_key in self.input_tensors_info:
resolved[alias_key] = v
continue
resolved[k] = v
return resolved

def _validate_io(self, inputs, outputs):
"""Validate that caller-supplied I/O dicts match the compiled NEFF.

Checks tensor names, shapes, dtypes, and core placement. Raises
``ValueError`` with the expected NEFF names on any name mismatch,
so callers get actionable diagnostics instead of a bare ``KeyError``.
"""
model_core_id = self.model_ref.core_id

unknown_inputs = set(inputs) - set(self.input_tensors_info)
if unknown_inputs:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the checks!

raise ValueError(
f"Unknown input(s) {unknown_inputs} for model '{self.name}'. "
f"Expected inputs: {list(self.input_tensors_info.keys())}"
)
unknown_outputs = set(outputs) - set(self.output_tensors_info)
if unknown_outputs:
raise ValueError(
f"Unknown output(s) {unknown_outputs} for model '{self.name}'. "
f"Expected outputs: {list(self.output_tensors_info.keys())}"
)

for k, v in inputs.items():
tensor_core_id = v.tensor_ref.core_id
assert tensor_core_id == model_core_id, (
Expand Down Expand Up @@ -246,6 +287,11 @@ def __call__(
outputs = {tensor.name: tensor for tensor in output_tensors}
auto_allocated = True

# Auto-resolve alias input naming: if caller passes "X" but the NEFF
# expects "X.must_alias_input", remap transparently so callers don't
# need to know about the alias suffix convention.
inputs = self._resolve_alias_inputs(inputs)

self._validate_io(inputs, outputs)

input_refs = {k: v.tensor_ref for k, v in inputs.items()}
Expand Down
145 changes: 145 additions & 0 deletions tests/unit/test_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ def nkipy_kernel_multi_alias(a_input, b_input, c_input):
return a_input, c_input


def nkipy_kernel_named_intermediate(a_input, b_input):
"""Kernel that returns an op result with an existing intermediate name."""
out = np.add(a_input, b_input)
out.backend_tensor.name = "intermediate0"
return out


def nkipy_kernel_no_return(a_input, b_input):
"""Kernel that mutates a_input but does not return anything."""
a_input[0, :] = b_input[1, :]
Expand Down Expand Up @@ -204,5 +211,143 @@ def test_mixed_return_alias(trace_mode):
trace_and_compile(nkipy_kernel_mixed_return, trace_mode, A.copy(), B)


def test_non_alias_outputs_are_renamed_to_output_names():
"""Non-aliased outputs must be renamed even if tracing assigned a temp name."""

a = ((np.random.rand(128, 512) - 0.5) * 2).astype(np.float16)
b = ((np.random.rand(128, 512) - 0.5) * 2).astype(np.float16)

traced = NKIPyKernel.trace(nkipy_kernel_named_intermediate, backend="hlo")
hlo = traced.specialize(a, b)

assert len(hlo.outputs) == 1
assert hlo.outputs[0].name == "output0"


# ------------------------------------------------------------------ #
# Output naming contract tests
#
# These verify the exact I/O names that end up in compiled NEFFs,
# which callers must match when invoking kernel(inputs={...}, outputs={...}).
# The patterns below were discovered while integrating with sglang-nkipy.
# ------------------------------------------------------------------ #


def test_alias_output_naming_simple():
"""Direct alias: mutated param returned by identity keeps param name.

Pattern: update_kv_cache(kv_cache: mutable) -> kv_cache
Expected NEFF:
input = "a_input.must_alias_input"
output = "a_input"
"""
a = ((np.random.rand(128, 512) - 0.5) * 2).astype(np.float16)
b = ((np.random.rand(128, 512) - 0.5) * 2).astype(np.float16)

traced = NKIPyKernel.trace(nkipy_kernel_single_alias, backend="hlo")
hlo = traced.specialize(a, b)

# 1 output: the aliased param
assert len(hlo.outputs) == 1
assert hlo.outputs[0].name == "a_input"

# Input param renamed with alias suffix
param_names = [p.name for p in hlo.parameters]
assert "a_input.must_alias_input" in param_names
assert "b_input" in param_names


def test_alias_auto_appended_output_naming():
"""Broken-identity alias: mutated param NOT returned → auto-appended.

Pattern: prefill_post_moe_fn(output: mutable, ...) where output is mutated
but the function returns a *different* computed value. The tracer
auto-appends the original mutated param as an extra output.

Expected NEFF:
inputs = "a_input.must_alias_input", "b_input"
outputs = "output0" (the sum), "a_input" (auto-appended alias)
"""
a = ((np.random.rand(128, 512) - 0.5) * 2).astype(np.float16)
b = ((np.random.rand(128, 512) - 0.5) * 2).astype(np.float16)

traced = NKIPyKernel.trace(nkipy_kernel_mixed_return, backend="hlo")
hlo = traced.specialize(a, b)

# 2 outputs: user-returned sum (output0) + auto-appended alias (a_input)
assert len(hlo.outputs) == 2
assert hlo.outputs[0].name == "output0"
assert hlo.outputs[1].name == "a_input"

# Alias metadata
assert len(hlo.aliases) == 1
assert hlo.aliases[0].param_name == "a_input"
assert hlo.aliases[0].output_index == 1
assert hlo.aliases[0].is_user_returned is False

# Input param renamed
param_names = [p.name for p in hlo.parameters]
assert "a_input.must_alias_input" in param_names


def nkipy_kernel_alias_with_multiple_outputs(a_input, b_input, c_input):
"""Kernel that aliases a_input and c_input, returns them plus a computed value.

Pattern: fused pre_moe graph returning (kv_cache, hidden, topk, ...)
where kv_cache is aliased but hidden/topk/... are not.
"""
a_input[0:1, :] = b_input[0:1, :]
c_input[2:3, :] = b_input[2:3, :]
computed = np.add(a_input, b_input)
return a_input, computed, c_input


def test_alias_mixed_with_non_alias_outputs():
"""Multiple outputs where some are aliased and some are not.

Expected NEFF:
inputs = "a_input.must_alias_input", "b_input", "c_input.must_alias_input"
outputs = "a_input" (alias), "output1" (computed), "c_input" (alias)
"""
a = ((np.random.rand(128, 512) - 0.5) * 2).astype(np.float16)
b = ((np.random.rand(128, 512) - 0.5) * 2).astype(np.float16)
c = ((np.random.rand(128, 512) - 0.5) * 2).astype(np.float16)

traced = NKIPyKernel.trace(nkipy_kernel_alias_with_multiple_outputs, backend="hlo")
hlo = traced.specialize(a, b, c)

assert len(hlo.outputs) == 3
# Aliased outputs keep param names; non-aliased get output{idx}
assert hlo.outputs[0].name == "a_input"
assert hlo.outputs[1].name == "output1"
assert hlo.outputs[2].name == "c_input"

# 2 aliases
assert len(hlo.aliases) == 2
alias_names = {a.param_name for a in hlo.aliases}
assert alias_names == {"a_input", "c_input"}


def test_no_return_alias_output_naming():
"""Mutation-only kernel: auto-appended alias is the sole output.

Expected NEFF:
input = "a_input.must_alias_input", "b_input"
output = "a_input"
"""
a = ((np.random.rand(128, 512) - 0.5) * 2).astype(np.float16)
b = ((np.random.rand(128, 512) - 0.5) * 2).astype(np.float16)

traced = NKIPyKernel.trace(nkipy_kernel_no_return, backend="hlo")
hlo = traced.specialize(a, b)

assert len(hlo.outputs) == 1
assert hlo.outputs[0].name == "a_input"

param_names = [p.name for p in hlo.parameters]
assert "a_input.must_alias_input" in param_names
assert "b_input" in param_names


if __name__ == "__main__":
pytest.main([__file__, "-v"])
38 changes: 14 additions & 24 deletions tests/unit/test_device_kernel_cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ def mock_load_from_neff():
@pytest.fixture
def mock_dist():
"""Mock torch.distributed as initialized with world_size=2, rank=0."""
with patch("nkipy.runtime.device_kernel._is_distributed", return_value=True), patch(
"nkipy.runtime.device_kernel.dist", create=True
) as mock_d:
with (
patch("nkipy.runtime.device_kernel._is_distributed", return_value=True),
patch("nkipy.runtime.device_kernel.dist", create=True) as mock_d,
):
mock_d.get_rank.return_value = 0
mock_d.get_world_size.return_value = 2
yield mock_d
Expand All @@ -52,9 +53,10 @@ def mock_dist():
@pytest.fixture
def mock_dist_rank1():
"""Mock torch.distributed as initialized with world_size=2, rank=1."""
with patch("nkipy.runtime.device_kernel._is_distributed", return_value=True), patch(
"nkipy.runtime.device_kernel.dist", create=True
) as mock_d:
with (
patch("nkipy.runtime.device_kernel._is_distributed", return_value=True),
patch("nkipy.runtime.device_kernel.dist", create=True) as mock_d,
):
mock_d.get_rank.return_value = 1
mock_d.get_world_size.return_value = 2
yield mock_d
Expand Down Expand Up @@ -243,33 +245,25 @@ def test_cc_enabled_without_rank_raises(
self, mock_trace_and_compile, mock_load_from_neff
):
"""cc_enabled=True without rank_id/world_size and no dist raises ValueError."""
with patch(
"nkipy.runtime.device_kernel._is_distributed", return_value=False
):
with patch("nkipy.runtime.device_kernel._is_distributed", return_value=False):
with pytest.raises(ValueError, match="rank_id and world_size are required"):
DeviceKernel.compile_and_load(_dummy_kernel, cc_enabled=True)

def test_cc_enabled_without_world_size_raises(
self, mock_trace_and_compile, mock_load_from_neff
):
"""cc_enabled=True with rank_id but no world_size and no dist raises."""
with patch(
"nkipy.runtime.device_kernel._is_distributed", return_value=False
):
with patch("nkipy.runtime.device_kernel._is_distributed", return_value=False):
with pytest.raises(ValueError, match="rank_id and world_size are required"):
DeviceKernel.compile_and_load(
_dummy_kernel, cc_enabled=True, rank_id=0
)
DeviceKernel.compile_and_load(_dummy_kernel, cc_enabled=True, rank_id=0)


class TestNonDistributed:
"""Tests for single-worker (non-distributed) mode."""

def test_no_dist_no_cc(self, mock_trace_and_compile, mock_load_from_neff):
"""Without distributed, loads without CC by default."""
with patch(
"nkipy.runtime.device_kernel._is_distributed", return_value=False
):
with patch("nkipy.runtime.device_kernel._is_distributed", return_value=False):
DeviceKernel.compile_and_load(_dummy_kernel)

mock_load_from_neff.assert_called_once_with(
Expand All @@ -278,9 +272,7 @@ def test_no_dist_no_cc(self, mock_trace_and_compile, mock_load_from_neff):

def test_no_dist_explicit_cc(self, mock_trace_and_compile, mock_load_from_neff):
"""Without torch.distributed, explicit CC params still work."""
with patch(
"nkipy.runtime.device_kernel._is_distributed", return_value=False
):
with patch("nkipy.runtime.device_kernel._is_distributed", return_value=False):
DeviceKernel.compile_and_load(
_dummy_kernel, cc_enabled=True, rank_id=0, world_size=2
)
Expand All @@ -295,9 +287,7 @@ def test_no_dist_explicit_cc(self, mock_trace_and_compile, mock_load_from_neff):

def test_no_dist_mpmd(self, mock_trace_and_compile, mock_load_from_neff):
"""MPMD without torch.distributed works with explicit CC."""
with patch(
"nkipy.runtime.device_kernel._is_distributed", return_value=False
):
with patch("nkipy.runtime.device_kernel._is_distributed", return_value=False):
DeviceKernel.compile_and_load(
_dummy_kernel,
is_spmd=False,
Expand Down