From ed1dc64209835c73e782ae32af222dfa8c9b788c Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Mon, 27 Apr 2026 12:37:48 -0700 Subject: [PATCH] feat: add kernelgen backend with NKI kernel generation support Add a complete kernelgen backend that generates NKI kernels via MLIR, as an alternative to the existing HLO backend. This includes: - Kernelgen backend with MLIR encapsulated behind a builder API - Op implementations for kernelgen (_kernelgen_impls.py) - Inplace update (dynamic_update_slice) support - Unified IR interface for HLO and kernelgen backends - Custom op interface via nki_op.nki_custom_op - Compilation delegation to nkipy_kernelgen.compile - Comprehensive unit and numerical tests --- nkipy/src/nkipy/__init__.py | 4 + nkipy/src/nkipy/core/backend/__init__.py | 87 +++- nkipy/src/nkipy/core/backend/hlo.py | 49 +- nkipy/src/nkipy/core/backend/kernelgen.py | 201 +++++++++ nkipy/src/nkipy/core/compile.py | 165 ++++--- nkipy/src/nkipy/core/knob.py | 70 +++ nkipy/src/nkipy/core/nki_op.py | 124 ++++- nkipy/src/nkipy/core/ops/_kernelgen_impls.py | 284 ++++++++++++ .../src/nkipy/core/ops/_register_kernelgen.py | 149 ++++++ nkipy/src/nkipy/core/trace.py | 326 ++++++++++---- nkipy/src/nkipy/runtime/baremetal_executor.py | 21 +- nkipy/src/nkipy/runtime/decorators.py | 8 +- nkipy/src/nkipy/runtime/device_kernel.py | 27 +- nkipy/src/nkipy/runtime/execute.py | 57 +-- .../src/nkipy/tools/kernel_agent/executor.py | 4 +- tests/test_kernelgen_numerical.py | 130 ++++++ tests/test_kernelgen_ops.py | 367 +++++++++++++++ tests/unit/test_device_kernel_cache.py | 29 +- tests/unit/test_kernelgen_backend.py | 424 ++++++++++++++++++ tests/utils.py | 29 +- 20 files changed, 2253 insertions(+), 302 deletions(-) create mode 100644 nkipy/src/nkipy/core/backend/kernelgen.py create mode 100644 nkipy/src/nkipy/core/knob.py create mode 100644 nkipy/src/nkipy/core/ops/_kernelgen_impls.py create mode 100644 nkipy/src/nkipy/core/ops/_register_kernelgen.py create mode 100644 tests/test_kernelgen_numerical.py create mode 100644 tests/test_kernelgen_ops.py create mode 100644 tests/unit/test_kernelgen_backend.py diff --git a/nkipy/src/nkipy/__init__.py b/nkipy/src/nkipy/__init__.py index 04f8b7b..42b8b4d 100644 --- a/nkipy/src/nkipy/__init__.py +++ b/nkipy/src/nkipy/__init__.py @@ -1,2 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 + +from nkipy.core.knob import knob + +__all__ = ["knob"] diff --git a/nkipy/src/nkipy/core/backend/__init__.py b/nkipy/src/nkipy/core/backend/__init__.py index 3c5e400..732a4f6 100644 --- a/nkipy/src/nkipy/core/backend/__init__.py +++ b/nkipy/src/nkipy/core/backend/__init__.py @@ -12,9 +12,94 @@ from __future__ import annotations from contextlib import contextmanager -from typing import Optional, Tuple +from dataclasses import dataclass +from typing import Dict, List, Optional, Protocol, Tuple, runtime_checkable +import numpy as np + +# --------------------------------------------------------------------------- +# Shared IR data types +# --------------------------------------------------------------------------- + + +@dataclass +class TensorPlaceholder: + """Lightweight tensor metadata used by the execution pipeline. + + Attributes: + name: Identifier used to key this tensor in input/output dicts at runtime. + shape: Static shape of the tensor. + dtype: NumPy dtype of the tensor elements. + """ + + name: str + shape: Tuple[int, ...] + dtype: np.dtype + + +@dataclass(frozen=True) +class AliasInfo: + """One input-output alias pair. + + Attributes: + output_index: Position of this alias in the IR outputs list. + param_index: Position of the aliased parameter in the IR inputs list. + param_name: Name of the aliased input parameter. + is_user_returned: True when the user's kernel explicitly returns this + tensor. False when the framework auto-appended it as an output + solely to write back an in-place mutation. + """ + + output_index: int + param_index: int + param_name: str + is_user_returned: bool + + +# --------------------------------------------------------------------------- +# IR Protocol — the interface that every backend IR must satisfy +# --------------------------------------------------------------------------- + + +@runtime_checkable +class ComputationIR(Protocol): + """Protocol satisfied by both ``HLOModule`` and ``KernelGenIR``.""" + + @property + def inputs(self) -> List[TensorPlaceholder]: ... + + @property + def outputs(self) -> List[TensorPlaceholder]: ... + + @property + def aliases(self) -> List[AliasInfo]: + """Input-output alias pairs for in-place mutations.""" + ... + + @property + def auto_aliased_indices(self) -> set[int]: + """Output indices implicitly appended for write-back, not user-returned.""" + ... + + def resolve_input_arrays( + self, original_inputs: Dict[str, np.ndarray] + ) -> Dict[str, np.ndarray]: + """Map parameter names to backend-specific input names.""" + ... + + def get_alias_input_name(self, alias: AliasInfo) -> str: + """Return the backend input name for an aliased parameter.""" + ... + + def content_hash(self, compiler_args: str) -> str: + """Deterministic hash of IR content and compiler flags for caching.""" + ... + + +# --------------------------------------------------------------------------- # Package-private active context — shared with submodules (e.g. hlo.py). +# --------------------------------------------------------------------------- + _active_ctx = None diff --git a/nkipy/src/nkipy/core/backend/hlo.py b/nkipy/src/nkipy/core/backend/hlo.py index a76b1bc..9ed0d89 100644 --- a/nkipy/src/nkipy/core/backend/hlo.py +++ b/nkipy/src/nkipy/core/backend/hlo.py @@ -8,6 +8,7 @@ from __future__ import annotations +import hashlib import struct from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple, Union @@ -15,6 +16,7 @@ import ml_dtypes import numpy as np +from nkipy.core.backend import AliasInfo, TensorPlaceholder from nkipy.third_party.xla import xla_data_pb2 from nkipy.third_party.xla.service import hlo_pb2 @@ -311,27 +313,6 @@ class HLOTensor: id: Optional[int] = None -@dataclass -class TensorPlaceholder: - """Placeholder for tensor metadata.""" - - name: str - shape: Tuple[int, ...] - dtype: np.dtype - - -@dataclass(frozen=True) -class AliasInfo: - """One input-output alias pair.""" - - output_index: int # Position in HLO output tuple - param_index: int # Position in HLO parameter list - param_name: str # Original parameter name (e.g., "a") - is_user_returned: ( - bool # False = auto-added output, True = user explicitly returned it - ) - - # ============================================================================= # HLO Module # ============================================================================= @@ -378,6 +359,26 @@ def outputs(self) -> List[TensorPlaceholder]: for r in self.results ] + def resolve_input_arrays(self, original_inputs): + """Map IR input names to numpy arrays. + + HLO input names are parameter names, possibly suffixed with + ``.must_alias_input`` for mutated parameters. Both forms are + resolved against *original_inputs* (keyed by bare parameter name). + """ + mapping = {} + for intensor in self.inputs: + if ".must_alias_input" in intensor.name: + base_name = intensor.name.split(".must_alias_input")[0] + else: + base_name = intensor.name + mapping[intensor.name] = original_inputs[base_name] + return mapping + + def get_alias_input_name(self, alias): + """Return the IR input name that an aliased output should share.""" + return f"{alias.param_name}.must_alias_input" + def add_parameter( self, shape: Tuple[int, ...], dtype: np.dtype, name: str = "" ) -> HLOTensor: @@ -407,6 +408,12 @@ def set_results(self, results: Union[HLOTensor, List[HLOTensor]]) -> None: """Set the output results of the module.""" self.results = results if isinstance(results, list) else [results] + def content_hash(self, compiler_args: str) -> str: + h = hashlib.sha256() + h.update(self.to_proto().SerializeToString()) + h.update(compiler_args.encode("utf-8")) + return h.hexdigest()[:12] + # ========================================================================= # Proto Generation # ========================================================================= diff --git a/nkipy/src/nkipy/core/backend/kernelgen.py b/nkipy/src/nkipy/core/backend/kernelgen.py new file mode 100644 index 0000000..fd9fdd0 --- /dev/null +++ b/nkipy/src/nkipy/core/backend/kernelgen.py @@ -0,0 +1,201 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""KernelGen backend for NKIPy. + +This module provides the kernelgen backend by delegating to +``nkipy_kernelgen.builder`` for all MLIR construction. No MLIR types +are imported or exposed — the builder API is the sole interface. +""" + +from __future__ import annotations + +import hashlib +from typing import List + +import numpy as np + +from nkipy.core.backend import AliasInfo, TensorPlaceholder + + +# --------------------------------------------------------------------------- +# KernelGenTensor -- analogue of HLOTensor +# --------------------------------------------------------------------------- + +class KernelGenTensor: + """Backend tensor for the kernelgen backend. + + Wraps an opaque ``TensorHandle`` from ``nkipy_kernelgen.builder`` + with the metadata that ``NKIPyTensorRef`` expects. + """ + + __slots__ = ("handle", "shape", "dtype", "is_parameter", "parameter_id", "name", "id") + + _next_id = 0 + + def __init__(self, handle, shape, dtype, *, is_parameter=False, parameter_id=None, name=""): + self.handle = handle + self.shape = tuple(shape) + self.dtype = np.dtype(dtype) if not isinstance(dtype, np.dtype) else dtype + self.is_parameter = is_parameter + self.parameter_id = parameter_id + self.name = name + self.id = KernelGenTensor._next_id + KernelGenTensor._next_id += 1 + + +# --------------------------------------------------------------------------- +# KernelGenTraceContext +# --------------------------------------------------------------------------- + +class KernelGenTraceContext: + """Trace context that delegates to ``nkipy_kernelgen.builder.IRBuilder``.""" + + backend_name = "kernelgen" + + def __init__(self): + from nkipy_kernelgen.builder import IRBuilder + self._builder = IRBuilder() + self._parameters: List[KernelGenTensor] = [] + self.current_source_location = None + + @property + def module(self): + """Return the underlying MLIR module from the builder.""" + return self._builder.module + + def set_source_location(self, location): + """Set the current source location for diagnostic tracking.""" + self.current_source_location = location + + def _begin_function(self, name, arg_shapes, arg_dtypes): + """Start an MLIR function and return parameter tensors.""" + handles = self._builder.begin_function(name, arg_shapes, arg_dtypes) + tensors = [] + for i, (h, (shape, dtype)) in enumerate( + zip(handles, zip(arg_shapes, arg_dtypes)) + ): + kt = KernelGenTensor( + h, shape, dtype, + is_parameter=True, parameter_id=i, name=f"arg{i}" + ) + self._parameters.append(kt) + tensors.append(kt) + return tensors + + def _finish_function(self, result_tensors): + """Finalize the MLIR function with the given result tensors.""" + self._builder.finish_function([t.handle for t in result_tensors]) + + def _run_canonicalize(self): + """Run MLIR canonicalization passes on the module.""" + self._builder.run_canonicalize() + + def _get_ir_text(self): + """Export the MLIR module as a text string.""" + return self._builder.get_ir_text() + + def _cleanup(self): + """Release builder resources.""" + self._builder.cleanup() + + +# --------------------------------------------------------------------------- +# Module-level context accessor +# --------------------------------------------------------------------------- + +def get_kernelgen_context() -> KernelGenTraceContext: + """Return the active ``KernelGenTraceContext``, or raise if none is active.""" + from nkipy.core.backend import _active_ctx + if _active_ctx is None or _active_ctx.backend_name != "kernelgen": + raise RuntimeError("No active kernelgen trace context") + return _active_ctx + + +# --------------------------------------------------------------------------- +# KernelGenIR -- make MLIR IR compatible with execution pipeline +# --------------------------------------------------------------------------- + + +class KernelGenIR: + """Adapter that makes an MLIR module compatible with the execution pipeline. + + Provides the same interface as ``HLOModule`` (``.inputs``, ``.outputs``, + ``.aliases``, ``.auto_aliased_indices``) so that ``compile.py`` and + ``execute.py`` can handle both backends uniformly. + """ + + def __init__(self, mlir_text, func_name, input_specs, output_specs, + alias_map=None, user_return_len=None, param_name_by_neff=None): + self._mlir_text = mlir_text + self._func_name = func_name + self._input_specs = input_specs # [(name, shape, dtype), ...] + self._output_specs = output_specs # [(name, shape, dtype), ...] + # alias_map: {output_index: (param_name, param_index)} + self._alias_map = alias_map or {} + self._user_return_len = user_return_len if user_return_len is not None else len(output_specs) + # Maps NEFF input names ("in_tensor_0") to original param names ("A") + self._param_name_by_neff = param_name_by_neff or {} + + @property + def inputs(self): + """Return input tensor metadata as ``TensorPlaceholder`` list.""" + return [TensorPlaceholder(n, tuple(s), np.dtype(d)) for n, s, d in self._input_specs] + + @property + def outputs(self): + """Return output tensor metadata as ``TensorPlaceholder`` list.""" + return [TensorPlaceholder(n, tuple(s), np.dtype(d)) for n, s, d in self._output_specs] + + @property + def aliases(self): + """Return input-output alias pairs as ``AliasInfo`` list.""" + return [ + AliasInfo( + output_index=out_idx, + param_index=pidx, + param_name=pname, + is_user_returned=out_idx < self._user_return_len, + ) + for out_idx, (pname, pidx) in self._alias_map.items() + ] + + @property + def auto_aliased_indices(self): + """Output indices that were auto-added (not user-returned).""" + return { + out_idx for out_idx in self._alias_map + if out_idx >= self._user_return_len + } + + def resolve_input_arrays(self, original_inputs): + """Map NEFF input names to numpy arrays. + + NEFF inputs use ``in_tensor_N`` names. *original_inputs* is keyed + by bare parameter names (``A``, ``B``). ``_param_name_by_neff`` + bridges the two. + """ + if len(original_inputs) != len(self._input_specs): + raise RuntimeError( + f"Expected {len(self._input_specs)} tensor arguments, " + f"got {len(original_inputs)}" + ) + mapping = {} + for intensor in self.inputs: + param_name = self._param_name_by_neff.get(intensor.name, intensor.name) + mapping[intensor.name] = original_inputs[param_name] + return mapping + + def get_alias_input_name(self, alias): + """Return the NEFF input name for an aliased parameter.""" + for neff_name, param_name in self._param_name_by_neff.items(): + if param_name == alias.param_name: + return neff_name + return alias.param_name + + def content_hash(self, compiler_args: str) -> str: + """Compute a content hash from the MLIR text and compiler args.""" + h = hashlib.sha256() + h.update(self._mlir_text.encode("utf-8")) + h.update(compiler_args.encode("utf-8")) + return h.hexdigest()[:12] + diff --git a/nkipy/src/nkipy/core/compile.py b/nkipy/src/nkipy/core/compile.py index b8a024a..8c9f5e4 100644 --- a/nkipy/src/nkipy/core/compile.py +++ b/nkipy/src/nkipy/core/compile.py @@ -112,89 +112,67 @@ class Compiler: def __init__(self, config: CompilationConfig): self.config = config - def _build_compile_command(self, mode="hlo") -> List[str]: + def _resolve_target(self) -> CompilationTarget: + if self.config.target != CompilationTarget.DEFAULT: + return self.config.target + try: + return get_platform_target() + except Exception: + logging.warning( + "Failed to detect platform target, falling back to trn2..." + ) + return CompilationTarget.TRN2 + + def _build_hlo_compile_command(self, work_dir: Path) -> List[str]: + """Build the neuronx-cc command line for HLO compilation.""" + target = self._resolve_target() + self.config.target = target + cmd = [ - "neuronx-cc", - "compile", - "--framework", - "XLA", + "neuronx-cc", "compile", + "--framework", "XLA", + str(work_dir / "hlo_module.pb"), + "--pipeline", *self.config.pipeline, + "--target", target.value, + f"--output={self.config.neff_name}", ] - if mode == "hlo": - cmd.extend(["hlo_module.pb"]) - else: - raise RuntimeError(f"Unknown mode: {mode}") - - cmd.append("--pipeline") - cmd.extend(self.config.pipeline) - - # When using default target, detect platform target - if self.config.target == CompilationTarget.DEFAULT: - try: - self.config.target = get_platform_target() - except Exception: - logging.warning( - "Failed to detect platform target, falling back to trn1..." - ) - self.config.target = CompilationTarget.TRN1 - - cmd.extend( - ["--target", self.config.target.value, f"--output={self.config.neff_name}"] - ) if self.config.additional_args: cmd.extend(shlex.split(self.config.additional_args)) return cmd - def compile( + @staticmethod + def _compilation_error(message, cmd=None, result=None): + """Build a RuntimeError with compiler output when available.""" + parts = [message] + if cmd is not None: + parts.append(f"Command: {' '.join(cmd)}") + if result is not None: + def decode(b): + return b.decode("utf-8", errors="replace") if b else "" + parts.append(f"stderr:\n{decode(result.stderr)}") + parts.append(f"stdout:\n{decode(result.stdout)}") + return RuntimeError("\n".join(parts)) + + def _compile_hlo( self, ir, work_dir: Path, output_file: str, use_neuronx_cc_python_interface: bool = False, ) -> Path: - """ - Run compilation in specified directory - - Args: - ir: The IR to compile - work_dir: Directory to compile in - output_file: Name of the output file to check for ("file.neff" or "nki.py") + """Compile an HLOModule to NEFF via neuronx-cc.""" + hlo_pb_path = work_dir / "hlo_module.pb" + proto = ir.to_proto() + with open(hlo_pb_path, "wb") as f: + f.write(proto.SerializeToString()) - Returns: - Path to the output file - """ - - mode = "hlo" if isinstance(ir, HLOModule) else "unknown" - cmd = self._build_compile_command(mode) - - def _compilation_error(message, result=None): - """Build a RuntimeError with compiler output when available.""" - parts = [message, f"Command: {' '.join(cmd)}"] - if result is not None: - - def decode(b): - return b.decode("utf-8", errors="replace") if b else "" - - parts.append(f"stderr:\n{decode(result.stderr)}") - parts.append(f"stdout:\n{decode(result.stdout)}") - return RuntimeError("\n".join(parts)) + cmd = self._build_hlo_compile_command(work_dir) current_dir = os.getcwd() try: os.chdir(work_dir) - if mode == "hlo": - hlo_pb_path = "hlo_module.pb" - proto = ir.to_proto() - with open(hlo_pb_path, "wb") as f: - f.write(proto.SerializeToString()) - else: - raise RuntimeError( - f"Unknown mode: {mode}. " - "Note: For NKI kernels, You can either embed a NKI kernel as an op" - " in NKIPy kernel or implement your own helper function to get the" - " NEFF from a NKI kernel." - ) if use_neuronx_cc_python_interface: original_argv = sys.argv.copy() sys.argv = cmd @@ -202,9 +180,9 @@ def decode(b): else: result = subprocess.run(cmd, capture_output=True) if result.returncode != 0: - raise _compilation_error( + raise self._compilation_error( f"Compilation failed (exit code {result.returncode}).", - result, + cmd, result, ) finally: if use_neuronx_cc_python_interface: @@ -213,13 +191,66 @@ def decode(b): output_path = work_dir / output_file if not output_path.exists(): - raise _compilation_error( + raise self._compilation_error( f"Compilation failed: {output_file} expected but not generated.", + cmd, result if not use_neuronx_cc_python_interface else None, ) + return output_path + + def _compile_kernelgen(self, ir, work_dir: Path, output_file: str) -> Path: + """Compile a KernelGenIR module to NEFF via nkipy_kernelgen.""" + from nkipy_kernelgen.compile import compile_to_neff + + target_str = self._resolve_target().value + cc_args = tuple(shlex.split(self.config.additional_args)) if self.config.additional_args else () + + compile_to_neff( + ir._mlir_text, + ir._func_name, + input_specs=[(s.name, s.shape, s.dtype) for s in ir.inputs], + output_specs=[(s.name, s.shape, s.dtype) for s in ir.outputs], + target=target_str, + output_path=str(work_dir / output_file), + artifacts_dir=str(work_dir), + neuronx_cc_args=cc_args, + ) + + output_path = work_dir / output_file + if not output_path.exists(): + raise self._compilation_error( + f"KernelGen compilation failed: {output_file} not generated." + ) return output_path + def compile( + self, + ir, + work_dir: Path, + output_file: str, + use_neuronx_cc_python_interface: bool = False, + ) -> Path: + """Compile an IR module to a NEFF file. + + Dispatches to ``_compile_hlo`` or ``_compile_kernelgen`` based on + the IR type. + """ + if isinstance(ir, HLOModule): + return self._compile_hlo( + ir, work_dir, output_file, use_neuronx_cc_python_interface + ) + + from nkipy.core.backend.kernelgen import KernelGenIR + + if isinstance(ir, KernelGenIR): + return self._compile_kernelgen(ir, work_dir, output_file) + + raise RuntimeError( + f"Unknown IR type: {type(ir).__name__}. " + "Expected HLOModule or KernelGenIR." + ) + def compile_in_directory( self, ir, diff --git a/nkipy/src/nkipy/core/knob.py b/nkipy/src/nkipy/core/knob.py new file mode 100644 index 0000000..b23513c --- /dev/null +++ b/nkipy/src/nkipy/core/knob.py @@ -0,0 +1,70 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Public knob() API for annotating tensors with hardware placement and tiling hints. + +Dispatches based on the active tracing backend: +- kernelgen: emits nkipy.AnnotateOp into MLIR +- hlo: warns and ignores +- cpu / no trace: no-op pass-through +""" + +from __future__ import annotations + +import warnings +from typing import List, Optional + + +def knob( + tensor, + *, + partition_dim: Optional[int] = None, + mem_space: Optional[str] = None, + tile_size: Optional[List[int]] = None, + reduction_tile: Optional[List[int]] = None, +): + """Annotate a tensor with hardware placement and tiling hints. + + Only effective when using the kernelgen backend. When used with the HLO + backend, issues a warning and returns the tensor unchanged. + + Args: + tensor: The tensor to annotate. + partition_dim: Dimension to partition (must be < tensor rank). + mem_space: Memory space ("Hbm", "Psum", "Sbuf", or "SharedHbm"). + tile_size: Tile sizes for each dimension. + reduction_tile: Tile sizes for reduction dimensions (e.g., K in matmul). + + Returns: + The same tensor, unchanged. + """ + from nkipy.core.backend import get_backend + + backend = get_backend() + + if backend == "kernelgen": + from nkipy.core.tensor import NKIPyTensorRef + + if not isinstance(tensor, NKIPyTensorRef): + return tensor + + if mem_space is None and partition_dim is None and tile_size is None and reduction_tile is None: + return tensor + + import nkipy_kernelgen.builder as B + + B.annotate( + tensor.backend_tensor.handle, + partition_dim=partition_dim, + mem_space=mem_space, + tile_size=tile_size, + reduction_tile=reduction_tile, + ) + return tensor + elif backend == "hlo": + warnings.warn( + "knob() annotations are only effective with backend='kernelgen'. " + "Ignoring annotation.", + stacklevel=2, + ) + + return tensor diff --git a/nkipy/src/nkipy/core/nki_op.py b/nkipy/src/nkipy/core/nki_op.py index 3ac472b..bc7fe21 100644 --- a/nkipy/src/nkipy/core/nki_op.py +++ b/nkipy/src/nkipy/core/nki_op.py @@ -1,8 +1,8 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -"""NKI kernel integration for NKIPy - wraps NKI kernels as HLO custom-calls +"""NKI kernel integration for NKIPy. -This module provides two ways to use NKI kernels in NKIPy: +This module provides three ways to use NKI kernels in NKIPy: 1. Direct @nki.jit support (lazy/dynamic): - Any kernel decorated with @nki.jit can be called directly during NKIPy tracing @@ -14,6 +14,10 @@ - Returns a NKICustomOp that only works with those shapes - Useful for explicit control over specialization +3. nki_custom_op for cross-backend custom ops: + - Accepts both @nki.jit (HLO backend) and kernel_builder (kernelgen backend) + - Dispatches to the correct implementation based on the active backend + Supports two NKI frontends: - Legacy frontend (neuronxcc.nki): Default, supports CPU execution - Beta 2 frontend (nki): New frontend, hardware-only (no CPU execution support) @@ -21,7 +25,7 @@ import dataclasses import inspect -from typing import Callable, Iterable, Optional, Tuple +from typing import Callable, Iterable, List, Optional, Tuple import numpy as np @@ -274,10 +278,10 @@ def _patched_beta2_generic_kernel_call(self, *args, **kwargs): class NKICustomOp: - """Backward-compatible NKI custom op class. + """HLO custom-call wrapper for a pre-traced NKI kernel. - This class provides the original API for wrapping NKI kernels. - New code should use wrap_nki_kernel() or direct @nki.jit instead. + Pre-traces the kernel at construction time for specific operand shapes. + Used by ``wrap_nki_kernel``. """ def __init__( @@ -362,3 +366,111 @@ def wrap_nki_kernel( is_nki_beta_2_version=is_nki_beta_2_version, platform_target=platform_target, ) + + +# --------------------------------------------------------------------------- +# Kernelgen custom op support +# --------------------------------------------------------------------------- + + +def _generate_kernelgen_custom_call(kernel_builder, input_specs, output_specs, *args): + """Compile a kernel_builder function and inline it during kernelgen tracing.""" + from nkipy_kernelgen.builder import apply_custom_op + + return apply_custom_op( + kernel_builder=kernel_builder, + reference_fn=None, + input_specs=input_specs, + output_specs=output_specs, + args=args, + ) + + +# --------------------------------------------------------------------------- +# Unified custom op interface +# --------------------------------------------------------------------------- + + +def nki_custom_op( + *, + nki_kernel: Optional[Callable] = None, + kernel_builder: Optional[Callable] = None, + input_specs: Optional[List[Tuple[Tuple[int, ...], str]]] = None, + output_specs: Optional[List[Tuple[Tuple[int, ...], str]]] = None, +) -> "NKICustomOpHandle": + """Create a cross-backend custom NKI op. + + Args: + nki_kernel: ``@nki.jit`` decorated kernel for the HLO backend. + kernel_builder: ``nki.compiler.kernel_builder`` function for the + kernelgen backend. Requires ``input_specs`` and ``output_specs``. + input_specs: List of ``((shape), dtype_str)`` for each input. + Required when ``kernel_builder`` is provided. + output_specs: List of ``((shape), dtype_str)`` for each output. + Required when ``kernel_builder`` is provided. + + Returns: + An ``NKICustomOpHandle`` callable that dispatches to the correct + backend at call time. + """ + if nki_kernel is None and kernel_builder is None: + raise ValueError( + "At least one of nki_kernel or kernel_builder must be provided." + ) + if kernel_builder is not None: + if input_specs is None or output_specs is None: + raise ValueError( + "input_specs and output_specs are required when kernel_builder " + "is provided." + ) + return NKICustomOpHandle( + nki_kernel=nki_kernel, + kernel_builder=kernel_builder, + input_specs=input_specs, + output_specs=output_specs, + ) + + +class NKICustomOpHandle: + """Backend-aware callable wrapping a custom NKI op definition.""" + + def __init__( + self, + *, + nki_kernel: Optional[Callable], + kernel_builder: Optional[Callable], + input_specs: Optional[List[Tuple[Tuple[int, ...], str]]], + output_specs: Optional[List[Tuple[Tuple[int, ...], str]]], + ): + self._nki_kernel = nki_kernel + self._kernel_builder = kernel_builder + self._input_specs = input_specs + self._output_specs = output_specs + + def __call__(self, *args): + backend = get_backend() + + if backend == "hlo": + if self._nki_kernel is None: + raise RuntimeError( + "nki_custom_op has no nki_kernel for the HLO backend. " + "Provide an @nki.jit decorated kernel via nki_kernel=." + ) + return _generate_nki_custom_call(self._nki_kernel, *args) + + if backend == "kernelgen": + if self._kernel_builder is None: + raise RuntimeError( + "nki_custom_op has no kernel_builder for the kernelgen " + "backend. Provide a kernel_builder function via " + "kernel_builder=." + ) + return _generate_kernelgen_custom_call( + self._kernel_builder, self._input_specs, self._output_specs, + *args, + ) + + raise RuntimeError( + f"nki_custom_op is not supported on backend '{backend}'. " + f"Use the 'hlo' or 'kernelgen' backend." + ) diff --git a/nkipy/src/nkipy/core/ops/_kernelgen_impls.py b/nkipy/src/nkipy/core/ops/_kernelgen_impls.py new file mode 100644 index 0000000..c7c1156 --- /dev/null +++ b/nkipy/src/nkipy/core/ops/_kernelgen_impls.py @@ -0,0 +1,284 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""KernelGen backend implementations for NKIPy ops. + +Trivial ops use _unary/_binary factories that delegate to the builder. +Non-trivial ops with custom logic are explicit functions below. +""" + +from __future__ import annotations + +import numpy as np + +from nkipy.core.tensor import NKIPyTensorRef +from nkipy.core.backend.kernelgen import KernelGenTensor + +_builder_module = None + + +def _builder(): + global _builder_module + if _builder_module is None: + import nkipy_kernelgen.builder as _mod + _builder_module = _mod + return _builder_module + + +def _unwrap(x): + if isinstance(x, NKIPyTensorRef): + return x.backend_tensor.handle + return x + + +def _wrap(handle): + kt = KernelGenTensor(handle, handle.shape, handle.dtype) + return NKIPyTensorRef(kt) + + +# --------------------------------------------------------------------------- +# Factories for trivial delegation to builder +# --------------------------------------------------------------------------- + +def _unary(method): + def impl(x, out=None, dtype=None): + return _wrap(getattr(_builder(), method)(_unwrap(x))) + return impl + + +def _binary(method): + def impl(x, y, out=None, dtype=None): + return _wrap(getattr(_builder(), method)(_unwrap(x), _unwrap(y))) + return impl + + +def _reduce(method): + def impl(x, axis=None, keepdims=False, **kwargs): + return _wrap(getattr(_builder(), method)(_unwrap(x), axis=axis, keepdims=keepdims)) + return impl + + +# Binary ops +add = _binary("add") +subtract = _binary("subtract") +multiply = _binary("multiply") +divide = _binary("divide") +power = _binary("power") +maximum = _binary("maximum") +minimum = _binary("minimum") +equal = _binary("equal") +not_equal = _binary("not_equal") +greater = _binary("greater") +greater_equal = _binary("greater_equal") +less = _binary("less") +less_equal = _binary("less_equal") +bitwise_and = _binary("bitwise_and") +bitwise_or = _binary("bitwise_or") +bitwise_xor = _binary("bitwise_xor") +matmul = _binary("matmul") + +# Unary ops +exp = _unary("exp") +log = _unary("log") +sqrt = _unary("sqrt") +tanh = _unary("tanh") +sin = _unary("sin") +cos = _unary("cos") +sign = _unary("sign") +abs = _unary("abs_") +ceil = _unary("ceil_") +floor = _unary("floor_") + +# Reductions +reduce_sum = _reduce("reduce_sum") +reduce_prod = _reduce("reduce_prod") +reduce_max = _reduce("reduce_max") +reduce_min = _reduce("reduce_min") +reduce_mean = _reduce("reduce_mean") +reduce_std = _reduce("reduce_std") +reduce_var = _reduce("reduce_var") + + +# --------------------------------------------------------------------------- +# Composed unary ops +# --------------------------------------------------------------------------- + +def negative(x, out=None, dtype=None): + return _wrap(_builder().subtract(_unwrap(0), _unwrap(x))) + + +def reciprocal(x, out=None, dtype=None): + return _wrap(_builder().divide(_unwrap(1.0), _unwrap(x))) + + +def square(x, out=None, dtype=None): + h = _unwrap(x) + return _wrap(_builder().multiply(h, h)) + + +def logical_not(x, out=None, dtype=None): + return _wrap(_builder().subtract(_unwrap(1), _unwrap(x))) + + +# --------------------------------------------------------------------------- +# Transform ops with custom signatures +# --------------------------------------------------------------------------- + +def transpose(x, axes=None): + return _wrap(_builder().transpose(_unwrap(x), axes=axes)) + + +def reshape(x, newshape, order='C'): + return _wrap(_builder().reshape(_unwrap(x), newshape)) + + +def expand_dims(x, axis): + return _wrap(_builder().expand_dims(_unwrap(x), axis)) + + +def copy(x, order='K', subok=True): + return _wrap(_builder().copy_(_unwrap(x))) + + +def broadcast_to(x, shape): + return _wrap(_builder().broadcast_to(_unwrap(x), tuple(shape))) + + +def astype(x, dtype): + return _wrap(_builder().astype(_unwrap(x), dtype)) + + +def concatenate(arrays, axis=0, out=None, dtype=None): + handles = [_unwrap(a) for a in arrays] + return _wrap(_builder().concatenate(handles, axis=axis)) + + +def where(condition, x, y): + return _wrap(_builder().where(_unwrap(condition), _unwrap(x), _unwrap(y))) + + +def take(a, indices, axis=0): + return _wrap(_builder().take(_unwrap(a), _unwrap(indices), axis=axis)) + + +# --------------------------------------------------------------------------- +# Creation ops +# --------------------------------------------------------------------------- + +def zeros(shape, dtype=np.float32): + return _wrap(_builder().zeros(tuple(shape), dtype)) + + +def full(shape, fill_value, dtype=np.float32): + return _wrap(_builder().full(tuple(shape), fill_value, dtype)) + + +def zeros_like(x, dtype=None): + h = _unwrap(x) + dt = dtype if dtype is not None else h.dtype + return zeros(h.shape, dt) + + +def ones_like(x, dtype=None): + h = _unwrap(x) + dt = dtype if dtype is not None else h.dtype + return full(h.shape, 1.0, dt) + + +def empty_like(x, dtype=None): + h = _unwrap(x) + dt = dtype if dtype is not None else h.dtype + return _wrap(_builder().empty(h.shape, dt)) + + +def full_like(x, fill_value, dtype=None): + h = _unwrap(x) + dt = dtype if dtype is not None else h.dtype + return full(h.shape, fill_value, dt) + + +# --------------------------------------------------------------------------- +# Squeeze / swapaxes / stack / split +# --------------------------------------------------------------------------- + +def squeeze(x, axis=None): + h = _unwrap(x) + shape = h.shape + if axis is None: + new_shape = tuple(d for d in shape if d != 1) + else: + if isinstance(axis, int): + axis = (axis,) + new_shape = tuple(d for i, d in enumerate(shape) if i not in axis) + if new_shape == shape: + return x + return reshape(x, new_shape) + + +def swapaxes(x, axis1, axis2): + h = _unwrap(x) + rank = len(h.shape) + perm = list(range(rank)) + perm[axis1], perm[axis2] = perm[axis2], perm[axis1] + return transpose(x, axes=perm) + + +def stack(arrays, axis=0): + expanded = [expand_dims(a, axis) for a in arrays] + return concatenate(expanded, axis=axis) + + +def split(x, indices_or_sections, axis=0): + h = _unwrap(x) + shape = h.shape + if isinstance(indices_or_sections, int): + sections = indices_or_sections + size = shape[axis] + section_size = size // sections + results = [] + for i in range(sections): + start = [0] * len(shape) + start[axis] = i * section_size + limit = list(shape) + limit[axis] = (i + 1) * section_size + strides = [1] * len(shape) + results.append(static_slice(x, start, limit, strides, [])) + return tuple(results) + raise NotImplementedError("split with explicit indices not yet implemented") + + +# --------------------------------------------------------------------------- +# Static slicing +# --------------------------------------------------------------------------- + +def static_slice(x, start_indices, limit_indices, strides, squeeze_dims): + return _wrap(_builder().static_slice( + _unwrap(x), start_indices, limit_indices, strides, squeeze_dims, + )) + + +# --------------------------------------------------------------------------- +# Slice assignment (dynamic_update_slice) +# --------------------------------------------------------------------------- + +def dynamic_update_slice(x, value, start_indices, update_shape): + b = _builder() + x_h = _unwrap(x) + if isinstance(value, NKIPyTensorRef): + value_h = _unwrap(value) + elif isinstance(value, (int, float)): + value_h = b.full(tuple(update_shape), value, x_h.dtype) + elif isinstance(value, np.ndarray): + raise NotImplementedError( + "Assigning a raw np.ndarray constant in kernelgen is not supported. " + "Use a traced tensor expression instead." + ) + else: + value_h = value + + if value_h.shape != tuple(update_shape): + value_h = b.reshape(value_h, tuple(update_shape)) + + sizes = list(update_shape) + strides = [1] * len(start_indices) + result_h = b.static_insert_slice(x_h, value_h, start_indices, sizes, strides) + return _wrap(result_h) diff --git a/nkipy/src/nkipy/core/ops/_register_kernelgen.py b/nkipy/src/nkipy/core/ops/_register_kernelgen.py new file mode 100644 index 0000000..e0d4773 --- /dev/null +++ b/nkipy/src/nkipy/core/ops/_register_kernelgen.py @@ -0,0 +1,149 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Register kernelgen backend implementations for all ops. + +Called lazily the first time the kernelgen backend is activated, so MLIR +imports only happen when needed. + +Composed ops (floor_divide, tan, rint, etc.) reuse the HLO implementations +since those are expressed purely in terms of Op-dispatched calls and work +on any backend that has the primitives registered. +""" + +_registered = False + + +def register_all_kernelgen_impls(): + global _registered + if _registered: + return + _registered = True + + from nkipy.core.ops import _kernelgen_impls as kernelgen_impls + + # --- Binary ops (primitives) --- + from nkipy.core.ops.binary import ( + add, subtract, multiply, divide, power, maximum, minimum, + ) + add.impl("kernelgen")(kernelgen_impls.add) + subtract.impl("kernelgen")(kernelgen_impls.subtract) + multiply.impl("kernelgen")(kernelgen_impls.multiply) + divide.impl("kernelgen")(kernelgen_impls.divide) + power.impl("kernelgen")(kernelgen_impls.power) + maximum.impl("kernelgen")(kernelgen_impls.maximum) + minimum.impl("kernelgen")(kernelgen_impls.minimum) + + # --- Binary ops (composed) — reuse HLO impls since they use Op dispatch --- + from nkipy.core.ops.binary import ( + floor_divide, remainder, logaddexp, + logical_and, logical_or, logical_xor, + ) + floor_divide.impl("kernelgen")(floor_divide._impls["hlo"]) + remainder.impl("kernelgen")(remainder._impls["hlo"]) + logaddexp.impl("kernelgen")(logaddexp._impls["hlo"]) + logical_and.impl("kernelgen")(logical_and._impls["hlo"]) + logical_or.impl("kernelgen")(logical_or._impls["hlo"]) + logical_xor.impl("kernelgen")(logical_xor._impls["hlo"]) + + # --- Comparison / bitwise ops --- + from nkipy.core.ops.binary import ( + equal, not_equal, greater, greater_equal, less, less_equal, + bitwise_and, bitwise_or, bitwise_xor, + ) + equal.impl("kernelgen")(kernelgen_impls.equal) + not_equal.impl("kernelgen")(kernelgen_impls.not_equal) + greater.impl("kernelgen")(kernelgen_impls.greater) + greater_equal.impl("kernelgen")(kernelgen_impls.greater_equal) + less.impl("kernelgen")(kernelgen_impls.less) + less_equal.impl("kernelgen")(kernelgen_impls.less_equal) + bitwise_and.impl("kernelgen")(kernelgen_impls.bitwise_and) + bitwise_or.impl("kernelgen")(kernelgen_impls.bitwise_or) + bitwise_xor.impl("kernelgen")(kernelgen_impls.bitwise_xor) + + # --- Unary ops (primitives) --- + from nkipy.core.ops.unary import ( + abs, exp, log, sqrt, sin, cos, tanh, ceil, floor, sign, + square, negative, reciprocal, logical_not, + ) + exp.impl("kernelgen")(kernelgen_impls.exp) + log.impl("kernelgen")(kernelgen_impls.log) + sqrt.impl("kernelgen")(kernelgen_impls.sqrt) + tanh.impl("kernelgen")(kernelgen_impls.tanh) + sin.impl("kernelgen")(kernelgen_impls.sin) + cos.impl("kernelgen")(kernelgen_impls.cos) + sign.impl("kernelgen")(kernelgen_impls.sign) + abs.impl("kernelgen")(kernelgen_impls.abs) + ceil.impl("kernelgen")(kernelgen_impls.ceil) + floor.impl("kernelgen")(kernelgen_impls.floor) + negative.impl("kernelgen")(kernelgen_impls.negative) + reciprocal.impl("kernelgen")(kernelgen_impls.reciprocal) + square.impl("kernelgen")(kernelgen_impls.square) + logical_not.impl("kernelgen")(kernelgen_impls.logical_not) + + # --- Unary ops (composed) — reuse HLO impls --- + from nkipy.core.ops.unary import ( + log1p, log2, expm1, tan, clip, rint, trunc, round_, isnan, isfinite, + ) + log1p.impl("kernelgen")(log1p._impls["hlo"]) + log2.impl("kernelgen")(log2._impls["hlo"]) + expm1.impl("kernelgen")(expm1._impls["hlo"]) + tan.impl("kernelgen")(tan._impls["hlo"]) + clip.impl("kernelgen")(clip._impls["hlo"]) + rint.impl("kernelgen")(rint._impls["hlo"]) + trunc.impl("kernelgen")(trunc._impls["hlo"]) + round_.impl("kernelgen")(round_._impls["hlo"]) + isnan.impl("kernelgen")(isnan._impls["hlo"]) + isfinite.impl("kernelgen")(isfinite._impls["hlo"]) + + # --- Linalg ops --- + from nkipy.core.ops.linalg import matmul + matmul.impl("kernelgen")(kernelgen_impls.matmul) + + # --- Reduction ops --- + from nkipy.core.ops.reduce import sum, prod, max, min, mean, std, var + sum.impl("kernelgen")(kernelgen_impls.reduce_sum) + prod.impl("kernelgen")(kernelgen_impls.reduce_prod) + max.impl("kernelgen")(kernelgen_impls.reduce_max) + min.impl("kernelgen")(kernelgen_impls.reduce_min) + mean.impl("kernelgen")(kernelgen_impls.reduce_mean) + std.impl("kernelgen")(kernelgen_impls.reduce_std) + var.impl("kernelgen")(kernelgen_impls.reduce_var) + + # --- Creation ops --- + from nkipy.core.ops.creation import ( + zeros as zeros_op, full as full_op, + zeros_like, ones_like, empty_like, full_like, + ) + zeros_op.impl("kernelgen")(kernelgen_impls.zeros) + full_op.impl("kernelgen")(kernelgen_impls.full) + zeros_like.impl("kernelgen")(kernelgen_impls.zeros_like) + ones_like.impl("kernelgen")(kernelgen_impls.ones_like) + empty_like.impl("kernelgen")(kernelgen_impls.empty_like) + full_like.impl("kernelgen")(kernelgen_impls.full_like) + + # --- Transform ops --- + from nkipy.core.ops.transform import ( + transpose, reshape, expand_dims, concatenate, + split, copy, broadcast_to, astype, squeeze, swapaxes, stack, + ) + transpose.impl("kernelgen")(kernelgen_impls.transpose) + reshape.impl("kernelgen")(kernelgen_impls.reshape) + expand_dims.impl("kernelgen")(kernelgen_impls.expand_dims) + concatenate.impl("kernelgen")(kernelgen_impls.concatenate) + split.impl("kernelgen")(kernelgen_impls.split) + copy.impl("kernelgen")(kernelgen_impls.copy) + broadcast_to.impl("kernelgen")(kernelgen_impls.broadcast_to) + astype.impl("kernelgen")(kernelgen_impls.astype) + squeeze.impl("kernelgen")(kernelgen_impls.squeeze) + swapaxes.impl("kernelgen")(kernelgen_impls.swapaxes) + stack.impl("kernelgen")(kernelgen_impls.stack) + + # --- Indexing ops --- + from nkipy.core.ops.indexing import ( + where as where_op, take as take_op, + static_slice, dynamic_update_slice, + ) + where_op.impl("kernelgen")(kernelgen_impls.where) + take_op.impl("kernelgen")(kernelgen_impls.take) + static_slice.impl("kernelgen")(kernelgen_impls.static_slice) + dynamic_update_slice.impl("kernelgen")(kernelgen_impls.dynamic_update_slice) diff --git a/nkipy/src/nkipy/core/trace.py b/nkipy/src/nkipy/core/trace.py index 88b908c..292bc7b 100644 --- a/nkipy/src/nkipy/core/trace.py +++ b/nkipy/src/nkipy/core/trace.py @@ -8,9 +8,8 @@ import numpy as np from nkipy.core._numpy_dispatch import register_all_numpy_apis -from nkipy.core.backend import tracing +from nkipy.core.backend import AliasInfo, tracing from nkipy.core.backend.hlo import ( - AliasInfo, HLOModule, HLOTraceContext, get_hlo_context, @@ -47,10 +46,48 @@ def _sanitize_array_dtype(arr: np.ndarray, name: str = "") -> np.ndarray: return arr.astype(target) +def _convert_args(sig, boundargs, convert_arg): + """Convert bound arguments to traced tensor refs. + + Shared by both HLO and kernelgen specialization paths. + + Each argument is passed through *convert_arg* which replaces ndarrays + with backend-specific tensor refs and returns non-tensor values unchanged. + VAR_POSITIONAL and VAR_KEYWORD arguments are expanded so each element is + converted individually. + + Args: + sig: The function's inspect.Signature. + boundargs: The BoundArguments (already defaulted). + convert_arg: ``(name, arg) -> converted_value``. + + Returns: + ``(converted_args, converted_kwargs)`` ready to call the kernel. + """ + converted_args = [] + converted_kwargs = {} + + for name, arg in boundargs.arguments.items(): + param = sig.parameters[name] + + if param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD): + converted_args.append(convert_arg(name, arg)) + elif param.kind == param.KEYWORD_ONLY: + converted_kwargs[name] = convert_arg(name, arg) + elif param.kind == param.VAR_POSITIONAL: + for item in arg: + converted_args.append(convert_arg(name, item)) + elif param.kind == param.VAR_KEYWORD: + for k, v in arg.items(): + converted_kwargs[k] = convert_arg(k, v) + + return converted_args, converted_kwargs + + class NKIPyKernel: """Simplified kernel wrapper for NKIPy tracing""" - def __init__(self, func, backend, **kwargs): + def __init__(self, func, backend): self.func = func self.backend = backend self._code = None @@ -65,102 +102,78 @@ def __repr__(self): def specialize(self, *args, **kwargs): if self.backend == "hlo": return self._specialize_hlo(*args, **kwargs) + elif self.backend == "kernelgen": + return self._specialize_kernelgen(*args, **kwargs) elif self.backend == "cpu": - print("CPU backend does not require specialization") + warnings.warn( + "CPU backend does not require specialization", stacklevel=2 + ) return else: raise ValueError(f"Unknown backend {self.backend}") def _create_parameter_hlo(self, shape, dtype, name=""): - """Create an HLO parameter tensor""" + """Create an HLO parameter tensor.""" ctx = get_hlo_context() hlo_tensor = ctx.module.add_parameter(shape, dtype, name=name) return NKIPyTensorRef(hlo_tensor, name=name) def _specialize_hlo(self, *args, **kwargs): - """Trace the kernel with specific arguments""" + """Trace the kernel with specific arguments.""" code = HLOModule(name=self.func.__name__) with tracing(HLOTraceContext(code)): - # Bind arguments sig = inspect.signature(self.func) boundargs = sig.bind(*args, **kwargs) boundargs.apply_defaults() - # Convert numpy arrays to tensor references - converted_args = [] - converted_kwargs = {} - # Track parameter tensor refs: list of (param_name, tensor_ref) for arrays param_tensor_refs = [] - for name, arg in boundargs.arguments.items(): - param = sig.parameters[name] - + def _make_hlo_ref(name, arg): if isinstance(arg, np.ndarray): arg = _sanitize_array_dtype(arg, name) tensor_ref = self._create_parameter_hlo(arg.shape, arg.dtype, name) tensor_ref._original_parameter = tensor_ref.backend_tensor - converted_value = tensor_ref - param_tensor_refs.append((name, tensor_ref)) - else: - converted_value = arg - - # Determine if this should be positional or keyword - if param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD): - converted_args.append(converted_value) - elif param.kind == param.KEYWORD_ONLY: - converted_kwargs[name] = converted_value - elif param.kind == param.VAR_POSITIONAL: - if isinstance(arg, (list, tuple)): - for item in arg: - if isinstance(item, np.ndarray): - item = _sanitize_array_dtype(item, f"{name}_item") - converted_args.append( - self._create_parameter_hlo( - item.shape, item.dtype, f"{name}_item" - ) - ) - else: - converted_args.append(item) - else: - converted_args.append(converted_value) - elif param.kind == param.VAR_KEYWORD: - if isinstance(arg, dict): - for k, v in arg.items(): - if isinstance(v, np.ndarray): - v = _sanitize_array_dtype(v, k) - converted_kwargs[k] = self._create_parameter_hlo( - v.shape, v.dtype, k - ) - else: - converted_kwargs[k] = v - - # Execute function + param_index = tensor_ref.backend_tensor.parameter_id + param_tensor_refs.append((name, param_index, tensor_ref)) + return tensor_ref + return arg + + converted_args, converted_kwargs = _convert_args( + sig, boundargs, _make_hlo_ref + ) + ret = self.func(*converted_args, **converted_kwargs) - # Mark outputs self._mark_hlo_outputs(code, ret, param_tensor_refs) self._code = code return code - def _mark_hlo_outputs(self, code: HLOModule, ret, param_tensor_refs): - """Mark HLO outputs using mutation tracking. + @staticmethod + def _detect_mutations(ret, param_tensor_refs): + """Detect mutated parameters and auto-append them to the return list. - Detects aliasing by checking which parameter tensor refs were mutated - (had __setitem__ called on them) during kernel execution. + Checks which parameter tensor refs were mutated (had __setitem__ + called on them) during kernel execution. Mutated parameters that + the user did not return are appended automatically so the backend + can compile them as outputs. Note: Only direct mutations on the original parameter tensor refs are detected. View aliasing (e.g. ``b = a[0]; b[x] = y``) is not tracked - because ``__getitem__`` creates a new tensor ref with no parent link. + because ``__getitem__`` creates a new NKIPyTensorRef with no parent link. Args: - code: The HLOModule being built - ret: The return value(s) from the kernel function - param_tensor_refs: List of (param_name, tensor_ref) for array parameters + ret: The return value(s) from the kernel function. + param_tensor_refs: List of (param_name, param_index, tensor_ref). + + Returns: + ``(ret, user_return_len, alias_map)`` where *ret* is the + (possibly extended) list of outputs, *user_return_len* is the + original count before auto-appending, and *alias_map* is + ``{output_index: (param_name, param_index)}``. """ - # Normalize return value to a list (may be None for mutation-only kernels) if ret is None: ret = [] elif not isinstance(ret, (list, tuple)): @@ -168,62 +181,64 @@ def _mark_hlo_outputs(self, code: HLOModule, ret, param_tensor_refs): ret = list(ret) user_return_len = len(ret) - ctx = get_hlo_context() - - # Step 1: For each mutated param, rename HLO parameter. - # Check if user returned it; if not, auto-append to output list. - aliased_return_positions = {} # output_index -> (param_name, param_index) - for name, tr in param_tensor_refs: + alias_map = {} + for name, pidx, tr in param_tensor_refs: if not tr._is_mutated: continue - - # Rename HLO parameter for compiler convention - param_index = None - for hlo_param in code.parameters: - if hlo_param.name == name: - hlo_param.name = f"{name}.must_alias_input" - param_index = hlo_param.parameter_id - break - - if param_index is None: - raise RuntimeError( - f"Mutated parameter '{name}' not found in HLO parameters" - ) - - # Check if this mutated param is in the user's return values (identity check) found_at = None for i, r in enumerate(ret): if isinstance(r, NKIPyTensorRef) and r is tr: found_at = i break - if found_at is not None: - aliased_return_positions[found_at] = (name, param_index) + alias_map[found_at] = (name, pidx) else: - # Auto-append to output list ret.append(tr) - aliased_return_positions[len(ret) - 1] = (name, param_index) + alias_map[len(ret) - 1] = (name, pidx) + + return ret, user_return_len, alias_map + + def _mark_hlo_outputs(self, code: HLOModule, ret, param_tensor_refs): + """Mark HLO outputs using mutation tracking. + + Args: + code: The HLOModule being built + ret: The return value(s) from the kernel function + param_tensor_refs: List of (param_name, param_index, tensor_ref) + """ + ret, user_return_len, alias_map = self._detect_mutations( + ret, param_tensor_refs + ) + + ctx = get_hlo_context() - # Step 2: Insert explicit copy for unmutated pass-through outputs. + # Rename mutated HLO parameters for compiler convention + for _, (param_name, _) in alias_map.items(): + for hlo_param in code.parameters: + if hlo_param.name == param_name: + hlo_param.name = f"{param_name}.must_alias_input" + break + + # Insert explicit copy for unmutated pass-through outputs. # The Neuron compiler cannot handle outputs that are raw parameter # references because inputs and outputs occupy separate memory regions. for i, r in enumerate(ret): if not isinstance(r, NKIPyTensorRef): continue - if i in aliased_return_positions: + if i in alias_map: continue bt = r.backend_tensor if bt.is_parameter: copy_tensor = ctx.build_op("copy", [bt], bt.shape, bt.dtype) ret[i] = NKIPyTensorRef(copy_tensor, name="") - # Step 3: Assign output names and build AliasInfo list + # Assign output names and build AliasInfo list for idx, r in enumerate(ret): if not isinstance(r, NKIPyTensorRef): raise RuntimeError(f"Unexpected return value type: {type(r)}") - if idx in aliased_return_positions: - param_name, param_index = aliased_return_positions[idx] + if idx in alias_map: + param_name, param_index = alias_map[idx] code.aliases.append( AliasInfo( output_index=idx, @@ -242,9 +257,132 @@ def _mark_hlo_outputs(self, code: HLOModule, ret, param_tensor_refs): result_tensors = [r.backend_tensor for r in ret] code.set_results(result_tensors) + def _specialize_kernelgen(self, *args, **kwargs): + """Trace the kernel to MLIR linalg/tensor IR via the kernelgen backend.""" + from nkipy.core.backend.kernelgen import KernelGenTraceContext + from nkipy.core.ops._register_kernelgen import register_all_kernelgen_impls + + register_all_kernelgen_impls() + + kctx = KernelGenTraceContext() + + sig = inspect.signature(self.func) + boundargs = sig.bind(*args, **kwargs) + boundargs.apply_defaults() + + arg_shapes = [] + arg_dtypes = [] + arg_names = [] + + def _collect_array(name, arg): + arg = _sanitize_array_dtype(arg, name) + arg_shapes.append(arg.shape) + arg_dtypes.append(arg.dtype) + arg_names.append(name) + return arg + + for name, arg in boundargs.arguments.items(): + param = sig.parameters[name] + if param.kind == param.VAR_POSITIONAL: + sanitized = [] + for item in arg: + sanitized.append( + _collect_array(name, item) + if isinstance(item, np.ndarray) + else item + ) + boundargs.arguments[name] = tuple(sanitized) + elif param.kind == param.VAR_KEYWORD: + for k, v in arg.items(): + if isinstance(v, np.ndarray): + arg[k] = _collect_array(k, v) + elif isinstance(arg, np.ndarray): + arg = _collect_array(name, arg) + boundargs.arguments[name] = arg + + param_tensors = kctx._begin_function(self.func.__name__, arg_shapes, arg_dtypes) + for pt, name in zip(param_tensors, arg_names): + pt.name = name + + param_tensor_refs = [] + + with tracing(kctx): + param_idx = 0 + + def _make_kg_ref(name, arg): + nonlocal param_idx + if isinstance(arg, np.ndarray): + ref = NKIPyTensorRef(param_tensors[param_idx], name=name) + param_tensor_refs.append((name, param_idx, ref)) + param_idx += 1 + return ref + return arg + + converted_args, converted_kwargs = _convert_args( + sig, boundargs, _make_kg_ref + ) + + raw_ret = self.func(*converted_args, **converted_kwargs) + + ret, user_return_len, alias_map = self._detect_mutations( + raw_ret, param_tensor_refs + ) + + result_kg_tensors = [] + for r in ret: + if isinstance(r, NKIPyTensorRef): + result_kg_tensors.append(r.backend_tensor) + else: + raise RuntimeError(f"Unexpected return type: {type(r)}") + + kctx._finish_function(result_kg_tensors) + + kctx._run_canonicalize() + + mlir_text = kctx._get_ir_text() + kctx._cleanup() + + # Use NEFF-compatible names: the kernelgen NEFF uses "in_tensor_N" + # for inputs and "output" / "output_N" for outputs (determined by the + # NKI compiler C++ pipeline from unnamed MLIR block arguments and the + # nki.output_names attribute set by the linalg-to-nisa pass). + num_outputs = len(result_kg_tensors) + input_info = [ + (f"in_tensor_{i}", shape, dtype) + for i, (shape, dtype) in enumerate(zip(arg_shapes, arg_dtypes)) + ] + output_info = [ + ( + "output" if num_outputs == 1 else f"output_{i}", + t.shape, + t.dtype, + ) + for i, t in enumerate(result_kg_tensors) + ] + + # Map NEFF input names back to original parameter names so + # resolve_input_arrays can look up the right numpy arrays. + param_name_by_neff = { + f"in_tensor_{i}": name + for i, name in enumerate(arg_names) + } + + from nkipy.core.backend.kernelgen import KernelGenIR + + self._code = KernelGenIR( + mlir_text=mlir_text, + func_name=self.func.__name__, + input_specs=input_info, + output_specs=output_info, + alias_map=alias_map, + user_return_len=user_return_len, + param_name_by_neff=param_name_by_neff, + ) + return self._code + @classmethod - def trace(cls, func=None, backend="hlo", **kwargs): - """Decorator to create traced kernel""" + def trace(cls, func=None, backend="hlo"): + """Decorator to create traced kernel.""" if func is None: - return lambda f: cls(f, backend, **kwargs) - return cls(func, backend, **kwargs) + return lambda f: cls(f, backend) + return cls(func, backend) diff --git a/nkipy/src/nkipy/runtime/baremetal_executor.py b/nkipy/src/nkipy/runtime/baremetal_executor.py index 3a01937..a0afc55 100644 --- a/nkipy/src/nkipy/runtime/baremetal_executor.py +++ b/nkipy/src/nkipy/runtime/baremetal_executor.py @@ -71,25 +71,22 @@ def _prepare_io_tensors( } # Prepare inputs using DeviceTensor - inputs = {} - for intensor in compiled_kernel.ir.inputs: - real_name = ( - intensor.name.split(".must_alias_input")[0] - if ".must_alias_input" in intensor.name - else intensor.name - ) - np_tensor = original_inputs.get(real_name, boundargs.arguments[real_name]) - inputs[intensor.name] = DeviceTensor.from_numpy(np_tensor) + ir = compiled_kernel.ir + input_arrays = ir.resolve_input_arrays(original_inputs) + inputs = { + name: DeviceTensor.from_numpy(arr) + for name, arr in input_arrays.items() + } # Prepare outputs — aliased outputs share the input device buffer outputs = device_kernel.allocate_output_tensors() outputs_dict = {t.name: t for t in outputs} - alias_by_output = {a.output_index: a for a in compiled_kernel.ir.aliases} - for i, outtensor in enumerate(compiled_kernel.ir.outputs): + alias_by_output = {a.output_index: a for a in ir.aliases} + for i, outtensor in enumerate(ir.outputs): if i in alias_by_output: alias = alias_by_output[i] - input_name = f"{alias.param_name}.must_alias_input" + input_name = ir.get_alias_input_name(alias) outputs_dict[outtensor.name] = inputs[input_name] return inputs, outputs_dict, original_inputs diff --git a/nkipy/src/nkipy/runtime/decorators.py b/nkipy/src/nkipy/runtime/decorators.py index 65fd000..7bf7785 100644 --- a/nkipy/src/nkipy/runtime/decorators.py +++ b/nkipy/src/nkipy/runtime/decorators.py @@ -10,6 +10,7 @@ def baremetal_jit( kernel_func=None, *, + backend="hlo", additional_compiler_args="", target=compile.CompilationTarget.DEFAULT, ): @@ -21,6 +22,7 @@ def baremetal_jit( Args: kernel_func: The kernel function to decorate (when used without parentheses) + backend: Compilation backend ("hlo" or "kernelgen") additional_compiler_args: Additional arguments to pass to the compiler target: Compilation target (default: CompilationTarget.DEFAULT) @@ -35,8 +37,8 @@ def my_kernel(A, B): # Compiles on first call with this signature result = my_kernel(input_a, input_b) - # Or with compiler args: - @baremetal_jit(additional_compiler_args="--lnc 1") + # Or with kernelgen backend: + @baremetal_jit(backend="kernelgen") def my_kernel(A, B): return A @ B """ @@ -45,7 +47,7 @@ def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): # Trace the kernel - traced_kernel = trace(func) + traced_kernel = trace(func, backend=backend) # Use baremetal_run_traced_kernel for execution return baremetal_run_traced_kernel( traced_kernel, diff --git a/nkipy/src/nkipy/runtime/device_kernel.py b/nkipy/src/nkipy/runtime/device_kernel.py index 2960fe9..2737fd2 100644 --- a/nkipy/src/nkipy/runtime/device_kernel.py +++ b/nkipy/src/nkipy/runtime/device_kernel.py @@ -1,14 +1,12 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import atexit -import hashlib import os import shutil import time import types from nkipy.core import compile -from nkipy.core.backend.hlo import HLOModule from nkipy.core.compile import CompilationTarget, _get_build_dir, compile_to_neff, trace from nkipy.core.logger import get_logger from nkipy.core.trace import NKIPyKernel @@ -35,24 +33,6 @@ def _cleanup_kernels(): atexit.register(_cleanup_kernels) -def _hlo_content_hash(hlo_module: HLOModule, compiler_args: str) -> str: - """Compute a content hash from the HLO protobuf and compiler args. - - Hashing the HLO (instead of source code) ensures that different input - shapes/dtypes produce different cache entries, even when the kernel - source is identical. - - The HLO proto uses only ``repeated`` fields (no ``map`` fields), so - ``SerializeToString()`` is deterministic for the same computation graph. - """ - h = hashlib.sha256() - - # TODO: this SerializeToString can be slow for large HLO - h.update(hlo_module.to_proto().SerializeToString()) - h.update(compiler_args.encode("utf-8")) - return h.hexdigest()[:12] - - def _is_distributed() -> bool: """Check if running in a multi-worker torch.distributed setting.""" return ( @@ -255,12 +235,7 @@ def _trace_and_compile( traced_kernel.specialize(*numpy_args, **numpy_kwargs) - # Compute content hash from HLO - hlo_module = traced_kernel._code - if not isinstance(hlo_module, HLOModule): - raise NotImplementedError("Only HLOModule is supported for content hashing") - - content_hash = _hlo_content_hash(hlo_module, compiler_args) + content_hash = traced_kernel._code.content_hash(compiler_args) cache_key = f"{name}_{content_hash}" # Determine output paths diff --git a/nkipy/src/nkipy/runtime/execute.py b/nkipy/src/nkipy/runtime/execute.py index da20063..c75efe3 100644 --- a/nkipy/src/nkipy/runtime/execute.py +++ b/nkipy/src/nkipy/runtime/execute.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 """Execution wrappers for NKIPy kernels""" +from __future__ import annotations + import inspect import os import shutil @@ -9,7 +11,7 @@ import numpy as np from nkipy.core import compile -from nkipy.core.trace import _sanitize_array_dtype +from nkipy.core.backend import ComputationIR try: from nkipy.runtime.device_kernel import DeviceKernel @@ -30,39 +32,21 @@ def _compile_kernel( ): """Specialize and compile a traced kernel to NEFF. - Returns (neff_path, kernel_name, ir, boundargs). + Returns (neff_path, kernel_name, ir, original_inputs). """ - # Sanitize unsupported dtypes (float64/int64/uint64) before tracing - args = tuple( - _sanitize_array_dtype(a, f"arg{i}") if isinstance(a, np.ndarray) else a - for i, a in enumerate(args) - ) - kwargs = { - k: _sanitize_array_dtype(v, k) if isinstance(v, np.ndarray) else v - for k, v in kwargs.items() - } - - # Trace the kernel with the provided arguments kernel.specialize(*args, **kwargs) ir = kernel._code - # Bind arguments for input/output mapping sig = inspect.signature(kernel.func) boundargs = sig.bind(*args, **kwargs) boundargs.apply_defaults() - # Save original input arrays before output allocation may overwrite them original_inputs = { name: arr for name, arr in boundargs.arguments.items() if isinstance(arr, np.ndarray) } - # Allocate output tensors based on IR outputs - for outtensor in ir.outputs: - output_array = np.empty(outtensor.shape, dtype=outtensor.dtype) - boundargs.arguments[outtensor.name] = output_array - name = kernel.__name__ build_dir = artifacts_dir if artifacts_dir else f"{compile._get_build_dir()}/{name}" @@ -88,10 +72,10 @@ def _compile_kernel( target=target, ) - return neff, name, ir, boundargs, original_inputs + return neff, name, ir, original_inputs -def _execute_neff(neff, name, ir, boundargs, original_inputs, save_trace=False): +def _execute_neff(neff, name, ir: ComputationIR, original_inputs, save_trace=False): """Load a compiled NEFF and run it on hardware. Returns output numpy array(s), with auto-aliased outputs filtered out. @@ -107,21 +91,17 @@ def _execute_neff(neff, name, ir, boundargs, original_inputs, save_trace=False): # Build alias lookup: output_index -> AliasInfo alias_by_output = {a.output_index: a for a in ir.aliases} - device_inputs = {} - for intensor in ir.inputs: - if "must_alias_input" in intensor.name: - base_name = intensor.name.split(".must_alias_input")[0] - np_tensor = original_inputs[base_name] - else: - np_tensor = boundargs.arguments[intensor.name] - device_inputs[intensor.name] = DeviceTensor.from_numpy(np_tensor) + input_arrays = ir.resolve_input_arrays(original_inputs) + device_inputs = { + input_name: DeviceTensor.from_numpy(arr) + for input_name, arr in input_arrays.items() + } device_outputs = {} for i, outtensor in enumerate(ir.outputs): if i in alias_by_output: - # Aliased output shares the same device buffer as the input alias = alias_by_output[i] - input_name = f"{alias.param_name}.must_alias_input" + input_name = ir.get_alias_input_name(alias) device_outputs[outtensor.name] = device_inputs[input_name] else: np_output = np.zeros(outtensor.shape, dtype=outtensor.dtype) @@ -129,21 +109,18 @@ def _execute_neff(neff, name, ir, boundargs, original_inputs, save_trace=False): device_kernel(inputs=device_inputs, outputs=device_outputs, save_trace=save_trace) + output_arrays = {} for i, outtensor in enumerate(ir.outputs): result = device_outputs[outtensor.name].numpy() if i in alias_by_output: alias = alias_by_output[i] np.copyto(dst=original_inputs[alias.param_name], src=result) - # Point boundargs at the same array so the return logic can find it - boundargs.arguments[outtensor.name] = original_inputs[alias.param_name] - else: - dst = boundargs.arguments[outtensor.name] - np.copyto(dst=dst, src=result) + output_arrays[outtensor.name] = result # Filter out auto-aliased outputs (not user-returned) auto_indices = ir.auto_aliased_indices user_outputs = [ - boundargs.arguments[out.name] + output_arrays[out.name] for i, out in enumerate(ir.outputs) if i not in auto_indices ] @@ -165,7 +142,7 @@ def baremetal_run_traced_kernel( **kwargs, ): """Compile and run a traced kernel on hardware.""" - neff, name, ir, boundargs, original_inputs = _compile_kernel( + neff, name, ir, original_inputs = _compile_kernel( kernel, *args, artifacts_dir=artifacts_dir, @@ -174,5 +151,5 @@ def baremetal_run_traced_kernel( **kwargs, ) return _execute_neff( - neff, name, ir, boundargs, original_inputs, save_trace=save_trace + neff, name, ir, original_inputs, save_trace=save_trace ) diff --git a/nkipy/src/nkipy/tools/kernel_agent/executor.py b/nkipy/src/nkipy/tools/kernel_agent/executor.py index 74f4f62..36a1f6b 100644 --- a/nkipy/src/nkipy/tools/kernel_agent/executor.py +++ b/nkipy/src/nkipy/tools/kernel_agent/executor.py @@ -80,7 +80,7 @@ def run_kernel( from nkipy.runtime.execute import _compile_kernel traced = NKIPyKernel.trace(kernel_fn) - neff, kname, ir, boundargs, original_inputs = _compile_kernel( + neff, kname, ir, original_inputs = _compile_kernel( traced, *args, artifacts_dir=artifacts_dir ) result.compile = StageResult(success=True) @@ -93,7 +93,7 @@ def run_kernel( try: from nkipy.runtime.execute import _execute_neff - out = _execute_neff(neff, kname, ir, boundargs, original_inputs) + out = _execute_neff(neff, kname, ir, original_inputs) result.hardware = StageResult(success=True, output=np.asarray(out)) except Exception as e: result.hardware = StageResult(success=False, error=str(e)) diff --git a/tests/test_kernelgen_numerical.py b/tests/test_kernelgen_numerical.py new file mode 100644 index 0000000..a2a4061 --- /dev/null +++ b/tests/test_kernelgen_numerical.py @@ -0,0 +1,130 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Numerical correctness and full-pipeline tests for the kernelgen backend. + +Two levels of verification: +1. LLVM JIT smoke test — trace via nkipy, run through MLIR passes, execute + via LLVM JIT to verify numerical correctness without hardware. +2. NEFF compilation — trace via nkipy with knob() annotations, compile all + the way to NEFF to catch pass-pipeline and mem_space enum issues. + +Requires nkipy_kernelgen (pass pipeline, LLVM JIT infrastructure). +""" + +import numpy as np +import pytest + +try: + from nkipy_kernelgen.llvm import LLVMModule + + HAS_KERNELGEN = True +except ImportError: + HAS_KERNELGEN = False + +from nkipy.core.trace import NKIPyKernel +from nkipy.core.knob import knob + +pytestmark = pytest.mark.skipif( + not HAS_KERNELGEN, reason="nkipy-kernelgen not installed" +) + + +def _trace_and_run_llvm(func, *np_args): + """Trace via nkipy kernelgen, execute via LLVM JIT, return result.""" + kernel = NKIPyKernel.trace(func, backend="kernelgen") + ir = kernel.specialize(*np_args) + mod = LLVMModule(ir._mlir_text, ir._func_name) + return mod(*np_args) + + +def _trace_and_compile_to_neff(func, *np_args): + """Trace a kernelgen kernel and compile all the way to NEFF. + + Exercises the full nkipy.core.knob -> builder.annotate() -> MLIR pass + pipeline -> NISA -> neuronx-cc -> NEFF path. Raises on any failure. + """ + import shutil + import tempfile + + from nkipy.core import compile as nkipy_compile + + kernel = NKIPyKernel.trace(func, backend="kernelgen") + kernel.specialize(*np_args) + + artifacts_dir = tempfile.mkdtemp(prefix="kernelgen_neff_test_") + try: + nkipy_compile.compile_to_neff( + kernel, + artifacts_dir, + additional_compiler_args=nkipy_compile.nkipy_compiler_args, + ) + finally: + shutil.rmtree(artifacts_dir, ignore_errors=True) + + +class TestNumericalLLVMJIT: + """Smoke tests: trace through nkipy, verify numerics via LLVM JIT.""" + + def test_matmul_add(self): + def kernel(a, b, bias): + return np.matmul(a, b) + bias + + a = np.random.randn(4, 8).astype(np.float32) + b = np.random.randn(8, 4).astype(np.float32) + bias = np.random.randn(4).astype(np.float32) + result = _trace_and_run_llvm(kernel, a, b, bias) + np.testing.assert_allclose(result, a @ b + bias, rtol=1e-4, atol=1e-4) + + def test_sigmoid(self): + def kernel(x): + return np.reciprocal(1.0 + np.exp(-x)) + + x = np.random.randn(4, 8).astype(np.float32) + result = _trace_and_run_llvm(kernel, x) + expected = 1.0 / (1.0 + np.exp(-x)) + np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + + +class TestNumericalFullPipeline: + """Compile to NEFF end-to-end via nkipy trace+compile. + + Exercises the nkipy.core.knob -> builder.annotate() -> MLIR -> NISA -> + NEFF path, catching issues like mem_space enum mismatches between the + Python builder and the MLIR dialect definition. + """ + + def test_add_full_pipeline(self): + def kernel(a, b): + C = np.add(a, b) + C = knob(C, mem_space="SharedHbm", tile_size=[128, 128]) + return C + + a = np.random.randn(128, 128).astype(np.float32) + b = np.random.randn(128, 128).astype(np.float32) + _trace_and_compile_to_neff(kernel, a, b) + + def test_matmul_full_pipeline(self): + def kernel(a, b): + C = np.matmul(a, b) + C = knob(C, mem_space="SharedHbm", tile_size=[128, 128], + reduction_tile=[128]) + return C + + a = np.random.randn(128, 256).astype(np.float32) + b = np.random.randn(256, 128).astype(np.float32) + _trace_and_compile_to_neff(kernel, a, b) + + def test_sigmoid_full_pipeline(self): + def kernel(x): + neg_x = -x + neg_x = knob(neg_x, mem_space="Sbuf", tile_size=[128, 128]) + exp_neg = np.exp(neg_x) + exp_neg = knob(exp_neg, mem_space="Sbuf", tile_size=[128, 128]) + denom = 1.0 + exp_neg + denom = knob(denom, mem_space="Sbuf", tile_size=[128, 128]) + result = 1.0 / denom + result = knob(result, mem_space="SharedHbm", tile_size=[128, 128]) + return result + + x = np.random.randn(128, 256).astype(np.float32) + _trace_and_compile_to_neff(kernel, x) diff --git a/tests/test_kernelgen_ops.py b/tests/test_kernelgen_ops.py new file mode 100644 index 0000000..ae723fb --- /dev/null +++ b/tests/test_kernelgen_ops.py @@ -0,0 +1,367 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the kernelgen backend integration. + +Each test traces a kernel with backend="kernelgen", compiles to NEFF, +runs on Neuron device, and compares the numerical result against NumPy. +When no device is available, falls back to compile-only validation. +""" + +import numpy as np +import pytest + +try: + import nkipy_kernelgen # noqa: F401 + + HAS_KERNELGEN = True +except ImportError: + HAS_KERNELGEN = False + +from utils import ( + NEURON_AVAILABLE, + baremetal_assert_allclose, + on_device_test, + trace_and_compile, +) + +pytestmark = pytest.mark.skipif( + not HAS_KERNELGEN, reason="nkipy-kernelgen not installed" +) + +TRACE_MODE = "kernelgen" + + +def _run_kernel(kernel_fn, *args): + """Run a kernel on device if available, else compile-only. Returns result or None.""" + if NEURON_AVAILABLE: + return on_device_test(kernel_fn, TRACE_MODE, *args) + else: + trace_and_compile(kernel_fn, TRACE_MODE, *args) + return None + + +class TestKernelGenBasicOps: + """Test basic arithmetic operations compile and run correctly.""" + + def test_add(self): + def kernel(A, B): + return np.add(A, B) + + A = np.random.randn(128, 128).astype(np.float32) + B = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A, B) + if result is not None: + baremetal_assert_allclose(result, A + B) + + def test_subtract(self): + def kernel(A, B): + return np.subtract(A, B) + + A = np.random.randn(128, 128).astype(np.float32) + B = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A, B) + if result is not None: + baremetal_assert_allclose(result, A - B) + + def test_multiply(self): + def kernel(A, B): + return np.multiply(A, B) + + A = np.random.randn(128, 128).astype(np.float32) + B = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A, B) + if result is not None: + baremetal_assert_allclose(result, A * B) + + def test_scalar_add(self): + def kernel(A): + return A + 1.0 + + A = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A) + if result is not None: + baremetal_assert_allclose(result, A + 1.0) + + def test_matmul(self): + def kernel(A, B): + return np.matmul(A, B) + + A = np.random.randn(128, 128).astype(np.float32) + B = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A, B) + if result is not None: + baremetal_assert_allclose(result, A @ B) + + def test_matmul_batched(self): + def kernel(A, B): + return np.matmul(A, B) + + A = np.random.randn(2, 128, 128).astype(np.float32) + B = np.random.randn(2, 128, 128).astype(np.float32) + result = _run_kernel(kernel, A, B) + if result is not None: + baremetal_assert_allclose(result, A @ B) + + +class TestKernelGenUnaryOps: + """Test unary operations compile and run correctly.""" + + def test_exp(self): + def kernel(A): + return np.exp(A) + + A = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A) + if result is not None: + baremetal_assert_allclose(result, np.exp(A)) + + def test_sqrt(self): + def kernel(A): + return np.sqrt(A) + + A = np.abs(np.random.randn(128, 128)).astype(np.float32) + 0.01 + result = _run_kernel(kernel, A) + if result is not None: + baremetal_assert_allclose(result, np.sqrt(A)) + + def test_tanh(self): + def kernel(A): + return np.tanh(A) + + A = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A) + if result is not None: + baremetal_assert_allclose(result, np.tanh(A)) + + def test_negative(self): + def kernel(A): + return -A + + A = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A) + if result is not None: + baremetal_assert_allclose(result, -A) + + +class TestKernelGenTransformOps: + """Test transform operations compile and run correctly.""" + + @pytest.mark.xfail(reason="linalg.transpose not lowered to NISA yet", run=True, strict=False) + def test_transpose(self): + def kernel(A): + return np.transpose(A) + + A = np.random.randn(128, 256).astype(np.float32) + result = _run_kernel(kernel, A) + if result is not None: + baremetal_assert_allclose(result, A.T) + + def test_reshape(self): + def kernel(A): + return np.reshape(A, (256, 64)) + + A = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A) + if result is not None: + baremetal_assert_allclose(result, A.reshape(256, 64)) + + def test_squeeze(self): + def kernel(A): + return np.squeeze(A, axis=1) + + A = np.random.randn(128, 1, 128).astype(np.float32) + result = _run_kernel(kernel, A) + if result is not None: + baremetal_assert_allclose(result, A.squeeze(axis=1)) + + @pytest.mark.xfail(reason="linalg.transpose not lowered to NISA yet", run=True, strict=False) + def test_swapaxes(self): + def kernel(A): + return np.swapaxes(A, 0, 1) + + A = np.random.randn(128, 256).astype(np.float32) + result = _run_kernel(kernel, A) + if result is not None: + baremetal_assert_allclose(result, np.swapaxes(A, 0, 1)) + + @pytest.mark.xfail(reason="tensor.insert_slice stack lowering produces incorrect NISA", run=True, strict=False) + def test_stack(self): + def kernel(A, B): + return np.stack([A, B], axis=0) + + A = np.random.randn(128, 128).astype(np.float32) + B = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A, B) + if result is not None: + baremetal_assert_allclose(result, np.stack([A, B], axis=0)) + + +class TestKernelGenReductions: + """Test reduction operations compile and run correctly.""" + + def test_sum(self): + def kernel(A): + return np.sum(A, axis=1, keepdims=True) + + A = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A) + if result is not None: + baremetal_assert_allclose(result, np.sum(A, axis=1, keepdims=True)) + + @pytest.mark.xfail(reason="mean reduction missing memory space annotation in NISA lowering", run=True, strict=False) + def test_mean(self): + def kernel(A): + return np.mean(A, axis=0, keepdims=True) + + A = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A) + if result is not None: + baremetal_assert_allclose(result, np.mean(A, axis=0, keepdims=True)) + + +class TestKernelGenComparisonOps: + """Test comparison and logical operations compile and run correctly.""" + + def test_equal(self): + def kernel(A, B): + return np.equal(A, B) + + A = np.random.randn(128, 128).astype(np.float32) + B = A.copy() + B[::2, :] = 0.0 + result = _run_kernel(kernel, A, B) + if result is not None: + expected = np.equal(A, B).astype(np.float32) + baremetal_assert_allclose(result, expected) + + def test_greater(self): + def kernel(A, B): + return np.greater(A, B) + + A = np.random.randn(128, 128).astype(np.float32) + B = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A, B) + if result is not None: + expected = np.greater(A, B).astype(np.float32) + baremetal_assert_allclose(result, expected) + + def test_less_scalar(self): + def kernel(A): + return np.less(A, 0.5) + + A = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A) + if result is not None: + expected = np.less(A, 0.5).astype(np.float32) + baremetal_assert_allclose(result, expected) + + def test_logical_not(self): + def kernel(A): + return np.logical_not(A) + + A = np.array(np.random.rand(128, 128) > 0.5, dtype=np.float32) + result = _run_kernel(kernel, A) + if result is not None: + expected = np.logical_not(A).astype(np.float32) + baremetal_assert_allclose(result, expected) + + def test_bitwise_and(self): + def kernel(A, B): + return np.bitwise_and(A, B) + + A = np.array(np.random.rand(128, 128) > 0.5, dtype=np.float32) + B = np.array(np.random.rand(128, 128) > 0.5, dtype=np.float32) + result = _run_kernel(kernel, A, B) + if result is not None: + expected = np.bitwise_and( + A.astype(np.int32), B.astype(np.int32) + ).astype(np.float32) + baremetal_assert_allclose(result, expected) + + +class TestKernelGenWhere: + """Test np.where compiles and runs correctly.""" + + def test_where_same_type(self): + def kernel(A, B, C): + return np.where(A, B, C) + + A = np.array(np.random.rand(128, 128) > 0.5, dtype=np.float32) + B = np.random.randn(128, 128).astype(np.float32) + C = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A, B, C) + if result is not None: + baremetal_assert_allclose(result, np.where(A, B, C)) + + def test_where_with_comparison(self): + def kernel(A, B): + mask = np.greater(A, 0.0) + return np.where(mask, A, B) + + A = np.random.randn(128, 128).astype(np.float32) + B = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A, B) + if result is not None: + baremetal_assert_allclose(result, np.where(A > 0.0, A, B)) + + +class TestKernelGenComposedKernel: + """Test non-trivial kernels that compose multiple ops.""" + + @pytest.mark.xfail(reason="broadcast add with rank-1 bias not lowered to NISA yet", run=True, strict=False) + def test_matmul_add_relu(self): + def kernel(A, B, bias): + C = np.matmul(A, B) + C = C + bias + return np.maximum(C, 0.0) + + A = np.random.randn(128, 128).astype(np.float32) + B = np.random.randn(128, 128).astype(np.float32) + bias = np.random.randn(128).astype(np.float32) + result = _run_kernel(kernel, A, B, bias) + if result is not None: + baremetal_assert_allclose(result, np.maximum(A @ B + bias, 0.0)) + + @pytest.mark.xfail(reason="composed mean/sqrt/broadcast not lowered to NISA yet", run=True, strict=False) + def test_rmsnorm(self): + def kernel(x, weight): + variance = np.mean(x * x, axis=-1, keepdims=True) + x_norm = x / np.sqrt(variance + 1e-6) + return x_norm * weight + + x = np.random.randn(128, 128).astype(np.float32) + w = np.random.randn(128).astype(np.float32) + result = _run_kernel(kernel, x, w) + if result is not None: + variance = np.mean(x * x, axis=-1, keepdims=True) + expected = (x / np.sqrt(variance + 1e-6)) * w + baremetal_assert_allclose(result, expected) + + @pytest.mark.xfail(reason="clip not lowered to NISA yet", run=True, strict=False) + def test_clip(self): + def kernel(A): + return np.clip(A, 0.0, 1.0) + + A = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A) + if result is not None: + baremetal_assert_allclose(result, np.clip(A, 0.0, 1.0)) + + +class TestKernelGenAnnotations: + """Test knob() annotations compile to NEFF and run correctly.""" + + def test_knob_mem_space(self): + from nkipy.core.knob import knob + + def kernel(A, B): + C = np.matmul(A, B) + C = knob(C, mem_space="SharedHbm", tile_size=[128, 128], + reduction_tile=[128]) + return C + + A = np.random.randn(128, 128).astype(np.float32) + B = np.random.randn(128, 128).astype(np.float32) + result = _run_kernel(kernel, A, B) + if result is not None: + baremetal_assert_allclose(result, A @ B) diff --git a/tests/unit/test_device_kernel_cache.py b/tests/unit/test_device_kernel_cache.py index 07b9949..2cc4643 100644 --- a/tests/unit/test_device_kernel_cache.py +++ b/tests/unit/test_device_kernel_cache.py @@ -1,14 +1,13 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -"""Unit tests for HLO-based kernel cache hashing.""" +"""Unit tests for IR content hashing.""" import numpy as np from nkipy.core.compile import trace -from nkipy.runtime.device_kernel import _hlo_content_hash def _trace_and_specialize(kernel_fn, *args, **kwargs): - """Helper: trace a kernel, specialize with given args, return the HLOModule.""" + """Helper: trace a kernel, specialize with given args, return the IR.""" traced = trace(kernel_fn) traced.specialize(*args, **kwargs) return traced._code @@ -20,18 +19,18 @@ def test_hlo_hash_varies_with_shape(): def add_kernel(x, y): return np.add(x, y) - hlo_small = _trace_and_specialize( + ir_small = _trace_and_specialize( add_kernel, np.zeros((2, 2), dtype=np.float32), np.zeros((2, 2), dtype=np.float32), ) - hlo_large = _trace_and_specialize( + ir_large = _trace_and_specialize( add_kernel, np.zeros((4, 4), dtype=np.float32), np.zeros((4, 4), dtype=np.float32), ) - assert _hlo_content_hash(hlo_small, "") != _hlo_content_hash(hlo_large, "") + assert ir_small.content_hash("") != ir_large.content_hash("") def test_hlo_hash_deterministic(): @@ -45,10 +44,10 @@ def add_kernel(x, y): np.zeros((2, 2), dtype=np.float32), ) - hlo1 = _trace_and_specialize(add_kernel, *inputs) - hlo2 = _trace_and_specialize(add_kernel, *inputs) + ir1 = _trace_and_specialize(add_kernel, *inputs) + ir2 = _trace_and_specialize(add_kernel, *inputs) - assert _hlo_content_hash(hlo1, "--lnc 1") == _hlo_content_hash(hlo2, "--lnc 1") + assert ir1.content_hash("--lnc 1") == ir2.content_hash("--lnc 1") def test_hlo_hash_varies_with_dtype(): @@ -57,18 +56,18 @@ def test_hlo_hash_varies_with_dtype(): def add_kernel(x, y): return np.add(x, y) - hlo_f32 = _trace_and_specialize( + ir_f32 = _trace_and_specialize( add_kernel, np.zeros((2, 2), dtype=np.float32), np.zeros((2, 2), dtype=np.float32), ) - hlo_f16 = _trace_and_specialize( + ir_f16 = _trace_and_specialize( add_kernel, np.zeros((2, 2), dtype=np.float16), np.zeros((2, 2), dtype=np.float16), ) - assert _hlo_content_hash(hlo_f32, "") != _hlo_content_hash(hlo_f16, "") + assert ir_f32.content_hash("") != ir_f16.content_hash("") def test_hlo_hash_varies_with_compiler_args(): @@ -77,12 +76,12 @@ def test_hlo_hash_varies_with_compiler_args(): def add_kernel(x, y): return np.add(x, y) - hlo = _trace_and_specialize( + ir = _trace_and_specialize( add_kernel, np.zeros((2, 2), dtype=np.float32), np.zeros((2, 2), dtype=np.float32), ) - hash1 = _hlo_content_hash(hlo, "--lnc 1") - hash2 = _hlo_content_hash(hlo, "--lnc 2") + hash1 = ir.content_hash("--lnc 1") + hash2 = ir.content_hash("--lnc 2") assert hash1 != hash2 diff --git a/tests/unit/test_kernelgen_backend.py b/tests/unit/test_kernelgen_backend.py new file mode 100644 index 0000000..3379333 --- /dev/null +++ b/tests/unit/test_kernelgen_backend.py @@ -0,0 +1,424 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the unified kernelgen backend integration. + +Ported from NeuronPy/tests/unit/test_kernelgen_backend.py with adaptations +for the current nkipy implementation. +""" + +import warnings + +import numpy as np +import pytest + +from nkipy import knob +from nkipy.core.nki_op import nki_custom_op +from nkipy.core.backend import get_backend, tracing +from nkipy.core.backend.kernelgen import KernelGenIR, KernelGenTraceContext + + +class TestKnobDispatch: + """Test knob() backend-aware dispatch.""" + + def test_knob_cpu_passthrough(self): + """knob() is a no-op pass-through in cpu mode (no trace).""" + arr = np.ones((4, 4), dtype=np.float32) + result = knob(arr, mem_space="Sbuf") + assert result is arr + + def test_knob_cpu_no_params(self): + """knob() with no params is always a no-op.""" + arr = np.ones((4, 4), dtype=np.float32) + result = knob(arr) + assert result is arr + + def test_knob_hlo_warns(self): + """knob() issues a warning under the HLO backend.""" + from nkipy.core.backend.hlo import HLOModule, HLOTraceContext + + code = HLOModule(name="test") + arr = np.ones((4, 4), dtype=np.float32) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + with tracing(HLOTraceContext(code)): + result = knob(arr, mem_space="Sbuf") + assert len(w) == 1 + assert "only effective with backend='kernelgen'" in str(w[0].message) + assert result is arr + + +def _make_silu_kernel_builder(M, N, tile_p=128, tile_f=128): + """Return a real NKI kernel_builder function that computes SiLU activation.""" + def silu_kernel(input_0, output_0): + import nki.compiler.kernel_builder as nb + import nki.language as nl + + n_row_tiles = M // tile_p + n_col_tiles = N // tile_f + for r in nl.affine_range(n_row_tiles): + for t in nl.affine_range(n_col_tiles): + x_sbuf = nb.ndarray((tile_p, tile_f), input_0.dtype, nb.sbuf) + nb.isa.dma_copy( + dst=x_sbuf, + src=input_0[ + r * tile_p : (r + 1) * tile_p, + t * tile_f : (t + 1) * tile_f, + ], + ) + + out_sbuf = nb.ndarray((tile_p, tile_f), input_0.dtype, nb.sbuf) + bias = nb.ndarray((tile_p, 1), input_0.dtype, nb.sbuf) + nb.isa.memset(dst=bias, value=0.0) + scale = nb.ndarray((tile_p, 1), input_0.dtype, nb.sbuf) + nb.isa.memset(dst=scale, value=1.0) + + nb.isa.activation( + dst=out_sbuf, + src=x_sbuf, + bias=bias, + scale=scale, + op=nb.isa.activation_function.silu, + ) + + nb.isa.dma_copy( + dst=output_0[ + r * tile_p : (r + 1) * tile_p, + t * tile_f : (t + 1) * tile_f, + ], + src=out_sbuf, + ) + return silu_kernel + + +class TestNKICustomOpDispatch: + """Test nki_custom_op() factory and dispatch.""" + + def test_requires_at_least_one(self): + """nki_custom_op raises if neither nki_kernel nor kernel_builder given.""" + with pytest.raises(ValueError, match="At least one"): + nki_custom_op() + + def test_kernel_builder_requires_specs(self): + """nki_custom_op raises if kernel_builder given without specs.""" + with pytest.raises(ValueError, match="input_specs and output_specs"): + nki_custom_op(kernel_builder=lambda: None) + + def test_cpu_raises(self): + """nki_custom_op raises on cpu backend.""" + op = nki_custom_op( + kernel_builder=lambda: None, + input_specs=[((4, 4), "f32")], + output_specs=[((4, 4), "f32")], + ) + with pytest.raises(RuntimeError, match="not supported on backend 'cpu'"): + op(np.ones((4, 4), dtype=np.float32)) + + def test_hlo_without_nki_kernel_raises(self): + """nki_custom_op with only kernel_builder raises on HLO.""" + + class _FakeHLOCtx: + backend_name = "hlo" + + op = nki_custom_op( + kernel_builder=lambda: None, + input_specs=[((128, 128), "f32")], + output_specs=[((128, 128), "f32")], + ) + with tracing(_FakeHLOCtx()): + with pytest.raises(RuntimeError, match="no nki_kernel"): + op(np.ones((128, 128), dtype=np.float32)) + + def test_kernelgen_without_kernel_builder_raises(self): + """nki_custom_op with only nki_kernel raises on kernelgen.""" + + class _FakeKernelgenCtx: + backend_name = "kernelgen" + + op = nki_custom_op(nki_kernel=lambda: None) + with tracing(_FakeKernelgenCtx()): + with pytest.raises(RuntimeError, match="no kernel_builder"): + op(np.ones((128, 128), dtype=np.float32)) + + +class TestKernelGenTraceContext: + """Test KernelGenTraceContext basics.""" + + @pytest.fixture(autouse=True) + def _skip_if_no_kernelgen(self): + try: + import nkipy_kernelgen # noqa: F401 + except ImportError: + pytest.skip("nkipy-kernelgen not installed") + + def test_backend_name(self): + ctx = KernelGenTraceContext() + assert ctx.backend_name == "kernelgen" + ctx._cleanup() + + def test_tracing_context_activates(self): + ctx = KernelGenTraceContext() + assert get_backend() == "cpu" + with tracing(ctx): + assert get_backend() == "kernelgen" + assert get_backend() == "cpu" + ctx._cleanup() + + +class TestSpecializeKernelgen: + """Test NKIPyKernel._specialize_kernelgen with device compilation and execution.""" + + @pytest.fixture(autouse=True) + def _skip_if_no_kernelgen(self): + try: + import nkipy_kernelgen # noqa: F401 + except ImportError: + pytest.skip("nkipy-kernelgen not installed") + + @staticmethod + def _run(func, *np_args): + from utils import NEURON_AVAILABLE, on_device_test, trace_and_compile + if NEURON_AVAILABLE: + return on_device_test(func, "kernelgen", *np_args) + else: + trace_and_compile(func, "kernelgen", *np_args) + return None + + def test_with_knob(self): + from utils import baremetal_assert_allclose + + def kernel_with_knob(a, b): + result = a + b + knob(result, mem_space="SharedHbm", tile_size=[128, 128]) + return result + + a = np.random.randn(128, 128).astype("float32") + b = np.random.randn(128, 128).astype("float32") + + result = self._run(kernel_with_knob, a, b) + if result is not None: + baremetal_assert_allclose(result, a + b) + + def test_multi_output(self): + from utils import baremetal_assert_allclose + + def multi_out(a, b): + return a + b, a - b + + a = np.random.randn(128, 128).astype("float32") + b = np.random.randn(128, 128).astype("float32") + + result = self._run(multi_out, a, b) + if result is not None: + baremetal_assert_allclose(result[0], a + b) + baremetal_assert_allclose(result[1], a - b) + + def test_dtype_downcast(self): + """float64 inputs should be auto-downcast to float32.""" + from nkipy.core.trace import NKIPyKernel + + def add_kernel(a, b): + return a + b + + kernel = NKIPyKernel.trace(add_kernel, backend="kernelgen") + a = np.random.randn(64, 64) # float64 + b = np.random.randn(64, 64) # float64 + + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + ir = kernel.specialize(a, b) + assert ir.inputs[0].dtype == np.dtype("float32") + + @pytest.mark.xfail( + reason="custom_op kernel_builder tracing requires TracedArray, " + "not yet wired through NKIPyTensorRef path" + ) + def test_custom_op_with_kernel_builder(self): + """nki_custom_op with real kernel_builder traces through kernelgen backend.""" + from nkipy.core.trace import NKIPyKernel + + silu_op = nki_custom_op( + kernel_builder=_make_silu_kernel_builder(256, 256), + input_specs=[((256, 256), "f32")], + output_specs=[((256, 256), "f32")], + ) + + def kernel(x): + return silu_op(x) + + k = NKIPyKernel.trace(kernel, backend="kernelgen") + ir = k.specialize(np.random.randn(256, 256).astype("float32")) + assert isinstance(ir, KernelGenIR) + assert "__custom_op__silu_kernel" in ir._mlir_text + assert "nkipy.custom_op_bodies" in ir._mlir_text + + +class TestKernelgenInplaceUpdate: + """Test in-place update (dynamic_update_slice) support for kernelgen. + + Each test traces → compiles → runs on device and compares + numerical results against NumPy. Alias metadata is verified as well. + """ + + @pytest.fixture(autouse=True) + def _skip_if_no_kernelgen(self): + try: + import nkipy_kernelgen # noqa: F401 + except ImportError: + pytest.skip("nkipy-kernelgen not installed") + + @staticmethod + def _trace_and_run(func, *np_args): + """Trace a kernelgen kernel, return (ir, device_result_or_None).""" + from nkipy.core.trace import NKIPyKernel + from utils import NEURON_AVAILABLE, on_device_test, trace_and_compile + + kernel = NKIPyKernel.trace(func, backend="kernelgen") + ir = kernel.specialize(*np_args) + if NEURON_AVAILABLE: + result = on_device_test(func, "kernelgen", *np_args) + else: + trace_and_compile(func, "kernelgen", *np_args) + result = None + return ir, result + + def test_single_alias(self): + """Mutate one parameter and return it — verify numerical result.""" + from utils import baremetal_assert_allclose + + def kernel(a, b): + a[0:1, :] = b[1:2, :] + return a + + a = np.random.randn(128, 128).astype("float32") + b = np.random.randn(128, 128).astype("float32") + + expected = a.copy() + expected[0:1, :] = b[1:2, :] + + ir, result = self._trace_and_run(kernel, a, b) + if result is not None: + baremetal_assert_allclose(result, expected) + + assert isinstance(ir, KernelGenIR) + assert len(ir.aliases) == 1 + assert ir.aliases[0].param_name == "a" + assert ir.aliases[0].param_index == 0 + assert ir.aliases[0].is_user_returned is True + assert ir.auto_aliased_indices == set() + + def test_multi_slice_update(self): + """Update multiple disjoint slices of the same tensor.""" + from utils import baremetal_assert_allclose + + def kernel(a, b): + a[0:2, :] = b[0:2, :] + a[4:6, :] = b[4:6, :] + return a + + a = np.random.randn(128, 128).astype("float32") + b = np.random.randn(128, 128).astype("float32") + + expected = a.copy() + expected[0:2, :] = b[0:2, :] + expected[4:6, :] = b[4:6, :] + + ir, result = self._trace_and_run(kernel, a, b) + if result is not None: + baremetal_assert_allclose(result, expected) + assert len(ir.aliases) == 1 + + def test_multi_alias(self): + """Mutate two parameters and return both.""" + from utils import baremetal_assert_allclose + + def kernel(a, b, c): + a[0:1, :] = b[0:1, :] + c[2:3, :] = b[2:3, :] + return a, c + + a = np.random.randn(128, 128).astype("float32") + b = np.random.randn(128, 128).astype("float32") + c = np.random.randn(128, 128).astype("float32") + + expected_a = a.copy() + expected_a[0:1, :] = b[0:1, :] + expected_c = c.copy() + expected_c[2:3, :] = b[2:3, :] + + ir, result = self._trace_and_run(kernel, a, b, c) + if result is not None: + baremetal_assert_allclose(result[0], expected_a) + baremetal_assert_allclose(result[1], expected_c) + + assert len(ir.aliases) == 2 + alias_names = {al.param_name for al in ir.aliases} + assert alias_names == {"a", "c"} + assert all(al.is_user_returned for al in ir.aliases) + + def test_no_return_auto_alias(self): + """Mutate without returning — auto-append to outputs.""" + from utils import baremetal_assert_allclose + + def kernel(a, b): + a[0:1, :] = b[1:2, :] + + a = np.random.randn(128, 128).astype("float32") + b = np.random.randn(128, 128).astype("float32") + + expected = a.copy() + expected[0:1, :] = b[1:2, :] + + ir, result = self._trace_and_run(kernel, a, b) + if result is not None: + baremetal_assert_allclose(result, expected) + + assert len(ir.aliases) == 1 + assert ir.aliases[0].param_name == "a" + assert ir.aliases[0].is_user_returned is False + assert ir.auto_aliased_indices == {0} + + def test_mixed_return_alias(self): + """Mutate a parameter but return a different computed value.""" + from utils import baremetal_assert_allclose + + def kernel(a, b): + a[0:1, :] = b[1:2, :] + return a + b + + a = np.random.randn(128, 128).astype("float32") + b = np.random.randn(128, 128).astype("float32") + + expected_a = a.copy() + expected_a[0:1, :] = b[1:2, :] + expected_sum = expected_a + b + + ir, result = self._trace_and_run(kernel, a, b) + if result is not None: + baremetal_assert_allclose(result, expected_sum) + + assert len(ir.aliases) == 1 + assert ir.aliases[0].param_name == "a" + assert ir.aliases[0].is_user_returned is False + assert len(ir.outputs) == 2 + assert ir.auto_aliased_indices == {1} + + def test_update_with_computation(self): + """Assign a computed expression into a slice.""" + from utils import baremetal_assert_allclose + + def kernel(a, b): + a[0:2, :] = b[0:2, :] * 2.0 + return a + + a = np.random.randn(128, 128).astype("float32") + b = np.random.randn(128, 128).astype("float32") + + expected = a.copy() + expected[0:2, :] = b[0:2, :] * 2.0 + + ir, result = self._trace_and_run(kernel, a, b) + if result is not None: + baremetal_assert_allclose(result, expected) + + diff --git a/tests/utils.py b/tests/utils.py index 8cc8baa..faa1cec 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -29,28 +29,30 @@ NEURON_AVAILABLE = is_neuron_compatible() +def _trace_mode_to_backend(trace_mode): + if trace_mode in ("hlo", "kernelgen"): + return trace_mode + raise ValueError(f"Unknown trace mode: {trace_mode}") + + def trace_and_compile(kernel_fn, trace_mode, *args, **kwargs): """ - Validate kernel is traceable to HLO and compilable to NEFF. + Validate kernel is traceable and compilable to NEFF. - Traces the kernel to HLO IR and compiles it using the Neuron compiler, + Traces the kernel to IR and compiles it using the Neuron compiler, but does not execute on device. Args: kernel_fn: The kernel function to test - trace_mode: "hlo" or other supported tracing mode + trace_mode: "hlo" or "kernelgen" *args: Input arrays for the kernel **kwargs: Additional arguments """ - if trace_mode == "hlo": - traced_kernel = NKIPyKernel.trace(kernel_fn, backend="hlo") - else: - raise ValueError(f"Unknown trace mode: {trace_mode}") + backend = _trace_mode_to_backend(trace_mode) + traced_kernel = NKIPyKernel.trace(kernel_fn, backend=backend) - # Trace to HLO traced_kernel.specialize(*args, **kwargs) - # Compile to NEFF worker_id = os.environ.get("PYTEST_XDIST_WORKER", "main") artifacts_dir = os.path.join(tempfile.gettempdir(), f"nkipy_artifacts_{worker_id}") if os.path.exists(artifacts_dir): @@ -70,7 +72,7 @@ def on_device_test(kernel_fn, trace_mode, *args, artifacts_dir=None, **kwargs): Args: kernel_fn: The kernel function to execute - trace_mode: "hlo" or other supported tracing mode + trace_mode: "hlo" or "kernelgen" *args: Input arrays for the kernel artifacts_dir: Directory for compilation artifacts (for parallel test isolation) **kwargs: Additional arguments @@ -78,17 +80,14 @@ def on_device_test(kernel_fn, trace_mode, *args, artifacts_dir=None, **kwargs): Returns: Device execution output """ - # Auto-generate worker-specific artifacts_dir if not provided if artifacts_dir is None: worker_id = os.environ.get("PYTEST_XDIST_WORKER", "main") artifacts_dir = os.path.join( tempfile.gettempdir(), f"nkipy_artifacts_{worker_id}" ) - if trace_mode == "hlo": - traced_kernel = NKIPyKernel.trace(kernel_fn, backend="hlo") - else: - raise ValueError(f"Unknown trace mode: {trace_mode}") + backend = _trace_mode_to_backend(trace_mode) + traced_kernel = NKIPyKernel.trace(kernel_fn, backend=backend) return baremetal_run_traced_kernel( traced_kernel, *args, artifacts_dir=artifacts_dir, **kwargs