From 3517bcbb225e170d62f388ba9faa67aca958ca5f Mon Sep 17 00:00:00 2001 From: Shaojie Xiang Date: Mon, 27 Apr 2026 18:33:44 +0000 Subject: [PATCH] feat(kernelgen): import NKIPyKernelGen as a subfolder MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Import the open_source branch of NKIPyKernelGen into `kernelgen/` as a self-contained subpackage. NKIPyKernelGen is a compiler that traces NumPy functions and lowers them to NISA (Neuron Instruction Set Architecture) for AWS Neuron hardware. Users write kernels in Python with `@trace` and `knob.knob()` annotations; the compiler handles tiling, memory placement, layout legalization, and NISA lowering. What's included --------------- - `kernelgen/nkipy_kernelgen/` — Python tracing frontend: - `trace.py` (@trace decorator) - `knob.py` (tensor annotations: mem_space, tile_size, reduction_tile, partition_dim) - `traced_array.py` (TracedArray wrapping MLIR SSA values) - `op_vtable.py` (NumPy op → MLIR lowering table) - `transforms/nkipy_opt.py` (pipeline orchestration, shells out to `nkipy-opt`) - `kernelgen/mlir/` — MLIR dialect + C++ passes: - `nkipy.annotate` op (target, mem_space, partition_dim, tile_size, reduction_tile) - 20+ transformation passes under `mlir/lib/Transforms/` implementing the 24-pass compilation pipeline (InferLayout, KnobDrivenTiling, AnnotateMemorySpace, LegalizeLayout, InsertSpillReload, LinalgToNisa, etc.) - `kernelgen/tests/` — test suite: - `passes/` — per-pass FileCheck tests - `e2e/` — end-to-end tests (trace → NISA → BIR sim / HW) - `unit/` — Python-level unit tests - `harness.py` — unified test harness with LLVM/BIR_SIM/HW/FileCheck modes - `kernelgen/examples/` — example kernels - `kernelgen/compiler_explorer/` — Compiler Explorer wrapper for inspecting IR at any pipeline stage - `kernelgen/setup.py`, `pyproject.toml`, `pytest.ini`, `requirements.txt` — build + test configuration (`pip install -e kernelgen/` builds the C++ passes via CMake) - `kernelgen/CLAUDE.md`, `README.md` — pipeline docs and usage notes Architecture notes ------------------ NKIPyKernelGen depends on the NISA dialect defined in private-nki-staging (the `nki` wheel). NKIPyKernelGen's `nkipy-opt` binary performs the tensor-level and bufferization phases; lowering to BIR then runs through the upstream `nki-opt-pipeline`. This import does not bring in the NISA dialect sources — only NKIPyKernelGen's own passes and frontend. Ignore rules ------------ Added a `!mlir/lib/` override in `kernelgen/.gitignore` so the parent nkipy repo's `lib/` rule (intended for Python venv `lib/` dirs) does not silently exclude the MLIR C++ pass sources under `kernelgen/mlir/lib/`. Source ------ Imported from NKIPyKernelGen `open_source` branch @ commit 973c1be ("fix: correct mem_space enum values in builder.annotate()"). Internal git history is not preserved — this is a single squash import for the open-source release. --- .../skills/build_nkipykernelgen/SKILL.md | 23 + .../build_nkipykernelgen/scripts/build.sh | 12 + .../.claude/skills/debug_nisa_ir/SKILL.md | 121 + .../skills/run_nkipykernelgen_tests/SKILL.md | 28 + .../scripts/run_tests.sh | 36 + kernelgen/.gitignore | 47 + kernelgen/CLAUDE.md | 268 ++ kernelgen/README.md | 99 + kernelgen/compiler_explorer/README.md | 118 + .../config/c.nkipy.properties | 31 + .../compiler_explorer/config/example.nkipy | 28 + .../config/nkipy.local.properties | 17 + .../examples/attention_scores.py | 45 + .../examples/attention_scores_loop.py | 51 + kernelgen/compiler_explorer/examples/bmm.py | 14 + .../compiler_explorer/examples/custom.py | 23 + .../compiler_explorer/examples/feedforward.py | 61 + .../compiler_explorer/examples/matmul_add.py | 22 + .../compiler_explorer/examples/qwen3_layer.py | 306 ++ .../compiler_explorer/examples/reduce_sum.py | 30 + .../compiler_explorer/examples/reshape.py | 42 + .../compiler_explorer/examples/rmsnorm.py | 56 + kernelgen/compiler_explorer/examples/rope.py | 41 + .../compiler_explorer/examples/softmax.py | 27 + .../compiler_explorer/examples/transpose.py | 42 + .../compiler_explorer/nkipy_ce_wrapper.sh | 15 + kernelgen/compiler_explorer/nkipy_compiler.py | 620 ++++ kernelgen/compiler_explorer/setup.sh | 158 + kernelgen/examples/custom_op.py | 187 ++ kernelgen/examples/qwen3_layer.py | 311 ++ kernelgen/mlir/CMakeLists.txt | 48 + kernelgen/mlir/include/CMakeLists.txt | 1 + .../mlir/include/nkipy-c/Dialect/Dialects.h | 17 + .../include/nkipy-c/Dialect/NkipyAttributes.h | 24 + .../include/nkipy-c/Dialect/Registration.h | 25 + .../include/nkipy/Bindings/CMakeLists.txt | 96 + .../mlir/include/nkipy/Bindings/NkipyModule.h | 17 + .../include/nkipy/Bindings/nkipy/__init__.py | 10 + .../Bindings/nkipy/dialects/NkipyBinding.td | 6 + .../Bindings/nkipy/dialects/_ods_common.py | 11 + .../nkipy/Bindings/nkipy/dialects/nkipy.py | 2 + .../nkipy/Bindings/nkipy/exceptions.py | 220 ++ kernelgen/mlir/include/nkipy/CMakeLists.txt | 4 + .../mlir/include/nkipy/Dialect/CMakeLists.txt | 20 + .../mlir/include/nkipy/Dialect/NkipyAttrs.h | 11 + .../mlir/include/nkipy/Dialect/NkipyAttrs.td | 33 + .../mlir/include/nkipy/Dialect/NkipyDialect.h | 9 + .../include/nkipy/Dialect/NkipyDialect.td | 29 + .../mlir/include/nkipy/Dialect/NkipyOps.h | 26 + .../mlir/include/nkipy/Dialect/NkipyOps.td | 117 + .../include/nkipy/TransformOps/CMakeLists.txt | 19 + .../nkipy/TransformOps/NkipyTransformOps.h | 28 + .../nkipy/TransformOps/NkipyTransformOps.td | 53 + .../include/nkipy/Transforms/CMakeLists.txt | 3 + .../nkipy/Transforms/HardwareConstants.h | 36 + .../mlir/include/nkipy/Transforms/IRHelpers.h | 75 + .../nkipy/Transforms/OpClassification.h | 48 + .../mlir/include/nkipy/Transforms/Passes.h | 38 + .../mlir/include/nkipy/Transforms/Passes.td | 478 +++ .../mlir/lib/Bindings/NkipyAttributes.cpp | 67 + kernelgen/mlir/lib/Bindings/NkipyModule.cpp | 112 + kernelgen/mlir/lib/CAPI/CMakeLists.txt | 1 + .../mlir/lib/CAPI/Dialect/CMakeLists.txt | 35 + kernelgen/mlir/lib/CAPI/Dialect/Dialects.cpp | 6 + .../mlir/lib/CAPI/Dialect/NkipyAttributes.cpp | 20 + .../mlir/lib/CAPI/Dialect/Registration.cpp | 58 + kernelgen/mlir/lib/CMakeLists.txt | 4 + kernelgen/mlir/lib/Dialect/CMakeLists.txt | 20 + kernelgen/mlir/lib/Dialect/NkipyDialect.cpp | 42 + kernelgen/mlir/lib/Dialect/NkipyOps.cpp | 256 ++ .../mlir/lib/TransformOps/CMakeLists.txt | 18 + .../lib/TransformOps/NkipyTransformOps.cpp | 145 + .../lib/Transforms/AnnotateMemorySpace.cpp | 285 ++ .../Transforms/ApplyAndStripTransforms.cpp | 92 + .../mlir/lib/Transforms/AssignLinalgOpIds.cpp | 74 + kernelgen/mlir/lib/Transforms/CMakeLists.txt | 43 + .../lib/Transforms/CanonicalizeLoopStep.cpp | 117 + .../Transforms/CanonicalizePartitionDim.cpp | 979 ++++++ .../lib/Transforms/CanonicalizeReshape.cpp | 367 +++ .../Transforms/EliminateSameMemSpaceCopy.cpp | 413 +++ .../EliminateUninitializedCopies.cpp | 136 + kernelgen/mlir/lib/Transforms/IRHelpers.cpp | 50 + kernelgen/mlir/lib/Transforms/InferLayout.cpp | 825 +++++ .../lib/Transforms/InlineNkipyReference.cpp | 140 + .../lib/Transforms/InsertMemRefDealloc.cpp | 189 ++ .../mlir/lib/Transforms/InsertSpillReload.cpp | 536 ++++ .../mlir/lib/Transforms/KnobDrivenTiling.cpp | 687 +++++ .../mlir/lib/Transforms/LegalizeLayout.cpp | 2691 ++++++++++++++++ .../mlir/lib/Transforms/OpClassification.cpp | 69 + kernelgen/mlir/lib/Transforms/PassGen.h | 19 + kernelgen/mlir/lib/Transforms/Passes.cpp | 15 + .../mlir/lib/Transforms/PrepareArithmetic.cpp | 301 ++ .../Transforms/RemoveRedundantZeroFill.cpp | 138 + .../mlir/lib/Transforms/SimplifyLinalg.cpp | 778 +++++ kernelgen/mlir/tools/CMakeLists.txt | 1 + kernelgen/mlir/tools/nkipy-opt/CMakeLists.txt | 20 + kernelgen/mlir/tools/nkipy-opt/README.md | 148 + kernelgen/mlir/tools/nkipy-opt/nkipy-opt.cpp | 51 + .../mlir/tools/nkipy-opt/test_example.mlir | 27 + kernelgen/nkipy_kernelgen/__init__.py | 27 + kernelgen/nkipy_kernelgen/apis.py | 11 + kernelgen/nkipy_kernelgen/builder.py | 1802 +++++++++++ kernelgen/nkipy_kernelgen/compile.py | 90 + kernelgen/nkipy_kernelgen/control_flow.py | 163 + kernelgen/nkipy_kernelgen/custom_op.py | 228 ++ kernelgen/nkipy_kernelgen/execution.py | 27 + kernelgen/nkipy_kernelgen/knob.py | 222 ++ kernelgen/nkipy_kernelgen/llvm.py | 642 ++++ kernelgen/nkipy_kernelgen/mlir_utils.py | 122 + kernelgen/nkipy_kernelgen/op_vtable.py | 411 +++ kernelgen/nkipy_kernelgen/pass_manager.py | 119 + kernelgen/nkipy_kernelgen/trace.py | 164 + kernelgen/nkipy_kernelgen/traced_array.py | 350 +++ .../nkipy_kernelgen/transforms/__init__.py | 36 + .../transforms/linalg_to_nisa_py.py | 2720 +++++++++++++++++ .../nkipy_kernelgen/transforms/nkipy_opt.py | 419 +++ kernelgen/nkipy_kernelgen/utils.py | 218 ++ kernelgen/pyproject.toml | 3 + kernelgen/pytest.ini | 23 + kernelgen/requirements.txt | 4 + kernelgen/setup.py | 141 + kernelgen/tests/README.md | 139 + kernelgen/tests/__init__.py | 6 + kernelgen/tests/conftest.py | 59 + kernelgen/tests/debug/.gitignore | 1 + kernelgen/tests/debug/README.md | 79 + kernelgen/tests/debug/__init__.py | 0 kernelgen/tests/debug/bmm/buggy.mlir | 47 + .../tests/debug/bmm/fix_3d_dma_indices.mlir | 74 + kernelgen/tests/debug/bmm/kernel.py | 14 + kernelgen/tests/debug/qwen3_layer/README.md | 240 ++ kernelgen/tests/debug/qwen3_layer/buggy.mlir | 808 +++++ .../fix_rope_vector_partition.mlir | 782 +++++ kernelgen/tests/debug/qwen3_layer/kernel.py | 305 ++ kernelgen/tests/debug/run.sh | 35 + kernelgen/tests/debug/run_sim.py | 195 ++ kernelgen/tests/e2e/__init__.py | 0 kernelgen/tests/e2e/conftest.py | 10 + kernelgen/tests/e2e/nkipy_tests/__init__.py | 0 .../tests/e2e/nkipy_tests/test_attention.py | 141 + .../tests/e2e/nkipy_tests/test_binary_ops.py | 181 ++ .../nkipy_tests/test_composite_patterns.py | 121 + .../tests/e2e/nkipy_tests/test_embedding.py | 49 + .../tests/e2e/nkipy_tests/test_indexing.py | 28 + .../e2e/nkipy_tests/test_llama_decoder.py | 187 ++ .../e2e/nkipy_tests/test_matmul_shapes.py | 106 + kernelgen/tests/e2e/nkipy_tests/test_mlp.py | 119 + .../tests/e2e/nkipy_tests/test_reductions.py | 106 + kernelgen/tests/e2e/nkipy_tests/test_rope.py | 54 + .../tests/e2e/nkipy_tests/test_simple_add.py | 25 + .../tests/e2e/nkipy_tests/test_softmax.py | 29 + .../nkipy_tests/test_tensor_manipulation.py | 178 ++ .../tests/e2e/nkipy_tests/test_unary_ops.py | 123 + kernelgen/tests/e2e/test_3d_elementwise.py | 90 + kernelgen/tests/e2e/test_attention.py | 274 ++ kernelgen/tests/e2e/test_auto_layout.py | 132 + kernelgen/tests/e2e/test_custom_op.py | 317 ++ kernelgen/tests/e2e/test_feedforward.py | 141 + kernelgen/tests/e2e/test_head_deconcat.py | 88 + kernelgen/tests/e2e/test_matmul_add.py | 119 + kernelgen/tests/e2e/test_multi_output.py | 114 + kernelgen/tests/e2e/test_partition_dim.py | 94 + kernelgen/tests/e2e/test_qwen3_layer.py | 309 ++ kernelgen/tests/e2e/test_reduce.py | 144 + kernelgen/tests/e2e/test_rmsnorm.py | 87 + kernelgen/tests/e2e/test_rope.py | 238 ++ kernelgen/tests/e2e/test_sigmoid.py | 172 ++ kernelgen/tests/harness.py | 1020 +++++++ kernelgen/tests/passes/__init__.py | 1 + .../passes/annotate_memory_space/__init__.py | 1 + .../annotate_memory_space/test_basic.py | 204 ++ .../passes/canonicalize_loop_step/__init__.py | 1 + .../test_elementwise.py | 154 + .../canonicalize_loop_step/test_matmul.py | 91 + .../canonicalize_loop_step/test_multi_op.py | 177 ++ .../canonicalize_partition_dim/__init__.py | 0 .../canonicalize_partition_dim/test_basic.py | 325 ++ .../canonicalize_partition_dim/test_reduce.py | 119 + .../__init__.py | 1 + .../test_basic.py | 78 + kernelgen/tests/passes/conftest.py | 61 + .../eliminate_same_memspace_copy/__init__.py | 1 + .../test_basic.py | 129 + .../__init__.py | 1 + .../test_basic.py | 131 + .../tests/passes/infer_layout/__init__.py | 1 + .../test_infer_layout_broadcast.py | 97 + .../test_infer_layout_elementwise.py | 420 +++ .../infer_layout/test_infer_layout_matmul.py | 474 +++ .../infer_layout/test_infer_layout_reduce.py | 135 + .../insert_spill_reload/test_basic_spill.py | 350 +++ .../passes/knob_driven_tiling/__init__.py | 1 + .../knob_driven_tiling/test_elementwise.py | 369 +++ .../passes/knob_driven_tiling/test_matmul.py | 272 ++ .../knob_driven_tiling/test_multi_op.py | 218 ++ .../tests/passes/legalize_layout/__init__.py | 1 + .../passes/legalize_layout/test_basic.py | 188 ++ .../legalize_layout/test_fold_reshape_copy.py | 431 +++ .../tests/passes/linalg_to_nisa/__init__.py | 0 .../tests/passes/linalg_to_nisa/test_basic.py | 105 + .../test_multi_non_unit_collapse.py | 205 ++ kernelgen/tests/passes/pass_utils.py | 363 +++ .../passes/prepare_arithmetic/__init__.py | 0 .../passes/prepare_arithmetic/test_basic.py | 258 ++ .../remove_linalg_zero_fill/__init__.py | 0 .../passes/resolve_custom_ops/test_basic.py | 193 ++ kernelgen/tests/python/__init__.py | 0 kernelgen/tests/python/lit.cfg.py | 38 + kernelgen/tests/python/passes/__init__.py | 0 .../python/passes/test_knob_annotations.py | 187 ++ kernelgen/tests/python/rewrites/__init__.py | 0 kernelgen/tests/unit/__init__.py | 0 kernelgen/tests/unit/conftest.py | 9 + kernelgen/tests/unit/test_broadcast_ops.py | 232 ++ kernelgen/tests/unit/test_custom_op.py | 355 +++ kernelgen/tests/unit/test_elementwise_ops.py | 435 +++ kernelgen/tests/unit/test_execution_engine.py | 79 + kernelgen/tests/unit/test_for_loops.py | 268 ++ kernelgen/tests/unit/test_gather_ops.py | 92 + .../tests/unit/test_import_compatibility.py | 11 + kernelgen/tests/unit/test_matrix_ops.py | 128 + kernelgen/tests/unit/test_reduction_ops.py | 370 +++ 222 files changed, 37640 insertions(+) create mode 100644 kernelgen/.claude/skills/build_nkipykernelgen/SKILL.md create mode 100755 kernelgen/.claude/skills/build_nkipykernelgen/scripts/build.sh create mode 100644 kernelgen/.claude/skills/debug_nisa_ir/SKILL.md create mode 100644 kernelgen/.claude/skills/run_nkipykernelgen_tests/SKILL.md create mode 100755 kernelgen/.claude/skills/run_nkipykernelgen_tests/scripts/run_tests.sh create mode 100644 kernelgen/.gitignore create mode 100644 kernelgen/CLAUDE.md create mode 100644 kernelgen/README.md create mode 100644 kernelgen/compiler_explorer/README.md create mode 100644 kernelgen/compiler_explorer/config/c.nkipy.properties create mode 100644 kernelgen/compiler_explorer/config/example.nkipy create mode 100644 kernelgen/compiler_explorer/config/nkipy.local.properties create mode 100644 kernelgen/compiler_explorer/examples/attention_scores.py create mode 100644 kernelgen/compiler_explorer/examples/attention_scores_loop.py create mode 100644 kernelgen/compiler_explorer/examples/bmm.py create mode 100644 kernelgen/compiler_explorer/examples/custom.py create mode 100644 kernelgen/compiler_explorer/examples/feedforward.py create mode 100644 kernelgen/compiler_explorer/examples/matmul_add.py create mode 100644 kernelgen/compiler_explorer/examples/qwen3_layer.py create mode 100644 kernelgen/compiler_explorer/examples/reduce_sum.py create mode 100644 kernelgen/compiler_explorer/examples/reshape.py create mode 100644 kernelgen/compiler_explorer/examples/rmsnorm.py create mode 100644 kernelgen/compiler_explorer/examples/rope.py create mode 100644 kernelgen/compiler_explorer/examples/softmax.py create mode 100644 kernelgen/compiler_explorer/examples/transpose.py create mode 100755 kernelgen/compiler_explorer/nkipy_ce_wrapper.sh create mode 100755 kernelgen/compiler_explorer/nkipy_compiler.py create mode 100755 kernelgen/compiler_explorer/setup.sh create mode 100644 kernelgen/examples/custom_op.py create mode 100644 kernelgen/examples/qwen3_layer.py create mode 100644 kernelgen/mlir/CMakeLists.txt create mode 100644 kernelgen/mlir/include/CMakeLists.txt create mode 100644 kernelgen/mlir/include/nkipy-c/Dialect/Dialects.h create mode 100644 kernelgen/mlir/include/nkipy-c/Dialect/NkipyAttributes.h create mode 100644 kernelgen/mlir/include/nkipy-c/Dialect/Registration.h create mode 100644 kernelgen/mlir/include/nkipy/Bindings/CMakeLists.txt create mode 100644 kernelgen/mlir/include/nkipy/Bindings/NkipyModule.h create mode 100644 kernelgen/mlir/include/nkipy/Bindings/nkipy/__init__.py create mode 100644 kernelgen/mlir/include/nkipy/Bindings/nkipy/dialects/NkipyBinding.td create mode 100644 kernelgen/mlir/include/nkipy/Bindings/nkipy/dialects/_ods_common.py create mode 100644 kernelgen/mlir/include/nkipy/Bindings/nkipy/dialects/nkipy.py create mode 100644 kernelgen/mlir/include/nkipy/Bindings/nkipy/exceptions.py create mode 100644 kernelgen/mlir/include/nkipy/CMakeLists.txt create mode 100644 kernelgen/mlir/include/nkipy/Dialect/CMakeLists.txt create mode 100644 kernelgen/mlir/include/nkipy/Dialect/NkipyAttrs.h create mode 100644 kernelgen/mlir/include/nkipy/Dialect/NkipyAttrs.td create mode 100644 kernelgen/mlir/include/nkipy/Dialect/NkipyDialect.h create mode 100644 kernelgen/mlir/include/nkipy/Dialect/NkipyDialect.td create mode 100644 kernelgen/mlir/include/nkipy/Dialect/NkipyOps.h create mode 100644 kernelgen/mlir/include/nkipy/Dialect/NkipyOps.td create mode 100644 kernelgen/mlir/include/nkipy/TransformOps/CMakeLists.txt create mode 100644 kernelgen/mlir/include/nkipy/TransformOps/NkipyTransformOps.h create mode 100644 kernelgen/mlir/include/nkipy/TransformOps/NkipyTransformOps.td create mode 100644 kernelgen/mlir/include/nkipy/Transforms/CMakeLists.txt create mode 100644 kernelgen/mlir/include/nkipy/Transforms/HardwareConstants.h create mode 100644 kernelgen/mlir/include/nkipy/Transforms/IRHelpers.h create mode 100644 kernelgen/mlir/include/nkipy/Transforms/OpClassification.h create mode 100644 kernelgen/mlir/include/nkipy/Transforms/Passes.h create mode 100644 kernelgen/mlir/include/nkipy/Transforms/Passes.td create mode 100644 kernelgen/mlir/lib/Bindings/NkipyAttributes.cpp create mode 100644 kernelgen/mlir/lib/Bindings/NkipyModule.cpp create mode 100644 kernelgen/mlir/lib/CAPI/CMakeLists.txt create mode 100644 kernelgen/mlir/lib/CAPI/Dialect/CMakeLists.txt create mode 100644 kernelgen/mlir/lib/CAPI/Dialect/Dialects.cpp create mode 100644 kernelgen/mlir/lib/CAPI/Dialect/NkipyAttributes.cpp create mode 100644 kernelgen/mlir/lib/CAPI/Dialect/Registration.cpp create mode 100644 kernelgen/mlir/lib/CMakeLists.txt create mode 100644 kernelgen/mlir/lib/Dialect/CMakeLists.txt create mode 100644 kernelgen/mlir/lib/Dialect/NkipyDialect.cpp create mode 100644 kernelgen/mlir/lib/Dialect/NkipyOps.cpp create mode 100644 kernelgen/mlir/lib/TransformOps/CMakeLists.txt create mode 100644 kernelgen/mlir/lib/TransformOps/NkipyTransformOps.cpp create mode 100644 kernelgen/mlir/lib/Transforms/AnnotateMemorySpace.cpp create mode 100644 kernelgen/mlir/lib/Transforms/ApplyAndStripTransforms.cpp create mode 100644 kernelgen/mlir/lib/Transforms/AssignLinalgOpIds.cpp create mode 100644 kernelgen/mlir/lib/Transforms/CMakeLists.txt create mode 100644 kernelgen/mlir/lib/Transforms/CanonicalizeLoopStep.cpp create mode 100644 kernelgen/mlir/lib/Transforms/CanonicalizePartitionDim.cpp create mode 100644 kernelgen/mlir/lib/Transforms/CanonicalizeReshape.cpp create mode 100644 kernelgen/mlir/lib/Transforms/EliminateSameMemSpaceCopy.cpp create mode 100644 kernelgen/mlir/lib/Transforms/EliminateUninitializedCopies.cpp create mode 100644 kernelgen/mlir/lib/Transforms/IRHelpers.cpp create mode 100644 kernelgen/mlir/lib/Transforms/InferLayout.cpp create mode 100644 kernelgen/mlir/lib/Transforms/InlineNkipyReference.cpp create mode 100644 kernelgen/mlir/lib/Transforms/InsertMemRefDealloc.cpp create mode 100644 kernelgen/mlir/lib/Transforms/InsertSpillReload.cpp create mode 100644 kernelgen/mlir/lib/Transforms/KnobDrivenTiling.cpp create mode 100644 kernelgen/mlir/lib/Transforms/LegalizeLayout.cpp create mode 100644 kernelgen/mlir/lib/Transforms/OpClassification.cpp create mode 100644 kernelgen/mlir/lib/Transforms/PassGen.h create mode 100644 kernelgen/mlir/lib/Transforms/Passes.cpp create mode 100644 kernelgen/mlir/lib/Transforms/PrepareArithmetic.cpp create mode 100644 kernelgen/mlir/lib/Transforms/RemoveRedundantZeroFill.cpp create mode 100644 kernelgen/mlir/lib/Transforms/SimplifyLinalg.cpp create mode 100644 kernelgen/mlir/tools/CMakeLists.txt create mode 100644 kernelgen/mlir/tools/nkipy-opt/CMakeLists.txt create mode 100644 kernelgen/mlir/tools/nkipy-opt/README.md create mode 100644 kernelgen/mlir/tools/nkipy-opt/nkipy-opt.cpp create mode 100644 kernelgen/mlir/tools/nkipy-opt/test_example.mlir create mode 100644 kernelgen/nkipy_kernelgen/__init__.py create mode 100644 kernelgen/nkipy_kernelgen/apis.py create mode 100644 kernelgen/nkipy_kernelgen/builder.py create mode 100644 kernelgen/nkipy_kernelgen/compile.py create mode 100644 kernelgen/nkipy_kernelgen/control_flow.py create mode 100644 kernelgen/nkipy_kernelgen/custom_op.py create mode 100644 kernelgen/nkipy_kernelgen/execution.py create mode 100644 kernelgen/nkipy_kernelgen/knob.py create mode 100644 kernelgen/nkipy_kernelgen/llvm.py create mode 100644 kernelgen/nkipy_kernelgen/mlir_utils.py create mode 100644 kernelgen/nkipy_kernelgen/op_vtable.py create mode 100644 kernelgen/nkipy_kernelgen/pass_manager.py create mode 100644 kernelgen/nkipy_kernelgen/trace.py create mode 100644 kernelgen/nkipy_kernelgen/traced_array.py create mode 100644 kernelgen/nkipy_kernelgen/transforms/__init__.py create mode 100644 kernelgen/nkipy_kernelgen/transforms/linalg_to_nisa_py.py create mode 100644 kernelgen/nkipy_kernelgen/transforms/nkipy_opt.py create mode 100644 kernelgen/nkipy_kernelgen/utils.py create mode 100644 kernelgen/pyproject.toml create mode 100644 kernelgen/pytest.ini create mode 100644 kernelgen/requirements.txt create mode 100644 kernelgen/setup.py create mode 100644 kernelgen/tests/README.md create mode 100644 kernelgen/tests/__init__.py create mode 100644 kernelgen/tests/conftest.py create mode 100644 kernelgen/tests/debug/.gitignore create mode 100644 kernelgen/tests/debug/README.md create mode 100644 kernelgen/tests/debug/__init__.py create mode 100644 kernelgen/tests/debug/bmm/buggy.mlir create mode 100644 kernelgen/tests/debug/bmm/fix_3d_dma_indices.mlir create mode 100644 kernelgen/tests/debug/bmm/kernel.py create mode 100644 kernelgen/tests/debug/qwen3_layer/README.md create mode 100644 kernelgen/tests/debug/qwen3_layer/buggy.mlir create mode 100644 kernelgen/tests/debug/qwen3_layer/fix_rope_vector_partition.mlir create mode 100644 kernelgen/tests/debug/qwen3_layer/kernel.py create mode 100755 kernelgen/tests/debug/run.sh create mode 100755 kernelgen/tests/debug/run_sim.py create mode 100644 kernelgen/tests/e2e/__init__.py create mode 100644 kernelgen/tests/e2e/conftest.py create mode 100644 kernelgen/tests/e2e/nkipy_tests/__init__.py create mode 100644 kernelgen/tests/e2e/nkipy_tests/test_attention.py create mode 100644 kernelgen/tests/e2e/nkipy_tests/test_binary_ops.py create mode 100644 kernelgen/tests/e2e/nkipy_tests/test_composite_patterns.py create mode 100644 kernelgen/tests/e2e/nkipy_tests/test_embedding.py create mode 100644 kernelgen/tests/e2e/nkipy_tests/test_indexing.py create mode 100644 kernelgen/tests/e2e/nkipy_tests/test_llama_decoder.py create mode 100644 kernelgen/tests/e2e/nkipy_tests/test_matmul_shapes.py create mode 100644 kernelgen/tests/e2e/nkipy_tests/test_mlp.py create mode 100644 kernelgen/tests/e2e/nkipy_tests/test_reductions.py create mode 100644 kernelgen/tests/e2e/nkipy_tests/test_rope.py create mode 100644 kernelgen/tests/e2e/nkipy_tests/test_simple_add.py create mode 100644 kernelgen/tests/e2e/nkipy_tests/test_softmax.py create mode 100644 kernelgen/tests/e2e/nkipy_tests/test_tensor_manipulation.py create mode 100644 kernelgen/tests/e2e/nkipy_tests/test_unary_ops.py create mode 100644 kernelgen/tests/e2e/test_3d_elementwise.py create mode 100644 kernelgen/tests/e2e/test_attention.py create mode 100644 kernelgen/tests/e2e/test_auto_layout.py create mode 100644 kernelgen/tests/e2e/test_custom_op.py create mode 100644 kernelgen/tests/e2e/test_feedforward.py create mode 100644 kernelgen/tests/e2e/test_head_deconcat.py create mode 100644 kernelgen/tests/e2e/test_matmul_add.py create mode 100644 kernelgen/tests/e2e/test_multi_output.py create mode 100644 kernelgen/tests/e2e/test_partition_dim.py create mode 100644 kernelgen/tests/e2e/test_qwen3_layer.py create mode 100644 kernelgen/tests/e2e/test_reduce.py create mode 100644 kernelgen/tests/e2e/test_rmsnorm.py create mode 100644 kernelgen/tests/e2e/test_rope.py create mode 100644 kernelgen/tests/e2e/test_sigmoid.py create mode 100644 kernelgen/tests/harness.py create mode 100644 kernelgen/tests/passes/__init__.py create mode 100644 kernelgen/tests/passes/annotate_memory_space/__init__.py create mode 100644 kernelgen/tests/passes/annotate_memory_space/test_basic.py create mode 100644 kernelgen/tests/passes/canonicalize_loop_step/__init__.py create mode 100644 kernelgen/tests/passes/canonicalize_loop_step/test_elementwise.py create mode 100644 kernelgen/tests/passes/canonicalize_loop_step/test_matmul.py create mode 100644 kernelgen/tests/passes/canonicalize_loop_step/test_multi_op.py create mode 100644 kernelgen/tests/passes/canonicalize_partition_dim/__init__.py create mode 100644 kernelgen/tests/passes/canonicalize_partition_dim/test_basic.py create mode 100644 kernelgen/tests/passes/canonicalize_partition_dim/test_reduce.py create mode 100644 kernelgen/tests/passes/cleanup_bufferization_artifacts/__init__.py create mode 100644 kernelgen/tests/passes/cleanup_bufferization_artifacts/test_basic.py create mode 100644 kernelgen/tests/passes/conftest.py create mode 100644 kernelgen/tests/passes/eliminate_same_memspace_copy/__init__.py create mode 100644 kernelgen/tests/passes/eliminate_same_memspace_copy/test_basic.py create mode 100644 kernelgen/tests/passes/eliminate_uninitialized_copies/__init__.py create mode 100644 kernelgen/tests/passes/eliminate_uninitialized_copies/test_basic.py create mode 100644 kernelgen/tests/passes/infer_layout/__init__.py create mode 100644 kernelgen/tests/passes/infer_layout/test_infer_layout_broadcast.py create mode 100644 kernelgen/tests/passes/infer_layout/test_infer_layout_elementwise.py create mode 100644 kernelgen/tests/passes/infer_layout/test_infer_layout_matmul.py create mode 100644 kernelgen/tests/passes/infer_layout/test_infer_layout_reduce.py create mode 100644 kernelgen/tests/passes/insert_spill_reload/test_basic_spill.py create mode 100644 kernelgen/tests/passes/knob_driven_tiling/__init__.py create mode 100644 kernelgen/tests/passes/knob_driven_tiling/test_elementwise.py create mode 100644 kernelgen/tests/passes/knob_driven_tiling/test_matmul.py create mode 100644 kernelgen/tests/passes/knob_driven_tiling/test_multi_op.py create mode 100644 kernelgen/tests/passes/legalize_layout/__init__.py create mode 100644 kernelgen/tests/passes/legalize_layout/test_basic.py create mode 100644 kernelgen/tests/passes/legalize_layout/test_fold_reshape_copy.py create mode 100644 kernelgen/tests/passes/linalg_to_nisa/__init__.py create mode 100644 kernelgen/tests/passes/linalg_to_nisa/test_basic.py create mode 100644 kernelgen/tests/passes/linalg_to_nisa/test_multi_non_unit_collapse.py create mode 100644 kernelgen/tests/passes/pass_utils.py create mode 100644 kernelgen/tests/passes/prepare_arithmetic/__init__.py create mode 100644 kernelgen/tests/passes/prepare_arithmetic/test_basic.py create mode 100644 kernelgen/tests/passes/remove_linalg_zero_fill/__init__.py create mode 100644 kernelgen/tests/passes/resolve_custom_ops/test_basic.py create mode 100644 kernelgen/tests/python/__init__.py create mode 100644 kernelgen/tests/python/lit.cfg.py create mode 100644 kernelgen/tests/python/passes/__init__.py create mode 100644 kernelgen/tests/python/passes/test_knob_annotations.py create mode 100644 kernelgen/tests/python/rewrites/__init__.py create mode 100644 kernelgen/tests/unit/__init__.py create mode 100644 kernelgen/tests/unit/conftest.py create mode 100644 kernelgen/tests/unit/test_broadcast_ops.py create mode 100644 kernelgen/tests/unit/test_custom_op.py create mode 100644 kernelgen/tests/unit/test_elementwise_ops.py create mode 100644 kernelgen/tests/unit/test_execution_engine.py create mode 100644 kernelgen/tests/unit/test_for_loops.py create mode 100644 kernelgen/tests/unit/test_gather_ops.py create mode 100644 kernelgen/tests/unit/test_import_compatibility.py create mode 100644 kernelgen/tests/unit/test_matrix_ops.py create mode 100644 kernelgen/tests/unit/test_reduction_ops.py diff --git a/kernelgen/.claude/skills/build_nkipykernelgen/SKILL.md b/kernelgen/.claude/skills/build_nkipykernelgen/SKILL.md new file mode 100644 index 0000000..734dbd3 --- /dev/null +++ b/kernelgen/.claude/skills/build_nkipykernelgen/SKILL.md @@ -0,0 +1,23 @@ +--- +name: build_nkipykernelgen +description: Rebuild NKIPyKernelGen (C++ passes and Python package) +user-invocable: true +--- + +## Usage + +`/build_nkipykernelgen` + +## Instructions + +Run the build script. Use `bash` (not `sh`) since it uses `source`. Use a timeout of 300000ms. + +```bash +bash .claude/skills/build_nkipykernelgen/scripts/build.sh +``` + +Note: Run this from the NKIPyKernelGen repo root. + +## Important + +`pip install -e .` builds BOTH the C++ passes (nkipy-opt binary) AND the Python package in one step. There is NO need to run cmake separately — the pyproject.toml build system handles the full C++ compilation via cmake internally. diff --git a/kernelgen/.claude/skills/build_nkipykernelgen/scripts/build.sh b/kernelgen/.claude/skills/build_nkipykernelgen/scripts/build.sh new file mode 100755 index 0000000..f43735a --- /dev/null +++ b/kernelgen/.claude/skills/build_nkipykernelgen/scripts/build.sh @@ -0,0 +1,12 @@ +#!/bin/bash +# Rebuild NKIPyKernelGen (C++ passes and Python package). +set -e + +# Derive repo root from script location: scripts/ -> build_nkipykernelgen/ -> skills/ -> .claude/ -> repo root +REPO_ROOT="$(cd "$(dirname "$0")/../../../.." && pwd)" + +cd "$REPO_ROOT" + +echo "=== Rebuilding NKIPyKernelGen ===" +pip install -e . 2>&1 | tail -5 +echo "=== Build complete ===" diff --git a/kernelgen/.claude/skills/debug_nisa_ir/SKILL.md b/kernelgen/.claude/skills/debug_nisa_ir/SKILL.md new file mode 100644 index 0000000..f1f81e5 --- /dev/null +++ b/kernelgen/.claude/skills/debug_nisa_ir/SKILL.md @@ -0,0 +1,121 @@ +--- +name: debug_nisa_ir +description: Debug NISA MLIR that fails BIRSim. Creates a debug case under tests/debug/ with buggy.mlir, kernel.py, iterative fixes, and a README proposing compiler pass changes. +user-invocable: true +--- + +## Usage + +`/debug_nisa_ir [kernel.py path] [buggy NISA MLIR path or inline]` + +- `bug_name`: Short snake_case name for the debug case (e.g., `rope_partition_oob`) +- `kernel.py path`: Path to the Python source that was fed into `nkipy_opt`. If omitted, ask the user. +- `buggy NISA MLIR`: Path to the `.mlir` file that `nkipy_opt` produced, or the user may paste it inline. If omitted, ask the user. + +## Instructions + +You are debugging a NISA-level MLIR kernel that `nkipy_opt` generated but that fails BIRSim verification or produces incorrect numerical results. Follow this systematic workflow. + +### Step 1: Set up the debug case directory + +Create `tests/debug//` with: + +``` +tests/debug// + kernel.py # Copy of the input Python kernel + buggy.mlir # The failing NISA MLIR from nkipy_opt + README.md # Will be populated in Step 6 +``` + +Copy the user-provided `kernel.py` and `buggy.mlir` into this directory. Ensure `kernel.py` contains a function whose name matches the `sym_name` in the MLIR (this is required by `run_sim.py`). + +### Step 2: Reproduce the failure + +Run the buggy MLIR through BIRSim: + +```bash +cd tests/debug && source ./run.sh /buggy.mlir +``` + +Record the exact error output. Common failure modes: +- **BIR verification error**: `Invalid access of N partitions starting at partition M` or `Access pattern out of bounds` +- **BIRSim runtime error**: `NCC_ISIM*` errors (e.g., uninitialized PSUM read) +- **Numerical mismatch**: `SIMULATION FAILED (max_diff=...)` -- BIRSim runs but output doesn't match kernel.py + +### Step 3: Analyze the bug + +Read the MLIR carefully and identify the root cause. Common patterns: + +1. **Multi-partition SBUF with vector engine**: `tensor_tensor_arith` (engine=vector) reading from a loop-indexed partition of a multi-partition SBUF tensor. The vector engine processes all 128 partitions simultaneously and cannot address partition N selectively. + +2. **Wrong reshape/transpose lowering**: Column-by-column transposes that conflate head and head_dim dimensions. Often manifests as `<128|2>` tile on a dim of size 2 (OOB), or silent numerical corruption. + +3. **Missing accumulate flags**: Matmul K-loops without `psum_accumulate_flags`, causing PSUM overwrite instead of accumulate. + +4. **SBUF OOM**: Too many live SBUF tensors. Check if intermediates can be fused or freed earlier. + +Focus on understanding: +- Which MLIR lines are problematic (cite line numbers) +- What the pass *intended* to generate vs what it actually generated +- Why the hardware rejects it (BIR rules violated) + +### Step 4: Create iterative fixes + +For each fix attempt, create a new MLIR file: + +``` +fix__.mlir +``` + +For example: +- `fix_01_fuse_rope_elementwise.mlir` +- `fix_02_reshape_head_granularity.mlir` + +Edit the MLIR by hand to correct the identified issue. Then run: + +```bash +cd tests/debug && source ./run.sh /fix_01_.mlir +``` + +If it still fails, analyze the new error, create another fix file, and iterate. Keep each attempt as a separate file so the progression is visible. + +### Step 5: Verify the final fix + +The last `fix_*.mlir` should produce: + +``` +BIRSim PASSED +SIMULATION PASSED +``` + +Confirm that the numerical output matches `kernel.py` within tolerance (atol=1e-2, rtol=1e-2). + +### Step 6: Write the README + +Create `tests/debug//README.md` documenting: + +1. **Overview**: One paragraph summarizing what `buggy.mlir` is (which kernel, what it does) and what goes wrong. + +2. **How to reproduce**: The exact `source ../run.sh` commands for buggy and fixed versions. + +3. **Bug analysis**: For each bug found: + - **Symptom**: The exact error message + - **Location in MLIR**: Line numbers and what the code does + - **What happens**: Why the hardware rejects it or produces wrong results + - **Fix**: What was changed in the MLIR (with code snippets) + +4. **Root cause summary**: Table mapping each bug to the compiler pass responsible and whether it causes a compilation error or silent corruption. + +5. **Proposed compiler pass fixes**: For each bug, describe: + - Which pass to fix (e.g., `simplify-linalg`, `linalg-to-nisa`, tiling) + - The root cause *in the pass* (not just the MLIR symptom) + - A concrete proposed change (pseudocode or description of the algorithm change) + +Use the format from existing debug cases (see `tests/debug/qwen3_layer/README.md` for reference). + +### Tips + +- The debug harness (`run.sh` / `run_sim.py`) automatically sets up the NKI environment, generates random inputs (seed=42), compiles to NEFF with BIRSim, and compares against `kernel.py`. +- Artifacts (NEFF, BIR) are written to `artifacts_/` next to each MLIR file (git-ignored). +- When editing MLIR, keep changes minimal and targeted. Change only the ops/loops related to the bug. +- If you're unsure which pass generated a problematic pattern, check the pass pipeline in `nkipy_opt` or ask the user. diff --git a/kernelgen/.claude/skills/run_nkipykernelgen_tests/SKILL.md b/kernelgen/.claude/skills/run_nkipykernelgen_tests/SKILL.md new file mode 100644 index 0000000..33991d6 --- /dev/null +++ b/kernelgen/.claude/skills/run_nkipykernelgen_tests/SKILL.md @@ -0,0 +1,28 @@ +--- +name: run_nkipykernelgen_tests +description: Run NKIPyKernelGen tests (without rebuilding) +user-invocable: true +--- + +## Usage + +`/run_nkipykernelgen_tests [scope]` + +Where `scope` is: `all` (default), `passes`, `e2e`, or a specific path like `passes/infer_layout` or `e2e/nkipy_tests`. + +## Instructions + +1. Run the script at `~/.claude/skills/run_nkipykernelgen_tests/scripts/run_tests.sh` with the requested scope as the argument. Use `bash` to invoke it (not `sh`) since it uses `source`. Use a timeout of 600000ms. + +```bash +bash .claude/skills/run_nkipykernelgen_tests/scripts/run_tests.sh +``` + +Note: Run this from the NKIPyKernelGen repo root. + +2. The script saves full test output to `/tmp/nkipykernelgen_test_results.txt`. After the script finishes, use the Read tool to read that file for the complete results. This avoids context window issues with long test output. + +3. When reporting results, summarize: + - Total passed/failed/xfailed/xpassed/skipped counts + - List any unexpected failures (FAILED, not XFAIL) + - Note any XPASS (unexpected passes) that indicate xfail markers should be removed diff --git a/kernelgen/.claude/skills/run_nkipykernelgen_tests/scripts/run_tests.sh b/kernelgen/.claude/skills/run_nkipykernelgen_tests/scripts/run_tests.sh new file mode 100755 index 0000000..88f1106 --- /dev/null +++ b/kernelgen/.claude/skills/run_nkipykernelgen_tests/scripts/run_tests.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# Run NKIPyKernelGen tests with proper environment setup. +# Usage: run_tests.sh [scope] +# scope: all (default), passes, e2e, or a specific path like passes/infer_layout + +SCOPE="${1:-all}" +RESULTS_FILE="/tmp/nkipykernelgen_test_results.txt" + +# Derive repo root from script location: scripts/ -> run_nkipykernelgen_tests/ -> skills/ -> .claude/ -> repo root +REPO_ROOT="$(cd "$(dirname "$0")/../../../.." && pwd)" + +cd "$REPO_ROOT" + +# Run tests, capturing full output to file +echo "=== Running tests (scope: $SCOPE) ===" +echo "Results will be saved to: $RESULTS_FILE" + +case "$SCOPE" in + all) + python -m pytest tests/ -v --tb=short 2>&1 | tee "$RESULTS_FILE" + ;; + passes) + python -m pytest tests/passes/ -v --tb=short 2>&1 | tee "$RESULTS_FILE" + ;; + e2e) + python -m pytest tests/e2e/ -v --tb=short 2>&1 | tee "$RESULTS_FILE" + ;; + *) + python -m pytest "tests/$SCOPE" -v --tb=short 2>&1 | tee "$RESULTS_FILE" + ;; +esac +EXIT_CODE=${PIPESTATUS[0]} + +echo "" +echo "=== Full results saved to: $RESULTS_FILE ===" +exit $EXIT_CODE diff --git a/kernelgen/.gitignore b/kernelgen/.gitignore new file mode 100644 index 0000000..b6bcf21 --- /dev/null +++ b/kernelgen/.gitignore @@ -0,0 +1,47 @@ +# Override parent nkipy/.gitignore's `lib/` rule so MLIR C++ sources in +# mlir/lib/ are tracked (the parent rule is aimed at Python venv lib/ dirs). +!mlir/lib/ +!mlir/lib/** + +# Python +__pycache__/ +*.py[cod] +*.so + +# Distribution / packaging +build/ +dist/ +*.egg-info/ +.eggs/ +*.whl + +# Built MLIR bindings (generated during build) +nkipy_kernelgen/_mlir/ + +# Virtual environments +venv/ +.env + +# Testing +.pytest_cache/ +.coverage +tests/**/outputs/ +tests/**/artifacts/ + +# IDE +.vscode/ +.idea/ + +# OS +.DS_Store +Thumbs.db + +# Logs +*.log + +# LLVM lit test outputs +.lit_test_times.txt +Output/ + +# Compiler Explorer (cloned repo) +compiler_explorer/compiler-explorer/ diff --git a/kernelgen/CLAUDE.md b/kernelgen/CLAUDE.md new file mode 100644 index 0000000..c2f9a92 --- /dev/null +++ b/kernelgen/CLAUDE.md @@ -0,0 +1,268 @@ +# CLAUDE.md + +> **Keep this file up to date.** After any change to the codebase (new passes, renamed files, pipeline changes, new test patterns), update the relevant sections of this file so future sessions start with accurate context. + +## Project Overview + +NKIPyKernelGen is a compiler that traces NumPy functions and lowers them to NISA (Neuron Instruction Set Architecture) for AWS Neuron hardware. Users write kernels in Python with `@trace` and `knob.knob()` annotations, and the compiler handles tiling, memory placement, layout legalization, and NISA lowering. + +### Git Commit Policy + +- Do not add a `Co-Authored-By` line in commit messages. + +### Design Philosophy + +- **Fast prototyping.** Prioritize speed of development. Keep code simple, ship quickly, iterate. +- **Composable passes.** Each pass does one thing well and produces valid IR. The IR after any pass should be functional and simulatable (before NISA lowering: via LLVM JIT; after: via BIR simulation). +- **Testable at every stage.** Pass tests verify individual transformations with FileCheck. E2E tests verify numerical correctness through the full pipeline. + +## Build & Environment Setup + +### Environment setup + +Before building or running tests, activate your Python virtual environment: + +```bash +source /bin/activate +``` + +### Install / rebuild + +`pip install -e .` builds the C++ passes (`nkipy-opt`) via CMake under the hood — no manual CMake invocation needed: + +```bash +pip install -e . # builds C++ and installs Python package in editable mode +``` + +If you only changed Python code (no C++ pass changes), the editable install means changes take effect immediately — no rebuild needed. + +### Run tests + +```bash +python3 -m pytest tests/passes/ -v # pass-level tests +python3 -m pytest tests/e2e/ -v # end-to-end tests +python3 -m pytest tests/unit/ -v # unit tests +``` + +## Compilation Pipeline (24 passes) + +Defined in `nkipy_kernelgen/transforms/nkipy_opt.py` → `apply_complete_knob_pipeline()`. +C++ pass implementations in `mlir/lib/Transforms/`. Pass 24 (`py:linalg-to-nisa`) is Python. + +### Phase 0: Arithmetic Preparation +| # | Pass | Source | What it does | +|---|------|--------|--------------| +| 1 | `remove-redundant-zero-fill` | RemoveRedundantZeroFill.cpp | Remove `linalg.fill(0)` ops whose only users are matmul-like ops. NISA matmul auto-zeros PSUM, so the fill is redundant. Must run before tiling to prevent the fill from becoming a `nisa.memset` with out-of-bounds access. | +| 2 | `prepare-arithmetic` | PrepareArithmetic.cpp | Rewrite `linalg.div(A,B)` → `linalg.mul(A, linalg.reciprocal(B))`. NISA has no divide. Runs before tiling so reciprocal gets tiled normally. | + +### Phase 1: Layout Inference, Partition Dim Canonicalization, and Tiling (on tensor IR) +| # | Pass | Source | What it does | +|---|------|--------|--------------| +| 3 | `infer-layout` | InferLayout.cpp | Auto-infer tile_size, mem_space, partition_dim, and reduction_tile for all linalg ops. Seeds from user annotations, then matmul hardware rules, then elementwise fallback defaults. Uses bidirectional BFS propagation with tile-size divisibility conflict checks. Return values get SharedHbm; intermediates default to SBUF. | +| 4 | `canonicalize-partition-dim` | CanonicalizePartitionDim.cpp | Insert `linalg.transpose` ops at boundaries where `partition_dim != 0`, so all downstream ops see `partition_dim=0`. Annotates inserted transposes with tile_size/mem_space for tiling. | +| 5 | `assign-linalg-op-ids` | AssignLinalgOpIds.cpp | Stamp unique `nkipy.op_id` on each linalg op (including transposes from step 3) so transform dialect can match individual instances. | +| 6 | `knob-driven-tiling` | KnobDrivenTiling.cpp | Read `nkipy.annotate` ops and emit a `transform.named_sequence @__transform_main` that tiles each linalg op according to its tile_size/reduction_tile. Supports arbitrary-rank tensors and `linalg.transpose`. | +| 7 | `apply-and-strip-transforms` | ApplyAndStripTransforms.cpp | Fused pass: runs `@__transform_main` (same semantics as upstream `--transform-interpreter`) and then erases the transform module (NamedSequenceOps + `transform.with_named_sequence` attr). Leaves the IR free of any transform-dialect ops, which is what the Python linalg→NISA phase (parsed via upstream MLIR bindings) needs. | +| 8 | `canonicalize-loop-step` | CanonicalizeLoopStep.cpp | Rewrite `scf.for %i = 0 to N step S` → `scf.for %i = 0 to N/S step 1` with `%orig = %i * S`. Simplifies downstream index math. | + +### Phase 2: Bufferization +| # | Pass | Source | What it does | +|---|------|--------|--------------| +| 9 | `one-shot-bufferize` | (upstream MLIR) | Convert tensor IR to memref IR (tensor.extract_slice → memref.subview, etc.) | +| 10 | `canonicalize` | (upstream MLIR) | Fold and simplify memref operations | + +### Phase 3: Memory Space Annotation + Reshape Canonicalization +| # | Pass | Source | What it does | +|---|------|--------|--------------| +| 11 | `eliminate-uninitialized-copies` | EliminateUninitializedCopies.cpp | Remove `memref.copy` from never-written allocs (e.g. PSUM accumulator init — matmul zeros PSUM via `psum_zero_region`). | +| 12 | `canonicalize` | (upstream MLIR) | Clean up dead subview chains from eliminated copies | +| 13 | `annotate-memory-space` | AnnotateMemorySpace.cpp | Read `nkipy.annotate` ops, apply NISA memory space attrs (`#nisa.mem`, `#nisa.mem`, `#nisa.mem`) to memref types, mark function args as SharedHbm, then erase all `nkipy.annotate` ops. | +| 14 | `canonicalize-reshape` | CanonicalizeReshape.cpp | Classify `expand_shape`/`collapse_shape` by mem_space and partition_dim. HBM reshapes and SBUF non-pdim reshapes stay as views. SBUF partition dim splits get alloc+copy (NISA has no modulo). Returned expand_shape views of func args and direct returns of func args get alloc+copy (NISA needs separate output allocations). | +| 15 | `eliminate-same-memspace-copy` | EliminateSameMemSpaceCopy.cpp | Eliminate redundant copies within the same memory space (e.g. SBUF→SBUF when data is already in SBUF from a previous op). Rewires uses to read the source directly. | +| 16 | `canonicalize` | (upstream MLIR) | DCE dead allocs from eliminated copies | + +### Phase 4: Layout Legalization + Spill/Reload +| # | Pass | Source | What it does | +|---|------|--------|--------------| +| 17 | `legalize-layout` | LegalizeLayout.cpp | Transform SBUF allocs from R-D to (R+2)-D physical layout (`[partTile, numBlocks..., freeTile]`), rewrite subviews to (R+2)-D indexing, collapse everything to 2D for NISA compute. Reconstruct linalg named ops (AddOp, SubOp, ReciprocalOp, etc.) with 2D operands. | +| 18 | `canonicalize` | (upstream MLIR) | Fold collapse_shape chains, simplify affine maps | +| 19 | `simplify-linalg` | SimplifyLinalg.cpp | Decompose high-rank transposes to loops of 2D, collapse >2D SBUF transpose to 2D, canonicalize trivial-broadcast generics to named ops. Runs before insert-spill-reload so any SBUF temps it creates are accounted for in spill/reload memory budgeting. | +| 20 | `insert-spill-reload` | InsertSpillReload.cpp | Analyze per-partition SBUF memory pressure and insert spill (SBUF→HBM) and reload (HBM→SBUF) `memref.copy` ops when capacity is exceeded. Uses Belady's MIN heuristic. Per-partition size = `total_size / shape[0]` on physical layout. | +| 21 | `insert-memref-dealloc` | InsertMemRefDealloc.cpp | Insert `memref.dealloc` for SBUF/PSUM allocs at the end of their enclosing scope (loop body or function). These become `nisa.release` after lowering. Skips HBM/SharedHBM (externally managed). | +| 22 | `cse` | (upstream MLIR) | Common subexpression elimination | +| 23 | `canonicalize` | (upstream MLIR) | DCE unused subviews and cleanup | + +### Phase 5: NISA Lowering + Finalization (Python) + +NISA lowering is implemented in Python using the `nki` wheel's Python bindings. + +| # | Pass | Source | What it does | +|---|------|--------|--------------| +| 24 | `py:linalg-to-nisa` | linalg_to_nisa_py.py | Python reimplementation combining the old linalg-to-nisa, resolve-custom-ops, and prepare-for-nki passes. Pattern-matches linalg/memref ops to NISA: `linalg.add/sub/mul` → `nisa.tensor_tensor_arith`; `linalg.matmul` → `nisa.matmul`; `memref.copy(HBM↔SBUF)` → `nisa.dma_copy`; `linalg.exp` → `nisa.activation(op=exp)`; scalar broadcast ops → `nisa.tensor_scalar_arith`; `linalg.reciprocal` → `nisa.reciprocal`; `linalg.transpose` → `nisa.dma_transpose` or copy (for trivial reshapes); `memref.dealloc` → `nisa.release`. Also inlines custom op bodies and adds `nisa.target` hardware attribute. | + +## Key Source Files + +### Python (tracing & frontend) +| File | Role | +|------|------| +| `nkipy_kernelgen/trace.py` | `@trace` decorator — traces NumPy functions to MLIR | +| `nkipy_kernelgen/knob.py` | `knob.knob()` API — annotate tensors with mem_space, tile_size, reduction_tile | +| `nkipy_kernelgen/op_vtable.py` | NumPy op → MLIR lowering table | +| `nkipy_kernelgen/traced_array.py` | TracedArray wrapping MLIR SSA values with NumPy-like interface | +| `nkipy_kernelgen/transforms/nkipy_opt.py` | Pipeline orchestration — shells out to `nkipy-opt` binary | +| `nkipy_kernelgen/transforms/linalg_to_nisa_py.py` | Python NISA lowering — reimplements linalg-to-nisa, resolve-custom-ops, prepare-for-nki | + +### C++ (MLIR dialect & passes) +| File | Role | +|------|------| +| `mlir/include/nkipy/Dialect/NkipyOps.td` | `nkipy.annotate` op (target, mem_space, partition_dim, tile_size, reduction_tile) | +| `mlir/lib/Transforms/*.cpp` | All pass implementations (see pipeline table) | +| `mlir/lib/Transforms/OpClassification.cpp` | Shared helpers: classify linalg ops as unary/binary elementwise, matmul, etc. | +| `mlir/lib/Transforms/IRHelpers.cpp` | Shared IR utilities: constant extraction, memref helpers | +| `mlir/lib/Transforms/InlineNkipyReference.cpp` | Inline reference bodies into post-bufferization IR | + +## Test Structure + +``` +tests/ +├── conftest.py # Root conftest: centralizes sys.path setup +├── harness.py # Unified test harness: run_kernel_test(), Mode flags +├── passes/ # Per-pass unit tests +│ ├── pass_utils.py # Shared: run_passes(), compile_through_passes(), run_filecheck() +│ ├── annotate_memory_space/ # Memory space annotation tests +│ ├── canonicalize_loop_step/ # Loop step normalization tests +│ ├── canonicalize_partition_dim/# Partition dim canonicalization tests +│ ├── cleanup_bufferization_artifacts/ # Bufferization cleanup tests +│ ├── eliminate_same_memspace_copy/ # Same-memspace copy elimination tests +│ ├── eliminate_uninitialized_copies/ # Uninitialized copy elimination tests +│ ├── infer_layout/ # Layout inference propagation tests +│ ├── insert_spill_reload/ # SBUF spill/reload tests +│ ├── knob_driven_tiling/ # Tiling tests (2D + 3D) +│ ├── legalize_layout/ # Layout legalization tests (2D + 3D) +│ ├── linalg_to_nisa/ # NISA lowering tests +│ ├── prepare_arithmetic/ # Div-to-reciprocal rewrite tests +│ ├── remove_linalg_zero_fill/ # Zero fill removal tests +│ ├── resolve_custom_ops/ # Custom op inlining tests +│ └── ... # One directory per pass, each with test_*.py +├── e2e/ # End-to-end: trace → NISA → BIR simulation / HW execution +│ ├── test_3d_elementwise.py # 3D tensor tests for rank-R generalization +│ ├── test_attention.py # Multi-head attention +│ ├── test_auto_layout.py # Auto-inferred layouts (no user annotations) +│ ├── test_custom_op.py # Custom NISA op inlining +│ ├── test_feedforward.py # Feedforward network (matmul + SiLU + split) +│ ├── test_head_deconcat.py # Head deconcatenation +│ ├── test_matmul_add.py # Matmul + add fusion patterns +│ ├── test_multi_output.py # Multiple output tensors +│ ├── test_partition_dim.py # partition_dim != 0 with canonicalize-partition-dim +│ ├── test_qwen3_layer.py # Qwen3 transformer layer +│ ├── test_reduce.py # Reduce sum/mean operations +│ ├── test_rmsnorm.py # RMS normalization +│ ├── test_rope.py # Rotary positional embedding +│ ├── test_sigmoid.py # Sigmoid, exp, scalar arithmetic, reciprocal +│ ├── nkipy_tests/ # Additional e2e tests (add, binary/unary ops, MLP, softmax, etc.) +│ └── ... +├── python/ # Python-level pass and rewrite tests (lit-style) +└── unit/ # Python-level unit tests +``` + +### Test modes (from `harness.py`) +- `Mode.LLVM` — LLVM JIT execution, compare to NumPy (requires `stop_after`) +- `Mode.BIR_SIM` — Full pipeline to NISA → BIR simulation via neuron-cc +- `Mode.HW` — Full pipeline → run on Trainium hardware (auto-skips if no device) +- `Mode.STRING_CHECK` — Assert compiled IR contains/excludes specific strings +- `Mode.FILECHECK` — Run LLVM FileCheck on compiled IR + +### Test artifacts +When `request` fixture is passed to `run_kernel_test`, intermediate IR is dumped to `tests//outputs//`. Each pass produces a numbered file (e.g., `01_assign_linalg_op_ids.mlir`, `15_legalize_layout.mlir`). + +Pass tests explicitly set `dump_dir` via their `utils.py` helpers, typically saving to an `outputs/` directory next to the test file. + +## Debugging + +### `--dump-ir`: Dump IR from any test (recommended) + +Add `--dump-ir` to any pytest invocation to save intermediate MLIR after every compiler pass: + +```bash +python3 -m pytest tests/e2e/test_rope.py::test_rope --dump-ir -v -s +``` + +Output: +``` +[dump-ir] IR will be saved to: /tmp/nkipy_dump_ir_abc123 +[dump-ir] 25 IR files saved to: /tmp/nkipy_dump_ir_abc123 + 00_input.mlir (3,532 bytes) + 01_remove-redundant-zero-fill.mlir (3,532 bytes) + 02_prepare-arithmetic.mlir (3,533 bytes) + ... + 23_canonicalize.mlir (8,044 bytes) + 24_py_linalg-to-nisa.mlir (11,260 bytes) +``` + +If a test passes `request` to `run_kernel_test`, IR goes to `tests//outputs//` instead of `/tmp`. + +**Auto-dump on failure:** Even without `--dump-ir`, if compilation crashes, the harness automatically re-runs pass-by-pass into a temp directory and prints the path so you always get IR context. + +### Compiler Explorer + +For standalone kernel files (not embedded in tests), use the Compiler Explorer wrapper: + +```bash +cd compiler_explorer +python3 nkipy_compiler.py examples/qwen3_layer.py --stop=24 --raw # clean IR after linalg-to-nisa +python3 nkipy_compiler.py examples/qwen3_layer.py --sim # full pipeline + BIR simulation +python3 nkipy_compiler.py examples/qwen3_layer.py --stop=6 --sim # IR at stop point + LLVM JIT verify +``` + +Use `--raw` for clean MLIR output (no `.loc`/`.file` annotations). Without `--raw`, output includes Compiler Explorer source-location annotations. + +Or launch the web UI: `./setup.sh` → open http://localhost:10240. + +### Inspect intermediate IR directly + +```bash +# Run single pass directly: +nkipy-opt --legalize-layout input.mlir + +# Chain passes: +nkipy-opt --annotate-memory-space --legalize-layout input.mlir + +# See IR after every pass: +nkipy-opt --mlir-print-ir-after-all --legalize-layout input.mlir 2>&1 +``` + +### From Python + +```python +from nkipy_kernelgen.transforms.nkipy_opt import apply_complete_knob_pipeline + +# Dump all intermediate files: +apply_complete_knob_pipeline(mlir_str, dump_dir="debug_outputs/") + +# Stop at a specific pass: +apply_complete_knob_pipeline(mlir_str, stop_after="legalize-layout", dump_dir="debug_outputs/") + +# Stop after pass N (1-indexed): +apply_complete_knob_pipeline(mlir_str, stop_after=10) # stop after bufferize (canonicalize) + +# For repeated passes like canonicalize, use "name:N" for the Nth occurrence: +apply_complete_knob_pipeline(mlir_str, stop_after="canonicalize:3") # 3rd canonicalize (pass 16) +``` + +### `tests/debug/`: NISA MLIR debug cases + +The `tests/debug/` directory contains standalone NISA MLIR test cases for debugging BIRSim failures. Each subdirectory is a self-contained repro with: + +- `kernel.py` — NumPy reference implementation (used to compare BIRSim output) +- `buggy.mlir` — the broken NISA IR +- `fix_*.mlir` — proposed fixes (one or more iterations) +- `README.md` — (optional) root cause analysis and proposed compiler changes + +Run a debug case: +```bash +cd tests/debug +source ./run.sh bmm/buggy.mlir # run buggy version +source ./run.sh bmm/fix_3d_dma_indices.mlir # run fixed version +``` + +This compiles the MLIR to NEFF, runs BIRSim, and compares against `kernel.py`. Artifacts (NEFF, BIR JSON, logs) are saved to `artifacts_/` next to the MLIR file. diff --git a/kernelgen/README.md b/kernelgen/README.md new file mode 100644 index 0000000..5f1b5ad --- /dev/null +++ b/kernelgen/README.md @@ -0,0 +1,99 @@ + +# NKIPy KernelGen + +A Python-to-NISA MLIR compiler for Trainium. Trace NumPy functions, annotate +with knobs, and lower through a pipeline of MLIR passes to NKI/NISA dialect. + +## Features + +- **`@trace` decorator** — trace Python functions with NumPy operations into linalg MLIR +- **`knob()` API** — annotate tensors with tiling, memory placement, and partitioning hints (similar to OpenMP pragmas) +- **MLIR pass pipeline** — prepare-arithmetic, assign-op-ids, infer-layout, knob-driven-tiling, legalize-layout, linalg-to-nisa, and more +- **Compiler Explorer** — interactive web UI for inspecting IR at each compilation stage + +## Setup + +NKIPyKernelGen depends on LLVM/MLIR. Install LLVM with MLIR support first. + +```bash +# Activate your Python virtual environment +source /bin/activate + +# Install the Python package (editable) — builds C++ passes via CMake +pip install -e . +``` + +## Quick Start + +```python +import numpy as np +from nkipy_kernelgen import trace, knob + +@trace(input_specs=[((256, 256), "f32"), ((256, 256), "f32")]) +def matmul_add(A, B): + C = np.matmul(A, B) + knob.knob(C, mem_space="Sbuf", tile_size=[128, 128, 128]) + result = np.exp(C) + knob.knob(result, mem_space="Sbuf", tile_size=[128, 128]) + return result +``` + +The `knob()` API injects `nkipy.annotate` ops into the IR to guide tiling and +buffer placement. The `infer-layout` pass propagates annotations to unannotated +intermediate ops. + +## Project Structure + +``` +NKIPyKernelGen/ +├── nkipy_kernelgen/ # Python package (tracer, knob API, transforms) +├── mlir/ # MLIR dialects (NKIPy) and C++ passes +│ └── lib/Transforms/ # C++ pass implementations (phases 0–4) +├── compiler_explorer/ # Compiler Explorer web UI integration +├── examples/ # Example kernels and usage patterns +└── tests/ # Test suite + ├── unit/ # Unit tests for ops and execution engine + ├── passes/ # MLIR pass tests (tiling, layout, etc.) + ├── e2e/ # End-to-end compilation tests + └── python/ # Python-level pass and rewrite tests (lit-style) +``` + +## MLIR Passes + +The compilation pipeline runs 24 passes in 5 phases. Key passes: + +| Phase | Key Passes | Description | +|-------|------------|-------------| +| 0. Arithmetic prep | `remove-redundant-zero-fill`, `prepare-arithmetic` | Remove redundant zero fills, rewrite div → mul+reciprocal | +| 1. Layout & tiling | `infer-layout`, `canonicalize-partition-dim`, `knob-driven-tiling`, `apply-and-strip-transforms` | Infer tile sizes/mem spaces, canonicalize partition dims, tile via transform dialect | +| 2. Bufferization | `one-shot-bufferize` | Convert tensor IR to memref IR | +| 3. Memory | `annotate-memory-space`, `canonicalize-reshape`, `eliminate-same-memspace-copy` | Apply NISA memory spaces, canonicalize reshapes, eliminate redundant copies | +| 4. Layout & spill | `legalize-layout`, `simplify-linalg`, `insert-spill-reload`, `insert-memref-dealloc` | Physical layout legalization, SBUF spill/reload, dealloc insertion | +| 5. NISA lowering | `py:linalg-to-nisa` | Lower linalg/memref to NISA ops (Python, via `nki` wheel) | + +See `CLAUDE.md` for the full 24-pass pipeline with detailed descriptions. + +## Testing + +```bash +# Run all tests +python -m pytest tests/ -v + +# Run a specific test category +python -m pytest tests/passes/infer_layout/ -v +python -m pytest tests/e2e/ -v +python -m pytest tests/unit/ -v +``` + +## Compiler Explorer + +Interactive web UI for inspecting IR at each compilation stage: + +```bash +cd compiler_explorer +./setup.sh # Clone Compiler Explorer and start server +# Open http://localhost:10240, select NKIPy MLIR as the compiler +``` + +Use `--stop=N` to view IR after each pass, `--sim` for BIR simulation +verification, `--hw` for hardware compilation to NEFF. diff --git a/kernelgen/compiler_explorer/README.md b/kernelgen/compiler_explorer/README.md new file mode 100644 index 0000000..da847fa --- /dev/null +++ b/kernelgen/compiler_explorer/README.md @@ -0,0 +1,118 @@ +# Compiler Explorer Integration for NKIPy + +This directory contains the configuration to run [Compiler Explorer](https://github.com/compiler-explorer/compiler-explorer) locally with NKIPy as a compiler backend. + +## Quick Start + +```bash +# Run the setup script (clones CE, installs deps, configures, and starts) +./setup.sh +``` + +Then open http://localhost:10240 in your browser. + +## Usage in Compiler Explorer + +1. Select **Python** as the language (left panel) +2. Select **NKIPy MLIR** as the compiler (right panel dropdown) +3. Write your traced kernel: + +```python +import numpy as np +from nkipy_kernelgen import trace, knob + +@trace(input_specs=[ + ((128, 128), "f32"), + ((128, 128), "f32"), +]) +def matmul_kernel(A, B): + C = np.matmul(A, B) + knob.knob(C, mem_space="Sbuf", tile_size=[64, 64, 64]) + return C +``` + +4. The right panel shows the compiled MLIR output + +## Compiler Options + +Add these in the "Compiler options" box to control output: + +| Option | Description | +|--------|-------------| +| `--stop=0` | Trace only — initial MLIR before any passes | +| `--stop=2` | After `prepare-arithmetic` (div → mul+reciprocal) | +| `--stop=7` | After `apply-and-strip-transforms` (tiling applied) | +| `--stop=10` | After bufferization + canonicalize | +| `--stop=16` | After memory space annotation + cleanup | +| `--stop=17` | After layout legalization | +| `--stop=24` | After NISA lowering (same as omitting `--stop`) | +| `--sim` | Run BIR simulation and verify against NumPy reference | +| `--sim --stop=N` | Run LLVM JIT simulation on intermediate IR at stop point | +| `--hw` | Compile to NEFF and execute on Trainium hardware (requires device) | +| `--target=trn1\|trn2\|trn3` | Target hardware (default: `trn2`) | +| `--raw` | Clean MLIR output without `.loc`/`.file` annotations | + +## Example Workflow + +To debug how tiling transforms your kernel: + +1. Write your kernel in the left panel +2. Set compiler options to `--stop=0` to see the initial traced MLIR +3. Change to `--stop=7` to see loops introduced by tiling +4. Change to `--stop=10` to see bufferized memref IR +5. Change to `--stop=16` for memory space annotations +6. Remove `--stop` to see the final NISA MLIR + +## Files + +``` +compiler_explorer/ +├── nkipy_compiler.py # Main compiler wrapper +├── nkipy_ce_wrapper.sh # Shell wrapper for CE +├── setup.sh # Setup and run script +├── README.md # This file +├── config/ +│ ├── nkipy.local.properties # CE config (Python language) +│ ├── c.nkipy.properties # Alternative: custom language +│ └── example.nkipy # Example kernel +``` + +## Manual Setup + +If you prefer to set up manually: + +```bash +# 1. Clone Compiler Explorer +git clone https://github.com/compiler-explorer/compiler-explorer.git +cd compiler-explorer + +# 2. Install dependencies +npm install + +# 3. Copy config (edit paths first!) +cp /path/to/NKIPyKernelGen/compiler_explorer/config/nkipy.local.properties etc/config/ + +# 4. Edit the config to fix the wrapper path +vim etc/config/nkipy.local.properties + +# 5. Start +npm run dev +``` + +## Troubleshooting + +**"No @trace decorated function found"** +- Ensure your function has the `@trace(input_specs=[...])` decorator + +**Import errors** +- The wrapper sets PYTHONPATH automatically +- Check that nkipy_kernelgen is installed or accessible + +**Node.js version error** +- Compiler Explorer requires Node.js 18+ +- Use `nvm` to manage Node versions + +## Limitations + +- Compiler Explorer shows source → single output (one `--stop` level at a time) +- For side-by-side comparison of passes, use the dumped artifacts with `diff` or `vimdiff` diff --git a/kernelgen/compiler_explorer/config/c.nkipy.properties b/kernelgen/compiler_explorer/config/c.nkipy.properties new file mode 100644 index 0000000..a10505c --- /dev/null +++ b/kernelgen/compiler_explorer/config/c.nkipy.properties @@ -0,0 +1,31 @@ +# Alternative configuration using a custom language definition +# This gives more control over syntax highlighting and behavior + +# Register NKIPy as a language (instead of using Python) +languages=&nkipy + +# Language definition +language.nkipy.name=NKIPy +language.nkipy.monaco=python +language.nkipy.extensions=.py +language.nkipy.alias=nkipy +language.nkipy.example=nkipy/example.nkipy + +# Compiler definition +compilers=&nkipy-mlir + +compiler.nkipy-mlir.exe=/path/to/NKIPyKernelGen/compiler_explorer/nkipy_ce_wrapper.sh +compiler.nkipy-mlir.name=NKIPy to MLIR +compiler.nkipy-mlir.lang=nkipy +compiler.nkipy-mlir.compilerType= +compiler.nkipy-mlir.options= +compiler.nkipy-mlir.supportsBinary=false +compiler.nkipy-mlir.supportsExecute=false + +# Define available compiler options that show in dropdown +compiler.nkipy-mlir.options.--stop=0=Trace only (no passes) +compiler.nkipy-mlir.options.--stop=7=Stop after tiling +compiler.nkipy-mlir.options.--stop=10=Stop after bufferization +compiler.nkipy-mlir.options.--stop=16=Stop after memory space annotation +compiler.nkipy-mlir.options.--stop=17=Stop after layout legalization +compiler.nkipy-mlir.options.--stop=24=Full compilation (NISA lowering) diff --git a/kernelgen/compiler_explorer/config/example.nkipy b/kernelgen/compiler_explorer/config/example.nkipy new file mode 100644 index 0000000..7dc48b6 --- /dev/null +++ b/kernelgen/compiler_explorer/config/example.nkipy @@ -0,0 +1,28 @@ +""" +Example NKIPy kernel for Compiler Explorer: Matmul + Add + +This kernel computes: result = matmul(A, B) + bias +- A: 256x256 (HBM input) +- B: 256x256 (HBM input) +- bias: 256x256 (HBM input) +- matmul output: SBUF intermediate +- final result: HBM output +""" +import numpy as np +from nkipy_kernelgen import trace, knob + +M, N, K = 256, 256, 256 +matmul_tile = [128, 128, 128] # TILE_M, TILE_N, TILE_K +add_tile = [128, 128] # TILE_M, TILE_N + +@trace(input_specs=[((M, K), "f32"), ((K, N), "f32"), ((M, N), "f32")]) +def matmul_add_kernel(a, b, bias): + # Matmul outputs to SBUF for reuse in the add + c = np.matmul(a, b) + knob.knob(c, mem_space="Sbuf", tile_size=matmul_tile) + + # Add outputs to SharedHbm (returned from kernel) + result = c + bias + knob.knob(result, mem_space="SharedHbm", tile_size=add_tile) + + return result diff --git a/kernelgen/compiler_explorer/config/nkipy.local.properties b/kernelgen/compiler_explorer/config/nkipy.local.properties new file mode 100644 index 0000000..b94bb1a --- /dev/null +++ b/kernelgen/compiler_explorer/config/nkipy.local.properties @@ -0,0 +1,17 @@ +# Compiler Explorer local configuration for NKIPy MLIR compiler +# +# Place this file in compiler-explorer/etc/config/ directory +# or merge with your existing local configuration + +# Define NKIPy as a compiler for Python language +compilers=&nkipy + +# NKIPy compiler definition +compiler.nkipy.exe=/path/to/NKIPyKernelGen/compiler_explorer/nkipy_ce_wrapper.sh +compiler.nkipy.name=NKIPy MLIR +compiler.nkipy.lang=python +compiler.nkipy.compilerType= +compiler.nkipy.notification=NKIPy compiles Python+NumPy to NISA MLIR +compiler.nkipy.options=--pass=final +compiler.nkipy.supportsBinary=false +compiler.nkipy.supportsExecute=false diff --git a/kernelgen/compiler_explorer/examples/attention_scores.py b/kernelgen/compiler_explorer/examples/attention_scores.py new file mode 100644 index 0000000..d2de933 --- /dev/null +++ b/kernelgen/compiler_explorer/examples/attention_scores.py @@ -0,0 +1,45 @@ +import numpy as np +from nkipy_kernelgen import trace, knob + +# Hardcoded dimensions +batch = 2 +n_heads = 4 +seq_len = 256 +head_dim = 256 +tile_size = [1, 128, 128] + +scale = 1.0 / np.sqrt(head_dim).item() + +@trace(input_specs=[ + ((batch * n_heads, seq_len, head_dim), "f32"), + ((batch * n_heads, head_dim, seq_len), "f32"), +]) +def attention_kernel(q, k_transposed): + # Score computation (K is pre-transposed to avoid np.transpose) + bmm_result = np.matmul(q, k_transposed) + # knob.knob(bmm_result, mem_space="SharedHbm", tile_size=tile_size, reduction_tile=[128]) + + # When placed to SBUF, the M-dim is the partition dimension + knob.knob(bmm_result, mem_space="Sbuf", tile_size=tile_size, reduction_tile=[128]) + + scores = bmm_result * scale + knob.knob(scores, mem_space="Sbuf", tile_size=tile_size, partition_dim=1) + + # Softmax + scores_fp32 = scores.astype(np.float32) + + scores_max = np.max(scores_fp32, axis=-1, keepdims=True) + knob.knob(scores_max, mem_space="Sbuf", tile_size=[1, 128], reduction_tile=[128], partition_dim=1) + + shifted = scores_fp32 - scores_max + knob.knob(shifted, mem_space="Sbuf", tile_size=tile_size, partition_dim=1) + + exp_s = np.exp(shifted) + knob.knob(exp_s, mem_space="Sbuf", tile_size=tile_size, partition_dim=1) + + sum_exp = np.sum(exp_s, axis=-1, keepdims=True) + knob.knob(sum_exp, mem_space="Sbuf", tile_size=[1, 128], reduction_tile=[128], partition_dim=1) + + result = exp_s / sum_exp + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size, partition_dim=1) + return result diff --git a/kernelgen/compiler_explorer/examples/attention_scores_loop.py b/kernelgen/compiler_explorer/examples/attention_scores_loop.py new file mode 100644 index 0000000..661e861 --- /dev/null +++ b/kernelgen/compiler_explorer/examples/attention_scores_loop.py @@ -0,0 +1,51 @@ +import numpy as np +from nkipy_kernelgen import trace, knob +from nkipy_kernelgen.apis import fori_loop + +# Hardcoded dimensions +batch = 2 +n_heads = 4 +seq_len = 256 +head_dim = 256 +tile_size = [128, 128] + +scale = 1.0 / np.sqrt(head_dim).item() + +@trace(input_specs=[ + ((batch * n_heads, seq_len, head_dim), "f32"), + ((batch * n_heads, head_dim, seq_len), "f32"), +]) +def attention_kernel_loop(q, k_transposed): + init_result = np.empty((batch * n_heads, seq_len, seq_len), dtype=np.float32) + + def body(i, acc): + q_i = q[i] + k_i = k_transposed[i] + + # Score computation (K is pre-transposed to avoid np.transpose) + scores = np.matmul(q_i, k_i) * scale + knob.knob(scores, mem_space="Sbuf", tile_size=tile_size, reduction_tile=[128]) + + # Softmax + scores_fp32 = scores.astype(np.float32) + + scores_max = np.max(scores_fp32, axis=-1, keepdims=True) + knob.knob(scores_max, mem_space="Sbuf", tile_size=[128], reduction_tile=[128]) + + shifted = scores_fp32 - scores_max + knob.knob(shifted, mem_space="Sbuf", tile_size=tile_size) + + exp_s = np.exp(shifted) + knob.knob(exp_s, mem_space="Sbuf", tile_size=tile_size) + + sum_exp = np.sum(exp_s, axis=-1, keepdims=True) + knob.knob(sum_exp, mem_space="Sbuf", tile_size=[128], reduction_tile=[128]) + + softmax_out = exp_s / sum_exp + knob.knob(softmax_out, mem_space="SharedHbm", tile_size=tile_size) + + acc[i] = softmax_out + return acc + + results = fori_loop(0, batch * n_heads, body, init_result) + return results diff --git a/kernelgen/compiler_explorer/examples/bmm.py b/kernelgen/compiler_explorer/examples/bmm.py new file mode 100644 index 0000000..59aaa6c --- /dev/null +++ b/kernelgen/compiler_explorer/examples/bmm.py @@ -0,0 +1,14 @@ +import numpy as np +from nkipy_kernelgen import trace, knob + +batch = 8 +M, N, K = 256, 256, 256 + +@trace(input_specs=[ + ((batch, M, K), "f32"), + ((batch, K, N), "f32"), +]) +def bmm_kernel(a, b): + result = np.matmul(a, b) + knob.knob(result, mem_space="SharedHbm", tile_size=[16, 128, 128], reduction_tile=[128]) + return result diff --git a/kernelgen/compiler_explorer/examples/custom.py b/kernelgen/compiler_explorer/examples/custom.py new file mode 100644 index 0000000..fb7c4eb --- /dev/null +++ b/kernelgen/compiler_explorer/examples/custom.py @@ -0,0 +1,23 @@ +""" +Example NKIPy kernel: Embedding lookup with bias + +Gathers rows from an embedding table using np.take, then adds a bias. +This exercises the nkipy.gather op (lowered to nisa.dma_copy_indirect). +""" +import numpy as np +from nkipy_kernelgen import trace, knob + +VOCAB = 1024 +EMBED = 512 +SEQ = 256 + +@trace(input_specs=[((VOCAB, EMBED), "f32"), ((SEQ,), "i32"), ((SEQ, EMBED), "f32")]) +def embedding_lookup(table, token_ids, bias): + + gathered = np.take(table, token_ids, axis=0) + knob.knob(gathered, mem_space="Sbuf", tile_size=[128, 128]) + + result = np.add(gathered, bias) + knob.knob(result, mem_space="SharedHbm", tile_size=[128, 128]) + + return result diff --git a/kernelgen/compiler_explorer/examples/feedforward.py b/kernelgen/compiler_explorer/examples/feedforward.py new file mode 100644 index 0000000..7d63edd --- /dev/null +++ b/kernelgen/compiler_explorer/examples/feedforward.py @@ -0,0 +1,61 @@ +""" +Example NKIPy kernel: Feedforward Network (SwiGLU) + +This kernel implements a feedforward layer with SwiGLU activation: +1. Gate+Up projection: x @ gate_up_weight -> split into gate and up +2. SwiGLU activation: SiLU(gate) * up +3. Down projection: result @ down_weight +""" +import numpy as np +from nkipy_kernelgen import trace, knob + +# Hardcoded dimensions +batch_size = 256 +hidden_size = 256 +intermediate_size = 256 + +# Tile sizes +matmul_tile = [128, 128, 128] # TILE_M, TILE_N, TILE_K +elementwise_tile = [128, 128] # TILE_M, TILE_N + +@trace(input_specs=[ + ((batch_size, hidden_size), "f32"), # x + ((hidden_size, 2 * intermediate_size), "f32"), # gate_up_weight + ((intermediate_size, hidden_size), "f32"), # down_weight +]) +def feedforward_kernel(x, gate_up_weight, down_weight): + """Feedforward network: Gate+Up projection -> SwiGLU -> Down projection""" + # Gate and Up projection + mm_gup = np.matmul(x, gate_up_weight) + knob.knob(mm_gup, mem_space="Sbuf", tile_size=matmul_tile) + + # Split into gate and up components + split_axis = mm_gup.ndim - 1 + gate, up = np.split(mm_gup, 2, axis=split_axis) + + # Apply SiLU activation to gate: sigmoid(gate) * gate + # Break down sigmoid into individual ops so each can be tiled + neg_gate = -gate + knob.knob(neg_gate, mem_space="Sbuf", tile_size=elementwise_tile) + + exp_neg_gate = np.exp(neg_gate) + knob.knob(exp_neg_gate, mem_space="Sbuf", tile_size=elementwise_tile) + + one_plus_exp = exp_neg_gate + 1.0 + knob.knob(one_plus_exp, mem_space="Sbuf", tile_size=elementwise_tile) + + sigmoid_gate = 1.0 / one_plus_exp + knob.knob(sigmoid_gate, mem_space="Sbuf", tile_size=elementwise_tile) + + swish_gate = gate * sigmoid_gate + knob.knob(swish_gate, mem_space="Sbuf", tile_size=elementwise_tile) + + # Element-wise multiplication (gating) + gated = swish_gate * up + knob.knob(gated, mem_space="Sbuf", tile_size=elementwise_tile) + + # Down projection + output = np.matmul(gated, down_weight) + knob.knob(output, mem_space="SharedHbm", tile_size=matmul_tile) + + return output diff --git a/kernelgen/compiler_explorer/examples/matmul_add.py b/kernelgen/compiler_explorer/examples/matmul_add.py new file mode 100644 index 0000000..2eae6f5 --- /dev/null +++ b/kernelgen/compiler_explorer/examples/matmul_add.py @@ -0,0 +1,22 @@ +""" +Example NKIPy kernel: Basic element-wise operations + +This kernel computes: result = (a + b) * c - d +""" +import numpy as np +from nkipy_kernelgen import trace, knob + +M, N, K = 256, 256, 256 +add_tile = [128, 128] # TILE_M, TILE_N + +@trace(input_specs=[((M, K), "f32"), ((K, N), "f32"), ((M, N), "f32")]) +def matmul_add_kernel(a, b, bias): + # Matmul outputs to SBUF for reuse in the add + c = np.matmul(a, b) + knob.knob(c, mem_space="Sbuf", tile_size=[128, 128], reduction_tile=[128]) + + # Add outputs to SharedHbm (returned from kernel) + result = c + bias + knob.knob(result, mem_space="SharedHbm", tile_size=add_tile) + + return result \ No newline at end of file diff --git a/kernelgen/compiler_explorer/examples/qwen3_layer.py b/kernelgen/compiler_explorer/examples/qwen3_layer.py new file mode 100644 index 0000000..8601e20 --- /dev/null +++ b/kernelgen/compiler_explorer/examples/qwen3_layer.py @@ -0,0 +1,306 @@ +""" +Qwen3 Transformer Decoder Layer (inlined for readability). + +All sub-kernels (RMSNorm, RoPE, softmax, SiLU) are inlined so the full +data flow is visible in one place. + +Shape convention: + - 2D projections use (BS, hidden_size) where BS = batch * seq_len. + This flattens batch and sequence into one "token" dimension so the + matmuls are plain 2D. The reshape to (batch, seq_len, n_heads, head_dim) + recovers the sequence dimension when needed for multi-head attention. + - 3D attention tensors are (BH, seq_len, X) where BH = batch * n_heads. + The partition dimension for these is dim 1 (seq_len), NOT dim 0 (BH). +""" +import numpy as np +from nkipy_kernelgen import trace, knob + +# ---------------------------------------------------------------- +# Model hyperparameters +# ---------------------------------------------------------------- +batch = 2 +seq_len = 128 +hidden_size = 256 +n_heads = 2 +head_dim = hidden_size // n_heads # 128 +intermediate_size = 256 +half_dim = head_dim // 2 # 64 +eps = 1e-6 +scale = 1.0 / np.sqrt(head_dim).item() + +# Derived (flattened dimensions) +BS = batch * seq_len # 256 (tokens = batch * seq_len) +BH = batch * n_heads # 4 (heads = batch * n_heads) + +# ---------------------------------------------------------------- +# Tile sizes +# ---------------------------------------------------------------- +matmul_tile_2d = [128, 128] +matmul_reduction_2d = [128] +attn_tile = [1, 128, 128] # (BH, seq_len, seq_len/head_dim) +attn_reduction = [128] +rope_tile = [1, 128, 64] # (BH, seq_len, half_dim) +elem_tile_2d = [128, 128] + + +@trace(input_specs=[ + # hidden_states: (BS, hidden_size) = (256, 256) + # BS = batch * seq_len, flattened for 2D matmul projections. + # Reshape to (batch, seq_len, n_heads, head_dim) recovers seq_len + # for multi-head attention. + ((BS, hidden_size), "f32"), + # RMSNorm weights — (hidden_size, 1) so broadcast is over the free dim + # ([P, 1] pattern), which maps to nisa.tensor_scalar_arith. + ((hidden_size, 1), "f32"), # ln1_weight + ((hidden_size, 1), "f32"), # ln2_weight + # Attention projection weights + ((hidden_size, hidden_size), "f32"), # w_q + ((hidden_size, hidden_size), "f32"), # w_k + ((hidden_size, hidden_size), "f32"), # w_v + ((hidden_size, hidden_size), "f32"), # w_o + # RoPE frequencies (position-dependent, broadcast over BH) + ((1, seq_len, half_dim), "f32"), # freqs_cos + ((1, seq_len, half_dim), "f32"), # freqs_sin + # FFN weights + ((hidden_size, intermediate_size), "f32"), # w_gate + ((hidden_size, intermediate_size), "f32"), # w_up + ((intermediate_size, hidden_size), "f32"), # w_down +]) +def qwen3_layer(hidden_states, + ln1_weight, ln2_weight, + w_q, w_k, w_v, w_o, + freqs_cos, freqs_sin, + w_gate, w_up, w_down): + + residual = hidden_states # (BS, hidden_size) + + # ================================================================ + # 1. Pre-attention RMSNorm + # norm(x) = x / sqrt(mean(x^2) + eps) * weight + # ================================================================ + x_fp32 = hidden_states.astype(np.float32) + w_fp32 = ln1_weight.astype(np.float32) + + sq = np.square(x_fp32) # (256, 256) + knob.knob(sq, mem_space="Sbuf", tile_size=elem_tile_2d) + + sum_sq = np.sum(sq, axis=-1, keepdims=True) # (256, 1) + knob.knob(sum_sq, mem_space="Sbuf", tile_size=[128], reduction_tile=[128]) + + mean_sq = sum_sq * np.float32(1.0 / hidden_size) # (256, 1) + knob.knob(mean_sq, mem_space="Sbuf", tile_size=[128, 1]) + + normed = x_fp32 / np.sqrt(mean_sq + eps) # (256, 256) + knob.knob(normed, mem_space="Sbuf", tile_size=elem_tile_2d) + + normed = normed * w_fp32 # (256, 256) + knob.knob(normed, mem_space="Sbuf", tile_size=elem_tile_2d) + + # ================================================================ + # 2. QKV projections (2D matmuls on flattened BS dimension) + # SharedHbm = sub-kernel boundary (results flow through reshape) + # ================================================================ + q = np.matmul(normed, w_q) # (256, 256) + knob.knob(q, mem_space="SharedHbm", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + k = np.matmul(normed, w_k) # (256, 256) + knob.knob(k, mem_space="SharedHbm", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + v = np.matmul(normed, w_v) # (256, 256) + knob.knob(v, mem_space="SharedHbm", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + # ================================================================ + # 3. Reshape to multi-head format + # (BS, hidden) -> (batch, seq_len, n_heads, head_dim) + # -> transpose to (batch, n_heads, seq_len, head_dim) + # -> flatten to (BH, seq_len, head_dim) + # ================================================================ + q = np.reshape(q, (batch, seq_len, n_heads, head_dim)) # (2, 128, 2, 128) + q = np.transpose(q, (0, 2, 1, 3)) # (2, 2, 128, 128) + q = np.reshape(q, (BH, seq_len, head_dim)) # (4, 128, 128) + + k = np.reshape(k, (batch, seq_len, n_heads, head_dim)) + k = np.transpose(k, (0, 2, 1, 3)) + k = np.reshape(k, (BH, seq_len, head_dim)) # (4, 128, 128) + + v = np.reshape(v, (batch, seq_len, n_heads, head_dim)) + v = np.transpose(v, (0, 2, 1, 3)) + v = np.reshape(v, (BH, seq_len, head_dim)) # (4, 128, 128) + # V is a sub-kernel boundary: keep in SharedHbm so the 4D transpose + # intermediate stays in HBM (avoids a 4D SBUF alloc that legalize-layout + # cannot tile). + knob.knob(v, mem_space="SharedHbm", tile_size=attn_tile) + + # ================================================================ + # 4. RoPE on Q and K (not V) + # Split head_dim in half, rotate: [x0, x1] -> [x0*cos - x1*sin, + # x0*sin + x1*cos] + # freqs_cos/sin are (1, 128, 64) — broadcast over BH dim + # ================================================================ + # --- RoPE on Q --- + q0 = q[:, :, :half_dim] # (4, 128, 64) + q1 = q[:, :, half_dim:] # (4, 128, 64) + + q0_cos = q0 * freqs_cos # (4, 128, 64) + knob.knob(q0_cos, mem_space="SharedHbm", tile_size=rope_tile) + q1_sin = q1 * freqs_sin # (4, 128, 64) + knob.knob(q1_sin, mem_space="SharedHbm", tile_size=rope_tile) + q_rot0 = q0_cos - q1_sin # (4, 128, 64) + knob.knob(q_rot0, mem_space="SharedHbm", tile_size=rope_tile) + + q0_sin = q0 * freqs_sin # (4, 128, 64) + knob.knob(q0_sin, mem_space="SharedHbm", tile_size=rope_tile) + q1_cos = q1 * freqs_cos # (4, 128, 64) + knob.knob(q1_cos, mem_space="SharedHbm", tile_size=rope_tile) + q_rot1 = q0_sin + q1_cos # (4, 128, 64) + knob.knob(q_rot1, mem_space="SharedHbm", tile_size=rope_tile) + + q = np.concatenate([q_rot0, q_rot1], axis=-1) # (4, 128, 128) + knob.knob(q, mem_space="SharedHbm", tile_size=attn_tile) + + # --- RoPE on K --- + k0 = k[:, :, :half_dim] # (4, 128, 64) + k1 = k[:, :, half_dim:] # (4, 128, 64) + + k0_cos = k0 * freqs_cos # (4, 128, 64) + knob.knob(k0_cos, mem_space="SharedHbm", tile_size=rope_tile) + k1_sin = k1 * freqs_sin # (4, 128, 64) + knob.knob(k1_sin, mem_space="SharedHbm", tile_size=rope_tile) + k_rot0 = k0_cos - k1_sin # (4, 128, 64) + knob.knob(k_rot0, mem_space="SharedHbm", tile_size=rope_tile) + + k0_sin = k0 * freqs_sin # (4, 128, 64) + knob.knob(k0_sin, mem_space="SharedHbm", tile_size=rope_tile) + k1_cos = k1 * freqs_cos # (4, 128, 64) + knob.knob(k1_cos, mem_space="SharedHbm", tile_size=rope_tile) + k_rot1 = k0_sin + k1_cos # (4, 128, 64) + knob.knob(k_rot1, mem_space="SharedHbm", tile_size=rope_tile) + + k = np.concatenate([k_rot0, k_rot1], axis=-1) # (4, 128, 128) + knob.knob(k, mem_space="SharedHbm", tile_size=attn_tile) + + # K^T for attention scores + k_t = np.transpose(k, (0, 2, 1)) # (4, 128, 128) + knob.knob(k_t, mem_space="SharedHbm", tile_size=attn_tile) + + # ================================================================ + # 5. Scaled dot-product attention + # scores = (Q @ K^T) * scale + # weights = softmax(scores) + # context = weights @ V + # ================================================================ + scores = np.matmul(q, k_t) # (4, 128, 128) + knob.knob(scores, mem_space="Sbuf", tile_size=attn_tile, reduction_tile=attn_reduction) + + scores = scores * scale # (4, 128, 128) + knob.knob(scores, mem_space="Sbuf", tile_size=attn_tile, partition_dim=1) + + # --- softmax (numerically stable) --- + scores_fp32 = scores.astype(np.float32) + + s_max = np.max(scores_fp32, axis=-1, keepdims=True) # (4, 128, 1) + knob.knob(s_max, mem_space="SharedHbm", tile_size=[1, 128], + reduction_tile=[128], partition_dim=1) + + shifted = scores_fp32 - s_max # (4, 128, 128) + knob.knob(shifted, mem_space="Sbuf", tile_size=attn_tile, partition_dim=1) + + exp_s = np.exp(shifted) # (4, 128, 128) + knob.knob(exp_s, mem_space="Sbuf", tile_size=attn_tile, partition_dim=1) + + sum_exp = np.sum(exp_s, axis=-1, keepdims=True) # (4, 128, 1) + knob.knob(sum_exp, mem_space="SharedHbm", tile_size=[1, 128], + reduction_tile=[128], partition_dim=1) + + attn_weights = exp_s / sum_exp # (4, 128, 128) + knob.knob(attn_weights, mem_space="SharedHbm", tile_size=attn_tile) + + # --- context = attn_weights @ V --- + context = np.matmul(attn_weights, v) # (4, 128, 128) + knob.knob(context, mem_space="SharedHbm", tile_size=attn_tile, reduction_tile=attn_reduction) + + # ================================================================ + # 6. Concat heads + output projection + # (BH, seq_len, head_dim) -> (batch, n_heads, seq_len, head_dim) + # -> transpose to (batch, seq_len, n_heads, head_dim) + # -> flatten to (BS, hidden_size) + # ================================================================ + context = np.reshape(context, (batch, n_heads, seq_len, head_dim)) + context = np.transpose(context, (0, 2, 1, 3)) + context = np.reshape(context, (BS, hidden_size)) # (256, 256) + knob.knob(context, mem_space="SharedHbm", tile_size=matmul_tile_2d) + + attn_out = np.matmul(context, w_o) # (256, 256) + knob.knob(attn_out, mem_space="Sbuf", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + # ================================================================ + # 7. First residual connection + # ================================================================ + hidden_states = residual + attn_out # (256, 256) + knob.knob(hidden_states, mem_space="Sbuf", tile_size=elem_tile_2d) + + residual = hidden_states + + # ================================================================ + # 8. Post-attention RMSNorm + # ================================================================ + x_fp32 = hidden_states.astype(np.float32) + w_fp32 = ln2_weight.astype(np.float32) + + sq = np.square(x_fp32) # (256, 256) + knob.knob(sq, mem_space="Sbuf", tile_size=elem_tile_2d) + + sum_sq = np.sum(sq, axis=-1, keepdims=True) # (256, 1) + knob.knob(sum_sq, mem_space="Sbuf", tile_size=[128], reduction_tile=[128]) + + mean_sq = sum_sq * np.float32(1.0 / hidden_size) # (256, 1) + knob.knob(mean_sq, mem_space="Sbuf", tile_size=[128, 1]) + + normed = x_fp32 / np.sqrt(mean_sq + eps) # (256, 256) + knob.knob(normed, mem_space="Sbuf", tile_size=elem_tile_2d) + + normed = normed * w_fp32 # (256, 256) + knob.knob(normed, mem_space="Sbuf", tile_size=elem_tile_2d) + + # ================================================================ + # 9. SwiGLU FFN + # gate = SiLU(normed @ w_gate) + # up = normed @ w_up + # out = (gate * up) @ w_down + # ================================================================ + gate = np.matmul(normed, w_gate) # (256, 256) + knob.knob(gate, mem_space="Sbuf", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + up = np.matmul(normed, w_up) # (256, 256) + knob.knob(up, mem_space="Sbuf", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + # --- SiLU(gate) = gate * sigmoid(gate) --- + neg_gate = -gate # (256, 256) + knob.knob(neg_gate, mem_space="Sbuf", tile_size=elem_tile_2d) + + exp_neg = np.exp(neg_gate) # (256, 256) + knob.knob(exp_neg, mem_space="Sbuf", tile_size=elem_tile_2d) + + one_plus = exp_neg + 1.0 # (256, 256) + knob.knob(one_plus, mem_space="Sbuf", tile_size=elem_tile_2d) + + sigmoid = 1.0 / one_plus # (256, 256) + knob.knob(sigmoid, mem_space="Sbuf", tile_size=elem_tile_2d) + + gate = gate * sigmoid # (256, 256) + knob.knob(gate, mem_space="Sbuf", tile_size=elem_tile_2d) + + # --- gated output --- + gated = gate * up # (256, 256) + knob.knob(gated, mem_space="Sbuf", tile_size=elem_tile_2d) + + ffn_out = np.matmul(gated, w_down) # (256, 256) + knob.knob(ffn_out, mem_space="Sbuf", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + # ================================================================ + # 10. Second residual connection + # ================================================================ + output = residual + ffn_out # (256, 256) + knob.knob(output, mem_space="SharedHbm", tile_size=elem_tile_2d) + + return output # (256, 256) diff --git a/kernelgen/compiler_explorer/examples/reduce_sum.py b/kernelgen/compiler_explorer/examples/reduce_sum.py new file mode 100644 index 0000000..1aea655 --- /dev/null +++ b/kernelgen/compiler_explorer/examples/reduce_sum.py @@ -0,0 +1,30 @@ +from nkipy_kernelgen import trace, knob +import numpy as np + + +M, N = 256, 256 +tile_size = [128, 128] +reduction_tile_size = [128] + + +@trace(input_specs=[((M, N), "f32")]) +def kernel(x): + sq = np.square(x.astype(np.float32)) + knob.knob(sq, mem_space="Sbuf", tile_size=tile_size) + + # np.mean(sq, axis=-1) == np.sum(sq, axis=-1) * (1/N) + sm = np.sum(sq, axis=-1, keepdims=True) + knob.knob( + sm, + mem_space="SharedHbm", + tile_size=[128], + reduction_tile=[128], + ) + + result = sm * np.float32(1.0 / N) + knob.knob( + result, + mem_space="SharedHbm", + tile_size=[128, 1], + ) + return result diff --git a/kernelgen/compiler_explorer/examples/reshape.py b/kernelgen/compiler_explorer/examples/reshape.py new file mode 100644 index 0000000..15cbd2a --- /dev/null +++ b/kernelgen/compiler_explorer/examples/reshape.py @@ -0,0 +1,42 @@ +""" +Example NKIPy kernel: Reshape operations (contiguous dim merge/split). + +Uncomment one (input_shape, output_shape) pair at a time to try different cases. +""" +import numpy as np +from nkipy_kernelgen import trace + +# PASS +# -- Merge dims: (2, 128, 256) -> (256, 256) -- +input_shape = (2, 128, 256) +output_shape = (256, 256) + +# PASS +# -- Split dim: (256, 256) -> (2, 128, 256) -- +input_shape = (256, 256) +output_shape = (2, 128, 256) + +# PASS +# -- Insert unit dim: (128, 256) -> (128, 1, 256) -- +input_shape = (128, 256) +output_shape = (128, 1, 256) + +# PASS +# -- Remove unit dim (squeeze): (128, 1, 256) -> (128, 256) -- +input_shape = (128, 1, 256) +output_shape = (128, 256) + +# PASS +# -- Infer dim with -1: (2, 128, 256) -> (-1, 256) -- +input_shape = (2, 128, 256) +output_shape = (-1, 256) + +# PASS +# -- Identity reshape (no-op): (128, 256) -> (128, 256) -- +input_shape = (128, 256) +output_shape = (128, 256) + + +@trace(input_specs=[(input_shape, "f32")]) +def kernel(x): + return np.reshape(x, output_shape) diff --git a/kernelgen/compiler_explorer/examples/rmsnorm.py b/kernelgen/compiler_explorer/examples/rmsnorm.py new file mode 100644 index 0000000..a7e67fc --- /dev/null +++ b/kernelgen/compiler_explorer/examples/rmsnorm.py @@ -0,0 +1,56 @@ +""" +Example NKIPy kernel: RMSNorm + +RMSNorm: output = (x / sqrt(mean(x^2) + eps)) * weight + +This exercises: +1. Element-wise square (multiply) +2. Mean reduction over last axis +3. Addition with scalar epsilon +4. Reciprocal square root (rsqrt) +5. Element-wise multiply with weight +""" +import numpy as np +from nkipy_kernelgen import trace, knob + +# Hardcoded dimensions +M = 256 +N = 256 + +# Tile sizes +tile_size = [128, 128] # TILE_M, TILE_N + +# RMSNorm epsilon +eps = 1e-6 + +@trace(input_specs=[((M, N), "f32"), ((N, 1), "f32")]) +def rmsnorm_kernel(x, weight): + """RMSNorm: x / sqrt(mean(x^2) + eps) * weight""" + x_fp32 = x.astype(np.float32) + w_fp32 = weight.astype(np.float32) + + sq = np.square(x_fp32) + knob.knob(sq, mem_space="Sbuf", tile_size=tile_size) + + sum_sq = np.sum(sq, axis=-1, keepdims=True) + knob.knob( + sum_sq, + mem_space="Sbuf", + tile_size=[128], + reduction_tile=[128], + ) + + mean_sq = sum_sq * np.float32(1.0 / N) + knob.knob( + mean_sq, + mem_space="Sbuf", + tile_size=[128, 1], + ) + + normed = x_fp32 / np.sqrt(mean_sq + eps) + knob.knob(normed, mem_space="Sbuf", tile_size=tile_size) + + result = normed * w_fp32 + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size) + + return result diff --git a/kernelgen/compiler_explorer/examples/rope.py b/kernelgen/compiler_explorer/examples/rope.py new file mode 100644 index 0000000..7e4d527 --- /dev/null +++ b/kernelgen/compiler_explorer/examples/rope.py @@ -0,0 +1,41 @@ +import numpy as np +from nkipy_kernelgen import trace, knob + +# Hardcoded dimensions +batch = 2 +seq_len = 128 +n_heads = 4 +head_dim = 128 +half_h = head_dim // 2 +bs = batch * seq_len +tile_size = [128, 1, 64] + +@trace(input_specs=[ + ((bs, n_heads, head_dim), "f32"), + ((bs, half_h), "f32"), + ((bs, half_h), "f32"), +]) +def rope_kernel(x, freqs_cos, freqs_sin): + # Broadcast cos/sin to (bs, 1, half_h) + cos = np.expand_dims(freqs_cos, axis=1) + sin = np.expand_dims(freqs_sin, axis=1) + + knob.knob(cos, mem_space="Sbuf") + knob.knob(sin, mem_space="Sbuf") + + # Split input into two halves along head_dim + x0 = x[:, :, :half_h] + x1 = x[:, :, half_h:] + + # Apply rotation + out_0 = x0 * cos - x1 * sin + knob.knob(out_0, mem_space="Sbuf", tile_size=tile_size) + + out_1 = x0 * sin + x1 * cos + knob.knob(out_1, mem_space="Sbuf", tile_size=tile_size) + + # Concatenate back along head_dim axis + result = np.concatenate([out_0, out_1], axis=-1) + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size) + + return result diff --git a/kernelgen/compiler_explorer/examples/softmax.py b/kernelgen/compiler_explorer/examples/softmax.py new file mode 100644 index 0000000..e693426 --- /dev/null +++ b/kernelgen/compiler_explorer/examples/softmax.py @@ -0,0 +1,27 @@ +import numpy as np +from nkipy_kernelgen import trace, knob + +# Hardcoded dimensions +M = 256 +N = 256 +tile_size = [128, 128] + +@trace(input_specs=[((M, N), "f32")]) +def softmax_kernel(x): + x_fp32 = x.astype(np.float32) + + x_max = np.max(x_fp32, axis=-1, keepdims=True) + knob.knob(x_max, mem_space="Sbuf", tile_size=[128], reduction_tile=[128]) + + shifted = x_fp32 - x_max + knob.knob(shifted, mem_space="Sbuf", tile_size=tile_size) + + exp_x = np.exp(shifted) + knob.knob(exp_x, mem_space="Sbuf", tile_size=tile_size) + + sum_exp = np.sum(exp_x, axis=-1, keepdims=True) + knob.knob(sum_exp, mem_space="Sbuf", tile_size=[128], reduction_tile=[128]) + + result = exp_x / sum_exp + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size) + return result \ No newline at end of file diff --git a/kernelgen/compiler_explorer/examples/transpose.py b/kernelgen/compiler_explorer/examples/transpose.py new file mode 100644 index 0000000..6633ebd --- /dev/null +++ b/kernelgen/compiler_explorer/examples/transpose.py @@ -0,0 +1,42 @@ +""" +Example NKIPy kernel: Transpose operations. + +Uncomment one (input_shape, axes) pair at a time to try different cases. +""" +import numpy as np +from nkipy_kernelgen import trace +from nkipy_kernelgen.knob import knob + +# PASS +# -- 2D transpose: (128, 256) with axes [1, 0] -- +# Output shape: (256, 128), tile_size: [128, 128] +input_shape = (128, 256) +axes = [1, 0] +tile_size = [128, 128] + +# PASS +# -- 3D transpose (swap last two dims): (2, 128, 256) with axes [0, 2, 1] -- +# Output shape: (2, 256, 128), tile_size: [1, 128, 128] (batch dim tiled to 1) +input_shape = (2, 128, 256) +axes = [0, 2, 1] +tile_size = [1, 128, 128] + +# -- 3D transpose (rotate dims): (2, 128, 256) with axes [1, 2, 0] -- +# Output shape: (128, 256, 2), tile_size: [128, 128, 1] +# (Non-unit dims keep order → effectively a reshape, emits copy not transpose) +# input_shape = (2, 128, 256) +# axes = [1, 2, 0] +# tile_size = [128, 128, 1] + +# PASS +# -- 3D transpose (reverse dims): (2, 128, 256) with axes [2, 1, 0] -- +# Output shape: (256, 128, 2), tile_size: [128, 128, 1] +input_shape = (2, 128, 256) +axes = [2, 1, 0] +tile_size = [128, 128, 1] + + +@trace(input_specs=[(input_shape, "f32")]) +def kernel(x): + result = np.transpose(x, axes) + return knob(result, tile_size=tile_size) diff --git a/kernelgen/compiler_explorer/nkipy_ce_wrapper.sh b/kernelgen/compiler_explorer/nkipy_ce_wrapper.sh new file mode 100755 index 0000000..7faf688 --- /dev/null +++ b/kernelgen/compiler_explorer/nkipy_ce_wrapper.sh @@ -0,0 +1,15 @@ +#!/bin/bash +# Wrapper script for Compiler Explorer to invoke nkipy_compiler.py +# +# Compiler Explorer passes: +# $1 = input file path +# Additional args passed via compiler options in CE UI + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +NKIPY_ROOT="$(dirname "$SCRIPT_DIR")" + +# Set PYTHONPATH to include nkipy_kernelgen +export PYTHONPATH="$NKIPY_ROOT:$PYTHONPATH" + +# Run the compiler +exec python3 "$SCRIPT_DIR/nkipy_compiler.py" "$@" diff --git a/kernelgen/compiler_explorer/nkipy_compiler.py b/kernelgen/compiler_explorer/nkipy_compiler.py new file mode 100755 index 0000000..2b2ac1a --- /dev/null +++ b/kernelgen/compiler_explorer/nkipy_compiler.py @@ -0,0 +1,620 @@ +#!/usr/bin/env python3 +""" +NKIPy Compiler wrapper for Compiler Explorer. + +This script takes Python source code with @trace decorated functions, +traces them to MLIR, and runs the full compilation pipeline. + +Usage: + python nkipy_compiler.py [options] + +Options: + --stop=N Stop after pass N (0 = trace only, 1-24 = after that pass) + --sim Run simulation (BIR sim for full pipeline, LLVM JIT with --stop) + --hw Compile to NEFF and execute on Trainium hardware + --target= Target hardware: trn1, trn2, trn3 (default: trn2) +""" + +import sys +import os +import argparse +import tempfile +import re +import time + +# Add parent directory to path for imports +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from nkipy_kernelgen import trace, apply_passes +from nkipy_kernelgen.transforms.nkipy_opt import apply_complete_knob_pipeline + + +def add_loc_comments(mlir_text: str) -> str: + """ + Convert MLIR location attributes to .loc comments for Compiler Explorer. + + MLIR locations look like: + - Inline: `linalg.add ... loc("file.py":6:5)` + - Aliased: `#loc1 = loc("file.py":6:5)` at bottom, referenced as `loc(#loc1)` + - Unknown: `loc(unknown)` or `#loc1 = loc(unknown)` + + CE expects: + - `.file 1 "file.py"` at the top + - `.loc 1 6 5` comment before each operation + + Args: + mlir_text: MLIR text with location attributes (from --mlir-print-debuginfo) + + Returns: + MLIR text with .loc comments inserted + """ + # Parse location aliases from the bottom of the file + # Format: #loc1 = loc("file.py":line:col) + loc_alias_pattern = re.compile(r'^(#loc\d*)\s*=\s*loc\("([^"]+)":(\d+):(\d+)\)', re.MULTILINE) + # Also match unknown locations: #loc1 = loc(unknown) + unknown_loc_pattern = re.compile(r'^(#loc\d*)\s*=\s*loc\(unknown\)', re.MULTILINE) + + loc_aliases = {} + unknown_aliases = set() + + for match in loc_alias_pattern.finditer(mlir_text): + alias, filename, line, col = match.groups() + loc_aliases[alias] = (filename, int(line), int(col)) + + for match in unknown_loc_pattern.finditer(mlir_text): + unknown_aliases.add(match.group(1)) + + # Build file number mapping + files = {} + file_counter = 1 + for alias, (filename, _, _) in loc_aliases.items(): + if filename not in files: + files[filename] = file_counter + file_counter += 1 + + # Also scan for inline locations to build file mapping + inline_loc_pattern = re.compile(r'loc\("([^"]+)":(\d+):(\d+)\)') + for match in inline_loc_pattern.finditer(mlir_text): + filename = match.group(1) + if filename not in files: + files[filename] = file_counter + file_counter += 1 + + if not files: + # No locations found, return unchanged but strip loc attributes + lines = mlir_text.split('\n') + result = [] + for line in lines: + # Skip unknown location alias lines + if unknown_loc_pattern.match(line.strip()): + continue + # Remove loc(unknown) and loc(#locN) references + clean = re.sub(r'\s*loc\(unknown\)', '', line) + clean = re.sub(r'\s*loc\(#loc\d*\)', '', clean) + result.append(clean) + return '\n'.join(result) + + # Build .file directives + file_directives = [] + for filename, num in sorted(files.items(), key=lambda x: x[1]): + file_directives.append(f'.file {num} "{filename}"') + + # Process each line and insert .loc comments + lines = mlir_text.split('\n') + result_lines = [] + + # Add file directives at the top (after any initial comments) + file_header_added = False + + for line in lines: + # Skip location alias definitions at the bottom (both known and unknown) + if loc_alias_pattern.match(line.strip()) or unknown_loc_pattern.match(line.strip()): + continue + + # Add file header before first non-empty, non-comment line + if not file_header_added and line.strip() and not line.strip().startswith('//'): + result_lines.extend(file_directives) + result_lines.append('') # blank line after file directives + file_header_added = True + + # Check for location on this line + loc_info = None + + # Try aliased location first: loc(#loc1) + alias_ref_match = re.search(r'loc\((#loc\d*)\)', line) + if alias_ref_match: + alias = alias_ref_match.group(1) + if alias in loc_aliases: + loc_info = loc_aliases[alias] + + # Try inline location: loc("file":line:col) + if not loc_info: + inline_match = re.search(r'loc\("([^"]+)":(\d+):(\d+)\)', line) + if inline_match: + filename, line_num, col = inline_match.groups() + loc_info = (filename, int(line_num), int(col)) + + # Remove all loc(...) from the line for cleaner output + clean_line = re.sub(r'\s*loc\([^)]+\)', '', line) + + # Insert .loc comment before the line if we found location info + if loc_info: + filename, line_num, col = loc_info + file_num = files.get(filename, 1) + result_lines.append(f'.loc {file_num} {line_num} {col}') + result_lines.append(clean_line) + else: + # Still use clean_line to remove any remaining loc(...) references + result_lines.append(clean_line) + + return '\n'.join(result_lines) + + +def find_traced_function(module): + """Find the first @trace decorated function in a module.""" + for name in dir(module): + obj = getattr(module, name) + if hasattr(obj, 'to_mlir') and callable(getattr(obj, 'to_mlir')): + print(f"DEBUG: Found traced function: {name}", file=sys.stderr) + return obj + print(f"DEBUG: No traced function found. Checking for callable objects...", file=sys.stderr) + for name in dir(module): + obj = getattr(module, name) + if callable(obj) and not name.startswith('_'): + print(f"DEBUG: Callable '{name}' has attrs: {[a for a in dir(obj) if not a.startswith('_')][:10]}", file=sys.stderr) + return None + + +def load_source_file(filepath): + """Dynamically load a Python source file as a module.""" + # Read the source code + with open(filepath, 'r') as f: + source_code = f.read() + + # Debug: show what we're compiling + print(f"DEBUG: Compiling file: {filepath}", file=sys.stderr) + print(f"DEBUG: Source code ({len(source_code)} chars):", file=sys.stderr) + print(source_code[:500], file=sys.stderr) + if len(source_code) > 500: + print("...(truncated)", file=sys.stderr) + + # Create a module and execute the code in it + import types + module = types.ModuleType("user_kernel") + module.__file__ = filepath + + # Compile with filepath so inspect.getsourcefile works for source locations + code = compile(source_code, filepath, 'exec') + exec(code, module.__dict__) + + # Debug: show what's in the module + print(f"DEBUG: Module contents: {[n for n in dir(module) if not n.startswith('_')]}", file=sys.stderr) + + return module + + +def compile_to_mlir(source_path, stop_after=None, target="trn2", raw=False): + """ + Compile Python source to MLIR. + + Args: + source_path: Path to Python file with @trace decorated function + stop_after: Pass number to stop after (0 = trace only, 1-19 = after that pass, None = all) + target: Target hardware + raw: If True, omit source location debug info for cleaner CLI output + + Returns: + Tuple of (ce_output, raw_mlir, traced_func): + - ce_output: MLIR with .loc comments for Compiler Explorer + - raw_mlir: Valid MLIR string (usable for simulation) + - traced_func: The traced function object + """ + # Load and execute the source file + module = load_source_file(source_path) + + # Find the traced function + traced_func = find_traced_function(module) + if traced_func is None: + print("Error: No @trace decorated function found in input file", file=sys.stderr) + sys.exit(1) + + # Step 1: Trace to MLIR + mlir_module = traced_func.to_mlir() + + # Include debug info for CE source-line mapping; skip for --raw CLI output + include_debuginfo = not raw + current_mlir = mlir_module.operation.get_asm(enable_debug_info=include_debuginfo) + + if stop_after == 0: + return add_loc_comments(current_mlir), current_mlir, traced_func + + # Run the compilation pipeline (optionally stopping after a specific pass) + result = apply_complete_knob_pipeline( + current_mlir, target=target, stop_after=stop_after, print_debuginfo=include_debuginfo + ) + return add_loc_comments(result), result, traced_func + + +DTYPE_MAP = { + "f32": "float32", + "f16": "float16", + "bf16": "bfloat16", + "f64": "float64", + "i32": "int32", + "i64": "int64", +} + + +def run_simulation(raw_mlir, traced_func): + """Run BIR simulation on compiled MLIR and verify against NumPy reference.""" + import numpy as np + + tests_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "tests" + ) + if tests_path not in sys.path: + sys.path.insert(0, tests_path) + from harness import simulate_mlir + + # Generate test inputs from input_specs + np.random.seed(42) + test_inputs = [] + for shape, dtype_str in traced_func.input_specs: + np_dtype = DTYPE_MAP.get(dtype_str, dtype_str) + test_inputs.append(np.random.uniform(-1, 1, size=shape).astype(np_dtype)) + + # Compute NumPy reference using the original unwrapped function + expected = traced_func.__wrapped__(*test_inputs) + print(f"NumPy reference: shape={expected.shape}, dtype={expected.dtype}", file=sys.stderr) + + func_name = traced_func.__wrapped__.__name__ + + # Run BIR simulation and compare. + # Redirect stdout to stderr so harness diagnostics appear in the CE log + # instead of polluting the MLIR output pane. + import io + captured = io.StringIO() + old_stdout = sys.stdout + try: + sys.stdout = captured + success, max_diff, artifacts = simulate_mlir( + mlir_str=raw_mlir, + func_name=func_name, + test_inputs=test_inputs, + expected_output=expected, + rtol=1e-3, + atol=1e-3, + verbose=True, + keep_artifacts=True, + ) + except (RuntimeError, AssertionError) as e: + sys.stdout = old_stdout + harness_output = captured.getvalue() + if harness_output: + print(harness_output, file=sys.stderr) + print(f"SIMULATION FAILED: {e}", file=sys.stderr) + return False + finally: + sys.stdout = old_stdout + + # Print harness output (including any error details) to stderr + harness_output = captured.getvalue() + if harness_output: + print(harness_output, file=sys.stderr) + + if success: + print(f"SIMULATION PASSED (max_diff={max_diff:.2e})") + else: + print(f"SIMULATION FAILED (max_diff={max_diff:.2e})") + if artifacts: + print(f"Artifacts: {artifacts}") + # Dump neuronx-cc log so the actual error is visible + ncc_log = os.path.join(artifacts, "log-neuron-cc.txt") + if os.path.exists(ncc_log): + with open(ncc_log, 'r') as f: + log_lines = f.readlines() + # Print lines containing errors/warnings, plus last 20 lines for context + error_lines = [l.rstrip() for l in log_lines + if any(k in l for k in ('ERROR', 'error', 'FAIL', 'fail', 'Exception'))] + if error_lines: + print("neuronx-cc errors:") + for line in error_lines[:30]: + print(f" {line}") + tail = [l.rstrip() for l in log_lines[-20:]] + print(f"neuronx-cc log (last 20 lines):") + for line in tail: + print(f" {line}") + + return success + + +def run_llvm_simulation(raw_mlir, traced_func): + """Run LLVM JIT simulation on intermediate MLIR and verify against NumPy reference.""" + tests_passes_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + "tests", "passes" + ) + if tests_passes_path not in sys.path: + sys.path.insert(0, tests_passes_path) + from pass_utils import verify_tiled_mlir_with_numpy + + func_name = traced_func.__wrapped__.__name__ + verify_tiled_mlir_with_numpy( + raw_mlir, traced_func, rtol=1e-4, atol=1e-4, func_name=func_name + ) + + +def run_hw_execution(raw_mlir, traced_func, target): + """Compile MLIR to NEFF and execute on Trainium hardware, verify against NumPy reference.""" + import numpy as np + import tempfile + + from nki.compiler.ncc_driver import CompileOptions, compile_mlir_to_neff + from nki.compiler._internal import ir, register_all_dialects + from nki.runtime import SpikeModel, SpikeTensor + + func_name = traced_func.__wrapped__.__name__ + + # Generate test inputs from input_specs + np.random.seed(42) + test_inputs = [] + for shape, dtype_str in traced_func.input_specs: + np_dtype = DTYPE_MAP.get(dtype_str, dtype_str) + test_inputs.append(np.random.uniform(-1, 1, size=shape).astype(np_dtype)) + + # Compute NumPy reference + expected = traced_func.__wrapped__(*test_inputs) + print(f"NumPy reference: shape={expected.shape}, dtype={expected.dtype}", file=sys.stderr) + + # Compile MLIR to NEFF + debug_dir = tempfile.mkdtemp(prefix="hw_exec_") + opts = CompileOptions( + target=target, + verbose=False, + output_path=os.path.join(debug_dir, "kernel.neff"), + neuronx_cc_args=("--lnc=1",), + artifacts_dir=debug_dir, + enable_simulation=False, + ) + + ctx = ir.Context() + register_all_dialects(ctx) + with ctx: + mlir = ir.Module.parse(raw_mlir, ctx) + + input_names = [f"in_tensor_{i}" for i in range(len(test_inputs))] + output_name = "out_tensor" + output_placeholder = np.zeros_like(expected) + + all_arrays = list(test_inputs) + [output_placeholder] + argument_names = input_names + [output_name] + output_arg_names = [output_name] + + compile_result = compile_mlir_to_neff( + mlir, + func_name, + all_arrays, + argument_names, + output_arg_names, + opts, + ) + + neff_path = compile_result.output_path + print(f"NEFF compiled: {neff_path}", file=sys.stderr) + + # Load and execute on hardware + model = SpikeModel.load_from_neff(neff_path) + + # Use the model's actual tensor names (NEFF compiler may rename them, e.g. _0 suffix) + neff_input_names = list(model.input_tensors_info.keys()) + neff_output_names = list(model.output_tensors_info.keys()) + print(f"NEFF inputs: {neff_input_names}, outputs: {neff_output_names}", file=sys.stderr) + + # Map compile-time names to arrays, then look up by NEFF name + # (NEFF preserves names but may reorder them) + compile_input_map = dict(zip(input_names, test_inputs)) + spike_inputs = { + name: SpikeTensor.from_numpy(compile_input_map[name], name=name) + for name in neff_input_names + } + + # Let the model auto-allocate output tensors with correct names + spike_outputs = model(inputs=spike_inputs, outputs=None) + + # Read back the first output + result_tensor = list(spike_outputs.values())[0] + result = np.frombuffer( + result_tensor.numpy(), dtype=expected.dtype + ).reshape(expected.shape) + + # Compare against NumPy reference + max_diff = np.max(np.abs(result - expected)) + success = np.allclose(result, expected, rtol=1e-3, atol=1e-3) + + if success: + print(f"HW EXECUTION PASSED (max_diff={max_diff:.2e})") + else: + print(f"HW EXECUTION FAILED (max_diff={max_diff:.2e})") + print(f"Artifacts: {debug_dir}") + + return success + + +def main(): + # Debug: log all arguments received from Compiler Explorer + print(f"DEBUG: sys.argv = {sys.argv}", file=sys.stderr) + + # Handle version request early (Compiler Explorer uses this during setup) + if "--version" in sys.argv or "-v" in sys.argv: + print("NKIPy MLIR Compiler 0.1.0") + sys.exit(0) + + parser = argparse.ArgumentParser( + description="NKIPy Compiler for Compiler Explorer" + ) + parser.add_argument("input", nargs='?', help="Input Python file with @trace function") + parser.add_argument( + "--stop", + dest="stop_after", + type=int, + default=None, + help="Stop after pass N (0 = trace only, 1-24 = after that pass, omit for all passes)" + ) + parser.add_argument( + "--target", + choices=["trn1", "trn2", "trn3"], + default="trn2", + help="Target hardware" + ) + # Compiler Explorer typically passes these + parser.add_argument("-o", "--outputfile", help="Output file (ignored, we write to stdout)") + parser.add_argument("-S", action="store_true", help="Compile to assembly (CE flag, ignored)") + parser.add_argument( + "--sim", action="store_true", + help="Run simulation and verify against NumPy (BIR sim for full pipeline, LLVM JIT when used with --stop)" + ) + parser.add_argument( + "--hw", action="store_true", + help="Compile to NEFF and execute on Trainium hardware (requires neuron device, uses --target)" + ) + parser.add_argument( + "--raw", action="store_true", + help="Output clean MLIR without .loc/.file annotations (for CLI use; CE UI needs them)" + ) + + # Use parse_known_args to ignore Compiler Explorer's extra flags (-I, etc.) + args, unknown = parser.parse_known_args() + + print(f"DEBUG: parsed args = {args}", file=sys.stderr) + print(f"DEBUG: unknown args = {unknown}", file=sys.stderr) + + # Find the input file - CE may pass it as a positional arg or as last unknown arg + # CE uses "" as placeholder when source is passed via stdin + input_file = args.input + if input_file == "": + # CE passes source via stdin - read it to a temp file + source_code = sys.stdin.read() + print(f"DEBUG: Read {len(source_code)} chars from stdin", file=sys.stderr) + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(source_code) + input_file = f.name + print(f"DEBUG: Wrote source to temp file: {input_file}", file=sys.stderr) + elif not input_file or not os.path.exists(input_file): + # Try to find a .py file in unknown args + for arg in unknown: + if arg.endswith('.py') and os.path.exists(arg): + input_file = arg + break + # Also check if any unknown arg is an existing file + if not input_file: + for arg in unknown: + if os.path.exists(arg) and not arg.endswith('.s'): + input_file = arg + break + + if not input_file or not os.path.exists(input_file): + print(f"Error: Input file not found: {input_file}", file=sys.stderr) + print(f"DEBUG: Searched in args.input={args.input} and unknown={unknown}", file=sys.stderr) + sys.exit(1) + + SEPARATOR = "=" * 60 + + try: + # ---- Compilation ---- + pipeline_desc = f"stop={args.stop_after}" if args.stop_after is not None else "full pipeline" + print(f"\n{SEPARATOR}", file=sys.stderr) + print(f" Compilation started ({pipeline_desc}, target={args.target})", file=sys.stderr) + print(SEPARATOR, file=sys.stderr) + + t0 = time.time() + ce_output, raw_mlir, traced_func = compile_to_mlir( + input_file, + stop_after=args.stop_after, + target=args.target, + raw=args.raw, + ) + elapsed = time.time() - t0 + + print(f"\n{SEPARATOR}", file=sys.stderr) + print(f" Compilation finished -- {elapsed:.2f}s", file=sys.stderr) + print(f"{SEPARATOR}\n", file=sys.stderr) + + # --raw: print clean MLIR (no .loc annotations); default: CE-style with .loc + output = raw_mlir if args.raw else ce_output + + # CE expects output in the file specified by -o + if args.outputfile: + with open(args.outputfile, 'w') as f: + f.write(output) + print(f"DEBUG: Wrote output to {args.outputfile}", file=sys.stderr) + else: + print(output) + + # Run simulation based on --sim and --stop flags: + # --sim only: full pipeline + BIR simulation + # --stop only: print IR, no simulation + # --sim + --stop: print IR + LLVM JIT execution of intermediate IR + if args.sim: + if args.stop_after is None: + # ---- BIR simulation ---- + print(f"\n{SEPARATOR}", file=sys.stderr) + print(f" BIR simulation started", file=sys.stderr) + print(SEPARATOR, file=sys.stderr) + + t0 = time.time() + success = run_simulation(raw_mlir, traced_func) + elapsed = time.time() - t0 + status = "PASSED" if success else "FAILED" + + print(f"\n{SEPARATOR}", file=sys.stderr) + print(f" BIR simulation {status} -- {elapsed:.2f}s", file=sys.stderr) + print(f"{SEPARATOR}\n", file=sys.stderr) + if not success: + sys.exit(1) + else: + print(raw_mlir) + # ---- LLVM JIT simulation ---- + print(f"\n{SEPARATOR}", file=sys.stderr) + print(f" LLVM JIT simulation started", file=sys.stderr) + print(SEPARATOR, file=sys.stderr) + + t0 = time.time() + run_llvm_simulation(raw_mlir, traced_func) + elapsed = time.time() - t0 + + print(f"\n{SEPARATOR}", file=sys.stderr) + print(f" LLVM JIT simulation finished -- {elapsed:.2f}s", file=sys.stderr) + print(f"{SEPARATOR}\n", file=sys.stderr) + + # Run on Trainium hardware (requires full pipeline, i.e. no --stop) + if args.hw: + if args.stop_after is not None: + print("Error: --hw requires full pipeline (cannot be used with --stop)", file=sys.stderr) + sys.exit(1) + + # ---- HW execution ---- + print(f"\n{SEPARATOR}", file=sys.stderr) + print(f" HW execution started (target={args.target})", file=sys.stderr) + print(SEPARATOR, file=sys.stderr) + + t0 = time.time() + success = run_hw_execution(raw_mlir, traced_func, target=args.target) + elapsed = time.time() - t0 + status = "PASSED" if success else "FAILED" + + print(f"\n{SEPARATOR}", file=sys.stderr) + print(f" HW execution {status} -- {elapsed:.2f}s", file=sys.stderr) + print(f"{SEPARATOR}\n", file=sys.stderr) + if not success: + sys.exit(1) + if not success: + sys.exit(1) + except Exception as e: + print(f"Compilation error: {e}", file=sys.stderr) + import traceback + traceback.print_exc(file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/kernelgen/compiler_explorer/setup.sh b/kernelgen/compiler_explorer/setup.sh new file mode 100755 index 0000000..a8406d5 --- /dev/null +++ b/kernelgen/compiler_explorer/setup.sh @@ -0,0 +1,158 @@ +#!/bin/bash +# Setup script for Compiler Explorer with NKIPy integration +# +# Usage: ./setup.sh [example_file.py] +# example_file.py - Optional path to a Python file to use as the default example +# +# This script: +# 1. Clones Compiler Explorer if not present +# 2. Installs dependencies +# 3. Configures NKIPy as a compiler backend (disables all other compilers) +# 4. Starts the server + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +NKIPY_ROOT="$(dirname "$SCRIPT_DIR")" +CE_DIR="$SCRIPT_DIR/compiler-explorer" + +# Convert example file to absolute path (needed since we cd later) +if [ -n "$1" ]; then + EXAMPLE_FILE="$(cd "$(dirname "$1")" && pwd)/$(basename "$1")" +else + EXAMPLE_FILE="$SCRIPT_DIR/config/example.nkipy" +fi + +echo "=== NKIPy Compiler Explorer Setup ===" +echo "Script directory: $SCRIPT_DIR" +echo "NKIPy root: $NKIPY_ROOT" +echo "Example file: $EXAMPLE_FILE" + +# If user provided a file path, make sure it exists +if [ -n "$1" ] && [ ! -f "$EXAMPLE_FILE" ]; then + echo "Error: File not found: $EXAMPLE_FILE" + exit 1 +fi + +# Check for Node.js +if ! command -v node &> /dev/null; then + echo "Error: Node.js is required. Please install Node.js 20+." + echo " Ubuntu: sudo apt install nodejs npm" + echo " Or use nvm: https://github.com/nvm-sh/nvm" + exit 1 +fi + +NODE_VERSION=$(node -v | cut -d'v' -f2 | cut -d'.' -f1) +if [ "$NODE_VERSION" -lt 18 ]; then + echo "Error: Node.js 18+ required. Found: $(node -v)" + exit 1 +fi + +echo "Node.js version: $(node -v)" + +# Clone Compiler Explorer if not present +if [ ! -d "$CE_DIR" ]; then + echo "" + echo "=== Cloning Compiler Explorer ===" + git clone --depth 1 https://github.com/compiler-explorer/compiler-explorer.git "$CE_DIR" +fi + +cd "$CE_DIR" + +# Install dependencies +echo "" +echo "=== Installing dependencies ===" +npm install + +# Create config directory and clean up default configs +echo "" +echo "=== Configuring for NKIPy only ===" +mkdir -p etc/config +mkdir -p examples/python + +# Remove all default config files to avoid warnings +rm -f etc/config/*.amazon.properties etc/config/*.defaults.properties 2>/dev/null || true + +# Copy example file +if [ -f "$EXAMPLE_FILE" ]; then + cp "$EXAMPLE_FILE" examples/python/default.py + echo "Copied example from: $EXAMPLE_FILE" +else + # Create default example + cat > examples/python/default.py << 'PYEXAMPLE' +import numpy as np +from nkipy_kernelgen import trace, knob + +M, N, K = 256, 256, 256 +matmul_tile = [128, 128, 128] +add_tile = [128, 128] + +@trace(input_specs=[((M, K), "f32"), ((K, N), "f32"), ((M, N), "f32")]) +def matmul_add_kernel(a, b, bias): + c = np.matmul(a, b) + knob.knob(c, mem_space="Sbuf", tile_size=matmul_tile) + result = c + bias + knob.knob(result, mem_space="SharedHbm", tile_size=add_tile) + return result +PYEXAMPLE + echo "Created default matmul_add example" +fi + +# Update the wrapper path in config and copy +WRAPPER_PATH="$SCRIPT_DIR/nkipy_ce_wrapper.sh" +chmod +x "$WRAPPER_PATH" +chmod +x "$SCRIPT_DIR/nkipy_compiler.py" + +# Generate minimal CE config (Python only) +cat > etc/config/compiler-explorer.local.properties << EOF +# Minimal Compiler Explorer config - only NKIPy +languages=python +defaultLanguage=python +noRemoteFetch=true + +# Increase compile timeout for --sim/--hw modes (NEFF compilation + execution) +# Default is 7500ms which is too short +compileTimeoutMs=600000 +EOF + +# Generate Python config with NKIPy only +# Read the example file and escape for Java properties format (newlines -> \n) +EXAMPLE_CONTENT=$(cat "$EXAMPLE_FILE" | sed 's/\\/\\\\/g' | sed ':a;N;$!ba;s/\n/\\n/g') + +cat > etc/config/python.local.properties << EOF +# NKIPy-only Python config +compilers=nkipy +defaultCompiler=nkipy +defaultSource=$EXAMPLE_CONTENT + +compiler.nkipy.exe=$WRAPPER_PATH +compiler.nkipy.name=NKIPy MLIR +compiler.nkipy.supportsBinary=false +compiler.nkipy.supportsExecute=false +compiler.nkipy.notification=Compiles Python+NumPy to NISA MLIR for Neuron hardware +EOF + +# Create sponsors.yaml with required format +echo "levels: []" > etc/config/sponsors.yaml + +echo "" +echo "=== Configuration complete ===" +echo "" +echo "Wrapper script: $WRAPPER_PATH" +echo "Example: examples/python/default.py" +echo "" +echo "=== Starting Compiler Explorer ===" +echo "Access at: http://localhost:10240" +echo "" +echo "Compiler options (add in the options box):" +echo " --stop=0 Trace only (initial MLIR before any passes)" +echo " --stop=7 Stop after tiling (apply-and-strip-transforms)" +echo " --stop=10 Stop after bufferization" +echo " --stop=16 Stop after memory space annotation + cleanup" +echo " --stop=24 Full compilation (same as omitting --stop)" +echo " --sim Run simulation (BIR sim or LLVM JIT with --stop)" +echo " --raw Clean MLIR without source annotations" +echo "" + +# Start the server +npm run dev diff --git a/kernelgen/examples/custom_op.py b/kernelgen/examples/custom_op.py new file mode 100644 index 0000000..c42c490 --- /dev/null +++ b/kernelgen/examples/custom_op.py @@ -0,0 +1,187 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Custom Op Integration — using pre-compiled NISA kernels in NKIPy. + +Demonstrates how to wrap a kernel_builder-compiled NISA function as a CustomOp +and call it from an @trace-decorated kernel. This lets you mix high-level NumPy +tracing (matmul, elementwise) with hand-optimized NISA code (activations, custom +DMA patterns). + +The flow: + 1. Write a NISA kernel using kernel_builder (nki.compiler.kernel_builder). + 2. Wrap it as a CustomOp via CustomOp.from_kernel_builder(). + 3. Call the CustomOp inside an @trace kernel — it emits a func.call during + tracing and falls back to a reference NumPy function outside of tracing. + 4. The compiler pipeline inlines the NISA body at each call site via the + resolve-custom-ops pass. + +Key concepts shown: + - kernel_builder API for writing NISA kernels (nb.ndarray, nb.isa.*) + - CustomOp.from_kernel_builder() for wrapping kernel_builder output + - Mixing custom ops with traced NumPy operations + - reference_fn for NumPy-level testing without hardware + +Usage: + # Compile and dump intermediate IR: + python examples/custom_op.py + + # Or run the full test suite: + python -m pytest tests/e2e/test_custom_op.py -v +""" + +import numpy as np + +from nkipy_kernelgen import trace, knob +from nkipy_kernelgen.custom_op import CustomOp +from nkipy_kernelgen.transforms.nkipy_opt import apply_complete_knob_pipeline + +import nki.compiler.kernel_builder as nb + + +# --------------------------------------------------------------------------- +# Step 1: Define a NISA kernel using kernel_builder +# --------------------------------------------------------------------------- + +def make_silu_custom_op(M, N, tile_p=128, tile_f=128): + """Create a CustomOp that computes SiLU activation using real NISA ops. + + The kernel tiles the (M, N) input into (tile_p, tile_f) chunks and + processes each tile: DMA in -> activation -> DMA out. + + Args: + M, N: Input/output tensor dimensions. + tile_p: Partition tile size (max 128 for SBUF). + tile_f: Free-dimension tile size. + + Returns: + A CustomOp callable in @trace kernels and in plain NumPy. + """ + + def silu_kernel(x_hbm, out_hbm): + """NISA implementation of SiLU, written with kernel_builder.""" + import nki.language as nl + + n_row_tiles = M // tile_p + n_col_tiles = N // tile_f + for r in nl.affine_range(n_row_tiles): + for t in nl.affine_range(n_col_tiles): + # Load a tile from HBM to SBUF + x_sbuf = nb.ndarray((tile_p, tile_f), x_hbm.dtype, nb.sbuf) + nb.isa.dma_copy( + dst=x_sbuf, + src=x_hbm[ + r * tile_p : (r + 1) * tile_p, + t * tile_f : (t + 1) * tile_f, + ], + ) + + # Allocate output tile and bias/scale for activation + out_sbuf = nb.ndarray((tile_p, tile_f), x_hbm.dtype, nb.sbuf) + bias = nb.ndarray((tile_p, 1), x_hbm.dtype, nb.sbuf) + nb.isa.memset(dst=bias, value=0.0) + scale = nb.ndarray((tile_p, 1), x_hbm.dtype, nb.sbuf) + nb.isa.memset(dst=scale, value=1.0) + + # Apply SiLU activation via NISA hardware instruction + nb.isa.activation( + dst=out_sbuf, + src=x_sbuf, + bias=bias, + scale=scale, + op=nb.isa.activation_function.silu, + ) + + # Store result from SBUF back to HBM + nb.isa.dma_copy( + dst=out_hbm[ + r * tile_p : (r + 1) * tile_p, + t * tile_f : (t + 1) * tile_f, + ], + src=out_sbuf, + ) + + def silu_reference(x): + """NumPy reference for SiLU: x * sigmoid(x).""" + return x / (1.0 + np.exp(-x)) + + # Compile the kernel_builder function and wrap as CustomOp + return CustomOp.from_kernel_builder( + kernel_func=silu_kernel, + input_specs={"x_hbm": nb.Tensor((M, N), nb.float32, nb.shared_hbm)}, + output_specs={"out_hbm": nb.Tensor((M, N), nb.float32, nb.shared_hbm)}, + reference_fn=silu_reference, + ) + + +# --------------------------------------------------------------------------- +# Step 2: Create the custom op instance +# --------------------------------------------------------------------------- + +# Build a SiLU custom op for 256x256 tensors with 128x128 internal tiling. +# SBUF partition dim max is 128, so tile_p must be <= 128. +custom_silu = make_silu_custom_op(256, 256, tile_p=128, tile_f=128) + + +# --------------------------------------------------------------------------- +# Step 3: Use the custom op in a traced kernel +# --------------------------------------------------------------------------- + +@trace( + input_specs=[ + ((256, 256), "f32"), # x + ((256, 256), "f32"), # weight + ] +) +def matmul_silu_kernel(x, weight): + """Matrix multiply followed by custom SiLU activation. + + The matmul is compiled by the NKIPy pipeline (tiling, bufferization, + NISA lowering), while the SiLU is a pre-compiled NISA function that + gets inlined at the call site by resolve-custom-ops. + """ + mm_out = np.matmul(x, weight) + knob.knob( + mm_out, mem_space="SharedHbm", tile_size=[128, 128], reduction_tile=[128] + ) + + # Call the custom op — during tracing this emits a func.call; + # during NumPy execution it falls back to silu_reference. + output = custom_silu(mm_out) + + return output + + +# --------------------------------------------------------------------------- +# Compile, print IR, and verify correctness when run as a script +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + import sys + import os + import tempfile + + # Trace and compile + module = matmul_silu_kernel.to_mlir() + traced_ir = str(module) + + dump_dir = tempfile.mkdtemp(prefix="custom_op_") + compiled_ir = apply_complete_knob_pipeline(traced_ir, dump_dir=dump_dir) + print(compiled_ir) + + # Verify numerical correctness via BIR simulation + sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "tests")) + from harness import simulate_mlir, generate_inputs, compute_reference + + inputs = generate_inputs(matmul_silu_kernel.input_specs) + reference = compute_reference(matmul_silu_kernel, inputs) + + success, max_diff, artifacts = simulate_mlir( + compiled_ir, + func_name="matmul_silu_kernel", + test_inputs=inputs, + expected_output=reference, + rtol=1e-3, + atol=1e-3, + ) + print(f"\nBIR simulation: {'PASS' if success else 'FAIL'} (max_diff={max_diff:.2e})") diff --git a/kernelgen/examples/qwen3_layer.py b/kernelgen/examples/qwen3_layer.py new file mode 100644 index 0000000..a1d5188 --- /dev/null +++ b/kernelgen/examples/qwen3_layer.py @@ -0,0 +1,311 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Qwen3 Transformer Decoder Layer — full single-layer inference kernel. + +Demonstrates how to express a complete transformer decoder layer in NKIPy: + 1. Pre-attention RMSNorm + 2. QKV projection (hidden -> q, k, v per head) + 3. Reshape + transpose to multi-head format + 4. RoPE on Q and K + 5. Scaled dot-product attention (with softmax) + 6. Concat heads + output projection + 7. Residual connection + 8. Post-attention RMSNorm + 9. SwiGLU feedforward (gate_up projection, SiLU, down projection) + 10. Residual connection + +Key concepts shown: + - `@trace` turns a NumPy function into a compilable kernel. + - `knob.knob()` annotates tensors with memory placement (Sbuf, SharedHbm) + and tiling hints (tile_size, reduction_tile, partition_dim). + - Sub-kernel boundaries (values flowing through reshape/transpose or between + independent compute stages) use SharedHbm so the compiler can freely + reshape/transpose without partition_dim constraints. + +Usage: + # Compile and dump intermediate IR: + python examples/qwen3_layer.py + + # Or use the compiler explorer: + cd compiler_explorer + python nkipy_compiler.py ../examples/qwen3_layer.py --stop=22 --raw +""" + +import numpy as np + +from nkipy_kernelgen import trace, knob +from nkipy_kernelgen.transforms.nkipy_opt import apply_complete_knob_pipeline + +# --------------------------------------------------------------------------- +# Model hyperparameters (small config for demonstration) +# --------------------------------------------------------------------------- +batch = 2 +seq_len = 128 +hidden_size = 256 +n_heads = 2 +head_dim = hidden_size // n_heads # 128 +intermediate_size = 256 +half_dim = head_dim // 2 # 64 +eps = 1e-6 +scale = 1.0 / np.sqrt(head_dim).item() + +BS = batch * seq_len # 256 (tokens = batch * seq_len) +BH = batch * n_heads # 4 (heads = batch * n_heads) + +# --------------------------------------------------------------------------- +# Tile sizes — must divide the corresponding tensor dimensions evenly. +# NISA hardware processes data in 128-element partitions. +# --------------------------------------------------------------------------- +matmul_tile_2d = [128, 128] +matmul_reduction_2d = [128] +attn_tile = [1, 128, 128] +attn_reduction = [128] +rope_tile = [1, 128, 64] # (BH, seq_len, half_dim) +elem_tile_2d = [128, 128] + + +# --------------------------------------------------------------------------- +# Helper functions — reusable building blocks +# --------------------------------------------------------------------------- + +def rmsnorm(x, weight): + """RMSNorm: x / sqrt(mean(x^2) + eps) * weight.""" + x_fp32 = x.astype(np.float32) + w_fp32 = weight.astype(np.float32) + + sq = np.square(x_fp32) + knob.knob(sq, mem_space="Sbuf", tile_size=elem_tile_2d) + + sum_sq = np.sum(sq, axis=-1, keepdims=True) + knob.knob(sum_sq, mem_space="Sbuf", tile_size=[128], reduction_tile=[128]) + + mean_sq = sum_sq * np.float32(1.0 / hidden_size) + knob.knob(mean_sq, mem_space="Sbuf", tile_size=[128, 1]) + + normed = x_fp32 / np.sqrt(mean_sq + eps) + knob.knob(normed, mem_space="Sbuf", tile_size=elem_tile_2d) + + result = normed * w_fp32 + knob.knob(result, mem_space="Sbuf", tile_size=elem_tile_2d) + + return result + + +def softmax_3d(x): + """Numerically-stable softmax over the last axis of a 3D tensor.""" + x_fp32 = x.astype(np.float32) + + # Reduction accumulators use SharedHbm to avoid 5D SBUF allocs + # that legalize-layout cannot tile. + x_max = np.max(x_fp32, axis=-1, keepdims=True) + knob.knob(x_max, mem_space="SharedHbm", tile_size=[1, 128], + reduction_tile=[128], partition_dim=1) + + shifted = x_fp32 - x_max + knob.knob(shifted, mem_space="Sbuf", tile_size=attn_tile, partition_dim=1) + + exp_s = np.exp(shifted) + knob.knob(exp_s, mem_space="Sbuf", tile_size=attn_tile, partition_dim=1) + + sum_exp = np.sum(exp_s, axis=-1, keepdims=True) + knob.knob(sum_exp, mem_space="SharedHbm", tile_size=[1, 128], + reduction_tile=[128], partition_dim=1) + + result = exp_s / sum_exp + knob.knob(result, mem_space="SharedHbm", tile_size=attn_tile) + + return result + + +def silu(x): + """SiLU activation: x * sigmoid(x).""" + neg_x = -x + knob.knob(neg_x, mem_space="Sbuf", tile_size=elem_tile_2d) + + exp_neg = np.exp(neg_x) + knob.knob(exp_neg, mem_space="Sbuf", tile_size=elem_tile_2d) + + one_plus = exp_neg + 1.0 + knob.knob(one_plus, mem_space="Sbuf", tile_size=elem_tile_2d) + + sigmoid = 1.0 / one_plus + knob.knob(sigmoid, mem_space="Sbuf", tile_size=elem_tile_2d) + + result = x * sigmoid + knob.knob(result, mem_space="Sbuf", tile_size=elem_tile_2d) + + return result + + +def apply_rope(x, freqs_cos, freqs_sin): + """Rotary positional embedding: rotate x by cos/sin frequencies.""" + x0 = x[:, :, :half_dim] + x1 = x[:, :, half_dim:] + + # Each intermediate uses SharedHbm to avoid 3D SBUF allocs with + # non-partition dim 0. + x0_cos = x0 * freqs_cos + knob.knob(x0_cos, mem_space="SharedHbm", tile_size=rope_tile) + x1_sin = x1 * freqs_sin + knob.knob(x1_sin, mem_space="SharedHbm", tile_size=rope_tile) + out_0 = x0_cos - x1_sin + knob.knob(out_0, mem_space="SharedHbm", tile_size=rope_tile) + + x0_sin = x0 * freqs_sin + knob.knob(x0_sin, mem_space="SharedHbm", tile_size=rope_tile) + x1_cos = x1 * freqs_cos + knob.knob(x1_cos, mem_space="SharedHbm", tile_size=rope_tile) + out_1 = x0_sin + x1_cos + knob.knob(out_1, mem_space="SharedHbm", tile_size=rope_tile) + + result = np.concatenate([out_0, out_1], axis=-1) + knob.knob(result, mem_space="SharedHbm", tile_size=attn_tile) + + return result + + +# --------------------------------------------------------------------------- +# Kernel definition +# --------------------------------------------------------------------------- + +@trace(input_specs=[ + ((BS, hidden_size), "f32"), # hidden_states + ((hidden_size, 1), "f32"), # ln1_weight + ((hidden_size, 1), "f32"), # ln2_weight + ((hidden_size, hidden_size), "f32"), # w_q + ((hidden_size, hidden_size), "f32"), # w_k + ((hidden_size, hidden_size), "f32"), # w_v + ((hidden_size, hidden_size), "f32"), # w_o + ((1, seq_len, half_dim), "f32"), # freqs_cos + ((1, seq_len, half_dim), "f32"), # freqs_sin + ((hidden_size, intermediate_size), "f32"), # w_gate + ((hidden_size, intermediate_size), "f32"), # w_up + ((intermediate_size, hidden_size), "f32"), # w_down +]) +def qwen3_layer(hidden_states, + ln1_weight, ln2_weight, + w_q, w_k, w_v, w_o, + freqs_cos, freqs_sin, + w_gate, w_up, w_down): + residual = hidden_states + + # 1. Pre-attention RMSNorm + normed = rmsnorm(hidden_states, ln1_weight) + + # 2. QKV projections — SharedHbm boundary (results flow through reshape) + q = np.matmul(normed, w_q) + knob.knob(q, mem_space="SharedHbm", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + k = np.matmul(normed, w_k) + knob.knob(k, mem_space="SharedHbm", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + v = np.matmul(normed, w_v) + knob.knob(v, mem_space="SharedHbm", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + # 3. Reshape to multi-head: (BS, hidden) -> (BH, seq_len, head_dim) + q = np.reshape(q, (batch, seq_len, n_heads, head_dim)) + q = np.transpose(q, (0, 2, 1, 3)) + q = np.reshape(q, (BH, seq_len, head_dim)) + + k = np.reshape(k, (batch, seq_len, n_heads, head_dim)) + k = np.transpose(k, (0, 2, 1, 3)) + k = np.reshape(k, (BH, seq_len, head_dim)) + + v = np.reshape(v, (batch, seq_len, n_heads, head_dim)) + v = np.transpose(v, (0, 2, 1, 3)) + v = np.reshape(v, (BH, seq_len, head_dim)) + knob.knob(v, mem_space="SharedHbm", tile_size=attn_tile) + + # 4. RoPE on Q and K + q = apply_rope(q, freqs_cos, freqs_sin) + k = apply_rope(k, freqs_cos, freqs_sin) + + k_t = np.transpose(k, (0, 2, 1)) + knob.knob(k_t, mem_space="SharedHbm", tile_size=attn_tile) + + # 5. Scaled dot-product attention + scores = np.matmul(q, k_t) + knob.knob(scores, mem_space="Sbuf", tile_size=attn_tile, reduction_tile=attn_reduction) + + scores = scores * scale + knob.knob(scores, mem_space="Sbuf", tile_size=attn_tile, partition_dim=1) + + attn_weights = softmax_3d(scores) + + context = np.matmul(attn_weights, v) + knob.knob(context, mem_space="SharedHbm", tile_size=attn_tile, reduction_tile=attn_reduction) + + # 6. Concat heads + output projection + context = np.reshape(context, (batch, n_heads, seq_len, head_dim)) + context = np.transpose(context, (0, 2, 1, 3)) + context = np.reshape(context, (BS, hidden_size)) + knob.knob(context, mem_space="SharedHbm", tile_size=matmul_tile_2d) + + attn_out = np.matmul(context, w_o) + knob.knob(attn_out, mem_space="Sbuf", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + # 7. Residual connection + hidden_states = residual + attn_out + knob.knob(hidden_states, mem_space="Sbuf", tile_size=elem_tile_2d) + + residual = hidden_states + + # 8. Post-attention RMSNorm + normed = rmsnorm(hidden_states, ln2_weight) + + # 9. SwiGLU FFN + gate = np.matmul(normed, w_gate) + knob.knob(gate, mem_space="Sbuf", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + up = np.matmul(normed, w_up) + knob.knob(up, mem_space="Sbuf", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + gate = silu(gate) + + gated = gate * up + knob.knob(gated, mem_space="Sbuf", tile_size=elem_tile_2d) + + ffn_out = np.matmul(gated, w_down) + knob.knob(ffn_out, mem_space="Sbuf", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + # 10. Residual connection + output = residual + ffn_out + knob.knob(output, mem_space="SharedHbm", tile_size=elem_tile_2d) + + return output + + +# --------------------------------------------------------------------------- +# Compile, print IR, and verify correctness when run as a script +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + import sys + import os + import tempfile + + # Trace and compile + module = qwen3_layer.to_mlir() + traced_ir = str(module) + + dump_dir = tempfile.mkdtemp(prefix="qwen3_layer_") + compiled_ir = apply_complete_knob_pipeline(traced_ir, dump_dir=dump_dir) + print(compiled_ir) + + # Verify numerical correctness via BIR simulation + sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "tests")) + from harness import simulate_mlir, generate_inputs, compute_reference + + inputs = generate_inputs(qwen3_layer.input_specs) + reference = compute_reference(qwen3_layer, inputs) + + success, max_diff, artifacts = simulate_mlir( + compiled_ir, + func_name="qwen3_layer", + test_inputs=inputs, + expected_output=reference, + rtol=1e-3, + atol=1e-3, + ) + print(f"\nBIR simulation: {'PASS' if success else 'FAIL'} (max_diff={max_diff:.2e})") diff --git a/kernelgen/mlir/CMakeLists.txt b/kernelgen/mlir/CMakeLists.txt new file mode 100644 index 0000000..3d3a163 --- /dev/null +++ b/kernelgen/mlir/CMakeLists.txt @@ -0,0 +1,48 @@ +cmake_minimum_required(VERSION 3.20.0) +cmake_policy(SET CMP0116 NEW) + +project(nkipy-kg LANGUAGES CXX C) + +set(CMAKE_BUILD_WITH_INSTALL_NAME_DIR ON) + +set(CMAKE_CXX_STANDARD 17 CACHE STRING "C++ standard to conform to") +set(CMAKE_CXX_FLAGS "-Wfatal-errors -std=c++17") +add_compile_options ( -w ) + +set(CMAKE_BUILD_TYPE Debug) + +# Define NDEBUG to match Release-built MLIR libraries (avoids undefined +# reference to debug-only functions like checkImplementsTransformOpInterface) +add_definitions(-DNDEBUG) + +find_package(MLIR REQUIRED CONFIG) + +message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") +message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") + +set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin) +set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib) +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") +list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") +include(TableGen) +include(AddLLVM) +include(AddMLIR) +include(HandleLLVMOptions) + +include_directories(${LLVM_INCLUDE_DIRS}) +include_directories(${MLIR_INCLUDE_DIRS}) +include_directories(${PROJECT_SOURCE_DIR}/include) +include_directories(${PROJECT_BINARY_DIR}/include) + +link_directories(${LLVM_BUILD_LIBRARY_DIR}) +add_definitions(${LLVM_DEFINITIONS}) + +message(STATUS "Using Python binding") +include(MLIRDetectPythonEnv) +mlir_configure_python_dev_packages() + +add_subdirectory(include) +add_subdirectory(lib) +add_subdirectory(tools) diff --git a/kernelgen/mlir/include/CMakeLists.txt b/kernelgen/mlir/include/CMakeLists.txt new file mode 100644 index 0000000..9e7f82a --- /dev/null +++ b/kernelgen/mlir/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(nkipy) \ No newline at end of file diff --git a/kernelgen/mlir/include/nkipy-c/Dialect/Dialects.h b/kernelgen/mlir/include/nkipy-c/Dialect/Dialects.h new file mode 100644 index 0000000..3024d28 --- /dev/null +++ b/kernelgen/mlir/include/nkipy-c/Dialect/Dialects.h @@ -0,0 +1,17 @@ + +#ifndef NKIPY_C_DIALECT__H +#define NKIPY_C_DIALECT__H + +#include "mlir-c/RegisterEverything.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(NKIPY, nkipy); + +#ifdef __cplusplus +} +#endif + +#endif // NKIPY_C_DIALECT__H \ No newline at end of file diff --git a/kernelgen/mlir/include/nkipy-c/Dialect/NkipyAttributes.h b/kernelgen/mlir/include/nkipy-c/Dialect/NkipyAttributes.h new file mode 100644 index 0000000..4a8ed8a --- /dev/null +++ b/kernelgen/mlir/include/nkipy-c/Dialect/NkipyAttributes.h @@ -0,0 +1,24 @@ +#ifndef NKIPY_MLIR_C_ATTRIBUTES__H +#define NKIPY_MLIR_C_ATTRIBUTES__H + +#include "mlir-c/IR.h" +#include "mlir-c/IntegerSet.h" +#include "mlir-c/Support.h" +#include "mlir/CAPI/IntegerSet.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// MLIR_CAPI_EXPORTED bool mlirAttributeIsAIntegerSet(MlirAttribute attr); +MLIR_CAPI_EXPORTED MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set); + +MLIR_CAPI_EXPORTED bool mlirAttributeIsAMemSpace(MlirAttribute attr); +MLIR_CAPI_EXPORTED MlirAttribute mlirMemSpaceGet(MlirContext ctx, + MlirAttribute space); + +#ifdef __cplusplus +} +#endif + +#endif // NKIPY_MLIR_C_ATTRIBUTES__H diff --git a/kernelgen/mlir/include/nkipy-c/Dialect/Registration.h b/kernelgen/mlir/include/nkipy-c/Dialect/Registration.h new file mode 100644 index 0000000..eb34603 --- /dev/null +++ b/kernelgen/mlir/include/nkipy-c/Dialect/Registration.h @@ -0,0 +1,25 @@ + +#ifndef NKIPY_MLIR_C_REGISTRATION_H +#define NKIPY_MLIR_C_REGISTRATION_H + +#include "mlir/CAPI/IR.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** Registers all dialects with a context. + * This is needed before creating IR for these Dialects. + */ +MLIR_CAPI_EXPORTED void nkipyMlirRegisterAllDialects(MlirContext context); + +/** Registers all passes for symbolic access with the global registry. */ +MLIR_CAPI_EXPORTED void nkipyMlirRegisterAllPasses(); + +#ifdef __cplusplus +} +#endif + +#endif // NKIPY_MLIR_C_REGISTRATION_H \ No newline at end of file diff --git a/kernelgen/mlir/include/nkipy/Bindings/CMakeLists.txt b/kernelgen/mlir/include/nkipy/Bindings/CMakeLists.txt new file mode 100644 index 0000000..e6e828d --- /dev/null +++ b/kernelgen/mlir/include/nkipy/Bindings/CMakeLists.txt @@ -0,0 +1,96 @@ +include(AddMLIRPython) + +# The directory at which the Python import tree begins. +# See documentation for `declare_mlir_python_sources`'s ROOT_DIR +# argument. +set(NKIPY_MLIR_PYTHON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/nkipy") +set(NKIPY_MLIR_PYTHON_PACKAGES_DIR "${PROJECT_BINARY_DIR}/tools/nkipy") +set(MLIR_PYTHON_SOURCE_DIR "${MLIR_MAIN_SRC_DIR}/lib/Bindings") +set(NKIPY_PYTHON_SOURCE_DIR "${PROJECT_SOURCE_DIR}/lib/Bindings") + +include_directories(${MLIR_PYTHON_SOURCE_DIR}) + +# Use the system MLIR package prefix to ensure capsule compatibility +add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=mlir.") + +################################################################################ +# Sources +################################################################################ + +declare_mlir_python_sources(NkipyMLIRPythonSources) +declare_mlir_python_sources(NkipyMLIRPythonExtensions) + +declare_mlir_python_sources(NkipyMLIRPythonSources.Dialects + ROOT_DIR "${NKIPY_MLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT NkipyMLIRPythonSources +) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT NkipyMLIRPythonSources.Dialects + ROOT_DIR "${NKIPY_MLIR_PYTHON_ROOT_DIR}" + TD_FILE dialects/NkipyBinding.td + SOURCES + dialects/nkipy.py + dialects/_ods_common.py + exceptions.py + __init__.py + DIALECT_NAME nkipy +) + +################################################################################ +# Extensions +################################################################################ + +declare_mlir_python_extension(NkipyMLIRPythonExtensions.Main + MODULE_NAME _nkipy + ADD_TO_PARENT NkipyMLIRPythonExtensions + ROOT_DIR "/" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + ${NKIPY_PYTHON_SOURCE_DIR}/NkipyModule.cpp + ${NKIPY_PYTHON_SOURCE_DIR}/NkipyAttributes.cpp + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPIDebug + MLIRNkipyCAPI + PRIVATE_LINK_LIBS + MLIRPass + MLIRFuncDialect + MLIRMemRefDialect + MLIRAffineDialect + LLVMSupport +) + +################################################################################ +# Generate packages and shared library +# Downstreams typically will not use these, but they are useful for local +# testing. +################################################################################ + +# Only build NKIPy custom dialect extensions, not full MLIR bindings +# This avoids nanobind type conflicts with system mlir/nki.compiler._internal +set(_source_components + NkipyMLIRPythonSources + NkipyMLIRPythonExtensions + # MLIRPythonSources - REMOVED: Use system mlir package instead + # MLIRPythonExtension.RegisterEverything - REMOVED: Use system mlir package instead +) + +add_mlir_python_common_capi_library(NkipyMLIRAggregateCAPI + INSTALL_COMPONENT NkipyMLIRPythonModules + INSTALL_DESTINATION _mlir + OUTPUT_DIRECTORY "${NKIPY_MLIR_PYTHON_PACKAGES_DIR}/_mlir" + RELATIVE_INSTALL_ROOT "../.." + DECLARED_HEADERS + MLIRPythonCAPI.HeaderSources + DECLARED_SOURCES + ${_source_components} +) + +add_mlir_python_modules(NkipyMLIRPythonModules + ROOT_PREFIX "${NKIPY_MLIR_PYTHON_PACKAGES_DIR}/_mlir" + INSTALL_PREFIX "_mlir" + DECLARED_SOURCES ${_source_components} + COMMON_CAPI_LINK_LIBS + NkipyMLIRAggregateCAPI + ) diff --git a/kernelgen/mlir/include/nkipy/Bindings/NkipyModule.h b/kernelgen/mlir/include/nkipy/Bindings/NkipyModule.h new file mode 100644 index 0000000..0c5e29b --- /dev/null +++ b/kernelgen/mlir/include/nkipy/Bindings/NkipyModule.h @@ -0,0 +1,17 @@ +#ifndef NKIPY_BINDINGS_PYTHON_IRMODULES_H +#define NKIPY_BINDINGS_PYTHON_IRMODULES_H + +// #include "NanobindUtils.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" + +namespace mlir { +namespace python { + +// void populateNkipyIRTypes(nanobind::module_ &m); +void populateNkipyAttributes(nanobind::module_ &m); + +} // namespace python +} // namespace mlir + +#endif // NKIPY_BINDINGS_PYTHON_IRMODULES_H diff --git a/kernelgen/mlir/include/nkipy/Bindings/nkipy/__init__.py b/kernelgen/mlir/include/nkipy/Bindings/nkipy/__init__.py new file mode 100644 index 0000000..1e29678 --- /dev/null +++ b/kernelgen/mlir/include/nkipy/Bindings/nkipy/__init__.py @@ -0,0 +1,10 @@ +# Import from system MLIR package (not bundled) +# This avoids nanobind type conflicts with nki.compiler._internal +from mlir import ir +from mlir import dialects + +# Import NKIPy-specific extensions +from ._mlir_libs._nkipy import nkipy + +# Re-export for convenience +__all__ = ['ir', 'dialects', 'nkipy'] \ No newline at end of file diff --git a/kernelgen/mlir/include/nkipy/Bindings/nkipy/dialects/NkipyBinding.td b/kernelgen/mlir/include/nkipy/Bindings/nkipy/dialects/NkipyBinding.td new file mode 100644 index 0000000..deb1d8a --- /dev/null +++ b/kernelgen/mlir/include/nkipy/Bindings/nkipy/dialects/NkipyBinding.td @@ -0,0 +1,6 @@ +#ifndef BINDINGS_PYTHON_NKIPY_OPS_TD +#define BINDINGS_PYTHON_NKIPY_OPS_TD + +include "nkipy/Dialect/NkipyOps.td" + +#endif // BINDINGS_PYTHON_NKIPY_OPS_TD \ No newline at end of file diff --git a/kernelgen/mlir/include/nkipy/Bindings/nkipy/dialects/_ods_common.py b/kernelgen/mlir/include/nkipy/Bindings/nkipy/dialects/_ods_common.py new file mode 100644 index 0000000..ed1d617 --- /dev/null +++ b/kernelgen/mlir/include/nkipy/Bindings/nkipy/dialects/_ods_common.py @@ -0,0 +1,11 @@ +# Re-export from system MLIR to avoid conflicts +from mlir.dialects._ods_common import * +from mlir.dialects._ods_common import ( + _cext, + segmented_accessor, + equally_sized_accessor, + get_default_loc_context, + get_op_result_or_value, + get_op_results_or_values, + get_op_result_or_op_results, +) diff --git a/kernelgen/mlir/include/nkipy/Bindings/nkipy/dialects/nkipy.py b/kernelgen/mlir/include/nkipy/Bindings/nkipy/dialects/nkipy.py new file mode 100644 index 0000000..385251b --- /dev/null +++ b/kernelgen/mlir/include/nkipy/Bindings/nkipy/dialects/nkipy.py @@ -0,0 +1,2 @@ +from ._nkipy_ops_gen import * +from .._mlir_libs._nkipy.nkipy import * \ No newline at end of file diff --git a/kernelgen/mlir/include/nkipy/Bindings/nkipy/exceptions.py b/kernelgen/mlir/include/nkipy/Bindings/nkipy/exceptions.py new file mode 100644 index 0000000..41c989c --- /dev/null +++ b/kernelgen/mlir/include/nkipy/Bindings/nkipy/exceptions.py @@ -0,0 +1,220 @@ +import warnings +from contextvars import ContextVar + +# By default, Python ignores deprecation warnings. +# we have to enable it to see the warning. +warnings.simplefilter("always", DeprecationWarning) + +PrintLog = ContextVar("PrintLog", default=False) + + +class bcolors: + """ANSI color escape codes for terminal output.""" + + HEADER = "\033[95m" + OKBLUE = "\033[94m" + OKCYAN = "\033[96m" + OKGREEN = "\033[92m" + WARNING = "\033[93m" + FAIL = "\033[91m" + ENDC = "\033[0m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + + +class NkipyException(Exception): + """Base class for all Nkipy exceptions. + + Exception is the base class for warnings and errors. + Developers can subclass this class to provide additional information + + Parameters + ---------- + message : str + The error message. + """ + + def __init__(self, message): + Exception.__init__(self, message) + self.message = message + + +class NkipyWarning(NkipyException): + """Base class for all Nkipy warnings. + + Warning is the base class for all warnings. + Developers can subclass this class to provide additional information + + Parameters + ---------- + message : str + The warning message. + + line : int, optional + The line number of the warning. + + category_str : str, optional + The warning category string. + + category : Warning, optional + The warning category. + """ + + def __init__(self, message, line=None, category_str=None, category=None): + message = bcolors.OKBLUE + message + bcolors.ENDC + if category_str is not None: + message = "\n{} {}".format(category_str, message) + if line is not None: + message += bcolors.BOLD + " (line {})".format(line) + bcolors.ENDC + NkipyException.__init__(self, message) + self.category = category + + def warn(self): + warnings.warn(self.message, category=self.category) + + def log(self): + if PrintLog.get(): + print(self.message) + + +class NkipyError(NkipyException): + """Base class for all Nkipy errors. + + Error is the base class for all errors. + Developers can subclass this class to provide additional information + + Parameters + ---------- + message : str + The error message. + + line: int, optional + The line number of the error. + + category_str : str, optional + The error category string. + """ + + def __init__(self, message, line=None, category_str=None): + message = bcolors.OKBLUE + message + bcolors.ENDC + if category_str is not None: + message = "{} {}".format(category_str, message) + if line is not None: + message += bcolors.BOLD + " (line {})".format(line) + bcolors.ENDC + NkipyException.__init__(self, message) + + def error(self): + raise self.message + + +""" Inherited Error subclasses """ + + +class DTypeError(NkipyError): + """A subclass for specifying data type related exception""" + + def __init__(self, msg, line=None): + category_str = bcolors.FAIL + "[Data Type]" + bcolors.ENDC + NkipyError.__init__(self, msg, line, category_str) + + +class APIError(NkipyError): + """A subclass for specifying API related exception""" + + def __init__(self, msg, line=None): + category_str = bcolors.FAIL + "[API]" + bcolors.ENDC + NkipyError.__init__(self, msg, line, category_str) + + +class DSLError(NkipyError): + """A subclass for specifying imperative DSL related exception""" + + def __init__(self, msg, line=None): + category_str = bcolors.FAIL + "[Imperative]" + bcolors.ENDC + NkipyError.__init__(self, msg, line, category_str) + + +class TensorError(NkipyError): + """A subclass for specifying tensor related exception""" + + def __init__(self, msg, line=None): + category_str = bcolors.FAIL + "[Tensor]" + bcolors.ENDC + NkipyError.__init__(self, msg, line, category_str) + + +class DeviceError(NkipyError): + """A subclass for specifying device related exception""" + + def __init__(self, msg, line=None): + category_str = bcolors.FAIL + "[Device]" + bcolors.ENDC + NkipyError.__init__(self, msg, line, category_str) + + +class AssertError(NkipyError): + """A subclass for specifying assert related exception""" + + def __init__(self, msg, line=None): + category_str = bcolors.FAIL + "[Assert]" + bcolors.ENDC + NkipyError.__init__(self, msg, line, category_str) + + +""" New Error subclasses """ + + +class NkipyNotImplementedError(NkipyError): + """A subclass for specifying not implemented exception""" + + def __init__(self, msg, line=None): + category_str = bcolors.FAIL + "[Not Implemented]" + bcolors.ENDC + NkipyError.__init__(self, msg, line, category_str) + + +class MLIRLimitationError(NkipyError): + """A subclass for specifying MLIR limitation exception""" + + def __init__(self, msg, line=None): + category_str = bcolors.FAIL + "[MLIR Limitation]" + bcolors.ENDC + NkipyError.__init__(self, msg, line, category_str) + + +class NkipyValueError(NkipyError): + """A subclass for specifying Nkipy value exception""" + + def __init__(self, msg, line=None): + category_str = bcolors.FAIL + "[Value Error]" + bcolors.ENDC + NkipyError.__init__(self, msg, line, category_str) + + +""" New Warning subclasses """ + + +class DTypeWarning(NkipyWarning): + """A subclass for specifying data type related warning""" + + def __init__(self, msg, line=None): + category_str = bcolors.WARNING + "[Data Type]" + bcolors.ENDC + NkipyWarning.__init__(self, msg, line, category_str, RuntimeWarning) + + +class NkipyDeprecationWarning(NkipyWarning): + """A subclass for specifying deprecation warning""" + + def __init__(self, msg, line=None): + category_str = bcolors.WARNING + "[Deprecation]" + bcolors.ENDC + NkipyWarning.__init__(self, msg, line, category_str, DeprecationWarning) + + +class APIWarning(NkipyWarning): + """A subclass for specifying API related warning""" + + def __init__(self, msg, line=None): + category_str = bcolors.WARNING + "[API]" + bcolors.ENDC + NkipyWarning.__init__(self, msg, line, category_str, RuntimeWarning) + + +class PassWarning(NkipyWarning): + """A subclass for specifying pass related warning""" + + def __init__(self, msg, line=None): + category_str = bcolors.WARNING + "[Pass]" + bcolors.ENDC + NkipyWarning.__init__(self, msg, line, category_str, RuntimeWarning) diff --git a/kernelgen/mlir/include/nkipy/CMakeLists.txt b/kernelgen/mlir/include/nkipy/CMakeLists.txt new file mode 100644 index 0000000..f659bbf --- /dev/null +++ b/kernelgen/mlir/include/nkipy/CMakeLists.txt @@ -0,0 +1,4 @@ +add_subdirectory(Transforms) +add_subdirectory(Bindings) +add_subdirectory(Dialect) +add_subdirectory(TransformOps) \ No newline at end of file diff --git a/kernelgen/mlir/include/nkipy/Dialect/CMakeLists.txt b/kernelgen/mlir/include/nkipy/Dialect/CMakeLists.txt new file mode 100644 index 0000000..9f4e85d --- /dev/null +++ b/kernelgen/mlir/include/nkipy/Dialect/CMakeLists.txt @@ -0,0 +1,20 @@ +set(LLVM_TARGET_DEFINITIONS NkipyOps.td) +mlir_tablegen(NkipyOps.h.inc -gen-op-decls) +mlir_tablegen(NkipyOps.cpp.inc -gen-op-defs) +mlir_tablegen(NkipyDialect.h.inc -gen-dialect-decls -dialect=nkipy) +mlir_tablegen(NkipyDialect.cpp.inc -gen-dialect-defs -dialect=nkipy) +add_public_tablegen_target(MLIRNkipyOpsIncGen) +add_dependencies(mlir-headers MLIRNkipyOpsIncGen) + +set(LLVM_TARGET_DEFINITIONS NkipyAttrs.td) +mlir_tablegen(NkipyAttrs.h.inc -gen-attrdef-decls) +mlir_tablegen(NkipyAttrs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(MLIRNkipyAttrsIncGen) + +set(LLVM_TARGET_DEFINITIONS NkipyAttrs.td) +mlir_tablegen(NkipyEnums.h.inc -gen-enum-decls) +mlir_tablegen(NkipyEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRNkipyEnumsIncGen) + +add_mlir_doc(NkipyDialect NkipyDialect Nkipy/ -gen-dialect-doc) +add_mlir_doc(NkipyOps NkipyOps Nkipy/ -gen-op-doc) \ No newline at end of file diff --git a/kernelgen/mlir/include/nkipy/Dialect/NkipyAttrs.h b/kernelgen/mlir/include/nkipy/Dialect/NkipyAttrs.h new file mode 100644 index 0000000..f433c9b --- /dev/null +++ b/kernelgen/mlir/include/nkipy/Dialect/NkipyAttrs.h @@ -0,0 +1,11 @@ +#ifndef NKIPY_ATTRS_H +#define NKIPY_ATTRS_H + +#include "mlir/IR/BuiltinAttributes.h" + +#include "nkipy/Dialect/NkipyEnums.h.inc" + +#define GET_ATTRDEF_CLASSES +#include "nkipy/Dialect/NkipyAttrs.h.inc" + +#endif // NKIPY_ATTRS_H \ No newline at end of file diff --git a/kernelgen/mlir/include/nkipy/Dialect/NkipyAttrs.td b/kernelgen/mlir/include/nkipy/Dialect/NkipyAttrs.td new file mode 100644 index 0000000..b99cd26 --- /dev/null +++ b/kernelgen/mlir/include/nkipy/Dialect/NkipyAttrs.td @@ -0,0 +1,33 @@ +#ifndef NKIPY_ATTRS +#define NKIPY_ATTRS + +include "nkipy/Dialect/NkipyDialect.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/AttrTypeBase.td" + +// https://mlir.llvm.org/docs/OpDefinitions/#enum-attributes +// Memory space enumeration for annotate operation. +// +// IMPORTANT: All values must be NON-ZERO. MemSpaceEnumAttr is built on +// IntegerAttr, and mlir::MemRefType::get() treats a memorySpace attribute +// whose integer value is 0 as "no memory space" — it is silently dropped +// during uniquing. That would make `root.setType(memref<..., Hbm>)` a no-op +// for HBM=0, which in turn breaks the fixpoint loop in +// AnnotateMemorySpace::inferHbmForCopySources (it keeps re-applying HBM +// every iteration and never converges). +// Values 1..5 keep all memspaces representable on memref types. +def Hbm: I32EnumAttrCase<"Hbm", 1>; +def Psum: I32EnumAttrCase<"Psum", 2>; +def Sbuf: I32EnumAttrCase<"Sbuf", 3>; +def SharedHbm: I32EnumAttrCase<"SharedHbm", 4>; +def Constant: I32EnumAttrCase<"Constant", 5>; // Scalar broadcast constants (marker, not real memory) + +def MemSpaceEnum: I32EnumAttr<"MemSpaceEnum", + "Memory space enumeration", + [Hbm, Psum, Sbuf, SharedHbm, Constant]> { + let cppNamespace = "mlir::nkipy"; + let stringToSymbolFnName = "ConvertToMemSpaceEnum"; + let symbolToStringFnName = "ConvertToMemSpaceString"; +} + +#endif // NKIPY_ATTRS diff --git a/kernelgen/mlir/include/nkipy/Dialect/NkipyDialect.h b/kernelgen/mlir/include/nkipy/Dialect/NkipyDialect.h new file mode 100644 index 0000000..799b06d --- /dev/null +++ b/kernelgen/mlir/include/nkipy/Dialect/NkipyDialect.h @@ -0,0 +1,9 @@ + +#ifndef NKIPY_DIALECT_H +#define NKIPY_DIALECT_H + +#include "mlir/IR/Dialect.h" + +#include "nkipy/Dialect/NkipyDialect.h.inc" + +#endif // NKIPY_DIALECT_H diff --git a/kernelgen/mlir/include/nkipy/Dialect/NkipyDialect.td b/kernelgen/mlir/include/nkipy/Dialect/NkipyDialect.td new file mode 100644 index 0000000..649faa6 --- /dev/null +++ b/kernelgen/mlir/include/nkipy/Dialect/NkipyDialect.td @@ -0,0 +1,29 @@ +#ifndef NKIPY_DIALECT +#define NKIPY_DIALECT + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" + + +def Nkipy_Dialect : Dialect { + let name = "nkipy"; + let summary = "A nkipy out-of-tree MLIR dialect."; + let description = [{ + This dialect is an example of an out-of-tree MLIR dialect designed to + illustrate the basic setup required to develop MLIR-based tools without + working inside of the LLVM source tree. + }]; + let useDefaultTypePrinterParser = 1; + let cppNamespace = "::mlir::nkipy"; +} + +class Nkipy_Op traits = []> : + Op; + +// class Nkipy_Type traits = []> : +// TypeDef; + +class Nkipy_Attr traits = []> : + AttrDef; + +#endif // NKIPY_DIALECT diff --git a/kernelgen/mlir/include/nkipy/Dialect/NkipyOps.h b/kernelgen/mlir/include/nkipy/Dialect/NkipyOps.h new file mode 100644 index 0000000..1d9b0ff --- /dev/null +++ b/kernelgen/mlir/include/nkipy/Dialect/NkipyOps.h @@ -0,0 +1,26 @@ +#ifndef NKIPY_OPS_H +#define NKIPY_OPS_H + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" + +#include "nkipy/Dialect/NkipyDialect.h" +#include "nkipy/Dialect/NkipyAttrs.h" + +#define GET_OP_CLASSES +#include "nkipy/Dialect/NkipyOps.h.inc" + +#endif // NKIPY_OPS_H \ No newline at end of file diff --git a/kernelgen/mlir/include/nkipy/Dialect/NkipyOps.td b/kernelgen/mlir/include/nkipy/Dialect/NkipyOps.td new file mode 100644 index 0000000..82a0477 --- /dev/null +++ b/kernelgen/mlir/include/nkipy/Dialect/NkipyOps.td @@ -0,0 +1,117 @@ +#ifndef NKIPY_OPS +#define NKIPY_OPS + +include "nkipy/Dialect/NkipyDialect.td" +include "nkipy/Dialect/NkipyAttrs.td" + +include "mlir/IR/EnumAttr.td" +include "mlir/IR/BuiltinTypes.td" +include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/FunctionInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/TilingInterface.td" +include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" +include "mlir/IR/SymbolInterfaces.td" + + +def Nkipy_AnnotateOp : Nkipy_Op<"annotate", [ + DeclareOpInterfaceMethods]> +{ + let summary = "annotate"; + let description = [{ + nkipy_d.annotate(target, mem_space=None, partition_dim=None, tile_size=None, reduction_tile=None) + + Annotate a Tensor with memory space and/or partition dimension information. + + Users can optionally specify the memory space where the tensor should reside + (Hbm, Psum, Sbuf, SharedHbm), the partition dimension, and tile sizes. + + Parameters + * target (Tensor) - The tensor to be annotated + * mem_space ({Hbm, Psum, Sbuf, SharedHbm}, optional) - The memory space for the tensor + * partition_dim (int, optional) - The dimension to be partitioned (must be 0 for NISA) + * tile_size (array, optional) - The tile sizes for each dimension. + Must have exactly the same number of elements as the tensor rank. + * reduction_tile (array, optional) - Tile sizes for reduction dimensions. + Used for contraction ops (e.g., matmul) where the iteration space has more + dimensions than the output tensor. For matmul, this is the K tile size. + }]; + + let arguments = (ins AnyTypeOf<[AnyTensor, AnyMemRef]>:$target, + OptionalAttr:$mem_space, + OptionalAttr:$partition_dim, + OptionalAttr:$tile_size, + OptionalAttr:$reduction_tile); + let results = (outs ); + // All optional attributes go in attr-dict so they can appear in any order + let assemblyFormat = [{ + `(` $target `:` type($target) `)` attr-dict + }]; +} + +def Nkipy_YieldOp : Nkipy_Op<"yield", [Pure, ReturnLike, Terminator]> +{ + let summary = "Yield values from an nkipy region"; + let description = [{ + Terminates the reference implementation region of an nkipy op. + Yields the computed result values back to the enclosing op. + }]; + + let arguments = (ins Variadic:$values); + let assemblyFormat = "attr-dict ($values^ `:` type($values))?"; +} + +def Nkipy_GatherOp : Nkipy_Op<"gather", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> +{ + let summary = "Gather rows from a tensor using an index tensor"; + let description = [{ + Gathers rows from `source` at positions given by `indices` along axis 0. + + Semantics: result[i, ...] = source[indices[i], ...] + + This op is emitted during tracing for: + - table[indices] (fancy indexing) + - np.take(table, indices, axis=0) + + It is lowered to nisa.dma_copy_indirect during the linalg-to-nisa pass. + + The source must be a 2D tensor/memref. The indices must be a 1D integer + tensor/memref. The result shape is [indices.shape[0], source.shape[1]]. + + The `output` operand is the DPS init (destination-passing style): an + uninitialized buffer of the same shape as the result. Bufferization + allocates the backing storage via tensor.empty -> memref.alloc. + + The op carries an optional reference implementation region containing + standard linalg/tensor ops. This region is used for LLVM CPU simulation + (inlined before JIT) and ignored during NISA lowering. + }]; + + let arguments = (ins + AnyTypeOf<[AnyTensor, AnyMemRef]>:$source, + AnyTypeOf<[AnyTensor, AnyMemRef]>:$indices, + AnyTypeOf<[AnyTensor, AnyMemRef]>:$output + ); + let results = (outs AnyTypeOf<[AnyTensor, AnyMemRef]>:$result); + let regions = (region AnyRegion:$reference_impl); + + let assemblyFormat = [{ + `(` $source `,` $indices `)` `outs` `(` $output `)` ($reference_impl^)? attr-dict `:` `(` type($source) `,` type($indices) `)` `outs` `(` type($output) `)` `->` type($result) + }]; +} + +#endif // NKIPY_OPS diff --git a/kernelgen/mlir/include/nkipy/TransformOps/CMakeLists.txt b/kernelgen/mlir/include/nkipy/TransformOps/CMakeLists.txt new file mode 100644 index 0000000..9e82e9e --- /dev/null +++ b/kernelgen/mlir/include/nkipy/TransformOps/CMakeLists.txt @@ -0,0 +1,19 @@ +set(LLVM_TARGET_DEFINITIONS NkipyTransformOps.td) +mlir_tablegen(NkipyTransformOps.h.inc -gen-op-decls) +mlir_tablegen(NkipyTransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRNkipyTransformOpsIncGen) +add_dependencies(mlir-headers MLIRNkipyTransformOpsIncGen) + +# Fix namespace issue in generated files (older MLIR uses mlir:: instead of ::mlir::) +add_custom_command( + TARGET MLIRNkipyTransformOpsIncGen POST_BUILD + COMMAND sed -i 's/std::optional/std::optional<::mlir::Attribute>/g' + ${CMAKE_CURRENT_BINARY_DIR}/NkipyTransformOps.h.inc + COMMAND sed -i 's/, mlir::Attribute/, ::mlir::Attribute/g' + ${CMAKE_CURRENT_BINARY_DIR}/NkipyTransformOps.h.inc + COMMAND sed -i 's/std::optional/std::optional<::mlir::Attribute>/g' + ${CMAKE_CURRENT_BINARY_DIR}/NkipyTransformOps.cpp.inc + COMMAND sed -i 's/, mlir::Attribute/, ::mlir::Attribute/g' + ${CMAKE_CURRENT_BINARY_DIR}/NkipyTransformOps.cpp.inc + COMMENT "Fixing namespace qualifications in generated transform ops" +) diff --git a/kernelgen/mlir/include/nkipy/TransformOps/NkipyTransformOps.h b/kernelgen/mlir/include/nkipy/TransformOps/NkipyTransformOps.h new file mode 100644 index 0000000..78efc0e --- /dev/null +++ b/kernelgen/mlir/include/nkipy/TransformOps/NkipyTransformOps.h @@ -0,0 +1,28 @@ +//===- NkipyTransformOps.h - Nkipy Transform Operations ---------*- C++ -*-===// +// +// Custom transform dialect operations for NKIPyKernelGen. +// +//===----------------------------------------------------------------------===// + +#ifndef NKIPY_TRANSFORMOPS_NKIPYTRANSFORMOPS_H +#define NKIPY_TRANSFORMOPS_NKIPYTRANSFORMOPS_H + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" + +namespace mlir { +namespace nkipy { + +/// Registers the Nkipy transform ops extension with the transform dialect. +void registerTransformDialectExtension(DialectRegistry ®istry); + +} // namespace nkipy +} // namespace mlir + +#define GET_OP_CLASSES +#include "nkipy/TransformOps/NkipyTransformOps.h.inc" + +#endif // NKIPY_TRANSFORMOPS_NKIPYTRANSFORMOPS_H diff --git a/kernelgen/mlir/include/nkipy/TransformOps/NkipyTransformOps.td b/kernelgen/mlir/include/nkipy/TransformOps/NkipyTransformOps.td new file mode 100644 index 0000000..9a827a9 --- /dev/null +++ b/kernelgen/mlir/include/nkipy/TransformOps/NkipyTransformOps.td @@ -0,0 +1,53 @@ +#ifndef NKIPY_TRANSFORM_OPS +#define NKIPY_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// PromoteTensorOp +//===----------------------------------------------------------------------===// + +def Nkipy_PromoteTensorOp : Op, + DeclareOpInterfaceMethods]> { + let summary = "Request a tensor value to live in a specific memory space " + "after bufferization"; + let description = [{ + Requests that a tensor value lives in a specific memory space for its + lifetime. This is achieved by allocating a new tensor in the desired + memory space with `bufferization.alloc_tensor` and optionally materializing + the source value into that allocation with + `bufferization.materialize_in_destination`. All uses of the original value + are then redirected to the promoted value. + + The generated code for promoting tensor value %0 resembles the following: + + %1 = bufferization.alloc_tensor() + { memory_space = memory_space } + // Note: the materialization is omitted if %0 is never read and is only + // written into (i.e., it behaves as a result tensor). + %2 = bufferization.materialize_in_destination %0 in %1 + // ... + + + Deallocation is not handled by this transform. + + Return modes: + - Produces a silenceable failure if the given handle does not point to + tensor-typed values. + - Succeeds otherwise and returns a handle to the promoted value(s), i.e., + the result of materialization if present and the allocation otherwise. + }]; + + let arguments = (ins TransformValueHandleTypeInterface:$tensor, + OptionalAttr:$memory_space); + let results = (outs TransformValueHandleTypeInterface:$promoted); + + let assemblyFormat = + "(`to` $memory_space^)? $tensor attr-dict `:` functional-type($tensor, $promoted)"; +} + +#endif // NKIPY_TRANSFORM_OPS diff --git a/kernelgen/mlir/include/nkipy/Transforms/CMakeLists.txt b/kernelgen/mlir/include/nkipy/Transforms/CMakeLists.txt new file mode 100644 index 0000000..febcbd6 --- /dev/null +++ b/kernelgen/mlir/include/nkipy/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls) +add_public_tablegen_target(MLIRNkipyPassesIncGen) \ No newline at end of file diff --git a/kernelgen/mlir/include/nkipy/Transforms/HardwareConstants.h b/kernelgen/mlir/include/nkipy/Transforms/HardwareConstants.h new file mode 100644 index 0000000..56fe196 --- /dev/null +++ b/kernelgen/mlir/include/nkipy/Transforms/HardwareConstants.h @@ -0,0 +1,36 @@ +//===- HardwareConstants.h - NeuronCore hardware limits ---------*- C++ -*-===// + +#ifndef NKIPY_TRANSFORMS_HARDWARECONSTANTS_H +#define NKIPY_TRANSFORMS_HARDWARECONSTANTS_H + +#include "llvm/ADT/StringRef.h" +#include +#include + +namespace mlir { +namespace nkipy { + +/// Maximum partition dimension size for NeuronCore hardware. +static constexpr int64_t MAX_PARTITION_DIM = 128; + +/// Maximum free dimension size for matmul operands. +static constexpr int64_t MAX_FREE_DIM_MATMUL = 512; + +/// Usable SBUF partition size in bytes for `target`. Values mirror +/// build-tools/j2gen/target_info.py in private-nki-staging. Kept local so +/// nkipy-opt does not need to link against NISA-internal target-info libs. +inline std::optional +getSbufPartitionUsableSize(llvm::StringRef target) { + if (target == "trn1") + return static_cast(192 * 1024 - 16384); + if (target == "trn2") + return static_cast(224 * 1024 - 16384 - 8); + if (target == "trn3") + return static_cast(256 * 1024 - 16384 - 8); + return std::nullopt; +} + +} // namespace nkipy +} // namespace mlir + +#endif // NKIPY_TRANSFORMS_HARDWARECONSTANTS_H diff --git a/kernelgen/mlir/include/nkipy/Transforms/IRHelpers.h b/kernelgen/mlir/include/nkipy/Transforms/IRHelpers.h new file mode 100644 index 0000000..b1aa829 --- /dev/null +++ b/kernelgen/mlir/include/nkipy/Transforms/IRHelpers.h @@ -0,0 +1,75 @@ +//===- IRHelpers.h - Shared IR utility functions ----------------*- C++ -*-===// +// +// Small helpers for querying MLIR values, shared across multiple passes. +// +//===----------------------------------------------------------------------===// + +#ifndef NKIPY_TRANSFORMS_IRHELPERS_H +#define NKIPY_TRANSFORMS_IRHELPERS_H + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "nkipy/Dialect/NkipyAttrs.h" +#include + +namespace mlir { +namespace nkipy { + +/// Return true if `op` is nested inside an nkipy dialect region (e.g., the +/// reference_impl body of nkipy.gather). These regions exist only for CPU +/// simulation and should be skipped by NISA-path passes. +inline bool isInsideNkipyRegion(Operation *op) { + for (Operation *parent = op->getParentOp(); parent; + parent = parent->getParentOp()) { + if (parent->getDialect() && + parent->getDialect()->getNamespace() == "nkipy") + return true; + } + return false; +} + +/// Get the constant integer value from a Value, or std::nullopt if not constant. +/// Works for arith.constant with IntegerAttr and arith.constant_index. +std::optional getConstantInt(Value v); + +/// Walk through view chains (subview, collapse_shape, expand_shape, etc.) +/// to find the base memref allocation. Uses ViewLikeOpInterface. +Value getBaseMemRef(Value v); + +/// Extract the nkipy memory space kind from a memref type, if present. +/// Returns std::nullopt if the type is not a memref or has no nkipy mem space. +std::optional getNkipyMemSpace(Type type); + +/// Convenience predicates on memref memory space. Return false if the type +/// is not a memref or has no nkipy memory space attribute. +inline bool isHbm(Type type) { + auto ms = getNkipyMemSpace(type); + return ms && *ms == nkipy::MemSpaceEnum::Hbm; +} +inline bool isSharedHbm(Type type) { + auto ms = getNkipyMemSpace(type); + return ms && *ms == nkipy::MemSpaceEnum::SharedHbm; +} +inline bool isSbuf(Type type) { + auto ms = getNkipyMemSpace(type); + return ms && *ms == nkipy::MemSpaceEnum::Sbuf; +} +inline bool isPsum(Type type) { + auto ms = getNkipyMemSpace(type); + return ms && *ms == nkipy::MemSpaceEnum::Psum; +} +inline bool isAnyHbm(Type type) { + auto ms = getNkipyMemSpace(type); + return ms && + (*ms == nkipy::MemSpaceEnum::Hbm || *ms == nkipy::MemSpaceEnum::SharedHbm); +} + +/// Walk up the parent chain from `op` until finding an ancestor that lives +/// directly in `block`. Returns nullptr if `op` is not nested under `block`. +Operation *getAncestorInBlock(Operation *op, Block *block); + +} // namespace nkipy +} // namespace mlir + +#endif // NKIPY_TRANSFORMS_IRHELPERS_H diff --git a/kernelgen/mlir/include/nkipy/Transforms/OpClassification.h b/kernelgen/mlir/include/nkipy/Transforms/OpClassification.h new file mode 100644 index 0000000..4d1178d --- /dev/null +++ b/kernelgen/mlir/include/nkipy/Transforms/OpClassification.h @@ -0,0 +1,48 @@ +//===- OpClassification.h - Shared op classification helpers ----*- C++ -*-===// +// +// Utility functions for classifying linalg ops (elementwise, reduction, matmul) +// shared across multiple passes. +// +//===----------------------------------------------------------------------===// + +#ifndef NKIPY_TRANSFORMS_OPCLASSIFICATION_H +#define NKIPY_TRANSFORMS_OPCLASSIFICATION_H + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "llvm/ADT/StringRef.h" + +namespace mlir { +namespace nkipy { + +/// Named unary elementwise ops (exp, log, tanh, negf, abs, ceil, floor, +/// sqrt, reciprocal, square, copy). +bool isNamedUnaryElementwiseOp(StringRef opName); + +/// Named binary elementwise ops (add, sub, mul, div). +bool isNamedBinaryElementwiseOp(StringRef opName); + +/// Any named elementwise op (unary or binary). +bool isNamedElementwiseOp(StringRef opName); + +/// linalg.generic with all-parallel iterator types. +bool isElementwiseGeneric(linalg::LinalgOp linalgOp); + +/// Named elementwise op OR all-parallel generic. +bool isElementwiseOp(linalg::LinalgOp linalgOp); + +/// linalg.generic with at least one reduction iterator type. +bool isReductionGeneric(linalg::LinalgOp linalgOp); + +/// linalg.matmul or linalg.batch_matmul (by op name). +bool isMatmulOp(StringRef opName); + +/// linalg.matmul or linalg.batch_matmul (from a LinalgOp). +bool isMatmulOp(linalg::LinalgOp linalgOp); + +/// Elementwise, reduction, or matmul — ops that receive layout annotations. +bool isAnnotatableOp(linalg::LinalgOp linalgOp); + +} // namespace nkipy +} // namespace mlir + +#endif // NKIPY_TRANSFORMS_OPCLASSIFICATION_H diff --git a/kernelgen/mlir/include/nkipy/Transforms/Passes.h b/kernelgen/mlir/include/nkipy/Transforms/Passes.h new file mode 100644 index 0000000..30a8c05 --- /dev/null +++ b/kernelgen/mlir/include/nkipy/Transforms/Passes.h @@ -0,0 +1,38 @@ + +#ifndef NKIPY_TRANSFORMS_PASSES_H +#define NKIPY_TRANSFORMS_PASSES_H + +#include "mlir/CAPI/IR.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" + +namespace mlir { +namespace nkipy { + +std::unique_ptr> createAnnotateMemorySpacePass(); +std::unique_ptr> createCanonicalizeReshapePass(); +std::unique_ptr> createCanonicalizePartitionDimPass(); +std::unique_ptr> createAssignLinalgOpIdsPass(); +std::unique_ptr> createInferLayoutPass(); +std::unique_ptr> createKnobDrivenTilingPass(); +std::unique_ptr> createApplyAndStripTransformsPass(); +std::unique_ptr> createEliminateUninitializedCopiesPass(); +std::unique_ptr> createEliminateSameMemSpaceCopyPass(); +std::unique_ptr> createInsertSpillReloadPass(); +std::unique_ptr> createInsertMemRefDeallocPass(); +std::unique_ptr> createCanonicalizeLoopStepPass(); +std::unique_ptr> createLegalizeLayoutPass(); +std::unique_ptr> createSimplifyLinalgPass(); +std::unique_ptr> createPrepareArithmeticPass(); +std::unique_ptr> createRemoveRedundantZeroFillPass(); +std::unique_ptr> createInlineNkipyReferencePass(); + + +/// Registers all transformation passes +void registerNkipyPasses(); + +} // namespace nkipy +} // namespace mlir + +#endif // NKIPY_TRANSFORMS_PASSES_H diff --git a/kernelgen/mlir/include/nkipy/Transforms/Passes.td b/kernelgen/mlir/include/nkipy/Transforms/Passes.td new file mode 100644 index 0000000..6cbde3d --- /dev/null +++ b/kernelgen/mlir/include/nkipy/Transforms/Passes.td @@ -0,0 +1,478 @@ + +#ifndef NKIPY_MLIR_PASSES +#define NKIPY_MLIR_PASSES + +include "mlir/Pass/PassBase.td" + +def AnnotateMemorySpace : Pass<"annotate-memory-space", "func::FuncOp"> { + let summary = "Annotate memrefs with memory space attributes from knobs"; + let description = [{ + This pass annotates memref types with memory space attributes based on + nkipy.annotate operations. It performs two main tasks: + + 1. Annotates function inputs/outputs with SharedHbm memory space + 2. Processes nkipy.annotate ops to apply memory spaces (Sbuf, Psum, etc.) + to internal memrefs + 3. Propagates memory space attributes through memref.subview operations + 4. Removes nkipy.annotate ops after processing + + This prepares memrefs for subsequent lowering to NISA dialect. + }]; + let constructor = "mlir::nkipy::createAnnotateMemorySpacePass()"; +} + + +def CanonicalizeReshape : Pass<"canonicalize-reshape", "func::FuncOp"> { + let summary = "Classify and canonicalize memref reshape ops by mem_space and partition_dim"; + let description = [{ + Post-bufferization pass that classifies expand_shape / collapse_shape ops + based on memory space and partition dim involvement: + + - HBM reshapes: always views (no partition concept) + - SBUF collapse (merge): always views (contiguous in memory) + - SBUF expand (split) of non-partition dims: views + - SBUF expand (split) of partition dim (dim 0): needs alloc+copy + (NISA has no modulo op), except trivial splits N->(N,1)/(1,N) + - Returned expand_shape views of func args: alloc+copy (NISA requires + separate output allocations) + + Runs after annotate-memory-space so all memrefs have explicit memory + spaces and partition_dim is guaranteed to be 0. + }]; + let constructor = "mlir::nkipy::createCanonicalizeReshapePass()"; +} + +def CanonicalizePartitionDim : Pass<"canonicalize-partition-dim", "func::FuncOp"> { + let summary = "Insert transposes to ensure partition_dim=0 everywhere"; + let description = [{ + This pass inserts linalg.transpose operations so that partition_dim=0 + holds for all annotated tensors. NISA hardware assumes dimension 0 is + always the partition dimension, and every downstream pass relies on this. + + For each connected component of elementwise ops sharing a non-zero + partition_dim, the pass: + 1. Inserts transposes on inputs entering the component + 2. Rewrites elementwise ops with permuted shapes + 3. Inserts inverse transposes on outputs leaving the component + 4. Updates nkipy.annotate ops: partition_dim -> 0, permutes tile_size + + Non-elementwise ops (matmul, reduction) with partition_dim != 0 are + rejected with an error — the user must fix their annotation. + + This pass runs after infer-layout (so partition_dim is propagated to all + ops in the chain) and before assign-linalg-op-ids (so new transposes + get op IDs for tiling). + }]; + let constructor = "mlir::nkipy::createCanonicalizePartitionDimPass()"; +} + +def AssignLinalgOpIds : Pass<"assign-linalg-op-ids", "func::FuncOp"> { + let summary = "Assign unique nkipy.op_id attributes to all linalg operations"; + let description = [{ + This pass assigns unique nkipy.op_id attributes to ALL linalg operations. + + The op_id enables per-instance matching during transform dialect application, + allowing different tile sizes to be applied to different instances of the + same linalg operation type. + + Example transformation: + Before: + %0 = linalg.matmul ins(%A, %B) outs(%C) -> tensor<256x256xf32> + %1 = linalg.add ins(%0, %D) outs(%E) -> tensor<256x256xf32> + + After: + %0 = linalg.matmul {nkipy.op_id = 0 : i64} ins(%A, %B) outs(%C) -> tensor<256x256xf32> + %1 = linalg.add {nkipy.op_id = 1 : i64} ins(%0, %D) outs(%E) -> tensor<256x256xf32> + }]; + let constructor = "mlir::nkipy::createAssignLinalgOpIdsPass()"; +} + +def InferLayout : Pass<"infer-layout", "func::FuncOp"> { + let summary = "Infer layout annotations (tiling and placement) for unannotated elementwise ops"; + let description = [{ + This pass infers layout information (tile_size and mem_space placement) for + elementwise operations that lack explicit annotations. It propagates these + layout attributes from annotated elementwise ops to adjacent unannotated + ones along the SSA use-def chain. + + Tiling and placement (e.g., SBUF vs PSUM) are the minimum layout info + required to make code runnable on hardware. Without them, ops remain + untiled and unplaced after bufferization, causing failures in linalg-to-nisa. + + This enables users to write compound expressions like: + gated = gate / (1.0 + exp(-gate)) * up + knob(gated, mem_space="Sbuf", tile_size=[128, 128]) + and have the intermediate ops (negate, exp, add, reciprocal, mul) + automatically inherit the tiling and placement from the annotated output. + + The pass walks backwards from each annotated elementwise op through its + input operands. For each predecessor that is also elementwise and has the + same output shape, it creates a new nkipy.annotate op with the same + tile_size and mem_space. + + Propagation rules: + - Only propagates to elementwise ops (named or generic with all-parallel) + - Only propagates when producer output shape matches consumer output shape + - Does not override existing annotations + - Only infers tile_size and mem_space (not other advanced knobs) + + This pass runs after prepare-arithmetic and assign-linalg-op-ids, but + before knob-driven-tiling. + }]; + let constructor = "mlir::nkipy::createInferLayoutPass()"; +} + +def KnobDrivenTiling : Pass<"knob-driven-tiling", "ModuleOp"> { + let summary = "Generate Transform dialect sequence for tiling linalg operations with blocking"; + let description = [{ + This pass generates Transform dialect operations to tile linalg operations using + tile sizes specified in nkipy.annotate knobs. It uses a two-level blocking strategy + for improved data reuse. + + The pass adds a `transform.named_sequence @__transform_main` to the module that + contains the tiling schedule. The transforms are NOT applied by this pass - use + the `apply-and-strip-transforms` pass afterwards to apply them and erase the + transform module. + + Blocking strategy (TILES_IN_BLOCK = 2 for both M and N): + - BLOCK_M = TILE_M × 2 + - BLOCK_N = TILE_N × 2 + + Generated loop structure for matmul: + for block_m in [0, M, BLOCK_M): // Level 1: Block-M loop + LOAD LHS to SBUF // Reused across all N-blocks + for block_n in [0, N, BLOCK_N): // Level 1: Block-N loop + LOAD RHS to SBUF // Reused within block + for tile_m in [0, BLOCK_M, TILE_M): // Level 2: Tile-M loop + for tile_n in [0, BLOCK_N, TILE_N): // Level 2: Tile-N loop + ALLOC psum buffer + for k in [0, K, TILE_K): // Level 2: K loop + psum += matmul(lhs_tile, rhs_tile) + STORE result tile + + Usage: + 1. Run `--knob-driven-tiling` to generate transform sequence + 2. Run `--apply-and-strip-transforms` to apply and strip the transforms + + Example: + nkipy-opt input.mlir --knob-driven-tiling --apply-and-strip-transforms + }]; + let constructor = "mlir::nkipy::createKnobDrivenTilingPass()"; +} + +def ApplyAndStripTransforms : Pass<"apply-and-strip-transforms", "ModuleOp"> { + let summary = "Run @__transform_main, then erase the transform module"; + let description = [{ + Fused replacement for `--transform-interpreter` followed by manual + stripping of the transform module. Applies the `@__transform_main` + named sequence to the enclosing ModuleOp (same semantics as upstream + `--transform-interpreter`), then erases every top-level op in the + transform dialect and clears the `transform.with_named_sequence` + module attribute. + + This lives in our pipeline in place of `--transform-interpreter`. + After this pass runs, the IR contains no transform-dialect ops, so + the downstream Python linalg→NISA phase can parse the module through + upstream MLIR bindings (which don't know about our custom transform + op `transform.nkipy.promote_tensor`). + + No-op if no `@__transform_main` entry point is present. + }]; + let constructor = "mlir::nkipy::createApplyAndStripTransformsPass()"; +} + +def EliminateUninitializedCopies : Pass<"eliminate-uninitialized-copies", "func::FuncOp"> { + let summary = "Eliminate copies from uninitialized allocations"; + let description = [{ + This pass eliminates memref.copy operations where the source is a freshly + allocated buffer that has never been written to (contains undefined values). + Such copies are effectively no-ops and can be safely eliminated. + + This commonly occurs after buffer promotion when the original tensor was + freshly allocated (e.g., for accumulator initialization). The promoted + buffer copies from the original uninitialized HBM allocation, which is + unnecessary since operations like matmul will zero their PSUM accumulator + via psum_zero_region anyway. + + Example transformation: + Before: + %uninit = memref.alloc() : memref<128x128xf32> // Never written to + memref.copy %uninit, %dst : memref<128x128xf32> to memref<128x128xf32> + linalg.matmul ins(%a, %b) outs(%dst) + + After: + %uninit = memref.alloc() : memref<128x128xf32> // Never written to + // Copy eliminated - %uninit was uninitialized + linalg.matmul ins(%a, %b) outs(%dst) + }]; + let constructor = "mlir::nkipy::createEliminateUninitializedCopiesPass()"; +} + +def EliminateSameMemSpaceCopy : Pass<"eliminate-same-memspace-copy", "func::FuncOp"> { + let summary = "Eliminate redundant copies between same memory space"; + let description = [{ + This pass eliminates redundant memref.copy operations where both source and + destination are in the same memory space (e.g., SBUF to SBUF). + + When the destination is a fresh allocation that is only used after the copy, + we can directly use the source instead, eliminating both the copy and the + unnecessary allocation. + + This is particularly useful after bufferization + memory space annotation, + where operand promotion may create intermediate SBUF buffers that are + immediately copied from another SBUF location. + + Example transformation: + Before: + %src = memref.subview ... : memref<128x128xf32, #sbuf> + %dst = memref.alloc() : memref<128x128xf32, #sbuf> + memref.copy %src, %dst : memref<128x128xf32, #sbuf> to memref<128x128xf32, #sbuf> + linalg.add ins(%dst, ...) outs(...) + + After: + %src = memref.subview ... : memref<128x128xf32, #sbuf> + linalg.add ins(%src, ...) outs(...) + + Note: PSUM to SBUF copies are NOT eliminated as they transfer data between + different memory spaces (accumulator to scratch buffer). + }]; + let constructor = "mlir::nkipy::createEliminateSameMemSpaceCopyPass()"; +} + +def InsertSpillReload : Pass<"insert-spill-reload", "func::FuncOp"> { + let summary = "Insert spill/reload operations for SBUF memory pressure"; + let description = [{ + This pass analyzes per-partition SBUF memory pressure and automatically + inserts spill (SBUF→HBM) and reload (HBM→SBUF) operations when capacity + is exceeded. + + SBUF has 128 partitions; usable size per partition is ~176 KB on TRN1, + ~208 KB on TRN2, and ~240 KB on TRN3. This pass runs after + legalize-layout, so SBUF memrefs are in physical layout + [partTile, numBlocks..., freeTile]. Per-partition size is computed as + total_size / shape[0] (partTile). + + When a kernel's live SBUF usage exceeds the per-partition capacity: + 1. Spill: Copy less-frequently-used data from SBUF → HBM + 2. Reload: Copy it back when needed HBM → SBUF + + Algorithm: + 1. Collect all SBUF allocations and compute their sizes + 2. Perform liveness analysis to find peak memory pressure points + 3. At high-pressure points, select victims to spill using a heuristic + (size-based, LRU, or Belady's MIN) + 4. Insert memref.copy operations for spill/reload + + These copies are later lowered to nisa.dma_copy in the linalg-to-nisa pass. + + Example transformation: + Before (3 × 1 MB SBUF allocations, exceed per-partition capacity): + %a = memref.alloc() : memref<512x512xf32, #sbuf> // 1MB + %b = memref.alloc() : memref<512x512xf32, #sbuf> // 1MB + %c = memref.alloc() : memref<512x512xf32, #sbuf> // 1MB + linalg.copy ins(%arg0) outs(%a) + linalg.exp ins(%a) outs(%b) + linalg.mul ins(%b, %b) outs(%c) + + After (with %a spilled after use): + %a = memref.alloc() : memref<512x512xf32, #sbuf> + %a_spill = memref.alloc() : memref<512x512xf32, #hbm> // Spill slot + %b = memref.alloc() : memref<512x512xf32, #sbuf> + %c = memref.alloc() : memref<512x512xf32, #sbuf> + linalg.copy ins(%arg0) outs(%a) + memref.copy %a, %a_spill // SPILL: %a to HBM + linalg.exp ins(%a) outs(%b) + linalg.mul ins(%b, %b) outs(%c) // Peak now 2MB (only %b, %c) + + Pipeline Position: + - Insert after legalize-layout + canonicalize (pass #17) + - Before insert-memref-dealloc (pass #19) + - SBUF allocs are in physical per-partition layout; sizes are per-partition + }]; + let constructor = "mlir::nkipy::createInsertSpillReloadPass()"; + let options = [ + Option<"target", "target", "std::string", "\"trn2\"", + "Target hardware (trn1, trn2, trn3). Used to query SBUF partition " + "usable size from NisaTargetInfo.">, + Option<"sbufCapacityOverride", "sbuf-capacity", "int64_t", "-1", + "Override SBUF capacity in bytes (-1 = query from target). " + "Useful for testing."> + ]; +} + +def InsertMemRefDealloc : Pass<"insert-memref-dealloc", "func::FuncOp"> { + let summary = "Insert memref.dealloc operations to mark allocation lifetime ends"; + let description = [{ + Analyzes lifetime of memref.alloc operations with NISA memory space and + inserts memref.dealloc at appropriate points. These are later lowered to + nisa.release by the linalg-to-nisa pass. + + - Only deallocates memrefs with NISA memory space (SBUF, PSUM, HBM) + - Skips SHAREDHBM (externally managed) + - Skips allocations that escape (returned from function) + - Uses scope-based release: dealloc at end of allocation's scope + }]; + let constructor = "mlir::nkipy::createInsertMemRefDeallocPass()"; +} + +def CanonicalizeLoopStep : Pass<"canonicalize-loop-step", "func::FuncOp"> { + let summary = "Canonicalize scf.for loop steps to 1"; + let description = [{ + This pass transforms scf.for loops to have step=1, simplifying index computation. + This is essential for NISA lowering because the NISA dialect cannot handle + arith.divsi and arith.remsi operations in index computations. + + The transformation: + Before: + scf.for %i = %lb to %ub step %step { + // uses %i + } + + After: + scf.for %i_new = 0 to (%ub - %lb) / %step step 1 { + %i = %lb + %i_new * %step + // uses %i (now computed from %i_new) + } + + For the common case where lb=0 and step is a constant power of 2: + Before: + scf.for %i = 0 to 256 step 128 { + // %slice = tensor.extract_slice %tensor[%i, %j] ... + } + + After: + scf.for %i_idx = 0 to 2 step 1 { + %i = %i_idx * 128 + // %slice = tensor.extract_slice %tensor[%i, %j] ... + } + + This enables subsequent passes to work with simple loop indices [0, 1, 2, ...] + rather than scaled indices [0, 128, 256, ...]. + }]; + let constructor = "mlir::nkipy::createCanonicalizeLoopStepPass()"; +} + +def LegalizeLayout : Pass<"legalize-layout", "func::FuncOp"> { + let summary = "Legalize SBUF tensor layouts to satisfy NKI hardware constraints"; + let description = [{ + This pass transforms SBUF tensor layouts to satisfy NKI hardware constraints + where the first dimension (partition dimension) must be ≤128. + + The pass identifies SBUF tensors via two mechanisms: + 1. bufferization.alloc_tensor with memory_space = #nisa.mem + 2. nkipy.annotate ops with mem_space = Sbuf (traced back to tensor.empty) + + For each SBUF tensor needing legalization, the pass: + 1. Computes the target 4D shape based on tile_size annotation + 2. Propagates the shape change through the entire use-def chain via BFS + 3. Updates scf.for init_args, block args, and results + 4. Transforms extract_slice to 4D indexing + collapse_shape + 5. Transforms insert_slice with expand_shape + 4D indexing + + Example transformation for tensor<512x512xf32> with tile_size=[128,128,128]: + Before: + %0 = tensor.empty() : tensor<512x512xf32> + %1 = scf.for ... iter_args(%arg = %0) { + %e = tensor.extract_slice %arg[%i*128, %j*128] [128, 128] + ... + %ins = tensor.insert_slice %tile into %arg[%i*128, %j*128] + scf.yield %ins + } + + After: + %0 = tensor.empty() : tensor<128x4x4x128xf32> + %1 = scf.for ... iter_args(%arg = %0) : tensor<128x4x4x128xf32> { + %e_4d = tensor.extract_slice %arg[0, %i, %j, 0] [128, 1, 1, 128] + %e = tensor.collapse_shape %e_4d [[0,1],[2,3]] + ... + %tile_4d = tensor.expand_shape %tile [[0,1],[2,3]] + %ins = tensor.insert_slice %tile_4d into %arg[0, %i, %j, 0] + scf.yield %ins + } + + Prerequisites: + - Runs after knob-driven-tiling + transform-interpreter + - Runs after canonicalize-loop-step (loops have step=1) + + This pass replaces the separate legalize-sbuf-outputs and legalize-sbuf-inputs + passes by performing both transformations atomically. + }]; + let constructor = "mlir::nkipy::createLegalizeLayoutPass()"; +} + +def SimplifyLinalg : Pass<"simplify-linalg", "func::FuncOp"> { + let summary = "Simplify linalg operations for NISA lowering"; + let description = [{ + Simplification pass that rewrites linalg operations before linalg-to-nisa. + + 1. Rewrites >2D SBUF linalg.transpose with unit dims to 2D. + NISA dma_transpose only supports [1,0] (2D) or [2,1,0] (3D full reverse). + Collapses SBUF allocs to 2D + expand_shape views, rewrites transpose or + emits copy (when non-unit dims keep order, i.e. just a reshape). + + 2. Converts trivial-broadcast linalg.generic ops to named linalg ops. + After tiling, broadcasts become same-shape operations. This converts + them to named ops (linalg.mul, etc.) so LinalgToNisa patterns can match. + }]; + let constructor = "mlir::nkipy::createSimplifyLinalgPass()"; +} + +def PrepareArithmetic : Pass<"prepare-arithmetic", "ModuleOp"> { + let summary = "Prepare arithmetic operations for NISA lowering"; + let description = [{ + This pass prepares arithmetic operations for NISA lowering by transforming + operations that don't have direct NISA equivalents. + + Transformations: + - linalg.div(A, B) -> linalg.mul(A, linalg.reciprocal(B)) + NISA's tensor_tensor_arith doesn't support DIVIDE, so we convert division + to multiplication by reciprocal. + + This pass runs before tiling so that the generated reciprocal operations + get tiled and bufferized normally. + }]; + let constructor = "mlir::nkipy::createPrepareArithmeticPass()"; +} + +def RemoveRedundantZeroFill : Pass<"remove-redundant-zero-fill", "ModuleOp"> { + let summary = "Remove linalg.fill ops with zero values when only used by matmul-like ops"; + let description = [{ + NISA matmul hardware initializes PSUM accumulators to zero automatically + (psum_zero_region), so linalg.fill ops that fill with zero and feed only + into matmul-like operations are redundant and can be removed. + + This runs on tensor IR before tiling/bufferization. Removing the fill early + prevents it from becoming a memref.copy chain that would generate an + unnecessary nisa.memset instruction downstream. + + Example: + Before: + %cst = arith.constant 0.0 : f32 + %empty = tensor.empty() : tensor<10x30xf32> + %filled = linalg.fill ins(%cst) outs(%empty) -> tensor<10x30xf32> + %result = linalg.matmul ins(%a, %b) outs(%filled) -> tensor<10x30xf32> + + After: + %empty = tensor.empty() : tensor<10x30xf32> + %result = linalg.matmul ins(%a, %b) outs(%empty) -> tensor<10x30xf32> + }]; + let constructor = "mlir::nkipy::createRemoveRedundantZeroFillPass()"; +} + +def InlineNkipyReference : Pass<"inline-nkipy-reference", "func::FuncOp"> { + let summary = "Inline reference_impl regions from nkipy ops"; + let description = [{ + nkipy ops like nkipy.gather carry an optional reference_impl region + containing standard linalg/tensor ops that express the same computation. + + This pass replaces each such op with the inlined contents of its region, + enabling LLVM CPU simulation on IR that would otherwise contain custom + dialect ops the LLVM JIT cannot lower. + + The pass is NOT part of the NISA compilation pipeline. It is used only + for the LLVM JIT verification path. + }]; + let constructor = "mlir::nkipy::createInlineNkipyReferencePass()"; +} + +#endif // NKIPY_MLIR_PASSES diff --git a/kernelgen/mlir/lib/Bindings/NkipyAttributes.cpp b/kernelgen/mlir/lib/Bindings/NkipyAttributes.cpp new file mode 100644 index 0000000..e617cea --- /dev/null +++ b/kernelgen/mlir/lib/Bindings/NkipyAttributes.cpp @@ -0,0 +1,67 @@ +#include "mlir-c/BuiltinAttributes.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/CAPI/IR.h" + +#include "nkipy-c/Dialect/NkipyAttributes.h" +#include "nkipy/Bindings/NkipyModule.h" + +#include +#include + +namespace nb = nanobind; +using namespace mlir::python::nanobind_adaptors; + +using namespace mlir; +using namespace mlir::python; + +namespace nanobind { +namespace detail { + +/// Casts object <-> MlirIntegerSet. +template <> struct type_caster { + NB_TYPE_CASTER(MlirIntegerSet, const_name("MlirIntegerSet")); + bool from_python(handle src, uint8_t, cleanup_list *) noexcept { + nb::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToIntegerSet(capsule.ptr()); + if (mlirIntegerSetIsNull(value)) { + return false; + } + return true; + } + static handle from_cpp(MlirIntegerSet v, rv_policy, cleanup_list *) noexcept { + nb::object capsule = + nb::steal(mlirPythonIntegerSetToCapsule(v)); + return nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("IntegerSet") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + } +}; + +} // namespace detail +} // namespace nanobind + +void mlir::python::populateNkipyAttributes(nb::module_ &m) { + mlir_attribute_subclass(m, "IntegerSetAttr", mlirAttributeIsAIntegerSet) + .def_classmethod( + "get", + [](nb::object cls, MlirIntegerSet IntegerSet, MlirContext ctx) { + return cls(mlirIntegerSetAttrGet(IntegerSet)); + }, + nb::arg("cls"), nb::arg("integer_set"), + nb::arg("context").none() = nb::none(), + "Gets an attribute wrapping an IntegerSet."); + + mlir_attribute_subclass(m, "MemSpaceEnum", mlirAttributeIsAMemSpace) + .def_classmethod( + "get", + [](nb::object cls, MlirAttribute space, MlirContext ctx) { + return cls(mlirMemSpaceGet(ctx, space)); + }, + nb::arg("cls"), nb::arg("space"), nb::arg("context").none() = nb::none(), + "Gets an attribute wrapping a memory space."); + +} diff --git a/kernelgen/mlir/lib/Bindings/NkipyModule.cpp b/kernelgen/mlir/lib/Bindings/NkipyModule.cpp new file mode 100644 index 0000000..c545647 --- /dev/null +++ b/kernelgen/mlir/lib/Bindings/NkipyModule.cpp @@ -0,0 +1,112 @@ +#include "nkipy/Transforms/Passes.h" +#include "nkipy-c/Dialect/Registration.h" +#include "nkipy-c/Dialect/Dialects.h" + +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/CAPI/IR.h" +#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" +#include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" + +// #include "mlir/Dialect/PDL/IR/PDLOps.h" +// #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" + +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + + +namespace nb = nanobind; +using namespace mlir::python::nanobind_adaptors; + +using namespace mlir; +using namespace mlir::python; +using namespace nkipy; + +// Helper to extract MlirContext from capsules with different package prefixes +static MlirContext extractContextFromCapsule(PyObject *capsule) { + MlirContext context; + + // Try standard mlir.ir.Context capsule name + void *ptr = PyCapsule_GetPointer(capsule, "mlir.ir.Context._CAPIPtr"); + if (ptr) { + context.ptr = ptr; + return context; + } + + // Clear error and try NKI's package prefix + PyErr_Clear(); + ptr = PyCapsule_GetPointer(capsule, "nki.compiler._internal.ir.Context._CAPIPtr"); + if (ptr) { + context.ptr = ptr; + return context; + } + + // Return null context if neither worked + context.ptr = nullptr; + return context; +} + +NB_MODULE(_nkipy, m) { + m.doc() = "Nkipy Python Native Extension"; + + // register passes + nkipyMlirRegisterAllPasses(); + + auto nkipy_m = m.def_submodule("nkipy"); + + nkipy_m.def( + "register_dialect", + [](nb::handle contextObj) { + // Get the _CAPIPtr attribute from any Context object (mlir or nki) + MlirContext context; + if (contextObj.is_none()) { + context = mlirContextCreate(); + } else { + // Get the _CAPIPtr capsule attribute + PyObject *capsule = PyObject_GetAttrString(contextObj.ptr(), MLIR_PYTHON_CAPI_PTR_ATTR); + if (!capsule) { + throw nb::type_error("Expected an MLIR Context object with _CAPIPtr attribute"); + } + // Try both mlir.ir.Context and nki.compiler._internal.ir.Context capsule names + context = extractContextFromCapsule(capsule); + Py_DECREF(capsule); + if (mlirContextIsNull(context)) { + throw nb::type_error("Invalid MLIR Context capsule - expected mlir.ir.Context or nki.compiler._internal.ir.Context"); + } + } + + // Register all dialects including Transform and extensions (nkipy transform ops) + nkipyMlirRegisterAllDialects(context); + + // Register and load the nkipy dialect + MlirDialectHandle nkipy = mlirGetDialectHandle__nkipy__(); + mlirDialectHandleRegisterDialect(nkipy, context); + mlirDialectHandleLoadDialect(nkipy, context); + }, + nb::arg("context") = nb::none(), + "Register the nkipy dialect with the given context"); + + // Apply transform to a design. + nkipy_m.def("apply_passes", [](MlirModule &mlir_mod) { + ModuleOp module = unwrap(mlir_mod); + + // Simplify the loop structure after the transform. + PassManager pm(module.getContext()); + pm.addNestedPass( + mlir::affine::createSimplifyAffineStructuresPass()); + pm.addPass(createCanonicalizerPass()); + if (failed(pm.run(module))) + throw nb::value_error("failed to apply the post-transform optimization"); + }); + + // Utility pass APIs - COMMENTED OUT: Requires MLIRNkipyPasses + // nkipy_m.def("memref_dce", &memRefDCE); + + // NOTE: Pass functions removed from Python bindings + // This pass requires NISA dialect which has global constructors that cause segfaults in Python + // Use nkipy-opt CLI tool instead via subprocess (see nkipy_kernelgen/transforms/nkipy_opt.py) + +} diff --git a/kernelgen/mlir/lib/CAPI/CMakeLists.txt b/kernelgen/mlir/lib/CAPI/CMakeLists.txt new file mode 100644 index 0000000..e6f347c --- /dev/null +++ b/kernelgen/mlir/lib/CAPI/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Dialect) \ No newline at end of file diff --git a/kernelgen/mlir/lib/CAPI/Dialect/CMakeLists.txt b/kernelgen/mlir/lib/CAPI/Dialect/CMakeLists.txt new file mode 100644 index 0000000..9206be3 --- /dev/null +++ b/kernelgen/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -0,0 +1,35 @@ + +add_mlir_public_c_api_library(MLIRNkipyCAPI + Dialects.cpp + NkipyAttributes.cpp + Registration.cpp + # ${PROJECT_SOURCE_DIR}/lib/Transforms/Passes.cpp # COMMENTED: Pass registration not needed, use nkipy-opt CLI + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir-c + + LINK_LIBS PUBLIC + MLIRIR + MLIRCAPIIR + MLIRSupport + MLIRNkipy + MLIRNkipyTransformOps + # MLIRNkipyPasses # COMMENTED: Contains NISA dependencies, use nkipy-opt CLI instead + MLIRFuncDialect + MLIRArithDialect + MLIRTensorDialect + MLIRAffineDialect + MLIRMathDialect + MLIRMemRefDialect + MLIRPDLDialect + MLIRTransformDialect + MLIRLinalgDialect + MLIRLinalgTransforms + MLIRTransforms + MLIRAffineTransforms + MLIRArithTransforms + MLIRMemRefTransforms + MLIRFuncTransforms + MLIRLLVMDialect + MLIRLLVMIRTransforms + ) diff --git a/kernelgen/mlir/lib/CAPI/Dialect/Dialects.cpp b/kernelgen/mlir/lib/CAPI/Dialect/Dialects.cpp new file mode 100644 index 0000000..9f84463 --- /dev/null +++ b/kernelgen/mlir/lib/CAPI/Dialect/Dialects.cpp @@ -0,0 +1,6 @@ +#include "nkipy-c/Dialect/Dialects.h" + +#include "nkipy/Dialect/NkipyDialect.h" +#include "mlir/CAPI/Registration.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Nkipy, nkipy, mlir::nkipy::NkipyDialect) \ No newline at end of file diff --git a/kernelgen/mlir/lib/CAPI/Dialect/NkipyAttributes.cpp b/kernelgen/mlir/lib/CAPI/Dialect/NkipyAttributes.cpp new file mode 100644 index 0000000..d685c1d --- /dev/null +++ b/kernelgen/mlir/lib/CAPI/Dialect/NkipyAttributes.cpp @@ -0,0 +1,20 @@ +#include "nkipy-c/Dialect/NkipyAttributes.h" +#include "nkipy/Dialect/NkipyAttrs.h" +#include "nkipy/Dialect/NkipyDialect.h" + +#include "mlir/CAPI/Registration.h" +#include "mlir/IR/Attributes.h" + +using namespace mlir; +using namespace nkipy; + +bool mlirAttributeIsAMemSpace(MlirAttribute attr) { + return mlir::isa(unwrap(attr)); +} + +MlirAttribute mlirMemSpaceGet(MlirContext ctx, MlirAttribute space) { + auto attr = llvm::cast(unwrap(space)); + MemSpaceEnum spaceEnum = + static_cast(attr.getInt()); + return wrap(MemSpaceEnumAttr::get(unwrap(ctx), spaceEnum)); +} diff --git a/kernelgen/mlir/lib/CAPI/Dialect/Registration.cpp b/kernelgen/mlir/lib/CAPI/Dialect/Registration.cpp new file mode 100644 index 0000000..cb392e8 --- /dev/null +++ b/kernelgen/mlir/lib/CAPI/Dialect/Registration.cpp @@ -0,0 +1,58 @@ + +#include "nkipy-c/Dialect/Registration.h" +// #include "nkipy/Transforms/Passes.h" // COMMENTED: Not using passes in CAPI +#include "nkipy/Dialect/NkipyDialect.h" +#include "nkipy/TransformOps/NkipyTransformOps.h" + +#include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/InitAllDialects.h" + +void nkipyMlirRegisterAllDialects(MlirContext context) { + mlir::DialectRegistry registry; + registry.insert(); + + // Register Transform dialect extensions (including nkipy transform ops) + mlir::nkipy::registerTransformDialectExtension(registry); + + unwrap(context)->appendDialectRegistry(registry); + unwrap(context)->loadAllAvailableDialects(); +} + +void nkipyMlirRegisterAllPasses() { + // General passes + mlir::registerTransformsPasses(); + + // Conversion passes + // Note: registerConversionPasses() registers ALL conversion passes, + // many of which require additional libraries we don't need. + // Comment out for now - add specific conversion passes as needed. + // mlir::registerConversionPasses(); + + // Dialect passes + mlir::affine::registerAffinePasses(); + mlir::arith::registerArithPasses(); + // mlir::LLVM::registerLLVMPasses(); // Requires NVVM and other GPU libraries + mlir::memref::registerMemRefPasses(); + mlir::registerLinalgPasses(); + + // mlir::nkipy::registerNkipyPasses(); +} diff --git a/kernelgen/mlir/lib/CMakeLists.txt b/kernelgen/mlir/lib/CMakeLists.txt new file mode 100644 index 0000000..7b77922 --- /dev/null +++ b/kernelgen/mlir/lib/CMakeLists.txt @@ -0,0 +1,4 @@ +add_subdirectory(CAPI) +add_subdirectory(Transforms) +add_subdirectory(Dialect) +add_subdirectory(TransformOps) \ No newline at end of file diff --git a/kernelgen/mlir/lib/Dialect/CMakeLists.txt b/kernelgen/mlir/lib/Dialect/CMakeLists.txt new file mode 100644 index 0000000..0123f5a --- /dev/null +++ b/kernelgen/mlir/lib/Dialect/CMakeLists.txt @@ -0,0 +1,20 @@ +add_mlir_dialect_library(MLIRNkipy + NkipyDialect.cpp + NkipyOps.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/nkipy + + DEPENDS + MLIRNkipyOpsIncGen + MLIRNkipyAttrsIncGen + MLIRNkipyEnumsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRBufferizationDialect + MLIRDestinationStyleOpInterface + MLIRLinalgDialect + MLIRTensorDialect + MLIRTilingInterface + ) diff --git a/kernelgen/mlir/lib/Dialect/NkipyDialect.cpp b/kernelgen/mlir/lib/Dialect/NkipyDialect.cpp new file mode 100644 index 0000000..5ead60b --- /dev/null +++ b/kernelgen/mlir/lib/Dialect/NkipyDialect.cpp @@ -0,0 +1,42 @@ +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/StringExtras.h" + +#include "llvm/ADT/TypeSwitch.h" + +#include "nkipy/Dialect/NkipyDialect.h" +#include "nkipy/Dialect/NkipyAttrs.h" +#include "nkipy/Dialect/NkipyOps.h" + +using namespace mlir; +using namespace mlir::nkipy; + +#include "nkipy/Dialect/NkipyDialect.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "nkipy/Dialect/NkipyAttrs.cpp.inc" + +#include "nkipy/Dialect/NkipyEnums.cpp.inc" + + +void NkipyDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "nkipy/Dialect/NkipyOps.cpp.inc" + >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "nkipy/Dialect/NkipyAttrs.cpp.inc" + >(); +} + +mlir::Type NkipyDialect::parseType(DialectAsmParser &parser) const { + parser.emitError(parser.getCurrentLocation(), + "nkipy dialect has no custom types"); + return mlir::Type(); +} + +void NkipyDialect::printType(Type type, DialectAsmPrinter &printer) const { + llvm_unreachable("nkipy dialect has no custom types"); +} \ No newline at end of file diff --git a/kernelgen/mlir/lib/Dialect/NkipyOps.cpp b/kernelgen/mlir/lib/Dialect/NkipyOps.cpp new file mode 100644 index 0000000..d9619b6 --- /dev/null +++ b/kernelgen/mlir/lib/Dialect/NkipyOps.cpp @@ -0,0 +1,256 @@ +#include "nkipy/Dialect/NkipyOps.h" +#include "nkipy/Dialect/NkipyDialect.h" + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/TilingInterface.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// AnnotateOp — BufferizableOpInterface +//===----------------------------------------------------------------------===// + +bool nkipy::AnnotateOp::bufferizesToMemoryRead( + OpOperand &opOperand, const bufferization::AnalysisState &state) { + return false; +} + +bool nkipy::AnnotateOp::bufferizesToMemoryWrite( + OpOperand &opOperand, const bufferization::AnalysisState &state) { + return false; +} + +bufferization::AliasingValueList nkipy::AnnotateOp::getAliasingValues( + OpOperand &opOperand, const bufferization::AnalysisState &state) { + return {}; +} + +LogicalResult nkipy::AnnotateOp::bufferize( + RewriterBase &rewriter, const bufferization::BufferizationOptions &options, + bufferization::BufferizationState &state) { + Value target = getTarget(); + // If the target is already a memref, nothing to do. + if (isa(target.getType())) + return success(); + + FailureOr buffer = + bufferization::getBuffer(rewriter, target, options, state); + if (failed(buffer)) + return failure(); + + rewriter.create( + getLoc(), *buffer, getMemSpaceAttr(), getPartitionDimAttr(), + getTileSizeAttr(), getReductionTileAttr()); + rewriter.eraseOp(getOperation()); + return success(); +} + +//===----------------------------------------------------------------------===// +// GatherOp — BufferizableOpInterface +//===----------------------------------------------------------------------===// + +bool nkipy::GatherOp::bufferizesToMemoryRead( + OpOperand &opOperand, const bufferization::AnalysisState &state) { + // Source and indices are read; output (DPS init) is only written. + return !isDpsInit(&opOperand); +} + +bool nkipy::GatherOp::bufferizesToMemoryWrite( + OpOperand &opOperand, const bufferization::AnalysisState &state) { + return isDpsInit(&opOperand); +} + +bufferization::AliasingValueList nkipy::GatherOp::getAliasingValues( + OpOperand &opOperand, const bufferization::AnalysisState &state) { + // DPS: the output buffer aliases the result. + if (isDpsInit(&opOperand)) + return {{getResult(), bufferization::BufferRelation::Equivalent}}; + return {}; +} + +LogicalResult nkipy::GatherOp::bufferize( + RewriterBase &rewriter, const bufferization::BufferizationOptions &options, + bufferization::BufferizationState &state) { + FailureOr srcBuf = + bufferization::getBuffer(rewriter, getSource(), options, state); + FailureOr idxBuf = + bufferization::getBuffer(rewriter, getIndices(), options, state); + FailureOr outBuf = + bufferization::getBuffer(rewriter, getOutput(), options, state); + if (failed(srcBuf) || failed(idxBuf) || failed(outBuf)) + return failure(); + + // Create memref-based gather that writes into the output buffer. + auto newGather = rewriter.create( + getLoc(), (*outBuf).getType(), *srcBuf, *idxBuf, *outBuf); + + // Move the reference_impl region. The body retains tensor-typed block args; + // inline-nkipy-reference handles the memref→tensor conversion at inline time. + rewriter.inlineRegionBefore(getReferenceImpl(), newGather.getReferenceImpl(), + newGather.getReferenceImpl().begin()); + + // DPS: replace the tensor result with the output buffer. + bufferization::replaceOpWithBufferizedValues(rewriter, getOperation(), + *outBuf); + return success(); +} + +//===----------------------------------------------------------------------===// +// GatherOp — DestinationStyleOpInterface +//===----------------------------------------------------------------------===// + +MutableOperandRange nkipy::GatherOp::getDpsInitsMutable() { + return getOutputMutable(); +} + +//===----------------------------------------------------------------------===// +// GatherOp — TilingInterface +//===----------------------------------------------------------------------===// + +SmallVector nkipy::GatherOp::getLoopIteratorTypes() { + auto resultType = cast(getResult().getType()); + return SmallVector( + resultType.getRank(), utils::IteratorType::parallel); +} + +SmallVector nkipy::GatherOp::getIterationDomain(OpBuilder &b) { + auto resultType = cast(getResult().getType()); + SmallVector domain; + for (int64_t i = 0; i < resultType.getRank(); ++i) { + domain.push_back(Range{b.getIndexAttr(0), + b.getIndexAttr(resultType.getDimSize(i)), + b.getIndexAttr(1)}); + } + return domain; +} + +FailureOr +nkipy::GatherOp::getTiledImplementation( + OpBuilder &b, ArrayRef offsets, + ArrayRef sizes) { + Location loc = getLoc(); + + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + int64_t rank = resultType.getRank(); + + // --- Slice indices: indices[i_off : i_off + tN] --- + SmallVector idxOffsets = {offsets[0]}; + SmallVector idxSizes = {sizes[0]}; + SmallVector idxStrides = {b.getIndexAttr(1)}; + Value indicesTile = b.create( + loc, getIndices(), idxOffsets, idxSizes, idxStrides); + + // --- Slice source: source[0:V, j_off : j_off+tH] --- + // All V rows are needed (indices can reference any row); only the + // embedding/free dimensions are sliced. + SmallVector srcOffsets(rank, b.getIndexAttr(0)); + SmallVector srcSizes; + SmallVector srcStrides(rank, b.getIndexAttr(1)); + srcSizes.push_back(b.getIndexAttr(sourceType.getDimSize(0))); // V (all rows) + for (int64_t d = 1; d < rank; ++d) { + srcOffsets[d] = offsets[d]; + srcSizes.push_back(sizes[d]); + } + Value sourceTile = b.create( + loc, getSource(), srcOffsets, srcSizes, srcStrides); + + // --- Slice output (DPS init): output[i_off:, j_off:] --- + SmallVector outOffsets(offsets.begin(), offsets.end()); + SmallVector outSizes(sizes.begin(), sizes.end()); + SmallVector outStrides(rank, b.getIndexAttr(1)); + Value outputTile = b.create( + loc, getOutput(), outOffsets, outSizes, outStrides); + + // --- Build tiled result type --- + SmallVector tiledShape; + for (auto s : sizes) { + if (auto attr = getConstantIntValue(s)) + tiledShape.push_back(*attr); + else + tiledShape.push_back(ShapedType::kDynamic); + } + auto tiledResultType = RankedTensorType::get( + tiledShape, resultType.getElementType()); + + // --- Create tiled gather --- + auto tiledGather = b.create( + loc, tiledResultType, sourceTile, indicesTile, outputTile); + + // --- Clone reference_impl into the tiled gather --- + // The reference body is used by InlineNkipyReference for LLVM CPU + // simulation. We clone the original region, adjusting block-arg types + // and fixing up shape-dependent ops (tensor.empty, linalg result types). + Region &origRegion = getReferenceImpl(); + if (!origRegion.empty()) { + Region &newRegion = tiledGather.getReferenceImpl(); + Block &origBlock = origRegion.front(); + + // Create new block with tiled operand types (source_tile, indices_tile). + Block *newBlock = new Block(); + newRegion.push_back(newBlock); + newBlock->addArgument(sourceTile.getType(), loc); + newBlock->addArgument(indicesTile.getType(), loc); + + // Map original block args → new block args. + IRMapping mapping; + mapping.map(origBlock.getArgument(0), newBlock->getArgument(0)); + mapping.map(origBlock.getArgument(1), newBlock->getArgument(1)); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(newBlock); + + for (Operation &op : origBlock) { + if (isa(&op)) { + // Replace tensor.empty with the tiled output shape. + auto newEmpty = b.create( + loc, tiledResultType.getShape(), + tiledResultType.getElementType()); + mapping.map(op.getResult(0), newEmpty.getResult()); + } else { + Operation *cloned = b.clone(op, mapping); + // Fix result types for DPS linalg ops: the cloned result type is + // still the original shape, but the (remapped) init operand has the + // tiled shape. Align result types with init types. + if (auto linalgOp = dyn_cast(cloned)) { + auto inits = linalgOp.getDpsInits(); + for (unsigned i = 0; i < cloned->getNumResults(); ++i) { + if (i < inits.size()) + cloned->getResult(i).setType(inits[i].getType()); + } + } + } + } + } + + return TilingResult{{tiledGather.getOperation()}, + {tiledGather.getResult()}, + {}}; +} + +LogicalResult nkipy::GatherOp::getResultTilePosition( + OpBuilder &b, unsigned resultNumber, + ArrayRef offsets, ArrayRef sizes, + SmallVector &resultOffsets, + SmallVector &resultSizes) { + if (resultNumber != 0) + return failure(); + resultOffsets.assign(offsets.begin(), offsets.end()); + resultSizes.assign(sizes.begin(), sizes.end()); + return success(); +} + +#define GET_OP_CLASSES +#include "nkipy/Dialect/NkipyOps.cpp.inc" \ No newline at end of file diff --git a/kernelgen/mlir/lib/TransformOps/CMakeLists.txt b/kernelgen/mlir/lib/TransformOps/CMakeLists.txt new file mode 100644 index 0000000..b6e0fc7 --- /dev/null +++ b/kernelgen/mlir/lib/TransformOps/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_library(MLIRNkipyTransformOps + NkipyTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/nkipy/TransformOps + + DEPENDS + MLIRNkipyTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRBufferizationDialect + MLIRLinalgDialect + MLIRTensorDialect + MLIRTransformDialect + MLIRTransformDialectInterfaces + MLIRIR +) diff --git a/kernelgen/mlir/lib/TransformOps/NkipyTransformOps.cpp b/kernelgen/mlir/lib/TransformOps/NkipyTransformOps.cpp new file mode 100644 index 0000000..542a426 --- /dev/null +++ b/kernelgen/mlir/lib/TransformOps/NkipyTransformOps.cpp @@ -0,0 +1,145 @@ +//===- NkipyTransformOps.cpp - Nkipy Transform Operations -----------------===// +// +// Implementation of custom transform dialect operations for NKIPyKernelGen. +// +//===----------------------------------------------------------------------===// + +#include "nkipy/TransformOps/NkipyTransformOps.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +using namespace mlir; + +#define GET_OP_CLASSES +#include "nkipy/TransformOps/NkipyTransformOps.cpp.inc" + +namespace { + +//===----------------------------------------------------------------------===// +// PromoteTensorOp helpers +//===----------------------------------------------------------------------===// + +/// Return true if the operand may be read from by its owner. This is currently +/// very conservative and only looks inside linalg operations to prevent +/// unintentional data loss. +static bool mayBeRead(OpOperand &operand) { + auto linalgOp = dyn_cast(operand.getOwner()); + + // Be conservative about ops we cannot analyze deeper. + if (!linalgOp) + return true; + + // Look inside linalg ops. + Value blockArgument = linalgOp.getMatchingBlockArgument(&operand); + return !blockArgument.use_empty(); +} + +/// Return true if the value may be read through any of its uses. +static bool mayBeRead(Value value) { + // If the value has a reference semantics, it + // may be read through any alias... + if (!isa(value.getType())) + return true; + return llvm::any_of(value.getUses(), + static_cast(mayBeRead)); +} + +} // namespace + +//===----------------------------------------------------------------------===// +// PromoteTensorOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + SmallVector promoted; + for (Value tensor : state.getPayloadValues(getTensor())) { + auto type = dyn_cast(tensor.getType()); + if (!type) { + return emitSilenceableError() << "non-tensor type: " << tensor; + } + + Operation *definingOp = tensor.getDefiningOp(); + if (definingOp) + rewriter.setInsertionPointAfter(definingOp); + else + rewriter.setInsertionPointToStart(cast(tensor).getOwner()); + + // Check this before we emit operations using this value. + bool needsMaterialization = mayBeRead(tensor); + + SmallVector dynamicDims; + llvm::SmallPtrSet preservedOps; + for (auto [pos, dim] : llvm::enumerate(type.getShape())) { + if (!ShapedType::isDynamic(dim)) + continue; + Value cst = + rewriter.create(tensor.getLoc(), static_cast(pos)); + auto dimOp = + rewriter.create(tensor.getLoc(), tensor, cst); + preservedOps.insert(dimOp); + dynamicDims.push_back(dimOp); + } + auto allocation = rewriter.create( + tensor.getLoc(), type, dynamicDims); + // Set memory space if provided. + if (getMemorySpaceAttr()) + allocation.setMemorySpaceAttr(getMemorySpaceAttr()); + Value allocated = allocation; + + // Only insert a materialization (typically bufferizes to a copy) when the + // value may be read from. + if (needsMaterialization) { + auto copy = rewriter.create( + tensor.getLoc(), tensor, allocated); + preservedOps.insert(copy); + promoted.push_back(copy.getResult()); + } else { + promoted.push_back(allocated); + } + rewriter.replaceAllUsesExcept(tensor, promoted.back(), preservedOps); + } + results.setValues(cast(getPromoted()), promoted); + return DiagnosedSilenceableFailure::success(); +} + +void transform::PromoteTensorOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getTensorMutable(), effects); + transform::producesHandle(getOperation()->getOpResults(), effects); + transform::modifiesPayload(effects); +} + +//===----------------------------------------------------------------------===// +// Transform dialect extension registration +//===----------------------------------------------------------------------===// + +namespace { + +class NkipyTransformDialectExtension + : public transform::TransformDialectExtension< + NkipyTransformDialectExtension> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NkipyTransformDialectExtension) + + NkipyTransformDialectExtension() { + registerTransformOps< +#define GET_OP_LIST +#include "nkipy/TransformOps/NkipyTransformOps.cpp.inc" + >(); + } +}; + +} // namespace + +void mlir::nkipy::registerTransformDialectExtension(DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/kernelgen/mlir/lib/Transforms/AnnotateMemorySpace.cpp b/kernelgen/mlir/lib/Transforms/AnnotateMemorySpace.cpp new file mode 100644 index 0000000..17e8f7a --- /dev/null +++ b/kernelgen/mlir/lib/Transforms/AnnotateMemorySpace.cpp @@ -0,0 +1,285 @@ +#include "PassGen.h" +#include "nkipy/Transforms/Passes.h" +#include "nkipy/Transforms/IRHelpers.h" +#include "nkipy/Dialect/NkipyAttrs.h" +#include "nkipy/Dialect/NkipyDialect.h" +#include "nkipy/Dialect/NkipyOps.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypes.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#include + +#define DEBUG_TYPE "annotate-memory-space" + +using namespace mlir; +using namespace nkipy; + +namespace mlir { +namespace nkipy { + +namespace { + +/// Verify that a memref annotated with CONSTANT was filled with a scalar constant. +/// Returns true if valid (is a scalar broadcast), false otherwise. +/// The pattern we look for is: +/// %buffer = memref.alloc() +/// %cst = arith.constant +/// linalg.fill ins(%cst) outs(%buffer) +/// nkipy.annotate(%buffer, CONSTANT) +static bool isScalarBroadcast(Value memref) { + Value base = getBaseMemRef(memref); + for (Operation *user : base.getUsers()) { + if (auto fillOp = dyn_cast(user)) { + if (fillOp.getOutputs()[0] == base && + fillOp.getInputs()[0].getDefiningOp()) + return true; + } + } + return false; +} + +struct NkipyAnnotateMemorySpacePass + : public AnnotateMemorySpaceBase { + + /// Rewrite a list of types, adding memSpaceAttr to any bare MemRefType. + /// Returns true if any type was changed. + static bool addMemSpaceToTypes(ArrayRef types, Attribute memSpaceAttr, + SmallVectorImpl &out) { + bool changed = false; + for (Type ty : types) { + auto memrefType = dyn_cast(ty); + if (!memrefType) { + out.push_back(ty); + continue; + } + assert(!memrefType.getMemorySpace() && + "memrefs should not have memory space before this pass"); + out.push_back(MemRefType::get(memrefType.getShape(), + memrefType.getElementType(), + memrefType.getLayout(), memSpaceAttr)); + changed = true; + } + return changed; + } + + /// Annotate function inputs and outputs with SharedHbm memory space. + void annotateInputOutput(func::FuncOp func) { + MLIRContext *ctx = func.getContext(); + FunctionType oldType = func.getFunctionType(); + auto sharedHbm = + nkipy::MemSpaceEnumAttr::get(ctx, nkipy::MemSpaceEnum::SharedHbm); + + SmallVector newInputs, newResults; + bool changed = addMemSpaceToTypes(oldType.getInputs(), sharedHbm, newInputs); + changed |= addMemSpaceToTypes(oldType.getResults(), sharedHbm, newResults); + if (!changed) + return; + + func.setType(FunctionType::get(ctx, newInputs, newResults)); + if (func.isDeclaration()) + return; + + // Update block arg types and return operand types. + Block &entry = func.getBody().front(); + for (auto [i, arg] : llvm::enumerate(entry.getArguments())) + arg.setType(newInputs[i]); + auto returnOp = cast(entry.getTerminator()); + for (auto [i, operand] : llvm::enumerate(returnOp.getOperands())) + if (i < newResults.size()) + operand.setType(newResults[i]); + } + + /// Apply nkipy.annotate mem_space to memref types, then erase annotations. + void applyAnnotations(func::FuncOp func) { + MLIRContext *ctx = func.getContext(); + + SmallVector annotateOps; + func.walk([&](nkipy::AnnotateOp op) { annotateOps.push_back(op); }); + + for (auto annotateOp : annotateOps) { + Value target = annotateOp.getTarget(); + auto memSpace = annotateOp.getMemSpace(); + + if (memSpace) { + auto memrefType = dyn_cast(target.getType()); + if (!memrefType) { + LLVM_DEBUG(llvm::dbgs() << "Warning: annotate target is not a memref\n"); + annotateOp.erase(); + continue; + } + + // CONSTANT is a marker; verify it matches a scalar broadcast pattern. + if (*memSpace == nkipy::MemSpaceEnum::Constant && + !isScalarBroadcast(target)) { + annotateOp.emitError() + << "CONSTANT memory space requires a scalar broadcast " + << "(linalg.fill with arith.constant)"; + signalPassFailure(); + return; + } + Attribute memSpaceAttr = + nkipy::MemSpaceEnumAttr::get(ctx, *memSpace); + + target.setType(MemRefType::get(memrefType.getShape(), + memrefType.getElementType(), + memrefType.getLayout(), memSpaceAttr)); + } + annotateOp.erase(); + } + } + + /// Propagate memory space from one value to another. + /// Returns true if the target type was changed. + static bool propagateMemSpace(Value from, Value to) { + auto fromType = cast(from.getType()); + auto toType = cast(to.getType()); + if (fromType.getMemorySpace() && !toType.getMemorySpace()) { + to.setType(MemRefType::get(toType.getShape(), toType.getElementType(), + toType.getLayout(), fromType.getMemorySpace())); + return true; + } + return false; + } + + /// Resolve memory space conflicts on SubView ops by inserting copies. + /// When a subview's source and result have different memory spaces + /// (e.g., SBUF source but SharedHbm result from return type annotation), + /// replace uses of the subview with a new alloc + copy. + void resolveSubViewConflicts(func::FuncOp func) { + MLIRContext *ctx = func.getContext(); + SmallVector conflictOps; + + func.walk([&](memref::SubViewOp op) { + auto srcType = cast(op.getSource().getType()); + auto dstType = cast(op.getResult().getType()); + if (srcType.getMemorySpace() && dstType.getMemorySpace() && + srcType.getMemorySpace() != dstType.getMemorySpace()) { + conflictOps.push_back(op); + } + }); + + if (conflictOps.empty()) + return; + + for (auto op : conflictOps) { + auto srcType = cast(op.getSource().getType()); + auto dstType = cast(op.getResult().getType()); + + LLVM_DEBUG(llvm::dbgs() + << "Resolving SubView memory space conflict: source " + << srcType.getMemorySpace() << " vs result " + << dstType.getMemorySpace() << "\n"); + + // Fix the subview to match its source memory space. + auto fixedSubviewType = MemRefType::get( + dstType.getShape(), dstType.getElementType(), + dstType.getLayout(), srcType.getMemorySpace()); + op.getResult().setType(fixedSubviewType); + + // Insert alloc + copy after the subview to materialize in target space. + OpBuilder builder(op->getBlock(), std::next(op->getIterator())); + auto allocType = MemRefType::get( + dstType.getShape(), dstType.getElementType(), + MemRefLayoutAttrInterface(), dstType.getMemorySpace()); + Value alloc = builder.create(op.getLoc(), allocType); + builder.create(op.getLoc(), op.getResult(), alloc); + + // Replace all uses of the subview with the alloc, except the copy. + op.getResult().replaceAllUsesExcept(alloc, + alloc.getDefiningOp()->getNextNode()); + } + + // Update function signature to match actual return operand types. + auto returnOp = cast( + func.getBody().front().getTerminator()); + FunctionType funcType = func.getFunctionType(); + SmallVector newResultTypes; + for (Value operand : returnOp.getOperands()) + newResultTypes.push_back(operand.getType()); + func.setType(FunctionType::get(ctx, funcType.getInputs(), newResultTypes)); + } + + /// Infer HBM memory space for unannotated allocs that feed into copies + /// to on-chip memory (e.g., gather output allocs that are DMA-copied + /// into SBUF). These are internal intermediates, not user-facing. + /// Returns true if any type was changed. + bool inferHbmForCopySources(func::FuncOp func) { + bool changed = false; + func.walk([&](memref::CopyOp op) { + auto srcType = cast(op.getSource().getType()); + auto dstType = cast(op.getTarget().getType()); + if (srcType.getMemorySpace() || !dstType.getMemorySpace()) + return; + // Walk backward through subview chain to find the root alloc. + Value root = op.getSource(); + while (auto sv = root.getDefiningOp()) + root = sv.getSource(); + auto rootType = cast(root.getType()); + if (rootType.getMemorySpace()) + return; + auto hbm = nkipy::MemSpaceEnumAttr::get( + func.getContext(), nkipy::MemSpaceEnum::Hbm); + root.setType(MemRefType::get(rootType.getShape(), + rootType.getElementType(), + rootType.getLayout(), hbm)); + changed = true; + }); + return changed; + } + + /// Propagate memory spaces through view-like ops until convergence. + /// Includes HBM inference for copy sources, which must interleave with + /// propagation (view ops need their source memspace propagated first). + void propagateMemSpaces(func::FuncOp func) { + bool changed = true; + while (changed) { + changed = false; + func.walk([&](Operation *op) { + TypeSwitch(op) + .Case([&](auto op) { + changed |= propagateMemSpace(op.getSource(), op.getResult()); + changed |= propagateMemSpace(op.getResult(), op.getSource()); + }) + .Case([&](auto op) { + changed |= propagateMemSpace(op.getSrc(), op.getResult()); + changed |= propagateMemSpace(op.getResult(), op.getSrc()); + }) + .Case([&](auto op) { + changed |= propagateMemSpace(op.getSource(), op.getResult()); + changed |= propagateMemSpace(op.getResult(), op.getSource()); + }) + .Case([&](auto op) { + changed |= propagateMemSpace(op.getSource(), op.getResult()); + }); + }); + changed |= inferHbmForCopySources(func); + } + } + + void runOnOperation() override { + func::FuncOp func = getOperation(); + LLVM_DEBUG(llvm::dbgs() << "Processing function: " << func.getName() << "\n"); + + annotateInputOutput(func); + if (func.isDeclaration()) + return; + + applyAnnotations(func); + resolveSubViewConflicts(func); + propagateMemSpaces(func); + } +}; +} // namespace + +std::unique_ptr> createAnnotateMemorySpacePass() { + return std::make_unique(); +} +} // namespace nkipy +} // namespace mlir diff --git a/kernelgen/mlir/lib/Transforms/ApplyAndStripTransforms.cpp b/kernelgen/mlir/lib/Transforms/ApplyAndStripTransforms.cpp new file mode 100644 index 0000000..d61c010 --- /dev/null +++ b/kernelgen/mlir/lib/Transforms/ApplyAndStripTransforms.cpp @@ -0,0 +1,92 @@ +//===- ApplyAndStripTransforms.cpp - Run + strip transform sequence -------===// +// +// This pass runs the transform dialect's @__transform_main named sequence on +// the enclosing module (same semantics as upstream --transform-interpreter) +// and then erases the transform module (the NamedSequenceOp and the +// `transform.with_named_sequence` module attribute). +// +// Motivation: after tiling, nothing downstream consumes the transform block, +// but it stays in the IR all the way until prepare-for-nki. The Python-side +// linalg→NISA phase needs to parse the IR through upstream MLIR bindings, +// which don't know about our custom transform op `transform.nkipy.promote_tensor` +// (it lives in our own dialect). Stripping the transform module right after +// interpretation gives the Python phase clean IR. +// +//===----------------------------------------------------------------------===// + +#include "nkipy/Transforms/Passes.h" + +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +namespace { + +struct ApplyAndStripTransformsPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ApplyAndStripTransformsPass) + + StringRef getArgument() const final { return "apply-and-strip-transforms"; } + + StringRef getDescription() const final { + return "Apply @__transform_main named sequence, then erase the transform " + "module (transform ops + with_named_sequence attr)"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + + // Locate the entry point sequence. If no transform module exists (e.g., + // kernel had no tilable ops), nothing to do. + transform::TransformOpInterface entry = + transform::detail::findTransformEntryPoint( + module, /*module=*/ModuleOp(), + transform::TransformDialect::kTransformEntryPointSymbolName); + if (entry) { + if (failed(transform::applyTransformNamedSequence( + module, entry, /*transformModule=*/ModuleOp(), + transform::TransformOptions()))) { + module->emitError() + << "[apply-and-strip-transforms] transform interpretation failed"; + return signalPassFailure(); + } + } + + // Erase all top-level transform ops (NamedSequenceOps etc.) regardless of + // whether we found an entry point — if an empty sequence got generated, it + // still needs to go. + SmallVector toErase; + for (Operation &op : module.getBody()->getOperations()) { + if (isa(op.getDialect())) + toErase.push_back(&op); + } + for (Operation *op : toErase) + op->erase(); + + if (module->hasAttr("transform.with_named_sequence")) + module->removeAttr("transform.with_named_sequence"); + } +}; + +} // namespace + +namespace mlir { +namespace nkipy { + +std::unique_ptr> createApplyAndStripTransformsPass() { + return std::make_unique(); +} + +} // namespace nkipy +} // namespace mlir diff --git a/kernelgen/mlir/lib/Transforms/AssignLinalgOpIds.cpp b/kernelgen/mlir/lib/Transforms/AssignLinalgOpIds.cpp new file mode 100644 index 0000000..ce75deb --- /dev/null +++ b/kernelgen/mlir/lib/Transforms/AssignLinalgOpIds.cpp @@ -0,0 +1,74 @@ +//===- AssignLinalgOpIds.cpp - Assign unique IDs to linalg ops ------------===// +// +// This pass assigns unique nkipy.op_id attributes to ALL linalg operations. +// +// The op_id enables per-instance matching during transform dialect application, +// allowing different tile sizes to be applied to different instances of the +// same linalg operation type. +// +//===----------------------------------------------------------------------===// + +#include "PassGen.h" +#include "nkipy/Transforms/Passes.h" +#include "nkipy/Transforms/IRHelpers.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/IR/Builders.h" + +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace nkipy; + +namespace mlir { +namespace nkipy { + +namespace { + +struct NkipyAssignLinalgOpIdsPass + : public AssignLinalgOpIdsBase { + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + func::FuncOp func = getOperation(); + + llvm::errs() << "[AssignLinalgOpIds] Processing function: " + << func.getName() << "\n"; + + // Counter for unique op_id + int64_t opIdCounter = 0; + + // Walk all linalg ops and assign unique IDs. + // Skip ops inside nkipy regions (e.g., reference_impl bodies) — these + // exist only for CPU simulation and must not participate in tiling. + func.walk([&](linalg::LinalgOp linalgOp) { + if (isInsideNkipyRegion(linalgOp)) + return; + // Only add op_id if it doesn't already have one + if (!linalgOp->hasAttr("nkipy.op_id")) { + OpBuilder builder(linalgOp); + linalgOp->setAttr("nkipy.op_id", + builder.getI64IntegerAttr(opIdCounter++)); + llvm::errs() << "[AssignLinalgOpIds] Added op_id=" + << (opIdCounter - 1) << " to " + << linalgOp->getName() << "\n"; + } + }); + + llvm::errs() << "[AssignLinalgOpIds] Assigned " << opIdCounter + << " unique op_ids\n"; + } +}; + +} // namespace + +std::unique_ptr> createAssignLinalgOpIdsPass() { + return std::make_unique(); +} + +} // namespace nkipy +} // namespace mlir diff --git a/kernelgen/mlir/lib/Transforms/CMakeLists.txt b/kernelgen/mlir/lib/Transforms/CMakeLists.txt new file mode 100644 index 0000000..83ff473 --- /dev/null +++ b/kernelgen/mlir/lib/Transforms/CMakeLists.txt @@ -0,0 +1,43 @@ +add_mlir_library(MLIRNkipyPasses + Passes.cpp + AnnotateMemorySpace.cpp + CanonicalizeReshape.cpp + CanonicalizePartitionDim.cpp + AssignLinalgOpIds.cpp + ApplyAndStripTransforms.cpp + InferLayout.cpp + KnobDrivenTiling.cpp + EliminateUninitializedCopies.cpp + EliminateSameMemSpaceCopy.cpp + InsertSpillReload.cpp + InsertMemRefDealloc.cpp + CanonicalizeLoopStep.cpp + LegalizeLayout.cpp + SimplifyLinalg.cpp + PrepareArithmetic.cpp + RemoveRedundantZeroFill.cpp + InlineNkipyReference.cpp + OpClassification.cpp + IRHelpers.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/nkipy + + DEPENDS + MLIRNkipyPassesIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRMemRefDialect + MLIRAffineDialect + MLIRFuncDialect + MLIRDestinationStyleOpInterface + MLIRTilingInterface + MLIRTransformDialect + MLIRTransformDialectTransforms + MLIRNkipyTransformOps +) diff --git a/kernelgen/mlir/lib/Transforms/CanonicalizeLoopStep.cpp b/kernelgen/mlir/lib/Transforms/CanonicalizeLoopStep.cpp new file mode 100644 index 0000000..1c2a073 --- /dev/null +++ b/kernelgen/mlir/lib/Transforms/CanonicalizeLoopStep.cpp @@ -0,0 +1,117 @@ +//===- CanonicalizeLoopStep.cpp - Canonicalize scf.for steps to 1 ---------===// +// +// This pass transforms scf.for loops to have step=1, which simplifies index +// computation for subsequent passes and is required for NISA lowering. +// +// The transformation: +// scf.for %i = %lb to %ub step %step { ... uses %i ... } +// => +// scf.for %i_idx = 0 to (%ub - %lb) / %step step 1 { +// %i = %lb + %i_idx * %step +// ... uses %i ... +// } +// +// Loops are processed in post-order (innermost first) to handle nesting. +// +//===----------------------------------------------------------------------===// + +#include "PassGen.h" +#include "nkipy/Transforms/Passes.h" +#include "nkipy/Transforms/IRHelpers.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace nkipy; + +namespace mlir { +namespace nkipy { + +namespace { + +/// Canonicalize a single scf.for loop to have step=1. +/// Returns true if the loop was modified. +static bool canonicalizeForOp(scf::ForOp forOp) { + auto stepConst = getConstantInt(forOp.getStep()); + + // Skip if already step=1 or step is dynamic (can't canonicalize). + if (!stepConst || *stepConst == 1) + return false; + + auto lbConst = getConstantInt(forOp.getLowerBound()); + auto ubConst = getConstantInt(forOp.getUpperBound()); + + // Check divisibility for all-constant case. + if (lbConst && ubConst) { + int64_t range = *ubConst - *lbConst; + if (range % *stepConst != 0) { + llvm::errs() << "[CanonicalizeLoopStep] Skipping: range " << range + << " not divisible by step " << *stepConst << "\n"; + return false; + } + } + + OpBuilder builder(forOp); + Location loc = forOp.getLoc(); + Value lb = forOp.getLowerBound(); + Value ub = forOp.getUpperBound(); + Value step = forOp.getStep(); + + // Compute new bounds: trip count = (ub - lb) / step. + // Canonicalize runs after this pass and will fold constants. + Value range = builder.create(loc, ub, lb); + Value tripCount = builder.create(loc, range, step); + Value zero = builder.create(loc, 0); + Value one = builder.create(loc, 1); + + // Reconstruct original IV at top of body: i = lb + idx * step. + builder.setInsertionPointToStart(forOp.getBody()); + Value iv = forOp.getInductionVar(); + Value scaled = builder.create(loc, iv, step); + Value originalIV = builder.create(loc, lb, scaled); + SmallPtrSet exceptions; + exceptions.insert(scaled.getDefiningOp()); + exceptions.insert(originalIV.getDefiningOp()); + iv.replaceAllUsesExcept(originalIV, exceptions); + + // Update loop bounds in place. + forOp.setLowerBound(zero); + forOp.setUpperBound(tripCount); + forOp.setStep(one); + + llvm::errs() << "[CanonicalizeLoopStep] Transformed loop (step=" + << *stepConst << ") to step=1\n"; + return true; +} + +struct NkipyCanonicalizeLoopStepPass + : public CanonicalizeLoopStepBase { + + void runOnOperation() override { + func::FuncOp func = getOperation(); + + // PostOrder walk visits children before parents, so inner loops + // are canonicalized before their enclosing outer loops. + bool changed = false; + func.walk([&](scf::ForOp forOp) { + if (canonicalizeForOp(forOp)) + changed = true; + }); + + if (changed) + llvm::errs() << "[CanonicalizeLoopStep] Pass completed with modifications\n"; + } +}; + +} // namespace + +std::unique_ptr> createCanonicalizeLoopStepPass() { + return std::make_unique(); +} + +} // namespace nkipy +} // namespace mlir diff --git a/kernelgen/mlir/lib/Transforms/CanonicalizePartitionDim.cpp b/kernelgen/mlir/lib/Transforms/CanonicalizePartitionDim.cpp new file mode 100644 index 0000000..dbc13f0 --- /dev/null +++ b/kernelgen/mlir/lib/Transforms/CanonicalizePartitionDim.cpp @@ -0,0 +1,979 @@ +//===- CanonicalizePartitionDim.cpp - Ensure partition_dim=0 everywhere ----===// +// +// This pass inserts transposes so that partition_dim=0 holds for all annotated +// tensors. NISA hardware assumes dimension 0 is always the partition +// dimension, and every downstream pass relies on this. +// +// Algorithm: +// 1. Collect all nkipy.annotate ops with partition_dim != 0. +// 2. For each such annotation, BFS through the connected elementwise +// component to find all values that share the same non-zero partition_dim. +// 3. At component boundaries (inputs from non-elementwise producers, outputs +// to non-elementwise consumers), insert linalg.transpose to move +// partition_dim to position 0. +// 4. Rewrite all elementwise ops inside the component with permuted shapes. +// 5. Update all nkipy.annotate ops: partition_dim -> 0, permute tile_size. +// +// The pass runs BEFORE assign-linalg-op-ids so that new transpose ops get IDs +// (needed for knob-driven tiling). It runs AFTER infer-layout so that all +// tensors in the chain already have partition_dim annotations. +// +//===----------------------------------------------------------------------===// + +#include "PassGen.h" +#include "nkipy/Transforms/Passes.h" +#include "nkipy/Transforms/HardwareConstants.h" +#include "nkipy/Transforms/OpClassification.h" +#include "nkipy/Dialect/NkipyAttrs.h" +#include "nkipy/Dialect/NkipyDialect.h" +#include "nkipy/Dialect/NkipyOps.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/IRMapping.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace nkipy; + +namespace mlir { +namespace nkipy { + +namespace { + +//===----------------------------------------------------------------------===// +// Helpers +//===----------------------------------------------------------------------===// + +/// Build the permutation that moves dimension `partDim` to position 0 +/// and shifts the others right. E.g. for rank=4, partDim=2: +/// [2, 0, 1, 3] +static SmallVector buildPermutation(int64_t rank, int64_t partDim) { + SmallVector perm; + perm.push_back(partDim); + for (int64_t i = 0; i < rank; ++i) { + if (i != partDim) + perm.push_back(i); + } + return perm; +} + +/// Build the inverse permutation. E.g. for perm=[2,0,1,3]: +/// inv=[1,2,0,3] +static SmallVector invertPermutation(ArrayRef perm) { + SmallVector inv(perm.size()); + for (size_t i = 0; i < perm.size(); ++i) + inv[perm[i]] = i; + return inv; +} + +/// Apply a permutation to a vector. +template +static SmallVector permuteVector(ArrayRef vec, ArrayRef perm) { + SmallVector result; + for (int64_t p : perm) + result.push_back(vec[p]); + return result; +} + +/// For >2D transposes: equalize tile sizes of swapped dim pairs so that +/// after tiling, the tiled transpose has ≤2 non-unit dims (required by +/// linalg-to-nisa, since NISA ops are 2D). +/// E.g. perm=[1,0,2], tile=[4,128,128] → tile=[4,4,128] +/// +/// Exception: when one of the swapped dims already has tile=1, the +/// transpose along that pair is trivial (a reshape), so equalization +/// would be harmful — it would shrink the non-unit dim to 1 and create +/// tile size mismatches with downstream consumers. +static void equalizeSwappedTileDims(SmallVector &tile, + ArrayRef perm) { + for (int64_t i = 0; i < static_cast(tile.size()); ++i) { + int64_t j = perm[i]; + if (j != i && j < static_cast(tile.size())) { + if (tile[i] == 1 || tile[j] == 1) + continue; + int64_t minTile = std::min(tile[i], tile[j]); + tile[i] = minTile; + tile[j] = minTile; + } + } +} + +/// Permute a reduced-rank tile_size (from a reduction op) through a +/// full-rank permutation. Expands to full rank using 1s for size-1 dims, +/// permutes, then strips back to parallel dims only. +static DenseI64ArrayAttr permuteReducedTileSize( + ArrayRef oldTileSize, ArrayRef perm, + ArrayRef invPerm, ArrayRef permutedShape, + int64_t rank, MLIRContext *ctx) { + // Recover original shape to find parallel dims. + SmallVector origShape = permuteVector(permutedShape, invPerm); + SmallVector origParDims; + for (int64_t i = 0; i < rank; ++i) { + if (origShape[i] > 1) + origParDims.push_back(i); + } + + if (static_cast(oldTileSize.size()) != + static_cast(origParDims.size())) + return {}; + + // Build full-rank tile with 1s for size-1 dims, then permute. + SmallVector fullTile(rank, 1); + for (size_t i = 0; i < origParDims.size(); ++i) + fullTile[origParDims[i]] = oldTileSize[i]; + SmallVector permFull = permuteVector(fullTile, perm); + + // Strip back to only parallel dims in the permuted shape. + SmallVector newTileSize; + for (int64_t i = 0; i < rank; ++i) { + if (permutedShape[i] > 1) + newTileSize.push_back(permFull[i]); + } + return DenseI64ArrayAttr::get(ctx, newTileSize); +} + +/// Emit an nkipy.annotate for a boundary transpose result. +/// Applies >2D tile equalization, then creates the annotation with +/// partition_dim=0 and the given mem_space/tile_size. +static void annotateBoundaryTranspose(OpBuilder &builder, Location loc, + Value transposed, + MemSpaceEnumAttr memSpace, + DenseI64ArrayAttr tileSize, + ArrayRef perm, int64_t rank) { + DenseI64ArrayAttr finalTileSize; + if (tileSize) { + SmallVector tileSizeVec(tileSize.asArrayRef()); + if (rank > 2) + equalizeSwappedTileDims(tileSizeVec, perm); + finalTileSize = DenseI64ArrayAttr::get(builder.getContext(), tileSizeVec); + } + auto zeroPdAttr = builder.getIntegerAttr( + builder.getIntegerType(32, /*isSigned=*/false), 0); + builder.create( + loc, transposed, memSpace, zeroPdAttr, + finalTileSize, /*reduction_tile=*/DenseI64ArrayAttr{}); +} + +/// Wrappers that accept Operation* for use in BFS traversal where we +/// iterate over generic Operations rather than typed LinalgOps. +static bool isElementwiseOp(Operation *op) { + auto linalgOp = dyn_cast(op); + return linalgOp && ::mlir::nkipy::isElementwiseOp(linalgOp); +} + +static bool isReductionGeneric(Operation *op) { + auto linalgOp = dyn_cast(op); + return linalgOp && ::mlir::nkipy::isReductionGeneric(linalgOp); +} + +static bool isMatmulOp(Operation *op) { + return ::mlir::nkipy::isMatmulOp(op->getName().getStringRef()); +} + +/// Get the partition_dim for a value from its nkipy.annotate op, if any. +/// Returns -1 if no annotation or no partition_dim. +static int64_t getPartitionDim(Value val, + DenseMap &partDimMap) { + auto it = partDimMap.find(val); + if (it != partDimMap.end()) + return it->second; + return -1; +} + +//===----------------------------------------------------------------------===// +// Pass implementation +//===----------------------------------------------------------------------===// + +struct NkipyCanonicalizePartitionDimPass + : public CanonicalizePartitionDimBase { + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + } + + /// Convert ALL batch_matmul ops to scf.for + matmul. + /// + /// SBUF output: produces MxBxN (M at dim 0 = partition), so LegalizeLayout + /// can properly expand it to physical format. + /// + /// HBM output (or no annotation): produces BxMxN (standard layout), same + /// as the original preprocessBatchedOps in KnobDrivenTiling. + /// + /// Returns the set of SBUF converted output values (the for-loop results) + /// so the partition_dim component processing can skip boundary + /// transposes for these values (they are already in MxBxN order). + DenseSet preprocessBatchMatmul(func::FuncOp func) { + DenseSet convertedBmmOutputs; + + SmallVector batchOps; + func.walk([&](linalg::BatchMatmulOp op) { + batchOps.push_back(op); + }); + + for (auto bmmOp : batchOps) { + Value lhs = bmmOp.getInputs()[0]; + Value rhs = bmmOp.getInputs()[1]; + Value init = bmmOp.getOutputs()[0]; + + auto initType = cast(init.getType()); + if (initType.getRank() != 3 || !initType.hasStaticShape()) + continue; + + // Check if the output is annotated as SBUF. + bool isSbuf = false; + SmallVector annotateOps; + for (auto *user : bmmOp.getResult(0).getUsers()) { + if (auto ann = dyn_cast(user)) { + annotateOps.push_back(ann); + if (auto memSpace = ann.getMemSpaceAttr()) { + if (memSpace.getValue() == nkipy::MemSpaceEnum::Sbuf) + isSbuf = true; + } + } + } + + Location loc = bmmOp.getLoc(); + // BxMxN shape + int64_t B = initType.getShape()[0]; + int64_t M = initType.getShape()[1]; + int64_t N = initType.getShape()[2]; + Type elemTy = initType.getElementType(); + + OpBuilder builder(bmmOp); + + // For SBUF: output is MxBxN (partition-correct, M at dim 0). + // Create a new MxBxN tensor.empty (no fill needed — NISA matmul + // auto-zeroes PSUM, so the init value is unused by hardware). + // For HBM: output is BxMxN (standard layout). + // Reuse the original init tensor directly. + Value loopInit; + if (isSbuf) { + loopInit = builder.create( + loc, SmallVector{M, B, N}, elemTy); + } else { + loopInit = init; + } + + // Create loop bounds: for %b = 0 to B step 1 + Value c0 = builder.create(loc, 0); + Value cB = builder.create(loc, B); + Value c1 = builder.create(loc, 1); + + auto forOp = builder.create(loc, c0, cB, c1, + ValueRange{loopInit}); + + // Build loop body + builder.setInsertionPointToStart(forOp.getBody()); + Value iv = forOp.getInductionVar(); + Value acc = forOp.getRegionIterArg(0); + + // Extract 2D slices from LHS and RHS (rank-reducing from BxMxK/BxKxN) + auto extract2DFromBatch = [&](Value src) -> Value { + auto srcType = cast(src.getType()); + auto shape = srcType.getShape(); // [B, dim1, dim2] + auto sliceType = RankedTensorType::get( + {shape[1], shape[2]}, srcType.getElementType()); + + SmallVector offsets = { + iv, builder.getIndexAttr(0), builder.getIndexAttr(0)}; + SmallVector sizes = { + builder.getIndexAttr(1), + builder.getIndexAttr(shape[1]), + builder.getIndexAttr(shape[2])}; + SmallVector strides = { + builder.getIndexAttr(1), builder.getIndexAttr(1), + builder.getIndexAttr(1)}; + + return builder.create( + loc, sliceType, src, offsets, sizes, strides); + }; + + Value lhsSlice = extract2DFromBatch(lhs); + Value rhsSlice = extract2DFromBatch(rhs); + + // Extract init slice from accumulator and run matmul. + auto mmType = RankedTensorType::get({M, N}, elemTy); + SmallVector extractOffsets, extractSizes; + SmallVector extractStrides = { + builder.getIndexAttr(1), builder.getIndexAttr(1), + builder.getIndexAttr(1)}; + + Value inserted; + linalg::MatmulOp matmulOp; + if (isSbuf) { + // SBUF (MxBxN): use a separate 2D tensor for matmul output. + // We cannot use a collapse_shape view of the 3D acc because after + // LegalizeLayout expands the 3D alloc to 5D, the collapsed strides + // produce an interleaved 2D view that corrupts the matmul output. + // Instead: matmul → separate 2D → rank-expanding insert_slice into 3D. + Value mmInit = builder.create( + loc, SmallVector{M, N}, elemTy); + + matmulOp = builder.create( + loc, TypeRange{mmType}, + ValueRange{lhsSlice, rhsSlice}, ValueRange{mmInit}); + + // Expand 2D [M, N] → 3D [M, 1, N] so the insert_slice is same-rank. + // This avoids rank-reducing subviews after bufferization, which + // LegalizeLayout's tileCopyAndTranspose cannot handle. + auto expandedType = RankedTensorType::get({M, 1, N}, elemTy); + SmallVector reassoc = {{0}, {1, 2}}; + Value expanded = builder.create( + loc, expandedType, matmulOp.getResult(0), reassoc); + + // Same-rank insert: 3D [M, 1, N] → 3D acc at [0, b, 0] [M, 1, N] + extractOffsets = {builder.getIndexAttr(0), iv, builder.getIndexAttr(0)}; + extractSizes = {builder.getIndexAttr(M), builder.getIndexAttr(1), + builder.getIndexAttr(N)}; + inserted = builder.create( + loc, expanded, acc, + extractOffsets, extractSizes, extractStrides); + } else { + // HBM (BxMxN): rank-reducing extract [%b,0,0][1,M,N] → MxN + extractOffsets = {iv, builder.getIndexAttr(0), builder.getIndexAttr(0)}; + extractSizes = {builder.getIndexAttr(1), builder.getIndexAttr(M), + builder.getIndexAttr(N)}; + + Value initSlice = builder.create( + loc, mmType, acc, extractOffsets, extractSizes, extractStrides); + + matmulOp = builder.create( + loc, TypeRange{mmType}, + ValueRange{lhsSlice, rhsSlice}, ValueRange{initSlice}); + + inserted = builder.create( + loc, matmulOp.getResult(0), acc, + extractOffsets, extractSizes, extractStrides); + } + + // Copy nkipy.op_id attribute + if (auto opIdAttr = bmmOp->getAttrOfType("nkipy.op_id")) + matmulOp->setAttr("nkipy.op_id", opIdAttr); + + builder.create(loc, ValueRange{inserted}); + + // Handle annotations: create separate annotations for the matmul + // result and the 3D accumulator. + for (auto ann : annotateOps) { + // 2D tile for matmul: drop batch dim from original tile_size + DenseI64ArrayAttr newTileSize; + if (auto ts = ann.getTileSizeAttr()) { + auto arr = ts.asArrayRef(); + if (arr.size() >= 2) { + SmallVector adjusted(arr.begin() + 1, arr.end()); + newTileSize = DenseI64ArrayAttr::get( + func.getContext(), adjusted); + } + } + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(matmulOp); + + if (isSbuf) { + // Matmul output: annotate as SharedHbm so LegalizeLayout won't + // touch it (it's a temporary that gets copied into the 3D SBUF acc). + auto hbmAttr = nkipy::MemSpaceEnumAttr::get( + func.getContext(), nkipy::MemSpaceEnum::SharedHbm); + builder.create( + ann.getLoc(), matmulOp.getResult(0), + hbmAttr, /*partitionDim=*/IntegerAttr{}, + newTileSize, ann.getReductionTileAttr()); + + // 3D accumulator: annotate as SBUF with 3D tile (partition, batch, free). + // Tile is [M_tile, 1, N_tile] from the original [B_tile, M_tile, N_tile]. + DenseI64ArrayAttr accTileSize; + if (auto ts = ann.getTileSizeAttr()) { + auto arr = ts.asArrayRef(); + if (arr.size() >= 3) { + // Original: [B, M, N] → transposed: [M, B, N] → tile: [M_tile, B_tile, N_tile] + SmallVector accTile = {arr[1], arr[0], arr[2]}; + accTileSize = DenseI64ArrayAttr::get(func.getContext(), accTile); + } + } + builder.setInsertionPointAfter(forOp); + builder.create( + ann.getLoc(), forOp.getResult(0), + ann.getMemSpaceAttr(), /*partitionDim=*/IntegerAttr{}, + accTileSize, /*reductionTile=*/DenseI64ArrayAttr{}); + } else { + // HBM: annotate matmul result with original mem_space + builder.create( + ann.getLoc(), matmulOp.getResult(0), + ann.getMemSpaceAttr(), /*partitionDim=*/IntegerAttr{}, + newTileSize, ann.getReductionTileAttr()); + } + + ann.erase(); + } + + // Replace all uses of batch_matmul result with for loop result. + Value forResult = forOp.getResult(0); + + // Collect non-annotate uses before replacing + SmallVector usesToReplace; + for (OpOperand &use : bmmOp.getResult(0).getUses()) { + if (!isa(use.getOwner())) + usesToReplace.push_back(&use); + } + + // Replace uses + for (OpOperand *use : usesToReplace) + use->set(forResult); + + // Erase remaining annotate uses and the bmm op + for (auto *user : llvm::make_early_inc_range( + bmmOp.getResult(0).getUsers())) { + if (isa(user)) + user->erase(); + } + bmmOp.erase(); + + if (isSbuf) { + convertedBmmOutputs.insert(forResult); + llvm::errs() << "[CanonicalizePartitionDim] Converted batch_matmul " + << "(B=" << B << ", M=" << M << ", N=" << N + << ") to loop + matmul with MxBxN output (SBUF)\n"; + } else { + llvm::errs() << "[CanonicalizePartitionDim] Converted batch_matmul " + << "(B=" << B << ", M=" << M << ", N=" << N + << ") to loop + matmul with BxMxN output (HBM)\n"; + } + } + + return convertedBmmOutputs; + } + + /// Collect annotations with partition_dim info from the function. + void collectAnnotations( + func::FuncOp func, + DenseMap &partDimMap, + DenseMap &valueAnnotateMap, + SmallVector &nonZeroAnnotations) { + func.walk([&](nkipy::AnnotateOp annotateOp) { + valueAnnotateMap[annotateOp.getTarget()] = annotateOp; + auto partDimAttr = annotateOp.getPartitionDimAttr(); + if (!partDimAttr) + return; + uint32_t partDim = partDimAttr.getUInt(); + partDimMap[annotateOp.getTarget()] = partDim; + if (partDim != 0) + nonZeroAnnotations.push_back(annotateOp); + }); + } + + /// BFS from seedOp to find connected elementwise/reduction component. + llvm::SetVector findComponent(Operation *seedOp) { + llvm::SetVector componentOps; + + auto canInclude = [](Operation *op) { + return isElementwiseOp(op) || isReductionGeneric(op); + }; + + if (!seedOp || !canInclude(seedOp)) + return componentOps; + + SmallVector bfsQueue; + bfsQueue.push_back(seedOp); + componentOps.insert(seedOp); + + while (!bfsQueue.empty()) { + Operation *op = bfsQueue.pop_back_val(); + + // Backward through DPS inputs. + if (auto linalgOp = dyn_cast(op)) { + for (Value input : linalgOp.getDpsInputs()) { + Operation *producer = input.getDefiningOp(); + if (producer && canInclude(producer) && + !componentOps.count(producer)) { + componentOps.insert(producer); + bfsQueue.push_back(producer); + } + } + } + + // Forward through uses. + for (Value result : op->getResults()) { + for (Operation *user : result.getUsers()) { + if (isa(user)) + continue; + if (canInclude(user) && !componentOps.count(user)) { + componentOps.insert(user); + bfsQueue.push_back(user); + } + } + } + } + + return componentOps; + } + + /// Find boundary inputs: values used by component ops but defined outside. + /// Skips tensor.empty and linalg.fill (recreated with permuted shapes). + llvm::SetVector findBoundaryInputs( + const llvm::SetVector &componentOps) { + llvm::SetVector boundaryInputs; + for (Operation *op : componentOps) { + for (Value operand : op->getOperands()) { + Operation *defOp = operand.getDefiningOp(); + if (!defOp || !componentOps.count(defOp)) { + if (defOp && (isa(defOp) || + isa(defOp))) + continue; + boundaryInputs.insert(operand); + } + } + } + return boundaryInputs; + } + + /// Find boundary outputs: results of component ops used outside. + llvm::SetVector findBoundaryOutputs( + const llvm::SetVector &componentOps) { + llvm::SetVector boundaryOutputs; + for (Operation *op : componentOps) { + for (Value result : op->getResults()) { + for (Operation *user : result.getUsers()) { + if (isa(user)) + continue; + if (!componentOps.count(user)) { + boundaryOutputs.insert(result); + break; + } + } + } + } + return boundaryOutputs; + } + + /// Insert input boundary transposes (original -> permuted). + void insertInputTransposes( + OpBuilder &builder, const llvm::SetVector &boundaryInputs, + ArrayRef perm, int64_t rank, + DenseI64ArrayAttr seedTileSizeAttr, + const DenseSet &convertedBmmOutputs, + IRMapping &valueMapping) { + for (Value input : boundaryInputs) { + auto inputType = dyn_cast(input.getType()); + if (!inputType || inputType.getRank() != rank) + continue; + + if (convertedBmmOutputs.count(input)) { + valueMapping.map(input, input); + llvm::errs() << "[CanonicalizePartitionDim] Skipping boundary " + << "transpose for converted BMM output (already MxBxN)\n"; + continue; + } + + SmallVector newShape = + permuteVector(inputType.getShape(), perm); + + if (input.getDefiningOp()) + builder.setInsertionPointAfter(input.getDefiningOp()); + else + builder.setInsertionPointToStart(input.getParentBlock()); + + Location loc = input.getLoc(); + Value init = builder.create( + loc, newShape, inputType.getElementType()); + auto transposeOp = + builder.create(loc, input, init, perm); + Value transposed = transposeOp.getResult()[0]; + valueMapping.map(input, transposed); + + DenseI64ArrayAttr transposeTileSize; + if (seedTileSizeAttr) { + SmallVector permutedTile = + permuteVector(seedTileSizeAttr.asArrayRef(), perm); + transposeTileSize = + DenseI64ArrayAttr::get(builder.getContext(), permutedTile); + } + auto sbufAttr = nkipy::MemSpaceEnumAttr::get( + builder.getContext(), nkipy::MemSpaceEnum::Sbuf); + annotateBoundaryTranspose(builder, loc, transposed, + sbufAttr, transposeTileSize, perm, rank); + } + } + + /// Rewrite component ops with permuted shapes. + void rewriteComponentOps( + OpBuilder &builder, func::FuncOp func, + const llvm::SetVector &componentOps, + ArrayRef perm, int64_t rank, + IRMapping &valueMapping) { + // Process in topological order. + SmallVector topoOrder; + func.walk([&](Operation *op) { + if (componentOps.count(op)) + topoOrder.push_back(op); + }); + + for (Operation *op : topoOrder) { + auto linalgOp = dyn_cast(op); + if (!linalgOp) + continue; + + // Replace operands with mapped (transposed) values. + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + if (Value mapped = valueMapping.lookupOrNull(op->getOperand(i))) + op->setOperand(i, mapped); + } + + // Recreate init operands with permuted shape. + for (auto [idx, initOperand] : + llvm::enumerate(linalgOp.getDpsInits())) { + if (auto emptyOp = initOperand.getDefiningOp()) { + auto emptyType = + dyn_cast(emptyOp.getResult().getType()); + if (!emptyType || emptyType.getRank() != rank) + continue; + SmallVector newShape = + permuteVector(emptyType.getShape(), perm); + builder.setInsertionPoint(emptyOp); + auto newEmpty = builder.create( + emptyOp.getLoc(), newShape, emptyType.getElementType()); + emptyOp.getResult().replaceAllUsesWith(newEmpty.getResult()); + emptyOp->erase(); + continue; + } + + if (auto fillOp = dyn_cast( + initOperand.getDefiningOp())) { + auto fillEmptyOp = + fillOp.getDpsInits()[0].getDefiningOp(); + if (!fillEmptyOp) + continue; + auto emptyType = dyn_cast( + fillEmptyOp.getResult().getType()); + if (!emptyType || emptyType.getRank() != rank) + continue; + SmallVector newShape = + permuteVector(emptyType.getShape(), perm); + builder.setInsertionPoint(fillEmptyOp); + auto newEmpty = builder.create( + fillEmptyOp.getLoc(), newShape, emptyType.getElementType()); + fillEmptyOp.getResult().replaceAllUsesWith(newEmpty.getResult()); + fillEmptyOp->erase(); + fillOp->getResult(0).setType( + RankedTensorType::get(newShape, emptyType.getElementType())); + } + } + + // Update result types. + for (Value result : op->getResults()) { + auto resultType = dyn_cast(result.getType()); + if (!resultType || resultType.getRank() != rank) + continue; + SmallVector newShape = + permuteVector(resultType.getShape(), perm); + result.setType( + RankedTensorType::get(newShape, resultType.getElementType())); + } + + // Permute indexing maps of linalg.generic ops. + // Named ops (add, mul, exp, etc.) have implicit identity maps that + // remain valid after consistent shape permutation. But generic ops + // may have non-identity maps (e.g. broadcast: (d0,d1,d2)->(0,d1,d2)) + // that must be permuted to match the new shape layout. + if (auto genericOp = dyn_cast(op)) { + MLIRContext *ctx = genericOp.getContext(); + SmallVector invPerm = invertPermutation(perm); + + // Build dimension remapping: d_j -> d_{invPerm[j]}. + SmallVector dimReplacements; + for (int64_t j = 0; j < rank; ++j) + dimReplacements.push_back(getAffineDimExpr(invPerm[j], ctx)); + + SmallVector newMaps; + for (AffineMap map : genericOp.getIndexingMapsArray()) { + unsigned numResults = map.getNumResults(); + + SmallVector exprs; + if (static_cast(numResults) == rank) { + // Full-rank map: reorder result positions by perm, then + // remap dimension references. + for (int64_t i = 0; i < rank; ++i) + exprs.push_back(map.getResult(perm[i])); + } else { + // Reduced-rank map (e.g. reduction output): keep result + // order, only remap dimension references. + for (unsigned i = 0; i < numResults; ++i) + exprs.push_back(map.getResult(i)); + } + + SmallVector finalExprs; + for (AffineExpr expr : exprs) + finalExprs.push_back( + expr.replaceDimsAndSymbols(dimReplacements, {})); + + newMaps.push_back( + AffineMap::get(map.getNumDims(), 0, finalExprs, ctx)); + } + genericOp.setIndexingMapsAttr( + builder.getAffineMapArrayAttr(newMaps)); + } + } + } + + /// Insert output boundary transposes (permuted -> original) and rewire uses. + void insertOutputTransposes( + OpBuilder &builder, + const llvm::SetVector &boundaryOutputs, + const llvm::SetVector &componentOps, + ArrayRef invPerm, int64_t rank, + DenseI64ArrayAttr seedTileSizeAttr, + DenseMap &valueAnnotateMap, + nkipy::AnnotateOp annotateOp, int64_t partDim) { + for (Value output : boundaryOutputs) { + auto outputType = dyn_cast(output.getType()); + if (!outputType) + continue; + + SmallVector origShape = + permuteVector(outputType.getShape(), invPerm); + + builder.setInsertionPointAfterValue(output); + Location loc = output.getLoc(); + Value init = builder.create( + loc, origShape, outputType.getElementType()); + auto transposeOp = + builder.create(loc, output, init, invPerm); + Value transposedBack = transposeOp.getResult()[0]; + + // Derive tile_size and mem_space for the output annotation. + DenseI64ArrayAttr outputTileSize; + MemSpaceEnumAttr outputMemSpace; + auto annIt = valueAnnotateMap.find(output); + if (annIt != valueAnnotateMap.end()) { + outputMemSpace = annIt->second.getMemSpaceAttr(); + outputTileSize = annIt->second.getTileSizeAttr(); + } + if (!outputTileSize && seedTileSizeAttr) + outputTileSize = seedTileSizeAttr; + + // Expand reduced-rank tile_size to full rank. + if (outputTileSize && + static_cast(outputTileSize.size()) < rank) { + SmallVector expanded; + ArrayRef reduced = outputTileSize.asArrayRef(); + size_t ri = 0; + for (int64_t i = 0; i < rank; ++i) { + if (origShape[i] > 1 && ri < reduced.size()) + expanded.push_back(reduced[ri++]); + else + expanded.push_back(1); + } + outputTileSize = + DenseI64ArrayAttr::get(builder.getContext(), expanded); + } + + // Clamp dim 0 to MAX_PARTITION_DIM. + if (outputTileSize) { + SmallVector tileSizeVec(outputTileSize.asArrayRef()); + if (tileSizeVec[0] > MAX_PARTITION_DIM) { + annotateOp.emitWarning("partition_dim=") + << partDim << ": clamping boundary transpose tile_size[0] " + << "from " << tileSizeVec[0] << " to " << MAX_PARTITION_DIM; + tileSizeVec[0] = MAX_PARTITION_DIM; + } + outputTileSize = + DenseI64ArrayAttr::get(builder.getContext(), tileSizeVec); + } + + annotateBoundaryTranspose(builder, loc, transposedBack, + outputMemSpace, outputTileSize, + invPerm, rank); + + // Rewire non-component uses to the inverse-transposed value. + SmallVector usesToReplace; + for (OpOperand &use : output.getUses()) { + Operation *user = use.getOwner(); + if (user == transposeOp || isa(user)) + continue; + if (!componentOps.count(user)) + usesToReplace.push_back(&use); + } + for (OpOperand *use : usesToReplace) + use->set(transposedBack); + } + } + + /// Update nkipy.annotate ops for values in the component: + /// set partition_dim=0 and permute tile_size. + void updateComponentAnnotations( + OpBuilder &builder, func::FuncOp func, + const llvm::SetVector &componentOps, + ArrayRef perm, ArrayRef invPerm, int64_t rank) { + func.walk([&](nkipy::AnnotateOp annOp) { + Value annTarget = annOp.getTarget(); + Operation *defOp = annTarget.getDefiningOp(); + if (!defOp || !componentOps.count(defOp)) + return; + + annOp.setPartitionDimAttr(builder.getIntegerAttr( + builder.getIntegerType(32, /*isSigned=*/false), 0)); + + if (auto tileSizeAttr = annOp.getTileSizeAttr()) { + ArrayRef oldTileSize = tileSizeAttr.asArrayRef(); + if (static_cast(oldTileSize.size()) == rank) { + SmallVector newTileSize = + permuteVector(oldTileSize, perm); + annOp.setTileSizeAttr( + DenseI64ArrayAttr::get(builder.getContext(), newTileSize)); + } else { + auto annTargetType = + dyn_cast(annTarget.getType()); + if (annTargetType) { + auto result = permuteReducedTileSize( + oldTileSize, perm, invPerm, + annTargetType.getShape(), rank, builder.getContext()); + if (result) + annOp.setTileSizeAttr(result); + } + } + } + }); + } + + void runOnOperation() override { + func::FuncOp func = getOperation(); + + llvm::errs() << "[CanonicalizePartitionDim] Processing function: " + << func.getName() << "\n"; + + // Phase 0: Convert batch_matmul to loop + matmul. + DenseSet convertedBmmOutputs = preprocessBatchMatmul(func); + + // Phase 1: Collect partition_dim annotations. + DenseMap partDimMap; + DenseMap valueAnnotateMap; + SmallVector nonZeroAnnotations; + collectAnnotations(func, partDimMap, valueAnnotateMap, nonZeroAnnotations); + + if (nonZeroAnnotations.empty()) { + llvm::errs() << "[CanonicalizePartitionDim] No non-zero partition_dim " + "annotations found\n"; + return; + } + + // Phase 2: Process each non-zero partition_dim component. + OpBuilder builder(func.getContext()); + int numTransposed = 0; + DenseSet processedOps; + + for (nkipy::AnnotateOp annotateOp : nonZeroAnnotations) { + Value target = annotateOp.getTarget(); + Operation *seedOp = target.getDefiningOp(); + + if (seedOp && processedOps.count(seedOp)) + continue; + + auto tensorType = dyn_cast(target.getType()); + if (!tensorType) { + annotateOp.emitError("partition_dim != 0 on non-tensor type"); + return signalPassFailure(); + } + + int64_t partDim = partDimMap[target]; + int64_t rank = tensorType.getRank(); + if (partDim >= rank) { + annotateOp.emitError("partition_dim ") + << partDim << " >= tensor rank " << rank; + return signalPassFailure(); + } + + SmallVector perm = buildPermutation(rank, partDim); + SmallVector invPerm = invertPermutation(perm); + auto seedTileSizeAttr = annotateOp.getTileSizeAttr(); + + // Validate partition tile size fits hardware. + if (seedTileSizeAttr) { + ArrayRef tileVals = seedTileSizeAttr.asArrayRef(); + if (static_cast(tileVals.size()) > partDim && + tileVals[partDim] > MAX_PARTITION_DIM) { + annotateOp.emitError("tile_size[") + << partDim << "] = " << tileVals[partDim] + << " exceeds hardware partition limit " << MAX_PARTITION_DIM; + return signalPassFailure(); + } + } + + // Skip non-linalg ops (partition_dim is informational only). + if (seedOp && !isa(seedOp)) { + processedOps.insert(seedOp); + continue; + } + + // Error on matmul with partition_dim != 0. + if (seedOp && isMatmulOp(seedOp)) { + annotateOp.emitError( + "partition_dim != 0 on matmul/bmm is not supported. " + "Please annotate downstream elementwise ops instead."); + return signalPassFailure(); + } + + // Step 1: Find connected component. + auto componentOps = findComponent(seedOp); + for (Operation *op : componentOps) + processedOps.insert(op); + + if (componentOps.empty()) { + annotateOp.emitError( + "partition_dim != 0 on an unsupported op. " + "Supported: elementwise and reduction ops."); + return signalPassFailure(); + } + + // Step 2: Find boundaries. + auto boundaryInputs = findBoundaryInputs(componentOps); + auto boundaryOutputs = findBoundaryOutputs(componentOps); + + // Step 3: Insert input boundary transposes. + IRMapping valueMapping; + insertInputTransposes(builder, boundaryInputs, perm, rank, + seedTileSizeAttr, convertedBmmOutputs, + valueMapping); + + // Step 4: Rewrite component ops with permuted shapes. + rewriteComponentOps(builder, func, componentOps, perm, rank, + valueMapping); + + // Step 5: Insert output boundary transposes. + insertOutputTransposes(builder, boundaryOutputs, componentOps, + invPerm, rank, seedTileSizeAttr, + valueAnnotateMap, annotateOp, partDim); + + // Step 6: Update annotations. + updateComponentAnnotations(builder, func, componentOps, + perm, invPerm, rank); + + numTransposed += componentOps.size(); + llvm::errs() << "[CanonicalizePartitionDim] Processed component of " + << componentOps.size() << " ops with partition_dim=" + << partDim << "\n"; + } + + llvm::errs() << "[CanonicalizePartitionDim] Rewritten " << numTransposed + << " op(s) total\n"; + } +}; + +} // namespace + +std::unique_ptr> +createCanonicalizePartitionDimPass() { + return std::make_unique(); +} + +} // namespace nkipy +} // namespace mlir diff --git a/kernelgen/mlir/lib/Transforms/CanonicalizeReshape.cpp b/kernelgen/mlir/lib/Transforms/CanonicalizeReshape.cpp new file mode 100644 index 0000000..c64b15f --- /dev/null +++ b/kernelgen/mlir/lib/Transforms/CanonicalizeReshape.cpp @@ -0,0 +1,367 @@ +//===- CanonicalizeReshape.cpp - Canonicalize memref reshape ops ----------===// +// +// Post-bufferization pass that canonicalizes reshape (expand_shape / +// collapse_shape) operations on memrefs. +// +// Runs after annotate-memory-space so all memrefs have explicit memory spaces +// (#nisa.mem, #nisa.mem, etc.) and partition_dim is +// guaranteed to be 0 (from canonicalize-partition-dim). +// +// Classification logic: +// +// +-------------------+--------+-----------+-----------------------------+ +// | Reshape type | MemSpc | Pdim(d0)? | Action | +// +-------------------+--------+-----------+-----------------------------+ +// | Any | HBM | N/A | View (no partition concept) | +// | Merge (collapse) | SBUF | any | View (contiguous in memory) | +// | Split (expand) | SBUF | no | View (partition unchanged) | +// | Split (expand) | SBUF | yes | Copy (needs modulo for | +// | | | | partition reassignment) | +// | Split N->(N,1) | SBUF | trivial | View (no modulo needed) | +// +-------------------+--------+-----------+-----------------------------+ +// +// Additionally, returned expand_shape views of function arguments need +// alloc+copy to ensure separate HBM output allocations. +// +//===----------------------------------------------------------------------===// + +#include "nkipy/Transforms/Passes.h" +#include "nkipy/Transforms/IRHelpers.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using nkipy::isSbuf; +using nkipy::isAnyHbm; + +namespace { + +/// Trace through memref view ops to find the base memref. +static Value traceToBase(Value v) { + while (auto defOp = v.getDefiningOp()) { + if (auto op = dyn_cast(defOp)) + v = op.getSrc(); + else if (auto op = dyn_cast(defOp)) + v = op.getSrc(); + else if (auto op = dyn_cast(defOp)) + v = op.getSource(); + else if (auto op = dyn_cast(defOp)) + v = op.getSource(); + else if (auto op = dyn_cast(defOp)) + v = op.getSource(); + else + break; + } + return v; +} + +/// Check if an expand_shape splits the partition dim (dim 0). +/// After canonicalize-partition-dim, partition dim is always 0. +/// A split of dim 0 means the reassociation group for src dim 0 maps to +/// multiple dst dims. +static bool splitsPartitionDim(memref::ExpandShapeOp expandOp) { + auto reassoc = expandOp.getReassociationIndices(); + // Reassociation is indexed by src dims. If src dim 0's group has + // multiple dst dims, the partition dim is being split. + if (reassoc.empty()) + return false; + return reassoc[0].size() > 1; +} + +/// Check if a partition dim split is trivial: N -> (N, 1) or (1, N). +/// Trivial splits don't need modulo because one of the factors is 1. +static bool isTrivialPartitionSplit(memref::ExpandShapeOp expandOp) { + if (!splitsPartitionDim(expandOp)) + return false; + + auto resultType = cast(expandOp.getResultType()); + auto reassoc = expandOp.getReassociationIndices(); + // Check the dims in the partition group + for (int64_t dstIdx : reassoc[0]) { + int64_t dimSize = resultType.getShape()[dstIdx]; + if (dimSize == 1) + return true; // One factor is 1, no modulo needed + } + return false; +} + +struct CanonicalizeReshapePass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CanonicalizeReshapePass) + + StringRef getArgument() const final { + return "canonicalize-reshape"; + } + + StringRef getDescription() const final { + return "Canonicalize memref reshape ops based on mem_space and partition_dim"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + func::FuncOp func = getOperation(); + + // --------------------------------------------------------------- + // Phase 0: Convert memref.reshape to memref.reinterpret_cast. + // + // Category 2 reshapes (no contiguous dim grouping) are emitted by the + // tracer as tensor.reshape, which bufferizes to memref.reshape. + // Since the data is contiguous in row-major order, we can replace with + // a reinterpret_cast using contiguous (row-major) strides for the new + // shape — a true zero-cost view reinterpretation. + // --------------------------------------------------------------- + SmallVector reshapeOps; + func.walk([&](memref::ReshapeOp op) { reshapeOps.push_back(op); }); + + for (auto reshapeOp : reshapeOps) { + auto srcType = cast(reshapeOp.getSource().getType()); + auto resultType = cast(reshapeOp.getResult().getType()); + + OpBuilder builder(reshapeOp); + Location loc = reshapeOp.getLoc(); + + // Decompose into collapse_shape(→1D) + expand_shape(→target). + // Both are pure views on contiguous memory — no data movement. + // FoldHbmReshapePattern in linalg-to-nisa handles these: the DMA + // addresses the alloc with its original shape, while the return + // reinterprets the same buffer via view() with the target shape. + int64_t totalElems = 1; + for (int64_t d : srcType.getShape()) + totalElems *= d; + + // Step 1: collapse source to 1D. + auto flatType = MemRefType::get( + {totalElems}, srcType.getElementType(), + MemRefLayoutAttrInterface{}, srcType.getMemorySpace()); + SmallVector collapseReassoc = {{}}; + for (int64_t i = 0; i < srcType.getRank(); ++i) + collapseReassoc[0].push_back(i); + auto collapseOp = builder.create( + loc, flatType, reshapeOp.getSource(), collapseReassoc); + + // Step 2: expand 1D to target shape. + auto expandType = MemRefType::get( + resultType.getShape(), resultType.getElementType(), + MemRefLayoutAttrInterface{}, resultType.getMemorySpace()); + SmallVector expandReassoc = {{}}; + for (int64_t i = 0; i < resultType.getRank(); ++i) + expandReassoc[0].push_back(i); + auto expandOp = builder.create( + loc, expandType, collapseOp.getResult(), expandReassoc); + + reshapeOp.getResult().replaceAllUsesWith(expandOp.getResult()); + reshapeOp->erase(); + + llvm::errs() << "[CanonicalizeReshape] Category 2 reshape: " + << "decomposed to collapse(1D)+expand(target)\n"; + } + + // --------------------------------------------------------------- + // Phase 1: Classify expand_shape ops and insert copies where needed. + // + // collapse_shape (merge) is always a view -- merging dims never needs + // modulo regardless of mem_space or partition_dim involvement. + // + // expand_shape (split) needs analysis: + // - HBM: always view (no partition concept) + // - SBUF, non-partition split: view + // - SBUF, partition dim split: copy (NISA has no modulo) + // Exception: trivial split N->(N,1)/(1,N) is view + // --------------------------------------------------------------- + SmallVector expandOps; + func.walk([&](memref::ExpandShapeOp op) { expandOps.push_back(op); }); + + for (auto expandOp : expandOps) { + auto srcType = cast(expandOp.getSrc().getType()); + + // HBM: always a view, no partition concept + if (isAnyHbm(srcType) || !isSbuf(srcType)) + continue; + + // SBUF: check if partition dim (dim 0) is being split + if (!splitsPartitionDim(expandOp)) + continue; + + // Trivial split N->(N,1) or (1,N): no modulo needed, keep as view + if (isTrivialPartitionSplit(expandOp)) + continue; + + // SBUF partition dim split that needs modulo -- NISA can't do this + // as a view. Insert alloc + copy to materialize the reshape. + // + // TODO: This should ideally be a tiled copy loop that explicitly + // computes the partition reassignment. For now, emit alloc + copy + // and let downstream passes handle the lowering. + OpBuilder builder(expandOp); + Location loc = expandOp.getLoc(); + + auto resultType = cast(expandOp.getResultType()); + auto allocOp = builder.create(loc, resultType); + + // Copy from the expanded view into the fresh allocation. + // The expand_shape itself is a valid view for memref.copy's purposes. + builder.create( + loc, expandOp.getResult(), allocOp.getResult()); + + expandOp.getResult().replaceAllUsesExcept( + allocOp.getResult(), allocOp->getNextNode()); + + llvm::errs() << "[CanonicalizeReshape] SBUF partition dim split: " + << "inserted alloc+copy for expand_shape\n"; + } + + // --------------------------------------------------------------- + // Phase 2: Returned view ops of function arguments. + // + // After bufferization, view ops (expand_shape, reinterpret_cast) + // produce no allocation, but NISA requires function outputs to be + // separate HBM allocations. Insert alloc + copy to materialize. + // + // collapse_shape already gets alloc+copy from bufferization. + // --------------------------------------------------------------- + func.walk([&](func::ReturnOp returnOp) { + for (unsigned i = 0; i < returnOp.getNumOperands(); ++i) { + Value retVal = returnOp.getOperand(i); + if (!isa(retVal.getType())) + continue; + + Value base = traceToBase(retVal); + if (!isa(base)) + continue; + if (retVal == base) + continue; + + Operation *defOp = retVal.getDefiningOp(); + if (!defOp) + continue; + + // Handle expand_shape views. + if (auto expandOp = dyn_cast(defOp)) { + Value expandSrc = expandOp.getSrc(); + if (!isa(traceToBase(expandSrc))) + continue; + + // Check if this expand_shape already has a non-arg source + // (Phase 1 may have already inserted a copy) + if (expandSrc.getDefiningOp()) + continue; + + auto srcType = cast(expandSrc.getType()); + auto oldResultType = cast(expandOp.getResultType()); + auto contigResultType = MemRefType::get( + oldResultType.getShape(), oldResultType.getElementType(), + MemRefLayoutAttrInterface{}, oldResultType.getMemorySpace()); + Location loc = expandOp.getLoc(); + + if (srcType.getRank() < 2) { + // Source is 1D: NISA DMA requires at least 2D tiles. + // Alloc with the expanded (2D) shape and copy from the + // expand_shape view to avoid a 1D DMA. + OpBuilder builder(expandOp->getNextNode()); + + auto allocOp = builder.create( + loc, contigResultType); + auto copyOp = builder.create( + loc, expandOp.getResult(), allocOp.getResult()); + + expandOp.getResult().replaceAllUsesExcept( + allocOp.getResult(), {copyOp}); + } else { + // Source is 2D+: alloc+copy before expand, then rebuild expand. + auto contigSrcType = MemRefType::get( + srcType.getShape(), srcType.getElementType(), + MemRefLayoutAttrInterface{}, srcType.getMemorySpace()); + OpBuilder builder(expandOp); + + auto allocOp = builder.create( + loc, contigSrcType); + builder.create( + loc, expandSrc, allocOp.getResult()); + + auto newExpandOp = builder.create( + loc, contigResultType, allocOp.getResult(), + expandOp.getReassociationIndices(), + expandOp.getMixedOutputShape()); + + expandOp.getResult().replaceAllUsesWith(newExpandOp.getResult()); + expandOp->erase(); + } + + // Update function signature return type. + auto funcType = func.getFunctionType(); + SmallVector newResultTypes(funcType.getResults()); + newResultTypes[i] = contigResultType; + func.setFunctionType(FunctionType::get( + func.getContext(), funcType.getInputs(), newResultTypes)); + + llvm::errs() << "[CanonicalizeReshape] Returned view of func arg: " + << "inserted alloc+copy for expand_shape\n"; + continue; + } + + } + }); + + // --------------------------------------------------------------- + // Phase 3: Direct return of function arguments (identity reshape). + // + // When a function argument is returned directly (no reshape at all), + // there's no allocation for the output. NISA requires function outputs + // to be separate HBM allocations. Insert alloc + copy. + // + // Example: np.reshape(x, x.shape) -> identity -> return %arg0 + // --------------------------------------------------------------- + func.walk([&](func::ReturnOp returnOp) { + for (unsigned i = 0; i < returnOp.getNumOperands(); ++i) { + Value retVal = returnOp.getOperand(i); + if (!isa(retVal.getType())) + continue; + + // Only handle direct return of block arguments + if (!isa(retVal)) + continue; + + auto retType = cast(retVal.getType()); + auto contigType = MemRefType::get( + retType.getShape(), retType.getElementType(), + MemRefLayoutAttrInterface{}, retType.getMemorySpace()); + OpBuilder builder(returnOp); + Location loc = returnOp.getLoc(); + + auto allocOp = builder.create(loc, contigType); + builder.create(loc, retVal, allocOp.getResult()); + returnOp.setOperand(i, allocOp.getResult()); + + // Update function signature to match contiguous type + auto funcType = func.getFunctionType(); + SmallVector newResultTypes(funcType.getResults()); + newResultTypes[i] = contigType; + func.setFunctionType(FunctionType::get( + func.getContext(), funcType.getInputs(), newResultTypes)); + + llvm::errs() << "[CanonicalizeReshape] Direct return of func arg: " + << "inserted alloc+copy\n"; + } + }); + } +}; + +} // namespace + +namespace mlir { +namespace nkipy { + +std::unique_ptr> +createCanonicalizeReshapePass() { + return std::make_unique(); +} + +} // namespace nkipy +} // namespace mlir diff --git a/kernelgen/mlir/lib/Transforms/EliminateSameMemSpaceCopy.cpp b/kernelgen/mlir/lib/Transforms/EliminateSameMemSpaceCopy.cpp new file mode 100644 index 0000000..2534c58 --- /dev/null +++ b/kernelgen/mlir/lib/Transforms/EliminateSameMemSpaceCopy.cpp @@ -0,0 +1,413 @@ +#include "PassGen.h" +#include "nkipy/Transforms/Passes.h" +#include "nkipy/Transforms/IRHelpers.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "eliminate-same-memspace-copy" + +using namespace mlir; +using nkipy::getBaseMemRef; +using nkipy::getNkipyMemSpace; +using nkipy::isAnyHbm; + +namespace mlir { +namespace nkipy { + +namespace { + +/// Check if a copy is between the same memory space (e.g., SBUF→SBUF). +static bool isSameMemSpaceCopy(memref::CopyOp copyOp) { + auto srcMemSpace = getNkipyMemSpace(copyOp.getSource().getType()); + auto dstMemSpace = getNkipyMemSpace(copyOp.getTarget().getType()); + if (!srcMemSpace || !dstMemSpace) + return false; + return *srcMemSpace == *dstMemSpace; +} + +/// Check if two subview operations access the exact same memory region. +/// They must have the same base, offsets, sizes, and strides. +static bool isSameRegion(memref::SubViewOp a, memref::SubViewOp b) { + return getBaseMemRef(a.getSource()) == getBaseMemRef(b.getSource()) && + a.getMixedOffsets() == b.getMixedOffsets() && + a.getMixedSizes() == b.getMixedSizes() && + a.getMixedStrides() == b.getMixedStrides(); +} + +/// Check if src and dst of a copy point to the exact same memory region. +/// If so, the copy is a no-op and can be eliminated. +static bool isCopyToSelf(memref::CopyOp copyOp) { + Value src = copyOp.getSource(); + Value dst = copyOp.getTarget(); + + // Trivial case: same SSA value + if (src == dst) + return true; + + // Check if both are subviews of the same base with same parameters + auto srcSubview = src.getDefiningOp(); + auto dstSubview = dst.getDefiningOp(); + + if (srcSubview && dstSubview) + return isSameRegion(srcSubview, dstSubview); + + return false; +} + +/// Pattern to eliminate self-copies (where src and dst point to same region). +/// These are no-ops and can be erased directly. +struct EliminateSelfCopyPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::CopyOp copyOp, + PatternRewriter &rewriter) const override { + if (!isCopyToSelf(copyOp)) + return failure(); + + LLVM_DEBUG(llvm::dbgs() << "Eliminating self-copy: " << copyOp << "\n"); + rewriter.eraseOp(copyOp); + return success(); + } +}; + +/// Pattern to eliminate same-memory-space copies where dst is a fresh alloc. +/// Transforms: +/// %alloc = memref.alloc() : memref<...xf32, #sbuf> +/// memref.copy %src, %alloc : memref<...xf32, #sbuf> to memref<...xf32, #sbuf> +/// use(%alloc) +/// Into: +/// use(%src) +struct EliminateAllocCopyPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::CopyOp copyOp, + PatternRewriter &rewriter) const override { + // Only eliminate same memory space copies + if (!isSameMemSpaceCopy(copyOp)) + return failure(); + + Value src = copyOp.getSource(); + Value dst = copyOp.getTarget(); + + // Only eliminate SBUF→SBUF copies (not PSUM→PSUM which doesn't make sense) + auto srcMemSpace = getNkipyMemSpace(src.getType()); + if (*srcMemSpace != nkipy::MemSpaceEnum::Sbuf) + return failure(); + + // Destination must be a fresh allocation (not written to before this copy) + auto allocOp = dst.getDefiningOp(); + if (!allocOp) + return failure(); + + // Check that the allocation hasn't been written to before this copy + // by ensuring no user of allocOp in the same block comes before copyOp. + // Users in nested regions (e.g., loop bodies) execute after the copy, + // so they don't count as "before". + // Verify the alloc hasn't been written to before this copy. + Block *copyBlock = copyOp->getBlock(); + for (Operation *user : allocOp->getUsers()) { + if (user == copyOp.getOperation()) + continue; + if (user->getBlock() != copyBlock) + continue; // nested regions execute after the copy + if (user->isBeforeInBlock(copyOp.getOperation())) + return failure(); // another use before the copy — bail out + } + + // Check types are compatible (shapes should match for direct replacement) + auto srcType = cast(src.getType()); + auto dstType = cast(dst.getType()); + + if (srcType.getShape() != dstType.getShape()) + return failure(); + + if (srcType.getElementType() != dstType.getElementType()) + return failure(); + + LLVM_DEBUG(llvm::dbgs() << "Eliminating SBUF copy: " << copyOp << "\n"); + rewriter.replaceAllUsesWith(dst, src); + + // Erase the copy operation + rewriter.eraseOp(copyOp); + + // Erase the now-unused allocation + if (allocOp->use_empty()) + rewriter.eraseOp(allocOp); + + return success(); + } +}; + +/// Helper: convert an OpFoldResult to a Value, materializing constants. +static Value materialize(OpFoldResult ofr, Location loc, + PatternRewriter &rewriter) { + if (auto attr = dyn_cast(ofr)) + return rewriter.create( + loc, cast(attr).getInt()); + return cast(ofr); +} + +/// Helper: check if an OpFoldResult is a static zero. +static bool isStaticZero(OpFoldResult ofr) { + if (auto attr = dyn_cast(ofr)) + return cast(attr).getInt() == 0; + return false; +} + +/// Check if defOp properly dominates insertionPt. +/// A value dominates if it's defined before the insertion point in the same +/// block, or if it's defined before the ancestor op in an ancestor block. +static bool properlyDominates(Operation *defOp, Operation *insertionPt) { + Block *insertBlock = insertionPt->getBlock(); + // Same block: defOp must come before insertionPt + if (defOp->getBlock() == insertBlock) + return defOp->isBeforeInBlock(insertionPt); + // Walk up from insertion point to find if defOp is in an ancestor block, + // and if so, check it comes before the child region's parent op. + for (Operation *ancestor = insertBlock->getParentOp(); ancestor; + ancestor = ancestor->getBlock() + ? ancestor->getBlock()->getParentOp() + : nullptr) { + if (defOp->getBlock() == ancestor->getBlock()) + return defOp->isBeforeInBlock(ancestor); + } + return false; +} + +/// Ensure an OpFoldResult is usable at the given insertion point. +/// If it's a static attribute, just return it. If it's a Value whose defining +/// op doesn't dominate the insertion point, clone the defining op there. +static OpFoldResult ensureDominates(OpFoldResult ofr, Location loc, + PatternRewriter &rewriter) { + if (isa(ofr)) + return ofr; + Value v = cast(ofr); + Operation *insertionPt = rewriter.getInsertionBlock() + ? &*rewriter.getInsertionPoint() + : nullptr; + if (!insertionPt) + return ofr; + // Block arguments dominate their block and all nested blocks + if (auto blockArg = dyn_cast(v)) { + Block *argBlock = blockArg.getOwner(); + // Check the arg's block is the same or an ancestor of insertionPt + for (Block *b = insertionPt->getBlock(); b; + b = b->getParentOp() ? b->getParentOp()->getBlock() : nullptr) { + if (b == argBlock) + return ofr; + } + // Block arg doesn't dominate - fall through to clone logic + // (shouldn't normally happen) + } + Operation *defOp = v.getDefiningOp(); + if (!defOp) + return ofr; + if (properlyDominates(defOp, insertionPt)) + return ofr; + // defOp doesn't dominate - clone it at the insertion point. + // First ensure its operands dominate too. + IRMapping mapping; + for (Value operand : defOp->getOperands()) { + OpFoldResult ensured = ensureDominates(operand, loc, rewriter); + if (auto ensuredVal = dyn_cast(ensured)) { + if (ensuredVal != operand) + mapping.map(operand, ensuredVal); + } + } + Operation *cloned = rewriter.clone(*defOp, mapping); + return cloned->getResult(0); +} + +/// Pattern to eliminate intermediate SharedHBM allocations that are written +/// to via tiled copies and then bulk-copied to the output via reshape+copy. +/// +/// Matches: +/// %tmp = memref.alloc() : memref +/// memref.copy %tile, subview(%tmp, [row, col]) // tiled SBUF→HBM writes +/// %reshape = memref.reshape %tmp(...) -> memref<1xMxNxf32> +/// memref.copy %reshape, subview(%out, [batch, 0, 0]) // HBM→HBM bulk copy +/// +/// Transforms to: +/// memref.copy %tile, subview(%out, [batch, row, col]) // write directly to output +/// +struct EliminateHbmIntermediatePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::CopyOp copyOp, + PatternRewriter &rewriter) const override { + // 1. Source must come from a memref.reshape or memref.expand_shape + // (expand_shape is produced when np.expand_dims emits + // tensor.expand_shape instead of tensor.reshape) + Operation *expandOp = nullptr; + Value intermediateAllocResult; + + if (auto reshapeOp = copyOp.getSource().getDefiningOp()) { + expandOp = reshapeOp; + intermediateAllocResult = reshapeOp.getSource(); + } else if (auto expandShapeOp = + copyOp.getSource().getDefiningOp()) { + expandOp = expandShapeOp; + intermediateAllocResult = expandShapeOp.getSrc(); + } else { + return failure(); + } + + // 2. Reshape/expand source must be a fresh alloc in HBM or SharedHBM + auto intermediateAlloc = + intermediateAllocResult.getDefiningOp(); + if (!intermediateAlloc) + return failure(); + + if (!isAnyHbm(intermediateAlloc.getResult().getType())) + return failure(); + + // 3. Destination must be a subview of an HBM or SharedHBM buffer + auto dstSubview = copyOp.getTarget().getDefiningOp(); + if (!dstSubview) + return failure(); + + Value outputBase = dstSubview.getSource(); + if (!isAnyHbm(outputBase.getType())) + return failure(); + + // 4. Verify reshape/expand is a trivial leading-1 expansion (MxN → 1xMxN) + auto intermediateType = + cast(intermediateAlloc.getResult().getType()); + auto expandResultType = cast(expandOp->getResult(0).getType()); + auto intermediateShape = intermediateType.getShape(); + auto expandShape = expandResultType.getShape(); + if (expandShape.size() != intermediateShape.size() + 1) + return failure(); + if (expandShape[0] != 1) + return failure(); + for (unsigned i = 0; i < intermediateShape.size(); ++i) { + if (expandShape[i + 1] != intermediateShape[i]) + return failure(); + } + + // 5. Reshape/expand must have exactly one use (this copy) + if (!expandOp->getResult(0).hasOneUse()) + return failure(); + + // 6. Collect all subviews of the intermediate alloc. + // Only subviews and the reshape/expand are allowed as users. + SmallVector intermediateSubviews; + for (Operation *user : intermediateAlloc->getUsers()) { + if (user == expandOp) + continue; + auto subview = dyn_cast(user); + if (!subview) + return failure(); + intermediateSubviews.push_back(subview); + } + + // 7. Get batch offset and other offsets from the dst subview + SmallVector dstOffsets = dstSubview.getMixedOffsets(); + auto outputBaseType = cast(outputBase.getType()); + + LLVM_DEBUG(llvm::dbgs() << "Eliminating HBM intermediate: " + << *intermediateAlloc << "\n"); + + // 8. For each subview of the intermediate, create a new rank-reducing + // subview of the output base with the batch offset prepended. + for (auto subview : intermediateSubviews) { + SmallVector tileOffsets = subview.getMixedOffsets(); + SmallVector tileSizes = subview.getMixedSizes(); + SmallVector tileStrides = subview.getMixedStrides(); + Location loc = subview.getLoc(); + + // New offsets: [batch, tile_row + dst_row, tile_col + dst_col, ...] + // Use ensureDominates because dstOffsets come from the dstSubview which + // is defined after the nested loops, but subview is inside the loops. + rewriter.setInsertionPoint(subview); + SmallVector newOffsets; + newOffsets.push_back(ensureDominates(dstOffsets[0], loc, rewriter)); + for (unsigned i = 0; i < tileOffsets.size(); ++i) { + unsigned dstIdx = i + 1; // skip batch dim in dst + if (dstIdx < dstOffsets.size() && !isStaticZero(dstOffsets[dstIdx])) { + // Non-zero dst offset: add to tile offset + Value sum = rewriter.create( + loc, materialize(tileOffsets[i], loc, rewriter), + materialize(ensureDominates(dstOffsets[dstIdx], loc, rewriter), + loc, rewriter)); + newOffsets.push_back(sum); + } else { + newOffsets.push_back(tileOffsets[i]); + } + } + + // New sizes: [1, tile_sizes...] + SmallVector newSizes; + newSizes.push_back(rewriter.getIndexAttr(1)); + for (auto sz : tileSizes) + newSizes.push_back(sz); + + // New strides: [1, tile_strides...] + SmallVector newStrides; + newStrides.push_back(rewriter.getIndexAttr(1)); + for (auto st : tileStrides) + newStrides.push_back(st); + + // Infer rank-reduced result type (drops the leading size-1 batch dim) + auto originalResultType = cast(subview.getResult().getType()); + auto reducedType = memref::SubViewOp::inferRankReducedResultType( + originalResultType.getShape(), outputBaseType, newOffsets, newSizes, + newStrides); + + auto newSubview = rewriter.create( + loc, reducedType, outputBase, newOffsets, newSizes, newStrides); + + rewriter.replaceOp(subview, newSubview.getResult()); + } + + // 9. Erase the HBM→HBM copy + rewriter.eraseOp(copyOp); + + // 10. Clean up dead ops + if (expandOp->use_empty()) + rewriter.eraseOp(expandOp); + if (intermediateAlloc->use_empty()) + rewriter.eraseOp(intermediateAlloc); + + // dstSubview will be cleaned up by canonicalize (now has no users) + + return success(); + } +}; + +struct EliminateSameMemSpaceCopyPass + : public EliminateSameMemSpaceCopyBase { + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + func::FuncOp func = getOperation(); + + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); + + if (failed(applyPatternsGreedily(func, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr> createEliminateSameMemSpaceCopyPass() { + return std::make_unique(); +} + +} // namespace nkipy +} // namespace mlir diff --git a/kernelgen/mlir/lib/Transforms/EliminateUninitializedCopies.cpp b/kernelgen/mlir/lib/Transforms/EliminateUninitializedCopies.cpp new file mode 100644 index 0000000..7184ae6 --- /dev/null +++ b/kernelgen/mlir/lib/Transforms/EliminateUninitializedCopies.cpp @@ -0,0 +1,136 @@ +//===- EliminateUninitializedCopies.cpp - Remove copies from uninit allocs ===// +// +// This pass eliminates memref.copy operations where the source is a freshly +// allocated buffer that has never been written to (contains undefined values). +// Such copies are effectively no-ops and can be safely eliminated. +// +// This commonly occurs after buffer promotion when the original tensor was +// freshly allocated (e.g., for accumulator initialization). +// +//===----------------------------------------------------------------------===// + +#include "nkipy/Transforms/Passes.h" +#include "nkipy/Transforms/IRHelpers.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using nkipy::getBaseMemRef; + +namespace { + +/// Check if `a` executes before `b` in program order. Walks up to the nearest +/// common ancestor block to compare positions. Conservative (returns true) if +/// no common ancestor is found. +static bool executesBefore(Operation *a, Operation *b) { + if (a->getBlock() == b->getBlock()) + return a->isBeforeInBlock(b); + + // Collect a's ancestor chain (block pointers) + SmallPtrSet aBlocks; + for (Operation *cur = a; cur; cur = cur->getParentOp()) + aBlocks.insert(cur->getBlock()); + + // Walk b up to find a common ancestor block + for (Operation *bAnc = b; bAnc; bAnc = bAnc->getParentOp()) { + if (aBlocks.count(bAnc->getBlock())) { + Block *common = bAnc->getBlock(); + Operation *aAnc = a; + while (aAnc->getBlock() != common) + aAnc = aAnc->getParentOp(); + return aAnc->isBeforeInBlock(bAnc); + } + } + return true; // conservative +} + +/// Check if `val` (or any view derived from it) is written to by any user +/// that executes before `beforeOp`. Recurses through view-like ops to catch +/// writes to subviews of the allocation. +static bool hasAnyWriteBefore(Value val, Operation *beforeOp) { + for (Operation *user : val.getUsers()) { + if (user == beforeOp) + continue; + // View-like ops: recurse into their users + if (isa(user)) { + if (hasAnyWriteBefore(user->getResult(0), beforeOp)) + return true; + continue; + } + bool isWrite = false; + if (auto copy = dyn_cast(user)) { + isWrite = (copy.getTarget() == val); + } else if (isa(user)) { + isWrite = true; + } else if (auto dps = dyn_cast(user)) { + isWrite = llvm::is_contained(dps.getDpsInits(), val); + } + if (isWrite && executesBefore(user, beforeOp)) + return true; + } + return false; +} + +struct EliminateUninitializedCopiesPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(EliminateUninitializedCopiesPass) + + StringRef getArgument() const final { + return "eliminate-uninitialized-copies"; + } + + StringRef getDescription() const final { + return "Eliminate memref.copy operations where the source is an " + "uninitialized allocation"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + func::FuncOp func = getOperation(); + + SmallVector toErase; + + func.walk([&](memref::CopyOp copyOp) { + Value srcBase = getBaseMemRef(copyOp.getSource()); + if (auto alloc = srcBase.getDefiningOp()) { + if (!hasAnyWriteBefore(alloc.getResult(), copyOp)) { + llvm::errs() << "[EliminateUninitializedCopies] Eliminating copy from " + "uninitialized alloc: " + << *copyOp << "\n"; + toErase.push_back(copyOp); + } + } + }); + + // Erase the identified copy operations + for (auto copyOp : toErase) { + copyOp.erase(); + } + + if (!toErase.empty()) { + llvm::errs() << "[EliminateUninitializedCopies] Eliminated " + << toErase.size() << " copy operation(s)\n"; + } + } +}; + +} // namespace + +namespace mlir { +namespace nkipy { + +std::unique_ptr> +createEliminateUninitializedCopiesPass() { + return std::make_unique(); +} + +} // namespace nkipy +} // namespace mlir diff --git a/kernelgen/mlir/lib/Transforms/IRHelpers.cpp b/kernelgen/mlir/lib/Transforms/IRHelpers.cpp new file mode 100644 index 0000000..b34d637 --- /dev/null +++ b/kernelgen/mlir/lib/Transforms/IRHelpers.cpp @@ -0,0 +1,50 @@ +//===- IRHelpers.cpp - Shared IR utility functions -------------------------===// + +#include "nkipy/Transforms/IRHelpers.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Interfaces/ViewLikeInterface.h" + +namespace mlir { +namespace nkipy { + +std::optional getConstantInt(Value v) { + if (auto constOp = v.getDefiningOp()) + if (auto intAttr = dyn_cast(constOp.getValue())) + return intAttr.getInt(); + if (auto constOp = v.getDefiningOp()) + return constOp.value(); + return std::nullopt; +} + +Value getBaseMemRef(Value v) { + while (auto *def = v.getDefiningOp()) { + if (auto view = dyn_cast(def)) { + v = view.getViewSource(); + continue; + } + break; + } + return v; +} + +std::optional getNkipyMemSpace(Type type) { + auto memrefType = dyn_cast(type); + if (!memrefType) + return std::nullopt; + auto memSpaceAttr = memrefType.getMemorySpace(); + if (!memSpaceAttr) + return std::nullopt; + if (auto ms = dyn_cast(memSpaceAttr)) + return ms.getValue(); + return std::nullopt; +} + +Operation *getAncestorInBlock(Operation *op, Block *block) { + while (op && op->getBlock() != block) + op = op->getParentOp(); + return op; +} + +} // namespace nkipy +} // namespace mlir diff --git a/kernelgen/mlir/lib/Transforms/InferLayout.cpp b/kernelgen/mlir/lib/Transforms/InferLayout.cpp new file mode 100644 index 0000000..5b2be58 --- /dev/null +++ b/kernelgen/mlir/lib/Transforms/InferLayout.cpp @@ -0,0 +1,825 @@ +//===- InferLayout.cpp - Infer layout (tiling + placement) for all ops -----===// +// +// This pass infers layout information (tile_size, mem_space, partition_dim, +// reduction_tile) for operations that lack explicit user annotations. +// +// Algorithm: +// 1. Collect existing user annotations into a map (no IR mutation). +// 2. BFS propagation from user annotations through elementwise chains +// (forward + backward along SSA edges). +// 3. Matmul seeding: for unannotated matmul ops, apply hardware-derived +// defaults. Validate any user annotations that reached matmul operands. +// 4. BFS propagation from matmul seeds. +// 5. Elementwise fallback: for any remaining unannotated ops, apply defaults. +// 6. BFS propagation from fallback seeds. +// 7. Error if any linalg op still has no annotation. +// 8. Materialize: create nkipy.annotate ops for all newly-inferred layouts. +// +//===----------------------------------------------------------------------===// + +#include "PassGen.h" +#include "nkipy/Transforms/Passes.h" +#include "nkipy/Transforms/HardwareConstants.h" +#include "nkipy/Transforms/IRHelpers.h" +#include "nkipy/Transforms/OpClassification.h" +#include "nkipy/Dialect/NkipyAttrs.h" +#include "nkipy/Dialect/NkipyDialect.h" +#include "nkipy/Dialect/NkipyOps.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Builders.h" +#include "llvm/Support/raw_ostream.h" + +#include + +using namespace mlir; +using namespace nkipy; + +namespace mlir { +namespace nkipy { + +namespace { + +/// Parsed matmul dimensions: A[M,K] x B[K,N] -> C[M,N]. +/// For batch_matmul: A[B,M,K] x B[B,K,N] -> C[B,M,N]. +struct MatmulDims { + int64_t M, K, N; + int64_t B = 0; // 0 means non-batched + + bool isBatched() const { return B > 0; } + + /// Try to parse from a matmul op. Returns std::nullopt on failure. + static std::optional parse(linalg::LinalgOp matmulOp) { + SmallVector inputs(matmulOp.getDpsInputs()); + if (inputs.size() < 2) + return std::nullopt; + auto typeA = dyn_cast(inputs[0].getType()); + auto typeB = dyn_cast(inputs[1].getType()); + if (!typeA || !typeB || + !typeA.hasStaticShape() || !typeB.hasStaticShape()) + return std::nullopt; + int64_t rankA = typeA.getRank(); + if (rankA == 3) { + // batch_matmul: A[B,M,K] x B[B,K,N] + return MatmulDims{typeA.getShape()[1], typeA.getShape()[2], + typeB.getShape()[2], typeA.getShape()[0]}; + } + return MatmulDims{typeA.getShape()[0], typeA.getShape()[1], + typeB.getShape()[1]}; + } +}; + +//===----------------------------------------------------------------------===// +// Layout entry for tracking tiling + placement annotations +//===----------------------------------------------------------------------===// + +struct LayoutEntry { + DenseI64ArrayAttr tileSize; + DenseI64ArrayAttr reductionTile; + MemSpaceEnumAttr memSpace; + IntegerAttr partitionDim; // UI32Attr, optional + int seedId = -1; // Which seed originated this entry +}; + +//===----------------------------------------------------------------------===// +// Pass implementation +//===----------------------------------------------------------------------===// + +struct NkipyInferLayoutPass + : public InferLayoutBase { + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + + /// Check if a value traces back to a function argument (possibly through + /// tensor.extract_slice). Such values live in HBM and cannot be placed + /// in SBUF. + bool tracesToFuncArg(Value val) { + while (val) { + if (isa(val)) + return true; + Operation *def = val.getDefiningOp(); + if (!def) + return false; + if (auto extractOp = dyn_cast(def)) { + val = extractOp.getSource(); + continue; + } + return false; + } + return false; + } + + /// Determine the appropriate memory space for a value. If the value + /// already has a known mem_space (user annotation) respect it; if it + /// traces to a function argument, use SharedHBM; otherwise default to SBUF. + MemSpaceEnumAttr inferMemSpace(Value val, + const DenseMap &annotatedValues, + MLIRContext *ctx) { + auto it = annotatedValues.find(val); + if (it != annotatedValues.end() && it->second.memSpace) + return it->second.memSpace; + if (tracesToFuncArg(val)) + return getSharedHbmMemSpace(ctx); + return getSbufMemSpace(ctx); + } + + /// Create a default MemSpaceEnumAttr for SBUF. + MemSpaceEnumAttr getSbufMemSpace(MLIRContext *ctx) { + return MemSpaceEnumAttr::get(ctx, MemSpaceEnum::Sbuf); + } + + MemSpaceEnumAttr getSharedHbmMemSpace(MLIRContext *ctx) { + return MemSpaceEnumAttr::get(ctx, MemSpaceEnum::SharedHbm); + } + + /// Create a UI32 IntegerAttr for partition_dim. + IntegerAttr makePartitionDimAttr(MLIRContext *ctx, unsigned pdim) { + return IntegerAttr::get(IntegerType::get(ctx, 32, IntegerType::Unsigned), + pdim); + } + + /// Compute the default elementwise tile_size for a given shape and + /// partition_dim. Rule: partition dim gets min(size, 128), last dim + /// (in original coordinates, excluding partition dim) gets full extent, + /// middle dims get 1. + SmallVector computeElementwiseTileSize(ArrayRef shape, + unsigned partDim) { + unsigned rank = shape.size(); + SmallVector tile(rank, 1); + + // Partition dim: capped at MAX_PARTITION_DIM. + tile[partDim] = std::min(shape[partDim], MAX_PARTITION_DIM); + + // Free dim: the last dim that is not the partition dim. + unsigned freeDim = rank - 1; + if (freeDim == partDim && rank > 1) + freeDim = rank - 2; + tile[freeDim] = shape[freeDim]; + + return tile; + } + + /// Try to compute a propagated layout for targetVal, given sourceVal's + /// layout. Does NOT mutate IR or annotatedValues. Returns true on + /// success, with the result in outLayout. + bool computePropagatedLayout(Value sourceVal, Value targetVal, + const LayoutEntry &sourceLayout, + linalg::LinalgOp targetLinalgOp, + const DenseMap &annotatedValues, + MLIRContext *ctx, + LayoutEntry &outLayout) { + auto targetType = dyn_cast(targetVal.getType()); + if (!targetType || !targetType.hasStaticShape()) + return false; + + auto sourceType = dyn_cast(sourceVal.getType()); + if (!sourceType || !sourceType.hasStaticShape()) + return false; + + ArrayRef targetShape = targetType.getShape(); + ArrayRef sourceShape = sourceType.getShape(); + + SmallVector tileValues(sourceLayout.tileSize.asArrayRef()); + + // If the source tile has fewer elements than the source shape rank, + // this is a reduction result with stripped dims. Reconstruct the + // full-rank tile by re-inserting reduction_tile values at the reduced + // dimensions (where sourceShape[i] == 1 && targetShape[i] > 1). + // Only applies to generic reductions (sum, mean), NOT matmul. + if (tileValues.size() < sourceShape.size() && sourceLayout.reductionTile) { + Operation *sourceDefOp = sourceVal.getDefiningOp(); + auto sourceLinalgOp = sourceDefOp + ? dyn_cast(sourceDefOp) : nullptr; + if (sourceLinalgOp && isReductionGeneric(sourceLinalgOp) && + !isMatmulOp(sourceLinalgOp)) { + ArrayRef redTile = sourceLayout.reductionTile.asArrayRef(); + SmallVector expanded; + size_t tileIdx = 0, redIdx = 0; + for (size_t i = 0; i < sourceShape.size(); i++) { + if (sourceShape[i] == 1 && targetShape.size() > i && + targetShape[i] > 1 && redIdx < redTile.size()) { + expanded.push_back(redTile[redIdx++]); + } else if (tileIdx < tileValues.size()) { + expanded.push_back(tileValues[tileIdx++]); + } else { + expanded.push_back(1); + } + } + tileValues = expanded; + } + } + + if (targetShape != sourceShape) { + if (targetShape.size() != sourceShape.size()) + return false; + bool broadcastable = true; + for (size_t i = 0; i < targetShape.size(); i++) { + if (targetShape[i] != sourceShape[i] && + targetShape[i] != 1 && sourceShape[i] != 1) { + broadcastable = false; + break; + } + } + if (!broadcastable) + return false; + + for (size_t i = 0; i < tileValues.size() && i < targetShape.size(); i++) { + tileValues[i] = std::min(tileValues[i], targetShape[i]); + } + } + + // Clamp partition tile to the minimum partition dim across all inputs. + // This ensures that after tiling, broadcast operands with smaller + // partition dims (e.g., [1,N] in add(a[128,N], b[1,N])) produce + // matching tile shapes for NISA ops like tensor_tensor_arith. + if (isElementwiseOp(targetLinalgOp) || isReductionGeneric(targetLinalgOp)) { + unsigned partDim = sourceLayout.partitionDim + ? sourceLayout.partitionDim.getValue().getZExtValue() : 0; + if (partDim < tileValues.size()) { + int64_t minPartSize = tileValues[partDim]; + for (Value input : targetLinalgOp.getDpsInputs()) { + auto inType = dyn_cast(input.getType()); + if (!inType || !inType.hasStaticShape()) + continue; + if (partDim < (unsigned)inType.getRank()) + minPartSize = std::min(minPartSize, inType.getShape()[partDim]); + } + if (minPartSize < tileValues[partDim]) { + tileValues[partDim] = minPartSize; + llvm::errs() << "[InferLayout] Clamped partition tile to " + << minPartSize << " for broadcast operand\n"; + } + } + } + + // Handle reduction ops: strip tile dimensions for size-1 output dims. + DenseI64ArrayAttr inferredReductionTile; + if (isReductionGeneric(targetLinalgOp)) { + SmallVector stripped; + SmallVector reductionTileValues; + + bool needInferReductionTile = !sourceLayout.reductionTile; + + Value reductionInput = targetLinalgOp.getDpsInputs()[0]; + auto inputType = dyn_cast(reductionInput.getType()); + ArrayRef inputTile; + if (needInferReductionTile) { + auto inputIt = annotatedValues.find(reductionInput); + if (inputIt != annotatedValues.end()) + inputTile = inputIt->second.tileSize.asArrayRef(); + } + + for (size_t i = 0; i < tileValues.size() && i < targetShape.size(); i++) { + if (targetShape[i] > 1) { + stripped.push_back(tileValues[i]); + } else if (needInferReductionTile) { + if (!inputTile.empty() && i < inputTile.size()) + reductionTileValues.push_back(inputTile[i]); + else if (inputType && inputType.hasStaticShape() && + i < (size_t)inputType.getRank()) + reductionTileValues.push_back(inputType.getShape()[i]); + } + } + tileValues = stripped; + + if (!reductionTileValues.empty()) + inferredReductionTile = + DenseI64ArrayAttr::get(ctx, reductionTileValues); + } + + outLayout.tileSize = DenseI64ArrayAttr::get(ctx, tileValues); + outLayout.reductionTile = + inferredReductionTile ? inferredReductionTile : sourceLayout.reductionTile; + // Always default to SBUF for propagated layouts. mem_space is + // determined by the value's role (return value → SharedHbm, else SBUF), + // not by propagation from neighbors. + outLayout.memSpace = getSbufMemSpace(ctx); + outLayout.partitionDim = sourceLayout.partitionDim; + return true; + } + + /// Check if two layouts conflict. Returns a human-readable description + /// of the conflict, or empty string if they are compatible. + /// + /// Tile sizes are compatible if one divides the other in every dimension + /// (the consumer with the smaller tile can subdivide the producer's tiles). + /// mem_space is not checked: propagation defaults to SBUF while user + /// annotations may specify SharedHbm — this is expected, not a conflict. + std::string describeConflict(const LayoutEntry &existing, + const LayoutEntry &proposed) { + if (existing.partitionDim && proposed.partitionDim && + existing.partitionDim.getInt() != proposed.partitionDim.getInt()) { + return "conflicting partition_dim: existing=" + + std::to_string(existing.partitionDim.getInt()) + + " vs proposed=" + + std::to_string(proposed.partitionDim.getInt()); + } + if (existing.tileSize && proposed.tileSize) { + auto existArr = existing.tileSize.asArrayRef(); + auto propArr = proposed.tileSize.asArrayRef(); + if (existArr.size() == propArr.size()) { + for (size_t i = 0; i < existArr.size(); i++) { + int64_t larger = std::max(existArr[i], propArr[i]); + int64_t smaller = std::min(existArr[i], propArr[i]); + if (smaller > 0 && larger % smaller != 0) { + std::string msg = "incompatible tile_size: ["; + for (auto [j, v] : llvm::enumerate(existArr)) { + if (j > 0) msg += ", "; + msg += std::to_string(v); + } + msg += "] vs ["; + for (auto [j, v] : llvm::enumerate(propArr)) { + if (j > 0) msg += ", "; + msg += std::to_string(v); + } + msg += "] (one must divide the other in each dimension)"; + return msg; + } + } + } + } + return ""; // no conflict + } + + /// Compute the matmul-specific layout for an operand, given the matmul's + /// shapes. operandIdx: 0 = A (stationary), 1 = B (moving). + /// Returns true on success. + bool computeMatmulOperandLayout(linalg::LinalgOp matmulOp, + unsigned operandIdx, + Value operandValue, + const LayoutEntry &resultLayout, + const DenseMap &annotatedValues, + MLIRContext *ctx, + LayoutEntry &outLayout) { + auto dims = MatmulDims::parse(matmulOp); + if (!dims) + return false; + + outLayout.memSpace = inferMemSpace(operandValue, annotatedValues, ctx); + outLayout.seedId = resultLayout.seedId; + + if (dims->isBatched()) { + // batch_matmul: operands are [B,M,K] and [B,K,N] + // Tile batch dim to 1 (will be looped over by CanonicalizePartitionDim) + if (operandIdx == 0) { + // A (stationary): [B, M, K], partition_dim=2 (K) + outLayout.tileSize = DenseI64ArrayAttr::get( + ctx, {1, + std::min(dims->M, MAX_PARTITION_DIM), + std::min(dims->K, MAX_PARTITION_DIM)}); + outLayout.partitionDim = makePartitionDimAttr(ctx, 2); + } else { + // B (moving): [B, K, N], partition_dim=1 (K) + outLayout.tileSize = DenseI64ArrayAttr::get( + ctx, {1, + std::min(dims->K, MAX_PARTITION_DIM), + std::min(dims->N, MAX_FREE_DIM_MATMUL)}); + outLayout.partitionDim = makePartitionDimAttr(ctx, 1); + } + } else if (operandIdx == 0) { + // A (stationary): [M, K], partition_dim=1 (K) + outLayout.tileSize = DenseI64ArrayAttr::get( + ctx, {std::min(dims->M, MAX_PARTITION_DIM), + std::min(dims->K, MAX_PARTITION_DIM)}); + outLayout.partitionDim = makePartitionDimAttr(ctx, 1); + } else { + // B (moving): [K, N], partition_dim=0 (K) + outLayout.tileSize = DenseI64ArrayAttr::get( + ctx, {std::min(dims->K, MAX_PARTITION_DIM), + std::min(dims->N, MAX_FREE_DIM_MATMUL)}); + outLayout.partitionDim = makePartitionDimAttr(ctx, 0); + } + return true; + } + + /// Result of tryInsertLayout: whether the value was newly inserted, + /// already existed (compatible), or conflicted. + enum class InsertResult { Inserted, Exists, Conflict }; + + /// Try to insert a propagated layout for targetValue. If already annotated, + /// check for conflicts. On success (Inserted), adds to queue. + InsertResult tryInsertLayout(Value targetValue, + const LayoutEntry &propagated, + const LayoutEntry &sourceLayout, + Operation *targetOp, + std::deque &queue, + DenseMap &annotatedValues, + int &numInferred, + StringRef direction) { + auto existingIt = annotatedValues.find(targetValue); + if (existingIt != annotatedValues.end()) { + if (existingIt->second.seedId != sourceLayout.seedId) { + std::string conflict = + describeConflict(existingIt->second, propagated); + if (!conflict.empty()) { + llvm::errs() << "[InferLayout] CONFLICT (" << direction << "): " + << "current seed=" << sourceLayout.seedId + << " existing seed=" << existingIt->second.seedId + << " at " << targetOp->getName() << "\n"; + targetOp->emitError("infer-layout: ") << conflict; + return InsertResult::Conflict; + } + } + return InsertResult::Exists; + } + + LayoutEntry entry = propagated; + entry.seedId = sourceLayout.seedId; + annotatedValues[targetValue] = entry; + queue.push_back(targetValue); + numInferred++; + llvm::errs() << "[InferLayout] " << direction << "-propagated to " + << targetOp->getName() << "\n"; + return InsertResult::Inserted; + } + + /// BFS propagation from all values currently in the queue. + /// Explores both forward (to consumers) and backward (to producers) + /// along SSA edges through elementwise/reduction chains. + /// Only modifies annotatedValues (no IR mutation). + int bfsPropagation(std::deque &queue, + DenseMap &annotatedValues, + MLIRContext *ctx, + bool &hasConflict) { + int numInferred = 0; + + while (!queue.empty()) { + Value current = queue.front(); + queue.pop_front(); + + if (hasConflict) + break; + + auto it = annotatedValues.find(current); + if (it == annotatedValues.end()) + continue; + LayoutEntry layout = it->second; + + Operation *defOp = current.getDefiningOp(); + auto defLinalgOp = defOp ? dyn_cast(defOp) : nullptr; + + // --- Backward: from result to producer inputs --- + if (defLinalgOp) { + bool isMatmul = isMatmulOp(defLinalgOp); + SmallVector inputs(defLinalgOp.getDpsInputs()); + + for (auto [idx, input] : llvm::enumerate(inputs)) { + Operation *producerOp = input.getDefiningOp(); + if (!producerOp) + continue; + + Value targetValue = input; + linalg::LinalgOp targetLinalgOp; + if (!isMatmul) { + targetLinalgOp = dyn_cast(producerOp); + if (!targetLinalgOp || !isAnnotatableOp(targetLinalgOp) || + targetLinalgOp->getNumResults() == 0) + continue; + targetValue = targetLinalgOp->getResult(0); + } + + LayoutEntry propagated; + bool computed = isMatmul + ? computeMatmulOperandLayout(defLinalgOp, idx, targetValue, + layout, annotatedValues, ctx, + propagated) + : computePropagatedLayout(current, targetValue, layout, + targetLinalgOp, annotatedValues, ctx, + propagated); + if (!computed) + continue; + + auto result = tryInsertLayout(targetValue, propagated, layout, + producerOp, queue, annotatedValues, + numInferred, "Backward"); + if (result == InsertResult::Conflict) { + hasConflict = true; + return numInferred; + } + } + } + + // --- Forward: from a value to its consumer op results --- + for (Operation *userOp : current.getUsers()) { + auto userLinalgOp = dyn_cast(userOp); + if (!userLinalgOp) + continue; + if (!isElementwiseOp(userLinalgOp) && !isReductionGeneric(userLinalgOp)) + continue; + if (userLinalgOp->getNumResults() == 0) + continue; + + Value userResult = userLinalgOp->getResult(0); + + LayoutEntry propagated; + if (!computePropagatedLayout(current, userResult, layout, + userLinalgOp, annotatedValues, + ctx, propagated)) + continue; + + auto result = tryInsertLayout(userResult, propagated, layout, + userOp, queue, annotatedValues, + numInferred, "Forward"); + if (result == InsertResult::Conflict) { + hasConflict = true; + return numInferred; + } + } + } + + return numInferred; + } + + void runOnOperation() override { + func::FuncOp func = getOperation(); + MLIRContext *ctx = func.getContext(); + + llvm::errs() << "[InferLayout] Processing function: " + << func.getName() << "\n"; + + // This map tracks the layout assignment for each Value. + // It is built up across all phases. IR is only mutated at the very end. + DenseMap annotatedValues; + // Track which values already had user-created nkipy.annotate ops + // so we don't duplicate them during materialization. + DenseSet userAnnotatedValues; + int numInferred = 0; + int nextSeedId = 0; + + // Collect values that flow into func.return — these must be SharedHbm. + DenseSet returnValues; + func.walk([&](func::ReturnOp returnOp) { + for (Value operand : returnOp.getOperands()) + returnValues.insert(operand); + }); + + // ================================================================ + // Phase 1: Collect existing user annotations + // ================================================================ + func.walk([&](nkipy::AnnotateOp annotateOp) { + if (isInsideNkipyRegion(annotateOp)) + return; + if (auto tileSizeAttr = annotateOp.getTileSizeAttr()) { + LayoutEntry entry; + entry.tileSize = tileSizeAttr; + entry.reductionTile = annotateOp.getReductionTileAttr(); + entry.memSpace = annotateOp.getMemSpaceAttr(); + entry.partitionDim = annotateOp.getPartitionDimAttr(); + entry.seedId = nextSeedId++; + annotatedValues[annotateOp.getTarget()] = entry; + userAnnotatedValues.insert(annotateOp.getTarget()); + } + }); + + llvm::errs() << "[InferLayout] Found " << annotatedValues.size() + << " user annotation(s)\n"; + + // ================================================================ + // Phase 2: BFS propagation from user annotations + // ================================================================ + bool hasConflict = false; + { + std::deque queue; + for (auto &kv : annotatedValues) + queue.push_back(kv.first); + numInferred += bfsPropagation(queue, annotatedValues, ctx, hasConflict); + } + if (hasConflict) + return signalPassFailure(); + + // ================================================================ + // Phase 3: Matmul seeding + // ================================================================ + { + std::deque matmulSeeds; + + func.walk([&](linalg::LinalgOp linalgOp) { + if (isInsideNkipyRegion(linalgOp)) + return; + if (!isMatmulOp(linalgOp) || linalgOp->getNumResults() == 0) + return; + + Value resultVal = linalgOp->getResult(0); + if (annotatedValues.count(resultVal)) + return; + + auto dims = MatmulDims::parse(linalgOp); + if (!dims) + return; + + // Seed result C[M,N] (or C[B,M,N] for batch_matmul): + // partition_dim=0 (M for 2D, B for 3D — batch dim gets looped over) + LayoutEntry cLayout; + if (dims->isBatched()) { + // batch_matmul: tile_size=[1, M_tile, N_tile] — batch dim tiled to 1 + // (CanonicalizePartitionDim will loop over batch and drop it) + cLayout.tileSize = DenseI64ArrayAttr::get( + ctx, + {1, + std::min(dims->M, MAX_PARTITION_DIM), + std::min(dims->N, MAX_FREE_DIM_MATMUL)}); + } else { + cLayout.tileSize = DenseI64ArrayAttr::get( + ctx, + {std::min(dims->M, MAX_PARTITION_DIM), + std::min(dims->N, MAX_FREE_DIM_MATMUL)}); + } + cLayout.reductionTile = DenseI64ArrayAttr::get( + ctx, {std::min(dims->K, MAX_PARTITION_DIM)}); + cLayout.memSpace = getSbufMemSpace(ctx); + cLayout.partitionDim = makePartitionDimAttr(ctx, 0); + cLayout.seedId = nextSeedId++; + + annotatedValues[resultVal] = cLayout; + matmulSeeds.push_back(resultVal); + numInferred++; + llvm::errs() << "[InferLayout] Matmul-seeded result C\n"; + + // Operands A and B are seeded during BFS backward propagation + // from the matmul result via computeMatmulOperandLayout(). + }); + + // Phase 4: BFS propagation from matmul seeds + numInferred += bfsPropagation(matmulSeeds, annotatedValues, ctx, + hasConflict); + } + if (hasConflict) + return signalPassFailure(); + + // ================================================================ + // Phase 5: Elementwise / standalone fallback + // ================================================================ + { + std::deque fallbackSeeds; + + func.walk([&](linalg::LinalgOp linalgOp) { + if (isInsideNkipyRegion(linalgOp)) + return; + if (!isAnnotatableOp(linalgOp)) + return; + if (linalgOp->getNumResults() == 0) + return; + + Value result = linalgOp->getResult(0); + if (annotatedValues.count(result)) + return; + + auto resultType = dyn_cast(result.getType()); + if (!resultType || !resultType.hasStaticShape()) + return; + + ArrayRef shape = resultType.getShape(); + unsigned partDim = 0; + + LayoutEntry layout; + layout.memSpace = getSbufMemSpace(ctx); + layout.partitionDim = makePartitionDimAttr(ctx, partDim); + layout.seedId = nextSeedId++; + + if (isReductionGeneric(linalgOp)) { + // For reduction ops, tile_size covers only parallel dims and + // reduction_tile covers only reduction dims. + auto genericOp = cast(linalgOp.getOperation()); + auto iterTypes = genericOp.getIteratorTypesArray(); + + // Get input shape to determine reduction dim sizes. + Value reductionInput = linalgOp.getDpsInputs()[0]; + auto inputType = dyn_cast(reductionInput.getType()); + + SmallVector parallelTile; + SmallVector reductionTile; + for (size_t i = 0; i < iterTypes.size(); i++) { + if (iterTypes[i] == utils::IteratorType::parallel) { + int64_t dimSize = shape[parallelTile.size()]; + if (parallelTile.size() == partDim) + parallelTile.push_back(std::min(dimSize, MAX_PARTITION_DIM)); + else + parallelTile.push_back(dimSize); + } else { + int64_t redDimSize = (inputType && inputType.hasStaticShape() && + i < (size_t)inputType.getRank()) + ? inputType.getShape()[i] + : 1; + reductionTile.push_back(redDimSize); + } + } + + layout.tileSize = DenseI64ArrayAttr::get(ctx, parallelTile); + layout.reductionTile = DenseI64ArrayAttr::get(ctx, reductionTile); + } else { + SmallVector tile = computeElementwiseTileSize(shape, partDim); + // Clamp partition tile for broadcast inputs. + int64_t minPartSize = tile[partDim]; + for (Value input : linalgOp.getDpsInputs()) { + auto inType = dyn_cast(input.getType()); + if (!inType || !inType.hasStaticShape()) + continue; + if (partDim < (unsigned)inType.getRank()) + minPartSize = std::min(minPartSize, inType.getShape()[partDim]); + } + tile[partDim] = minPartSize; + layout.tileSize = DenseI64ArrayAttr::get(ctx, tile); + } + + annotatedValues[result] = layout; + fallbackSeeds.push_back(result); + numInferred++; + llvm::errs() << "[InferLayout] Fallback-seeded " + << linalgOp->getName() << "\n"; + }); + + // Phase 6: BFS propagation from fallback seeds + numInferred += bfsPropagation(fallbackSeeds, annotatedValues, ctx, + hasConflict); + } + if (hasConflict) + return signalPassFailure(); + + // ================================================================ + // Phase 7: Error check — all linalg ops should be annotated + // ================================================================ + bool hasError = false; + func.walk([&](linalg::LinalgOp linalgOp) { + if (isInsideNkipyRegion(linalgOp)) + return; + if (!isAnnotatableOp(linalgOp)) + return; + if (linalgOp->getNumResults() == 0) + return; + Value result = linalgOp->getResult(0); + if (!annotatedValues.count(result)) { + linalgOp.emitError("infer-layout: unable to determine layout for op"); + hasError = true; + } + }); + if (hasError) + return signalPassFailure(); + + // ================================================================ + // Phase 7.5b: Override mem_space for return values to SharedHbm + // ================================================================ + for (Value retVal : returnValues) { + auto it = annotatedValues.find(retVal); + if (it != annotatedValues.end()) { + it->second.memSpace = getSharedHbmMemSpace(ctx); + } + } + + // ================================================================ + // Phase 7.6: Fill in missing mem_space for all entries + // ================================================================ + for (auto &kv : annotatedValues) { + LayoutEntry &layout = kv.second; + if (!layout.memSpace) { + layout.memSpace = inferMemSpace(kv.first, annotatedValues, ctx); + } + } + + // ================================================================ + // Phase 8: Materialize — create nkipy.annotate ops for all + // newly-inferred layouts, and update user-annotated ops + // that had mem_space filled in. + // ================================================================ + OpBuilder builder(ctx); + for (auto &kv : annotatedValues) { + Value val = kv.first; + const LayoutEntry &layout = kv.second; + + if (userAnnotatedValues.count(val)) { + // Update existing annotate op if mem_space was inferred + for (Operation *user : val.getUsers()) { + if (auto annotateOp = dyn_cast(user)) { + if (!annotateOp.getMemSpace() && layout.memSpace) { + annotateOp.setMemSpaceAttr(layout.memSpace); + } + break; + } + } + continue; + } + + builder.setInsertionPointAfterValue(val); + builder.create( + val.getLoc(), val, + layout.memSpace, layout.partitionDim, + layout.tileSize, layout.reductionTile); + } + + llvm::errs() << "[InferLayout] Inferred " << numInferred + << " layout annotation(s)\n"; + } +}; + +} // namespace + +std::unique_ptr> createInferLayoutPass() { + return std::make_unique(); +} + +} // namespace nkipy +} // namespace mlir diff --git a/kernelgen/mlir/lib/Transforms/InlineNkipyReference.cpp b/kernelgen/mlir/lib/Transforms/InlineNkipyReference.cpp new file mode 100644 index 0000000..b3f2303 --- /dev/null +++ b/kernelgen/mlir/lib/Transforms/InlineNkipyReference.cpp @@ -0,0 +1,140 @@ +#include "PassGen.h" +#include "nkipy/Transforms/Passes.h" +#include "nkipy/Dialect/NkipyDialect.h" +#include "nkipy/Dialect/NkipyOps.h" + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/IRMapping.h" + +using namespace mlir; +using namespace nkipy; + +namespace mlir { +namespace nkipy { + +namespace { + +/// After inlining a reference body into post-bufferization IR, there may be +/// tensor.extract(bufferization.to_tensor(memref), indices) patterns that +/// canonicalize alone cannot fold. Walk the function and replace them with +/// memref.load(memref, indices). +static void foldTensorExtractOfToTensor(func::FuncOp func) { + SmallVector toFold; + func.walk([&](tensor::ExtractOp extractOp) { + if (extractOp.getTensor().getDefiningOp()) + toFold.push_back(extractOp); + }); + for (auto extractOp : toFold) { + auto toTensor = + extractOp.getTensor().getDefiningOp(); + OpBuilder b(extractOp); + Value loaded = b.create( + extractOp.getLoc(), toTensor.getBuffer(), extractOp.getIndices()); + extractOp.getResult().replaceAllUsesWith(loaded); + extractOp.erase(); + } +} + +struct InlineNkipyReferencePass + : public InlineNkipyReferenceBase { + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + func::FuncOp func = getOperation(); + + // Collect nkipy ops with non-empty reference_impl regions. + SmallVector opsToInline; + func.walk([&](Operation *op) { + if (!op->getDialect() || + op->getDialect()->getNamespace() != "nkipy") + return; + if (isa(op) || isa(op)) + return; + if (op->getNumRegions() == 0) + return; + Region ®ion = op->getRegion(0); + if (region.empty()) + return; + opsToInline.push_back(op); + }); + + for (Operation *op : opsToInline) + inlineReferenceRegion(op); + + // Clean up tensor.extract(to_tensor(memref)) → memref.load(memref) + // patterns left after inlining into post-bufferization IR. + foldTensorExtractOfToTensor(func); + } + + void inlineReferenceRegion(Operation *nkipyOp) { + Region ®ion = nkipyOp->getRegion(0); + Block &refBlock = region.front(); + + // Build value mapping: block args → op operands. + // Post-bufferization, operands may be memrefs while block args are tensors. + // Insert to_tensor conversions so the reference body works correctly. + IRMapping mapping; + OpBuilder builder(nkipyOp); + for (unsigned i = 0; i < refBlock.getNumArguments(); ++i) { + Value blockArg = refBlock.getArgument(i); + Value operand = nkipyOp->getOperand(i); + if (isa(blockArg.getType()) && + isa(operand.getType())) { + operand = builder.create( + nkipyOp->getLoc(), blockArg.getType(), operand); + } + mapping.map(blockArg, operand); + } + + // Clone each op (except the yield) before the nkipy op. + SmallVector yieldValues; + + for (Operation &innerOp : llvm::make_early_inc_range(refBlock)) { + if (isa(innerOp)) { + for (Value v : innerOp.getOperands()) + yieldValues.push_back(mapping.lookupOrDefault(v)); + } else { + Operation *cloned = builder.clone(innerOp, mapping); + (void)cloned; + } + } + + // Post-bufferization: the inlined body yields tensors, but the outer code + // reads from the DPS output memref. Copy each result into its DPS init. + if (auto dpsOp = dyn_cast(nkipyOp)) { + for (unsigned i = 0; i < yieldValues.size(); ++i) { + if (!isa(yieldValues[i].getType())) + continue; + auto inits = dpsOp.getDpsInits(); + if (i >= inits.size() || !isa(inits[i].getType())) + continue; + auto tensorType = cast(yieldValues[i].getType()); + auto bufType = MemRefType::get(tensorType.getShape(), + tensorType.getElementType()); + auto buf = builder.create( + nkipyOp->getLoc(), bufType, yieldValues[i]); + builder.create(nkipyOp->getLoc(), buf, inits[i]); + } + } + + // Replace uses and erase. + assert(yieldValues.size() == nkipyOp->getNumResults()); + for (unsigned i = 0; i < nkipyOp->getNumResults(); ++i) + nkipyOp->getResult(i).replaceAllUsesWith(yieldValues[i]); + nkipyOp->erase(); + } +}; +} // namespace + +std::unique_ptr> createInlineNkipyReferencePass() { + return std::make_unique(); +} + +} // namespace nkipy +} // namespace mlir diff --git a/kernelgen/mlir/lib/Transforms/InsertMemRefDealloc.cpp b/kernelgen/mlir/lib/Transforms/InsertMemRefDealloc.cpp new file mode 100644 index 0000000..73a4881 --- /dev/null +++ b/kernelgen/mlir/lib/Transforms/InsertMemRefDealloc.cpp @@ -0,0 +1,189 @@ +//===- InsertMemRefDealloc.cpp - Insert dealloc after last use -----------===// +// +// This pass analyzes the lifetime of memref.alloc operations with NISA memory +// space attributes and inserts memref.dealloc after each allocation's last use. +// These are later lowered to nisa.release by the linalg-to-nisa pass. +// +// Uses last-use deallocation: traces through view chains (subview, +// reinterpret_cast, collapse_shape, expand_shape, cast) via BFS to find the +// last consumer, then inserts dealloc immediately after. Falls back to +// scope-based dealloc if the alloc has no uses. Uses inside nested regions +// (e.g., scf.for loop bodies) are mapped to their ancestor op in the +// allocation's block. +// +//===----------------------------------------------------------------------===// + +#include "nkipy/Transforms/Passes.h" +#include "nkipy/Transforms/IRHelpers.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Interfaces/ViewLikeInterface.h" +#include "mlir/Pass/Pass.h" +#include "nkipy/Dialect/NkipyAttrs.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "insert-memref-dealloc" + +using namespace mlir; +using nkipy::getAncestorInBlock; +using nkipy::getBaseMemRef; +using nkipy::getNkipyMemSpace; + +namespace { + +/// Collect all allocations that escape via function return values. +/// For each return operand, trace back through view chains to find the base +/// allocation and add it to the set of escaped allocations. +static void collectEscapedAllocations(func::FuncOp func, + llvm::DenseSet &escaped) { + func.walk([&](func::ReturnOp returnOp) { + for (Value operand : returnOp.getOperands()) { + // Only care about memref types + if (!isa(operand.getType())) + continue; + + // Trace back to base allocation + Value base = getBaseMemRef(operand); + if (auto allocOp = base.getDefiningOp()) { + escaped.insert(allocOp); + } + } + }); +} + +struct InsertMemRefDeallocPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertMemRefDeallocPass) + + // Track if any errors occurred during the pass + bool hasError = false; + + StringRef getArgument() const final { return "insert-memref-dealloc"; } + + StringRef getDescription() const final { + return "Insert memref.dealloc operations to mark allocation lifetime ends"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + func::FuncOp func = getOperation(); + hasError = false; + + // First, collect all allocations that escape via return values + llvm::DenseSet escapedAllocs; + collectEscapedAllocations(func, escapedAllocs); + + // Collect allocations to process and validate they have NISA memspace. + // Skip ops inside nkipy regions (reference_impl bodies are for CPU + // simulation only and don't have NISA memory spaces). + SmallVector allocOps; + func.walk([&](memref::AllocOp op) { + if (nkipy::isInsideNkipyRegion(op)) + return; + auto memSpace = getNkipyMemSpace(op.getType()); + if (memSpace) { + // Skip SharedHbm, Hbm (externally managed) and Constant (scalar + // broadcasts — small and not real SBUF allocations). + if (*memSpace == nkipy::MemSpaceEnum::SharedHbm || + *memSpace == nkipy::MemSpaceEnum::Hbm || + *memSpace == nkipy::MemSpaceEnum::Constant) + return; + allocOps.push_back(op); + return; + } + + // Error: memref.alloc without memory space annotation. + op.emitError() << "memref.alloc must have an nkipy memory space " + << "annotation (SBUF, PSUM, HBM, or SHAREDHBM)"; + hasError = true; + }); + + if (hasError) { + signalPassFailure(); + return; + } + + if (allocOps.empty()) + return; + + int numInserted = 0; + for (auto allocOp : allocOps) { + // Skip if allocation escapes (returned from function) + if (escapedAllocs.contains(allocOp)) + continue; + + Block *block = allocOp->getBlock(); + Value allocVal = allocOp.getResult(); + + // Find the last use of this alloc (or any derived view) in its block. + // We trace through the entire view chain (subview, reinterpret_cast, + // collapse_shape, expand_shape, cast) because the alloc may be + // accessed only through derived views, not directly. + // Uses inside nested regions (e.g., scf.for loop bodies) are mapped + // to their ancestor op that lives directly in `block`. + Operation *lastUseOp = nullptr; + + // BFS through all transitively derived values from the alloc. + SmallVector worklist; + llvm::DenseSet visited; + worklist.push_back(allocVal); + visited.insert(allocVal); + + while (!worklist.empty()) { + Value val = worklist.pop_back_val(); + for (Operation *user : val.getUsers()) { + // Walk up to find the ancestor in allocOp's block. + Operation *ancestor = getAncestorInBlock(user, block); + if (!ancestor) + continue; + if (!lastUseOp || lastUseOp->isBeforeInBlock(ancestor)) + lastUseOp = ancestor; + + // If this user produces a view-like result, trace through it. + if (isa(user)) { + for (Value result : user->getResults()) { + if (isa(result.getType()) && !visited.contains(result)) { + visited.insert(result); + worklist.push_back(result); + } + } + } + } + } + + if (lastUseOp) { + // Insert dealloc right after the last use. This allows the backend + // allocator to reuse the SBUF space earlier, reducing peak memory + // pressure after loop unrolling. + OpBuilder builder(lastUseOp->getBlock(), ++Block::iterator(lastUseOp)); + builder.create(allocOp.getLoc(), allocVal); + } else { + // No uses found — fall back to scope-based dealloc. + Operation *terminator = block->getTerminator(); + OpBuilder builder(terminator); + builder.create(allocOp.getLoc(), allocVal); + } + ++numInserted; + } + + LLVM_DEBUG(llvm::dbgs() << "Inserted " << numInserted + << " dealloc operation(s)\n"); + } +}; + +} // namespace + +namespace mlir { +namespace nkipy { + +std::unique_ptr> createInsertMemRefDeallocPass() { + return std::make_unique(); +} + +} // namespace nkipy +} // namespace mlir diff --git a/kernelgen/mlir/lib/Transforms/InsertSpillReload.cpp b/kernelgen/mlir/lib/Transforms/InsertSpillReload.cpp new file mode 100644 index 0000000..4d8e4c0 --- /dev/null +++ b/kernelgen/mlir/lib/Transforms/InsertSpillReload.cpp @@ -0,0 +1,536 @@ +//===- InsertSpillReload.cpp - Insert spill/reload for SBUF pressure ====// +// +// This pass analyzes per-partition SBUF memory pressure and inserts spill +// (SBUF→HBM) and reload (HBM→SBUF) operations when capacity is exceeded. +// +// Runs after legalize-layout so SBUF allocs are in physical layout +// [partTile, numBlocks..., freeTile]. The per-partition size is +// total_size / shape[0] (partTile), matching getSbufPartitionUsableSize. +// +// Algorithm: +// 1. Collect all SBUF allocations and compute their sizes +// 2. Perform liveness analysis to find peak memory pressure points +// 3. At high-pressure points, select victims to spill using a heuristic +// 4. Insert memref.copy operations for spill/reload +// +// These copies are lowered to nisa.dma_copy in the LinalgToNisa pass. +// +// Limitations: +// - Only analyzes the function entry block. SBUF pressure that arises +// exclusively inside a loop body (e.g., a loop-local alloc that exceeds +// capacity only within its iteration) is not detected. Full loop-body +// analysis would require per-block traversal and loop-carried liveness. +// The preferred solution for loop-body pressure is tiling (reduce the working +// set per iteration) combined with multi-buffering (overlap DMA and compute +// by keeping only the current and next tile in SBUF simultaneously). +// +//===----------------------------------------------------------------------===// + +#include "PassGen.h" +#include "nkipy/Transforms/Passes.h" +#include "nkipy/Transforms/IRHelpers.h" +#include "nkipy/Transforms/HardwareConstants.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Dominance.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "insert-spill-reload" + +using namespace mlir; +using nkipy::getNkipyMemSpace; + +namespace mlir { +namespace nkipy { + +namespace { + +//===----------------------------------------------------------------------===// +// Configuration +//===----------------------------------------------------------------------===// + +// Spill heuristic strategy +enum class SpillHeuristic { + FARTHEST_NEXT_USE, // Belady's MIN (optimal) + LRU, // Least recently used + SIZE_BASED, // Largest first +}; + +//===----------------------------------------------------------------------===// +// Helper: Compute memref size in bytes +//===----------------------------------------------------------------------===// + +// Returns the per-partition size of an SBUF memref in bytes. +// +// After legalize-layout, SBUF allocs have physical shape +// [partTile, numBlocks..., freeTile]. Each of the 128 hardware partitions +// holds total_size / partTile bytes, so we divide by shape[0]. +static std::optional computePerPartitionSize(MemRefType type) { + if (!type.hasStaticShape()) + return std::nullopt; + + auto shape = type.getShape(); + if (shape.empty() || shape[0] == 0) + return std::nullopt; + + int64_t numElements = type.getNumElements() / shape[0]; + unsigned elementBits = type.getElementTypeBitWidth(); + int64_t elementBytes = (elementBits + 7) / 8; + + return numElements * elementBytes; +} + +//===----------------------------------------------------------------------===// +// Allocation Info +//===----------------------------------------------------------------------===// + +struct AllocationInfo { + memref::AllocOp allocOp; + Value value; + int64_t sizeBytes; + Operation *firstUse = nullptr; + Operation *lastUse = nullptr; + bool isSpilled = false; + Value spillSlot; // HBM buffer for spilled data +}; + +//===----------------------------------------------------------------------===// +// Liveness Analysis (Simplified) +//===----------------------------------------------------------------------===// + +class SimpleLivenessAnalysis { +public: + explicit SimpleLivenessAnalysis(Block *block) : block(block) { + analyze(); + } + + // Get the first operation that uses this value + Operation *getFirstUse(Value val) const { + auto it = uses.find(val); + if (it == uses.end() || it->second.empty()) + return nullptr; + return it->second.front(); + } + + // Get the last operation that uses this value + Operation *getLastUse(Value val) const { + auto it = uses.find(val); + if (it == uses.end() || it->second.empty()) + return nullptr; + return it->second.back(); + } + + // Get the first use of val that comes strictly AFTER point in the block + Operation *getNextUseAfter(Value val, Operation *point) const { + auto it = uses.find(val); + if (it == uses.end()) + return nullptr; + for (Operation *use : it->second) { + if (point->isBeforeInBlock(use)) + return use; + } + return nullptr; + } + + // Get the last use of val that comes strictly BEFORE point in the block + Operation *getLastUseBefore(Value val, Operation *point) const { + auto it = uses.find(val); + if (it == uses.end()) + return nullptr; + Operation *result = nullptr; + for (Operation *use : it->second) { + if (use->isBeforeInBlock(point)) + result = use; + } + return result; + } + + // Check if value is live at a given operation + bool isLive(Value val, Operation *op) const { + auto first = getFirstUse(val); + auto last = getLastUse(val); + if (!first || !last) + return false; + + // Check if op is between [first, last] in the block + if (op->getBlock() != block) + return false; + + return !op->isBeforeInBlock(first) && !last->isBeforeInBlock(op); + } + +private: + Block *block; + DenseMap> uses; + + void analyze() { + for (Operation &op : *block) { + // Record uses of all operands + for (Value operand : op.getOperands()) { + uses[operand].push_back(&op); + } + + // Recursively analyze nested regions (e.g., loop bodies). + // Record the ancestor op directly in `block` rather than the nested op + // itself — isBeforeInBlock requires both ops to be in the same block. + op.walk([&](Operation *nestedOp) { + if (nestedOp == &op) + return; + for (Value operand : nestedOp->getOperands()) { + if (operand.getParentBlock() != block) + continue; + // Walk up to find the direct child of `block`. + Operation *ancestor = nkipy::getAncestorInBlock(nestedOp, block); + // Avoid duplicate entries for the same ancestor. + if (uses[operand].empty() || uses[operand].back() != ancestor) + uses[operand].push_back(ancestor); + } + }); + } + } +}; + +//===----------------------------------------------------------------------===// +// Memory Pressure Tracking +//===----------------------------------------------------------------------===// + +struct PressurePoint { + Operation *op; + int64_t sbufUsageBytes; + SmallVector liveAllocs; +}; + +static SmallVector +computeMemoryPressure(Block *block, ArrayRef allocs) { + SimpleLivenessAnalysis liveness(block); + SmallVector pressurePoints; + + for (Operation &op : *block) { + PressurePoint point; + point.op = &op; + point.sbufUsageBytes = 0; + + // Check which allocations are live at this point + for (auto &alloc : allocs) { + if (liveness.isLive(alloc.value, &op)) { + point.sbufUsageBytes += alloc.sizeBytes; + point.liveAllocs.push_back(const_cast(&alloc)); + } + } + + pressurePoints.push_back(point); + } + + return pressurePoints; +} + +//===----------------------------------------------------------------------===// +// Spill Decision: Pick victims based on heuristic +//===----------------------------------------------------------------------===// + +static SmallVector +selectSpillVictims(const PressurePoint &point, int64_t capacityBytes, + SpillHeuristic heuristic, SimpleLivenessAnalysis &liveness) { + SmallVector victims; + + // Compute effective pressure from unspilled live allocs only. Already-spilled + // allocs are excluded because they were selected at an earlier pressure point + // and their contribution has already been accounted for. + int64_t livePressure = 0; + SmallVector candidates; + for (AllocationInfo *a : point.liveAllocs) { + if (!a->isSpilled) { + livePressure += a->sizeBytes; + candidates.push_back(a); + } + } + + if (livePressure <= capacityBytes) + return victims; // Effective pressure within capacity + + int64_t excessBytes = livePressure - capacityBytes; + int64_t spilledBytes = 0; + + switch (heuristic) { + case SpillHeuristic::SIZE_BASED: + // Spill largest allocations first + llvm::sort(candidates, [](AllocationInfo *a, AllocationInfo *b) { + return a->sizeBytes > b->sizeBytes; + }); + break; + + case SpillHeuristic::LRU: + // Spill least recently used: the candidate whose last use before this + // pressure point is earliest (i.e., most stale). + llvm::sort(candidates, [&](AllocationInfo *a, AllocationInfo *b) { + auto aLast = liveness.getLastUseBefore(a->value, point.op); + auto bLast = liveness.getLastUseBefore(b->value, point.op); + if (!aLast) return true; // Never used before this point → most stale + if (!bLast) return false; + return aLast->isBeforeInBlock(bLast); // Earlier last use → more stale + }); + break; + + case SpillHeuristic::FARTHEST_NEXT_USE: + // Belady's MIN: spill the value whose next use AFTER this pressure point + // is farthest away (optimal page-replacement policy). + llvm::sort(candidates, [&](AllocationInfo *a, AllocationInfo *b) { + auto aNext = liveness.getNextUseAfter(a->value, point.op); + auto bNext = liveness.getNextUseAfter(b->value, point.op); + if (!aNext) return true; // No future use → spill first + if (!bNext) return false; + return !bNext->isBeforeInBlock(aNext); // Farther next use → spill first + }); + break; + } + + // Select victims until we free enough space + for (AllocationInfo *candidate : candidates) { + if (spilledBytes >= excessBytes) + break; + victims.push_back(candidate); + spilledBytes += candidate->sizeBytes; + } + + return victims; +} + +//===----------------------------------------------------------------------===// +// Spill/Reload Insertion +//===----------------------------------------------------------------------===// + +static Value createSpillSlot(AllocationInfo &alloc, OpBuilder &builder) { + auto sbufType = cast(alloc.value.getType()); + + // Create HBM memory space attribute + auto hbmMemSpace = nkipy::MemSpaceEnumAttr::get( + builder.getContext(), nkipy::MemSpaceEnum::Hbm); + + // Create HBM type with same shape/element type + auto hbmType = MemRefType::get(sbufType.getShape(), sbufType.getElementType(), + sbufType.getLayout(), hbmMemSpace); + + // Insert HBM allocation after SBUF allocation + builder.setInsertionPointAfter(alloc.allocOp); + auto spillSlot = + builder.create(alloc.allocOp.getLoc(), hbmType); + + LLVM_DEBUG(llvm::dbgs() << " Created HBM spill slot for " + << alloc.value << " (size: " << alloc.sizeBytes << " bytes)\n"); + + return spillSlot.getResult(); +} + +static void insertSpill(AllocationInfo &alloc, Operation *insertAfter, + OpBuilder &builder) { + builder.setInsertionPointAfter(insertAfter); + builder.create(insertAfter->getLoc(), alloc.value, + alloc.spillSlot); + + LLVM_DEBUG(llvm::dbgs() << " Inserted spill (SBUF→HBM) after " + << *insertAfter << "\n"); +} + +static void insertReload(AllocationInfo &alloc, Operation *insertBefore, + OpBuilder &builder) { + builder.setInsertionPoint(insertBefore); + builder.create(insertBefore->getLoc(), alloc.spillSlot, + alloc.value); + + LLVM_DEBUG(llvm::dbgs() << " Inserted reload (HBM→SBUF) before " + << *insertBefore << "\n"); +} + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +struct InsertSpillReloadPass + : public InsertSpillReloadBase { + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + + // Resolve the SBUF capacity to use: + // - sbufCapacityOverride >= 0 → use override directly (for testing) + // - otherwise → query getSbufPartitionUsableSize(target) + std::optional resolveSbufCapacity(func::FuncOp func) const { + if (sbufCapacityOverride >= 0) + return sbufCapacityOverride; + + StringRef targetStr = target.empty() ? StringRef("trn2") : StringRef(target); + auto size = nkipy::getSbufPartitionUsableSize(targetStr); + if (!size) { + func.emitError("insert-spill-reload: unknown target '") << targetStr << "'"; + return std::nullopt; + } + return *size; + } + + void runOnOperation() override { + func::FuncOp func = getOperation(); + + // Resolve SBUF capacity + auto capacityOpt = resolveSbufCapacity(func); + if (!capacityOpt) + return signalPassFailure(); + int64_t capacityBytes = *capacityOpt; + + LLVM_DEBUG(llvm::dbgs() << " Processing function: " << func.getName() + << " (SBUF capacity: " << capacityBytes << " bytes)\n"); + + // Phase 1: Collect SBUF allocations + SmallVector sbufAllocs; + func.walk([&](memref::AllocOp allocOp) { + auto memSpace = getNkipyMemSpace(allocOp.getType()); + if (memSpace && *memSpace == nkipy::MemSpaceEnum::Sbuf) { + AllocationInfo info; + info.allocOp = allocOp; + info.value = allocOp.getResult(); + + auto sizeOpt = computePerPartitionSize(allocOp.getType()); + if (!sizeOpt) { + LLVM_DEBUG(llvm::dbgs() << " Warning: Skipping dynamic-shaped " + "SBUF allocation\n"); + return; + } + info.sizeBytes = *sizeOpt; + + sbufAllocs.push_back(info); + + LLVM_DEBUG(llvm::dbgs() << " Found SBUF alloc: " << info.value + << " (" << info.sizeBytes << " bytes)\n"); + } + }); + + if (sbufAllocs.empty()) { + LLVM_DEBUG(llvm::dbgs() << "No SBUF allocations found\n"); + return; + } + + // Calculate total SBUF usage + int64_t totalSbufBytes = 0; + for (const auto &alloc : sbufAllocs) { + totalSbufBytes += alloc.sizeBytes; + } + + LLVM_DEBUG(llvm::dbgs() << " Total SBUF usage: " << totalSbufBytes + << " bytes (capacity: " << capacityBytes << " bytes)\n"); + + if (totalSbufBytes <= capacityBytes) { + LLVM_DEBUG(llvm::dbgs() << " SBUF usage within capacity, no " + "spilling needed\n"); + return; + } + + // Phase 2: Analyze each block (function body, loop bodies) + Block &entryBlock = func.getBody().front(); + SimpleLivenessAnalysis liveness(&entryBlock); + + // Compute liveness for each allocation + for (auto &alloc : sbufAllocs) { + alloc.firstUse = liveness.getFirstUse(alloc.value); + alloc.lastUse = liveness.getLastUse(alloc.value); + } + + // Phase 3: Compute memory pressure at each program point + auto pressurePoints = computeMemoryPressure(&entryBlock, sbufAllocs); + + // Find peak pressure + int64_t peakPressure = 0; + for (auto &point : pressurePoints) + peakPressure = std::max(peakPressure, point.sbufUsageBytes); + + if (peakPressure <= capacityBytes) { + LLVM_DEBUG(llvm::dbgs() << " Peak pressure within capacity (due " + "to non-overlapping lifetimes)\n"); + return; + } + + LLVM_DEBUG(llvm::dbgs() << " Peak SBUF pressure: " << peakPressure + << " bytes\n"); + + // Phase 4: Select spill victims at ALL over-capacity pressure points. + // Pressure points are visited in program order. Once a victim is marked + // isSpilled, subsequent pressure points see reduced effective pressure + // (the victim's bytes are excluded), so each alloc is selected at most once. + SmallVector> toSpill; + for (auto &point : pressurePoints) { + auto victims = selectSpillVictims(point, capacityBytes, + SpillHeuristic::FARTHEST_NEXT_USE, liveness); + for (AllocationInfo *v : victims) { + if (!v->isSpilled) { + v->isSpilled = true; + toSpill.push_back({v, point.op}); + } + } + } + + LLVM_DEBUG(llvm::dbgs() << " Selected " << toSpill.size() + << " total victims to spill\n"); + + // Phase 5: Insert spill/reload for each (victim, spillPoint) pair. + OpBuilder builder(func.getContext()); + + for (auto [victim, spillPoint] : toSpill) { + // Create HBM spill slot + victim->spillSlot = createSpillSlot(*victim, builder); + + LLVM_DEBUG(llvm::dbgs() << " Spilling " << victim->value + << " at " << *spillPoint << "\n"); + + // Collect uses that come after spillPoint, including uses inside nested + // regions (e.g., loop bodies). For each user, walk up the op-parent + // chain until we reach spillPoint's block, then check ordering. + // Do this BEFORE inserting the spill so the new memref.copy is not + // counted as a "use after spill". + SmallVector usesAfterSpill; + for (Operation *user : victim->value.getUsers()) { + Operation *ancestor = user; + while (ancestor->getBlock() != spillPoint->getBlock()) + ancestor = ancestor->getParentOp(); + if (spillPoint->isBeforeInBlock(ancestor)) + usesAfterSpill.push_back(ancestor); + } + // Sort and deduplicate: multiple uses inside the same nested region + // all map to the same ancestor op. + llvm::sort(usesAfterSpill, [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); + usesAfterSpill.erase( + std::unique(usesAfterSpill.begin(), usesAfterSpill.end()), + usesAfterSpill.end()); + + // Insert spill after spillPoint + insertSpill(*victim, spillPoint, builder); + + // Insert a single reload before the first use after the spill + if (!usesAfterSpill.empty()) { + LLVM_DEBUG(llvm::dbgs() << "Found " << usesAfterSpill.size() + << " uses after spill, inserting reload before first use: " + << *usesAfterSpill.front() << "\n"); + insertReload(*victim, usesAfterSpill.front(), builder); + } else { + LLVM_DEBUG(llvm::dbgs() << " No uses after spill for " + << victim->value << " (dead after spill)\n"); + } + } + + LLVM_DEBUG(llvm::dbgs() << "Pass completed\n"); + } +}; + +} // namespace + +std::unique_ptr> createInsertSpillReloadPass() { + return std::make_unique(); +} + +} // namespace nkipy +} // namespace mlir diff --git a/kernelgen/mlir/lib/Transforms/KnobDrivenTiling.cpp b/kernelgen/mlir/lib/Transforms/KnobDrivenTiling.cpp new file mode 100644 index 0000000..2d01032 --- /dev/null +++ b/kernelgen/mlir/lib/Transforms/KnobDrivenTiling.cpp @@ -0,0 +1,687 @@ +//===- KnobDrivenTiling.cpp - Generate Transform dialect for tiling -------===// +// +// This pass generates Transform dialect sequences for tiling operations that +// implement TilingInterface, based on knob annotations. Linalg ops get +// op-specific treatment (matmul blocking, reduction interleaving, etc.); +// other TilingInterface ops (e.g., nkipy.gather) get elementwise-like tiling. +// +// The pass adds a transform.named_sequence @__transform_main to the module. +// Run --transform-interpreter afterwards to apply the generated transforms. +// +//===----------------------------------------------------------------------===// + +#include "PassGen.h" +#include "nkipy/Transforms/Passes.h" +#include "nkipy/Transforms/HardwareConstants.h" +#include "nkipy/Transforms/IRHelpers.h" +#include "nkipy/Transforms/OpClassification.h" +#include "nkipy/Dialect/NkipyAttrs.h" +#include "nkipy/Dialect/NkipyDialect.h" +#include "nkipy/Dialect/NkipyOps.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Interfaces/TilingInterface.h" +#include "nkipy/TransformOps/NkipyTransformOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "llvm/Support/raw_ostream.h" +#include "nkipy/Dialect/NkipyAttrs.h" + +#include +#include + +using namespace mlir; +using namespace nkipy; + +namespace mlir { +namespace nkipy { + +namespace { + +/// Structure to hold knob information +struct KnobInfo { + SmallVector tileSize; + SmallVector reductionTile; // Tile sizes for reduction dims (e.g., K for matmul) + int64_t opId = -1; // nkipy.op_id for per-instance matching, -1 if not set + int numDpsInputs = -1; // Number of DPS inputs, -1 if unknown + bool isElementwise = false; // Whether this op is elementwise (verified during extraction) + bool isReduction = false; // Whether this op has both parallel and reduction iterators + SmallVector iteratorTypes; // For generic ops (needed for reduction tiling) + SmallVector matmulDims; // [M, N, K] for matmul ops (for dynamic blocking) + + bool isValid() const { return !tileSize.empty(); } +}; + +/// Check if an op name corresponds to a transpose operation. +bool isTransposeOp(StringRef opName) { + return opName == "linalg.transpose"; +} + +//===----------------------------------------------------------------------===// +// Transform dialect helpers +//===----------------------------------------------------------------------===// + +/// Emit transform.structured.match for a linalg op by name and optional op_id. +Value emitMatch(OpBuilder &builder, Location loc, Value moduleArg, + StringRef opName, DictionaryAttr opAttrs) { + auto anyOpType = transform::AnyOpType::get(builder.getContext()); + return builder.create( + loc, anyOpType, moduleArg, + builder.getStrArrayAttr({opName}), + /*matchInterfaceEnum=*/transform::MatchInterfaceEnumAttr{}, + /*opAttrs=*/opAttrs, + /*filterResultType=*/TypeAttr{}, + /*filterOperandTypes=*/ArrayAttr{}).getResult(); +} + +/// Emit transform.structured.tile_using_for. Returns the tiled op (result 0). +Value emitTile(OpBuilder &builder, Location loc, Value target, + ArrayRef tileSizes) { + auto anyOpType = transform::AnyOpType::get(builder.getContext()); + SmallVector scalableSizes(tileSizes.size(), false); + + // 1 result for tiled op + 1 per non-zero tile (loop handles) + SmallVector resultTypes; + resultTypes.push_back(anyOpType); + for (int64_t t : tileSizes) + if (t != 0) resultTypes.push_back(anyOpType); + + auto tileOp = builder.create( + loc, TypeRange(resultTypes), target, ValueRange{}, + tileSizes, ArrayRef{}, ArrayRef(scalableSizes)); + return tileOp.getResult(0); +} + +/// Emit promote_tensor for a specific DPS operand position. +void emitPromoteOperand(OpBuilder &builder, Location loc, Value tiledOp, + int64_t operandIdx, Attribute memSpace) { + auto anyValueType = transform::AnyValueType::get(builder.getContext()); + SmallVector position = {operandIdx}; + auto getOp = builder.create( + loc, anyValueType, tiledOp, ArrayRef(position), + /*is_inverted=*/false, /*is_all=*/false); + builder.create( + loc, anyValueType, getOp.getResult(), /*memory_space=*/memSpace); +} + +/// Promote all DPS inputs and the output to SBUF. +void emitPromoteAllToSbuf(OpBuilder &builder, Location loc, Value tiledOp, + int numInputs) { + auto sbufMemSpace = nkipy::MemSpaceEnumAttr::get( + builder.getContext(), nkipy::MemSpaceEnum::Sbuf); + for (int i = 0; i < numInputs; ++i) + emitPromoteOperand(builder, loc, tiledOp, i, sbufMemSpace); + emitPromoteOperand(builder, loc, tiledOp, numInputs, sbufMemSpace); +} + +/// Log tile sizes for debugging. +void logTileSizes(StringRef label, ArrayRef tiles) { + llvm::errs() << "[KnobDrivenTiling] " << label << ": ["; + for (size_t i = 0; i < tiles.size(); ++i) { + llvm::errs() << tiles[i]; + if (i + 1 < tiles.size()) llvm::errs() << ", "; + } + llvm::errs() << "]\n"; +} + +/// Validate matmul tile sizes against matrix dimensions. +/// Returns empty string if valid, error message if invalid. +std::string validateMatmulTileSize(linalg::LinalgOp linalgOp, const KnobInfo &knob) { + Value output = linalgOp.getDpsInits()[0]; + auto outputType = dyn_cast(output.getType()); + if (!outputType || !outputType.hasStaticShape() || outputType.getRank() < 2) + return ""; + + int64_t outputRank = outputType.getRank(); + if (knob.tileSize.size() != static_cast(outputRank)) { + return "matmul tile_size has " + std::to_string(knob.tileSize.size()) + + " elements but output tensor has rank " + std::to_string(outputRank) + + "; tile_size must match output rank"; + } + + if (knob.reductionTile.empty()) + return "matmul requires reduction_tile (e.g., reduction_tile=[K]), got none"; + + int64_t tileM = knob.tileSize[outputRank - 2]; + int64_t tileN = knob.tileSize[outputRank - 1]; + int64_t tileK = knob.reductionTile[0]; + + int64_t dimM = outputType.getDimSize(outputRank - 2); + int64_t dimN = outputType.getDimSize(outputRank - 1); + + if (tileM > dimM) { + return "matmul tile_size M (" + std::to_string(tileM) + + ") is larger than M dimension (" + std::to_string(dimM) + ")"; + } + if (tileN > dimN) { + return "matmul tile_size N (" + std::to_string(tileN) + + ") is larger than N dimension (" + std::to_string(dimN) + ")"; + } + + Value lhs = linalgOp.getDpsInputs()[0]; + auto lhsType = dyn_cast(lhs.getType()); + int64_t dimK = -1; + if (lhsType && lhsType.hasStaticShape() && lhsType.getRank() >= 2) + dimK = lhsType.getDimSize(lhsType.getRank() - 1); + + if (dimK > 0 && tileK > dimK) { + return "matmul K tile (" + std::to_string(tileK) + + ") is larger than K dimension (" + std::to_string(dimK) + ")"; + } + + return ""; +} + +/// Validate that tensor dimensions are large enough for tiling. +/// Returns empty string if valid, error message if invalid. +std::string validateElementwiseTileSize(linalg::LinalgOp linalgOp, const KnobInfo &knob) { + if (linalgOp.getDpsInits().empty()) + return ""; + + Value output = linalgOp.getDpsInits()[0]; + auto outputType = dyn_cast(output.getType()); + if (!outputType || !outputType.hasStaticShape()) + return ""; // Can't validate dynamic shapes + + int64_t rank = outputType.getRank(); + if (knob.tileSize.size() != static_cast(rank)) { + return "tile_size has " + std::to_string(knob.tileSize.size()) + + " elements but tensor has rank " + std::to_string(rank) + + "; tile_size must have exactly one element per dimension"; + } + + for (size_t i = 0; i < knob.tileSize.size(); ++i) { + int64_t tile = knob.tileSize[i]; + int64_t dim = outputType.getDimSize(i); + if (tile > dim) { + return "tile_size[" + std::to_string(i) + "]=" + std::to_string(tile) + + " is larger than dimension[" + std::to_string(i) + "]=" + std::to_string(dim); + } + } + + return ""; +} + +/// Validate tile sizes for reduction generic ops. +/// tile_size must match the number of parallel dims, reduction_tile must match reduction dims. +std::string validateReductionTileSize(linalg::LinalgOp linalgOp, const KnobInfo &knob) { + auto genericOp = dyn_cast(linalgOp.getOperation()); + if (!genericOp) + return "expected linalg.generic for reduction validation"; + + auto iterTypes = genericOp.getIteratorTypesArray(); + size_t numParallel = 0, numReduction = 0; + for (auto t : iterTypes) { + if (t == utils::IteratorType::parallel) numParallel++; + else if (t == utils::IteratorType::reduction) numReduction++; + } + + if (knob.tileSize.size() != numParallel) { + return "reduction op tile_size has " + std::to_string(knob.tileSize.size()) + + " elements but op has " + std::to_string(numParallel) + + " parallel dimensions; tile_size must match parallel dim count"; + } + + if (knob.reductionTile.empty()) { + return "reduction op requires reduction_tile, got none"; + } + + if (knob.reductionTile.size() != numReduction) { + return "reduction op reduction_tile has " + std::to_string(knob.reductionTile.size()) + + " elements but op has " + std::to_string(numReduction) + + " reduction dimensions; reduction_tile must match reduction dim count"; + } + + return ""; +} + +/// Extract knob map from nkipy.annotate operations +/// Only extracts knobs that have valid tile_size attributes +/// Returns empty map and sets errorMsg if validation fails +std::map> extractKnobsByOpType( + ModuleOp module, std::string &errorMsg) { + std::map> knobsByOp; + errorMsg = ""; + + module.walk([&](func::FuncOp func) { + if (!errorMsg.empty()) return; // Already failed + + // First, build a map of Value → KnobInfo. + // Skip annotations nested inside nkipy regions (reference_impl bodies). + DenseMap valueToKnob; + func.walk([&](nkipy::AnnotateOp annotateOp) { + if (isInsideNkipyRegion(annotateOp)) + return; + Value target = annotateOp.getTarget(); + KnobInfo info; + + // Extract tile_size attribute + if (auto tileSizeAttr = annotateOp.getTileSizeAttr()) { + auto arrayRef = tileSizeAttr.asArrayRef(); + info.tileSize.assign(arrayRef.begin(), arrayRef.end()); + } + + // Extract reduction_tile attribute (for matmul K dimension, etc.) + if (auto reductionTileAttr = annotateOp.getReductionTileAttr()) { + auto arrayRef = reductionTileAttr.asArrayRef(); + info.reductionTile.assign(arrayRef.begin(), arrayRef.end()); + } + + // Only add to map if it has valid tile_size + if (info.isValid()) { + valueToKnob[target] = info; + } + }); + + // Then, group knobs by op type, validating each. + // Walk all ops with TilingInterface (covers linalg ops and any nkipy ops + // that implement TilingInterface, e.g., nkipy.gather). + // Skip ops nested inside nkipy regions (e.g., the linalg.generic inside + // nkipy.gather's reference_impl — it must not be tiled independently). + func.walk([&](Operation *op) { + if (!errorMsg.empty()) return WalkResult::interrupt(); + if (!isa(op)) + return WalkResult::advance(); + if (op->getNumResults() == 0) + return WalkResult::advance(); + if (isInsideNkipyRegion(op)) + return WalkResult::advance(); + + for (Value result : op->getResults()) { + auto it = valueToKnob.find(result); + if (it != valueToKnob.end()) { + std::string opName = op->getName().getStringRef().str(); + + // Copy knob and extract op_id + KnobInfo knobWithId = it->second; + if (auto opIdAttr = op->getAttrOfType("nkipy.op_id")) { + knobWithId.opId = opIdAttr.getInt(); + } + + // Validate tile sizes against dimensions + std::string validationError; + + // Linalg-specific classification + if (auto linalgOp = dyn_cast(op)) { + if (isMatmulOp(opName)) { + // Capture matrix dimensions for dynamic blocking. + Value output = linalgOp.getDpsInits()[0]; + auto outType = dyn_cast(output.getType()); + Value lhs = linalgOp.getDpsInputs()[0]; + auto lhsType = dyn_cast(lhs.getType()); + if (outType && outType.hasStaticShape() && outType.getRank() >= 2 && + lhsType && lhsType.hasStaticShape() && lhsType.getRank() >= 2) { + int64_t r = outType.getRank(); + knobWithId.matmulDims = { + outType.getDimSize(r - 2), // M + outType.getDimSize(r - 1), // N + lhsType.getDimSize(lhsType.getRank() - 1) // K + }; + } + validationError = validateMatmulTileSize(linalgOp, knobWithId); + } else if (isTransposeOp(opName)) { + knobWithId.isElementwise = true; + knobWithId.numDpsInputs = linalgOp.getNumDpsInputs(); + validationError = validateElementwiseTileSize(linalgOp, knobWithId); + } else if (isElementwiseOp(linalgOp)) { + knobWithId.isElementwise = true; + knobWithId.numDpsInputs = linalgOp.getNumDpsInputs(); + validationError = validateElementwiseTileSize(linalgOp, knobWithId); + } else if (isReductionGeneric(linalgOp)) { + knobWithId.isReduction = true; + knobWithId.numDpsInputs = linalgOp.getNumDpsInputs(); + knobWithId.iteratorTypes = linalgOp.getIteratorTypesArray(); + validationError = validateReductionTileSize(linalgOp, knobWithId); + } + } else { + // Non-linalg TilingInterface op (e.g., nkipy.gather). + // Default to elementwise-like tiling: tile all dims, promote. + knobWithId.isElementwise = true; + if (auto dstOp = dyn_cast(op)) + knobWithId.numDpsInputs = dstOp.getNumDpsInputs(); + } + + if (!validationError.empty()) { + errorMsg = validationError; + return WalkResult::interrupt(); + } + + knobsByOp[opName].push_back(knobWithId); + + llvm::errs() << "[KnobDrivenTiling] Found knob for " << opName; + if (knobWithId.opId >= 0) { + llvm::errs() << " (op_id=" << knobWithId.opId << ")"; + } + llvm::errs() << ": tile_size=["; + for (size_t i = 0; i < it->second.tileSize.size(); ++i) { + llvm::errs() << it->second.tileSize[i]; + if (i + 1 < it->second.tileSize.size()) llvm::errs() << ", "; + } + llvm::errs() << "]"; + if (!it->second.reductionTile.empty()) { + llvm::errs() << ", reduction_tile=["; + for (size_t i = 0; i < it->second.reductionTile.size(); ++i) { + llvm::errs() << it->second.reductionTile[i]; + if (i + 1 < it->second.reductionTile.size()) llvm::errs() << ", "; + } + llvm::errs() << "]"; + } + llvm::errs() << "\n"; + break; + } + } + return WalkResult::advance(); + }); + }); + + return knobsByOp; +} + +/// Build tiling + SBUF promotion for elementwise (and transpose) operations. +void buildElementwiseTiling(OpBuilder &builder, Location loc, + Value moduleArg, + const std::string &opName, + const KnobInfo &knob, + DictionaryAttr opAttrs) { + logTileSizes(opName + " elementwise tile_size", knob.tileSize); + + Value matched = emitMatch(builder, loc, moduleArg, opName, opAttrs); + Value tiledOp = emitTile(builder, loc, matched, knob.tileSize); + + int numInputs = knob.numDpsInputs >= 0 ? knob.numDpsInputs + : (isNamedUnaryElementwiseOp(opName) ? 1 : 2); + emitPromoteAllToSbuf(builder, loc, tiledOp, numInputs); + + llvm::errs() << "[KnobDrivenTiling] Elementwise: promoted " << numInputs + << " inputs + 1 output to SBUF\n"; +} + +/// Build tiling + SBUF promotion for reduction operations. +/// Interleaves tile_size (parallel) and reduction_tile (reduction) based on +/// iterator types, then tiles and promotes all operands. +void buildReductionTiling(OpBuilder &builder, Location loc, + Value moduleArg, + const std::string &opName, + const KnobInfo &knob, + DictionaryAttr opAttrs) { + // Interleave parallel and reduction tile sizes by iterator type order. + SmallVector combinedTileSizes; + int parallelIdx = 0, reductionIdx = 0; + for (auto iterType : knob.iteratorTypes) { + if (iterType == utils::IteratorType::parallel) { + combinedTileSizes.push_back( + parallelIdx < static_cast(knob.tileSize.size()) + ? knob.tileSize[parallelIdx++] : 0); + } else { + combinedTileSizes.push_back( + reductionIdx < static_cast(knob.reductionTile.size()) + ? knob.reductionTile[reductionIdx++] : 0); + } + } + + logTileSizes(opName + " reduction combined_tile_sizes", combinedTileSizes); + + Value matched = emitMatch(builder, loc, moduleArg, opName, opAttrs); + Value tiledOp = emitTile(builder, loc, matched, combinedTileSizes); + + int numInputs = knob.numDpsInputs >= 0 ? knob.numDpsInputs : 1; + emitPromoteAllToSbuf(builder, loc, tiledOp, numInputs); + + llvm::errs() << "[KnobDrivenTiling] Reduction: promoted " << numInputs + << " inputs + 1 output to SBUF\n"; +} + +/// Build the transform sequence for matmul with dynamic blocking. +/// +/// Uses 2-tile blocking when dimensions are large enough (BLOCK = TILE * 2), +/// degenerating to 1-tile blocking (BLOCK = TILE) for small dimensions. +/// +/// Generated loop structure: +/// for block_m in [0, M, BLOCK_M): // BLOCK_M = TILE_M * blocksM +/// LOAD LHS to SBUF // Reused across N-blocks +/// for block_n in [0, N, BLOCK_N): // BLOCK_N = TILE_N * blocksN +/// LOAD RHS to SBUF // Reused within block +/// for tile_m in [0, BLOCK_M, TILE_M): // trip count = blocksM (1 or 2) +/// for tile_n in [0, BLOCK_N, TILE_N): // trip count = blocksN (1 or 2) +/// ALLOC psum buffer +/// for k in [0, K, TILE_K): +/// psum += matmul(lhs_tile, rhs_tile) +/// STORE result tile +/// +/// Returns false if the knob is invalid for matmul. +bool buildMatmulBlockingTransforms(OpBuilder &builder, Location loc, + Value moduleArg, + const std::string &opName, + const KnobInfo &knob, + DictionaryAttr opAttrs) { + if (knob.tileSize.size() < 2) { + llvm::errs() << "[KnobDrivenTiling] Matmul tile_size must have >= 2 elements\n"; + return false; + } + if (knob.reductionTile.empty()) { + llvm::errs() << "[KnobDrivenTiling] Matmul requires reduction_tile\n"; + return false; + } + + size_t tsz = knob.tileSize.size(); + int64_t tileM = knob.tileSize[tsz - 2]; + int64_t tileN = knob.tileSize[tsz - 1]; + int64_t tileK = knob.reductionTile[0]; + + // Dynamic blocking: use block size 2 if dimension is large enough, + // otherwise degenerate to block size 1 (no blocking, less data reuse). + int64_t blocksM = 2, blocksN = 2; + if (knob.matmulDims.size() == 3) { + int64_t dimM = knob.matmulDims[0]; + int64_t dimN = knob.matmulDims[1]; + if (dimM < tileM * 2) blocksM = 1; + if (dimN < tileN * 2) blocksN = 1; + } + int64_t blockM = tileM * blocksM; + int64_t blockN = tileN * blocksN; + + llvm::errs() << "[KnobDrivenTiling] Matmul: TILE=[" << tileM << "," << tileN + << "," << tileK << "], BLOCK=[" << blockM << "," << blockN + << "] (blocksM=" << blocksM << ", blocksN=" << blocksN << ")\n"; + + auto anyOpType = transform::AnyOpType::get(builder.getContext()); + auto sbufMemSpace = nkipy::MemSpaceEnumAttr::get( + builder.getContext(), nkipy::MemSpaceEnum::Sbuf); + auto psumMemSpace = nkipy::MemSpaceEnumAttr::get( + builder.getContext(), nkipy::MemSpaceEnum::Psum); + + // --- Match --- + Value matmul = emitMatch(builder, loc, moduleArg, opName, opAttrs); + + // --- Level 1: Block-level tiling --- + + // Tile M blocks + Value blockMTiled = emitTile(builder, loc, matmul, {blockM, 0, 0}); + + // Transpose matmul: matmul(A,B) → matmul_transpose_a(transpose(A), B) + auto transposeMatmul = builder.create( + loc, anyOpType, blockMTiled, + transform::TransposeMatmulInput::lhs); + Value transposedMatmul = transposeMatmul.getResult(); + + // Promote the transpose output to SBUF (inserted by TransposeMatmulOp). + // Use GetProducerOfOperand to target only this specific transpose, + // not user-provided transposes that have nkipy.op_id. + auto getTransposeOp = builder.create( + loc, anyOpType, transposedMatmul, /*operand_number=*/0); + emitPromoteOperand(builder, loc, getTransposeOp.getResult(), 1, sbufMemSpace); + + // Promote LHS at block-M level (reused across all N-blocks) + emitPromoteOperand(builder, loc, transposedMatmul, 0, sbufMemSpace); + + // Tile N blocks + Value blockNTiled = emitTile(builder, loc, transposedMatmul, {0, blockN, 0}); + + // Promote RHS at block-N level (reused within this N-block) + emitPromoteOperand(builder, loc, blockNTiled, 1, sbufMemSpace); + + // --- Level 2: Tile-level tiling (within blocks) --- + + Value tileMTiled = emitTile(builder, loc, blockNTiled, {tileM, 0, 0}); + Value tileNTiled = emitTile(builder, loc, tileMTiled, {0, tileN, 0}); + + // Promote output to PSUM (accumulator for partial sums) + emitPromoteOperand(builder, loc, tileNTiled, 2, psumMemSpace); + + // Tile K (innermost reduction loop) + emitTile(builder, loc, tileNTiled, {0, 0, tileK}); + + return true; +} + +struct NkipyKnobDrivenTilingPass + : public KnobDrivenTilingBase { + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + MLIRContext *ctx = &getContext(); + + // Note: batch_matmul -> loop + matmul conversion is now handled by + // canonicalize-partition-dim pass. All matmul ops here are rank-2. + + // Extract knobs from annotations (only those with valid tile_size) + std::string validationError; + auto knobsByOp = extractKnobsByOpType(module, validationError); + + // Check for validation errors (tile sizes too large for dimensions) + if (!validationError.empty()) { + module->emitError() << "[KnobDrivenTiling] Invalid tile configuration: " + << validationError; + return signalPassFailure(); + } + + if (knobsByOp.empty()) { + llvm::errs() << "[KnobDrivenTiling] No knob annotations found\n"; + // Still emit an empty __transform_main so transform-interpreter + // doesn't crash. This happens when a kernel has only data-movement + // ops (e.g. np.concatenate) and no compute linalg ops. + OpBuilder emptyBuilder(ctx); + module->setAttr("transform.with_named_sequence", + emptyBuilder.getUnitAttr()); + emptyBuilder.setInsertionPointToEnd(module.getBody()); + auto anyOpType = transform::AnyOpType::get(ctx); + auto emptySeq = emptyBuilder.create( + module.getLoc(), "__transform_main", + TypeAttr::get(FunctionType::get(ctx, {anyOpType}, {})), + /*sym_visibility=*/StringAttr{}, + /*arg_attrs=*/ArrayAttr{}, + /*res_attrs=*/ArrayAttr{}); + emptySeq.addEntryBlock(); + emptySeq.setArgAttr(0, "transform.readonly", + emptyBuilder.getUnitAttr()); + emptyBuilder.setInsertionPointToStart(&emptySeq.getBody().front()); + emptyBuilder.create(module.getLoc()); + return; + } + + OpBuilder builder(ctx); + Location loc = module.getLoc(); + + // Add transform.with_named_sequence attribute to module + module->setAttr("transform.with_named_sequence", builder.getUnitAttr()); + + // Create transform.named_sequence @__transform_main at end of module + builder.setInsertionPointToEnd(module.getBody()); + + auto anyOpType = transform::AnyOpType::get(ctx); + + // Create the named sequence with proper signature + auto namedSeq = builder.create( + loc, + "__transform_main", + TypeAttr::get(FunctionType::get(ctx, {anyOpType}, {})), + /*sym_visibility=*/StringAttr{}, + /*arg_attrs=*/ArrayAttr{}, + /*res_attrs=*/ArrayAttr{}); + + // Add entry block with module argument + namedSeq.addEntryBlock(); + + // Add attribute to mark the argument as readonly + namedSeq.setArgAttr(0, "transform.readonly", builder.getUnitAttr()); + + // Get the module argument for use in generated transforms + Value moduleArg = namedSeq.getBody().getArgument(0); + + // Set insertion point to the start of the named sequence body + builder.setInsertionPointToStart(&namedSeq.getBody().front()); + + bool hasAnyTransforms = false; + + // Process all ops with knobs (per-instance) + for (const auto &[opName, knobs] : knobsByOp) { + for (const KnobInfo &knob : knobs) { + // Build op_attrs for per-instance matching if op_id is set + DictionaryAttr opAttrs; + if (knob.opId >= 0) { + opAttrs = builder.getDictionaryAttr({ + builder.getNamedAttr("nkipy.op_id", builder.getI64IntegerAttr(knob.opId)) + }); + } + + if (isMatmulOp(opName)) { + // Matmul gets special 6-level blocking treatment + if (!buildMatmulBlockingTransforms(builder, loc, moduleArg, opName, knob, opAttrs)) { + llvm::errs() << "[KnobDrivenTiling] Failed to build matmul transforms\n"; + continue; + } + hasAnyTransforms = true; + } else if (knob.isElementwise) { + // Single-level tiling for elementwise ops (named or elementwise generic) + buildElementwiseTiling(builder, loc, moduleArg, opName, knob, opAttrs); + hasAnyTransforms = true; + } else if (knob.isReduction) { + // Single-level tiling for reduction ops (generic with reduction iterators) + buildReductionTiling(builder, loc, moduleArg, opName, knob, opAttrs); + hasAnyTransforms = true; + } else { + llvm::errs() << "[KnobDrivenTiling] Unknown op type: " << opName << " - skipping\n"; + } + } + } + + // Add transform.yield at the end + builder.create(loc); + + if (!hasAnyTransforms) { + // No transforms generated - clean up + namedSeq.erase(); + module->removeAttr("transform.with_named_sequence"); + llvm::errs() << "[KnobDrivenTiling] No transforms generated\n"; + return; + } + + llvm::errs() << "[KnobDrivenTiling] Generated transform sequence\n"; + } +}; + +} // namespace + +std::unique_ptr> createKnobDrivenTilingPass() { + return std::make_unique(); +} + +} // namespace nkipy +} // namespace mlir diff --git a/kernelgen/mlir/lib/Transforms/LegalizeLayout.cpp b/kernelgen/mlir/lib/Transforms/LegalizeLayout.cpp new file mode 100644 index 0000000..ddad83f --- /dev/null +++ b/kernelgen/mlir/lib/Transforms/LegalizeLayout.cpp @@ -0,0 +1,2691 @@ +//===- LegalizeLayout.cpp - Legalize SBUF tensor layouts ------------------===// +// +// This pass transforms SBUF tensor layouts to satisfy NKI hardware constraints +// where the first dimension (partition dimension) must be ≤128. +// +// The pass identifies SBUF tensors via: +// 1. bufferization.alloc_tensor with memory_space = #nisa.mem +// 2. nkipy.annotate ops with mem_space = Sbuf (traced back to tensor.empty) +// +// For each SBUF tensor needing legalization, the pass: +// 1. Computes the target (R+2)-D physical shape based on tile_size annotation +// 2. Propagates the shape change through the entire use-def chain via BFS +// 3. Updates scf.for init_args, block args, and results +// 4. Transforms extract_slice to (R+2)-D indexing + collapse_shape +// 5. Transforms insert_slice with expand_shape + (R+2)-D indexing +// +// Prerequisites: +// - Runs after knob-driven-tiling + transform-interpreter +// - Runs after canonicalize-loop-step (loops have step=1) +// +//===----------------------------------------------------------------------===// + +#include "PassGen.h" +#include "nkipy/Transforms/Passes.h" +#include "nkipy/Transforms/HardwareConstants.h" +#include "nkipy/Transforms/IRHelpers.h" +#include "nkipy/Dialect/NkipyAttrs.h" +#include "nkipy/Dialect/NkipyDialect.h" +#include "nkipy/Dialect/NkipyOps.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/IRMapping.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#include + +#define DEBUG_TYPE "legalize-layout" + +using namespace mlir; +using namespace nkipy; + +namespace mlir { +namespace nkipy { + +namespace { + +/// Structure to hold layout transformation info for an SBUF tensor. +/// +/// For a rank-R tensor [d_0, ..., d_{R-1}] with tile [t_0, ..., t_{R-1}]: +/// numBlocks[i] = d_i / t_i +/// Physical shape (rank R+2): [t_0, numBlocks[0], ..., numBlocks[R-1], t_{R-1}] +/// +/// Key constraint: only dim 0 (partition) and dim R-1 (free) may have tile > 1. +/// All middle dims (1..R-2) must have tile = 1. This means single-block subviews +/// have shape [partTile, 1, ..., 1, freeTile] which can be collapsed directly to +/// 2D [partTile, freeTile] for NISA consumption. +/// +/// For R=2: physical is 4D [t0, nB0, nB1, t1], collapse [[0,1,2],[3]] → 2D. +struct LayoutInfo { + Value originalValue; // The original tensor allocation + SmallVector origShape; // [d_0, ..., d_{R-1}] + SmallVector tileSize; // [t_0, ..., t_{R-1}] + SmallVector numBlocks; // [d_0/t_0, ..., d_{R-1}/t_{R-1}] + + int64_t rank() const { return origShape.size(); } + int64_t physicalRank() const { return rank() + 2; } + + /// Physical shape: [tileSize[0], numBlocks[0], ..., numBlocks[R-1], tileSize[R-1]] + SmallVector getPhysicalShape() const { + SmallVector shape; + shape.push_back(tileSize[0]); + for (auto nb : numBlocks) + shape.push_back(nb); + shape.push_back(tileSize.back()); + return shape; + } + + /// Collapse reassociation from (R+2)-D physical layout directly to 2D: + /// [[0, 1, ..., R], [R+1]] → [partTile, freeTile] + /// + /// This works for single-block subviews where all numBlocks dims are 1. + /// For R=2: [[0,1,2], [3]] on [128,1,1,128] → [128, 128] (same as legacy). + SmallVector getCollapseReassociation() const { + int64_t R = rank(); + SmallVector reassoc; + ReassociationIndices group0; + for (int64_t i = 0; i <= R; i++) + group0.push_back(i); // partTile + all numBlocks dims + reassoc.push_back(group0); + reassoc.push_back({R + 1}); // freeTile + return reassoc; + } + +}; + +/// Build collapse reassociation from (R+2)-D to 2D: [[0, 1, ..., R], [R+1]] +/// +/// For single-block subviews with shape [partTile, 1, ..., 1, freeTile], +/// this produces [partTile, freeTile] (2D) suitable for NISA ops. +static SmallVector build2DCollapseFromPhysical( + int64_t physicalRank) { + int64_t R = physicalRank - 2; // logical rank + SmallVector reassoc; + ReassociationIndices group0; + for (int64_t i = 0; i <= R; i++) + group0.push_back(i); + reassoc.push_back(group0); + reassoc.push_back({R + 1}); + return reassoc; +} + +/// Build collapse reassociation from R-D to 2D: [[0, 1, ..., R-2], [R-1]] +/// +/// For R-D tiles with shape [partTile, 1, ..., 1, freeTile] (where middle +/// dims have tile=1), this produces [partTile, freeTile] (2D). +/// For R=2 this is [[0], [1]] (identity, no-op). +static SmallVector build2DCollapseFromLogical( + int64_t logicalRank) { + SmallVector reassoc; + ReassociationIndices group0; + for (int64_t i = 0; i < logicalRank - 1; i++) + group0.push_back(i); + reassoc.push_back(group0); + reassoc.push_back({logicalRank - 1}); + return reassoc; +} + +/// Result of createBlockLoopNest: IVs and the builder insertion point is +/// set to the innermost loop body. +struct BlockLoopNest { + SmallVector ivs; // One IV per logical dim +}; + +/// Create an R-level scf.for loop nest iterating over block indices. +/// After this call, builder insertion point is inside the innermost loop body. +static BlockLoopNest createBlockLoopNest( + OpBuilder &builder, Location loc, ArrayRef numBlocks) { + BlockLoopNest result; + Value c0 = builder.create(loc, 0); + Value c1 = builder.create(loc, 1); + for (int64_t nb : numBlocks) { + Value ub = builder.create(loc, nb); + auto loop = builder.create(loc, c0, ub, c1); + builder.setInsertionPointToStart(loop.getBody()); + result.ivs.push_back(loop.getInductionVar()); + } + return result; +} + +/// Create a single-tile subview of a physical (R+2)-D SBUF buffer and +/// collapse it to 2D [partTile, freeTile]. +/// +/// Physical layout: [partTile, numBlocks[0], ..., numBlocks[R-1], freeTile] +/// Subview: [0, iv0, ..., ivR-1, 0] / [partTile, 1, ..., 1, freeTile] +/// Collapse: [[0, 1, ..., R], [R+1]] → [partTile, freeTile] +static Value createTileSubviewAndCollapse( + OpBuilder &builder, Location loc, Value physBuf, + int64_t partTile, int64_t freeTile, int64_t R, + ArrayRef blockIVs) { + SmallVector offsets, sizes, strides; + offsets.push_back(builder.getIndexAttr(0)); + for (Value iv : blockIVs) + offsets.push_back(OpFoldResult(iv)); + offsets.push_back(builder.getIndexAttr(0)); + + sizes.push_back(builder.getIndexAttr(partTile)); + for (int64_t i = 0; i < R; i++) + sizes.push_back(builder.getIndexAttr(1)); + sizes.push_back(builder.getIndexAttr(freeTile)); + + for (int64_t i = 0; i < R + 2; i++) + strides.push_back(builder.getIndexAttr(1)); + + auto subview = builder.create( + loc, physBuf, offsets, sizes, strides); + return builder.create( + loc, subview, build2DCollapseFromPhysical(R + 2)); +} + +static bool isSbuf(Attribute memSpaceAttr) { + if (auto a = dyn_cast_or_null(memSpaceAttr)) + return a.getValue() == nkipy::MemSpaceEnum::Sbuf; + return false; +} + +static bool isHbm(Attribute memSpaceAttr) { + if (auto a = dyn_cast_or_null(memSpaceAttr)) + return a.getValue() == nkipy::MemSpaceEnum::Hbm || + a.getValue() == nkipy::MemSpaceEnum::SharedHbm; + return false; +} + +/// Check if a copy/transpose needs tiling (HBM↔SBUF transfer). +static bool needsTiledTransfer(MemRefType srcType, MemRefType dstType) { + bool srcH = isHbm(srcType.getMemorySpace()); + bool srcS = isSbuf(srcType.getMemorySpace()); + bool dstH = isHbm(dstType.getMemorySpace()); + bool dstS = isSbuf(dstType.getMemorySpace()); + return (srcH && dstS) || (srcS && dstH); +} + +/// Look through memref.cast ops to find the base value. +/// memref.cast changes static type information but doesn't create new data. +static Value lookThroughCast(Value v) { + while (auto castOp = v.getDefiningOp()) { + v = castOp.getSource(); + } + return v; +} + +/// Like lookThroughCast, but also tries to resolve through collapse_shape +/// when the base alloc is in valueMapping (i.e., was actually legalized). +/// This handles chains like: legalized_alloc → collapse_shape → cast → user +/// where Step 1.5 replaced the Phase 0 collapse with one from the legalized alloc. +static Value lookThroughCastAndResolve(Value v, IRMapping &valueMapping) { + v = lookThroughCast(v); + // If the value after cast-stripping is in valueMapping, we're done. + if (valueMapping.lookupOrNull(v)) + return v; + // Try going through collapse_shape to find a mapped (legalized) base alloc. + if (auto collapseOp = v.getDefiningOp()) { + Value src = lookThroughCast(collapseOp.getSrc()); + if (valueMapping.lookupOrNull(src)) + return src; + } + return v; +} + +/// Walk backward from `current` through subviews/casts to find a +/// memref.collapse_shape in the defining chain. Returns it if found, or nullptr. +static memref::CollapseShapeOp findCollapseInDefChain(Value current) { + Value v = current; + while (v) { + if (auto collapseOp = v.getDefiningOp()) + return collapseOp; + if (auto subviewOp = v.getDefiningOp()) { + v = subviewOp.getSource(); + continue; + } + if (auto castOp = v.getDefiningOp()) { + v = castOp.getSource(); + continue; + } + break; + } + return nullptr; +} + +/// Given a tile shape from a collapsed domain (e.g. [128, 128] from a 2D view), +/// expand it back to the alloc's rank using the collapse_shape's reassociation +/// indices. +/// +/// For each reassociation group with multiple source dims, try to assign the +/// full collapsed tile to a single source dim (keeping other dims at 1). +/// This preserves the middle-tile-must-be-1 invariant for legalize-layout. +/// +/// Priority: +/// 1. Exact match: a dim whose size == the collapsed tile. +/// 2. Outermost dim that can hold the tile (srcShape[dim] >= tile). +/// 3. Fallback: distribute across dims from innermost outward. +/// +/// Example: reassoc = [[0, 1], [2]], srcShape = [4, 128, 64], tile = [128, 64] +/// group [0, 1], srcShape=[4,128], tile=128: +/// dim 1 (128) == 128 → exact match → [1, 128] +/// group [2], tile=64: → [64] +/// result: [1, 128, 64] +static SmallVector expandTileShape( + ArrayRef collapsedTile, + memref::CollapseShapeOp collapseOp) { + auto reassoc = collapseOp.getReassociationIndices(); + auto srcType = cast(collapseOp.getSrc().getType()); + auto srcShape = srcType.getShape(); + SmallVector expanded(srcShape.size(), 1); + + for (size_t g = 0; g < reassoc.size(); g++) { + auto &group = reassoc[g]; + int64_t tile = collapsedTile[g]; + + if (group.size() == 1) { + expanded[group[0]] = tile; + continue; + } + + // Try to assign the full tile to a single dim in the group. + // Priority: exact match first, then outermost dim that can hold it. + int bestIdx = -1; + for (size_t i = 0; i < group.size(); i++) { + int64_t srcDim = srcShape[group[i]]; + if (srcDim == tile) { + bestIdx = static_cast(i); + break; // Exact match — stop looking. + } + if (bestIdx == -1 && srcDim >= tile) { + bestIdx = static_cast(i); // First (outermost) holder. + } + } + + if (bestIdx >= 0) { + // Assign full tile to the chosen dim; rest stay 1. + expanded[group[bestIdx]] = tile; + } else { + // No single dim can hold the tile — distribute from innermost outward. + int64_t remaining = tile; + for (int i = static_cast(group.size()) - 1; i >= 0; i--) { + int64_t srcDim = srcShape[group[i]]; + int64_t tileForDim = std::min(srcDim, remaining); + expanded[group[i]] = tileForDim; + if (tileForDim > 0) + remaining /= tileForDim; + } + } + } + return expanded; +} + +/// Trace a value forward through uses to find ALL linalg operands it feeds into. +/// Returns a vector of (linalgOp, operandIndex, operandShape) tuples. +/// +/// Post-bufferization version: follows memref.subview ops to find linalg users. +/// Used to collect tile sizes from all linalg uses of an SBUF allocation. +static void traceToLinalgOperands( + Value val, + SmallVector>> &results) { + + // Track visited values to avoid infinite loops + llvm::SmallPtrSet visited; + std::queue workList; + workList.push(val); + + while (!workList.empty()) { + Value current = workList.front(); + workList.pop(); + + if (visited.contains(current)) + continue; + visited.insert(current); + + for (OpOperand &use : current.getUses()) { + Operation *user = use.getOwner(); + + // Check if this is a linalg op + if (auto linalgOp = dyn_cast(user)) { + // Record the operand shape as tile size. + // If a collapse_shape exists anywhere in the def chain back to the alloc, + // expand the tile shape back to the alloc's rank using the reassociation. + // This handles chains like: alloc[MxBxN] → subview → collapse_shape → subview → linalg + MemRefType operandType = dyn_cast(current.getType()); + if (operandType) { + SmallVector tileShape(operandType.getShape().begin(), + operandType.getShape().end()); + if (auto collapseOp = findCollapseInDefChain(current)) { + // Only expand if tile rank matches the collapsed rank (number of + // reassociation groups). If they differ, the tile went through + // additional rank changes (e.g., a subview that dropped dims after + // the collapse_shape) and cannot be directly expanded back. + // Skip recording this tile — it doesn't represent the alloc's tiling. + auto reassoc = collapseOp.getReassociationIndices(); + if (tileShape.size() == reassoc.size()) { + tileShape = expandTileShape(tileShape, collapseOp); + } else { + LLVM_DEBUG(llvm::dbgs() << " Skipping linalg operand (tile rank " + << tileShape.size() << " != collapsed rank " + << reassoc.size() << ")\n"); + continue; + } + } + LLVM_DEBUG({ + llvm::dbgs() << " Found linalg operand with shape ["; + llvm::interleave(tileShape, llvm::dbgs(), "x"); + llvm::dbgs() << "] in " << linalgOp->getName() << "\n"; + }); + // Find which operand index this use corresponds to + for (unsigned i = 0; i < linalgOp->getNumOperands(); ++i) { + if (linalgOp->getOperand(i) == current) { + results.push_back({linalgOp, i, tileShape}); + break; + } + } + } + continue; // Don't stop - keep looking for more linalg uses + } + + // Follow through memref.subview (post-bufferization) + // Don't record shape here - we'll record it when we see actual uses + // (linalg ops or copy destinations) + if (auto subviewOp = dyn_cast(user)) { + workList.push(subviewOp.getResult()); + continue; + } + + // Follow through memref.cast (type refinement, doesn't change data) + if (auto castOp = dyn_cast(user)) { + workList.push(castOp.getResult()); + continue; + } + + // Follow through memref.collapse_shape (created by Phase 0's foldReshapeIntoAlloc) + if (auto collapseOp = dyn_cast(user)) { + workList.push(collapseOp.getResult()); + continue; + } + + // Follow through memref.expand_shape + if (auto expandOp = dyn_cast(user)) { + workList.push(expandOp.getResult()); + continue; + } + + // Record shape when current value is used as copy destination + // This captures tile sizes for buffers that are written via copy. + // If a collapse_shape exists in the def chain, expand back to alloc rank. + if (auto copyOp = dyn_cast(user)) { + if (copyOp.getTarget() == current) { + MemRefType memrefType = dyn_cast(current.getType()); + if (memrefType) { + SmallVector tileShape(memrefType.getShape().begin(), + memrefType.getShape().end()); + if (auto collapseOp = findCollapseInDefChain(current)) { + auto reassoc = collapseOp.getReassociationIndices(); + if (tileShape.size() == reassoc.size()) { + tileShape = expandTileShape(tileShape, collapseOp); + } else { + LLVM_DEBUG(llvm::dbgs() << " Skipping copy dest (tile rank " + << tileShape.size() << " != collapsed rank " + << reassoc.size() << ")\n"); + continue; + } + } + LLVM_DEBUG({ + llvm::dbgs() << " Found copy destination tile shape ["; + llvm::interleave(tileShape, llvm::dbgs(), "x"); + llvm::dbgs() << "]\n"; + }); + results.push_back({linalg::LinalgOp(nullptr), 0, tileShape}); + } + } + // Don't follow through copies + continue; + } + } + } +} + +struct NkipyLegalizeLayoutPass + : public LegalizeLayoutBase { + + // Track if any errors occurred during the pass + bool hasError = false; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + } + + /// Try to compute a reassociation map that collapses newShape back to origShape. + /// Greedily groups consecutive dims of newShape whose product equals each dim + /// of origShape. Works for any rank combination. + /// + /// Example: origShape=[256,64], newShape=[256,1,64] + /// dim 0: 256 == 256 → group [0] + /// dim 1: 1*64 == 64 → group [1,2] + /// result: [[0], [1,2]] + static bool computeCollapseReassociation( + ArrayRef newShape, + ArrayRef origShape, + SmallVector &reassociation) { + reassociation.clear(); + int64_t newIdx = 0; + int64_t newRank = newShape.size(); + + for (int64_t origDim : origShape) { + ReassociationIndices group; + int64_t product = 1; + + while (newIdx < newRank && product < origDim) { + product *= newShape[newIdx]; + group.push_back(newIdx); + newIdx++; + } + + if (product != origDim || group.empty()) + return false; + + reassociation.push_back(group); + } + + return newIdx == newRank; + } + + /// Pre-processing: fold alloc+reshape patterns into a single higher-rank alloc. + /// + /// Transforms: + /// %a = memref.alloc() : memref + /// memref.copy %src, %a + /// %r = memref.reshape %a(%shape) -> memref + /// Into: + /// %a = memref.alloc() : memref + /// %v = memref.collapse_shape %a [...] -> memref + /// memref.copy %src, %v + /// // uses of %r replaced with %a + /// + /// This eliminates lower-rank SBUF allocs that can't be legalized because + /// their tile sizes are only discoverable through the reshape. + void foldReshapeIntoAlloc(func::FuncOp func) { + SmallVector reshapesToFold; + + func.walk([&](memref::ReshapeOp reshapeOp) { + Value source = reshapeOp.getSource(); + auto sourceType = cast(source.getType()); + auto resultType = cast(reshapeOp.getResult().getType()); + + // Only handle SBUF allocs + if (!isSbuf(sourceType.getMemorySpace())) + return; + auto allocOp = source.getDefiningOp(); + if (!allocOp) + return; + + // Both must have identity layout + if (!sourceType.getLayout().isIdentity() || + !resultType.getLayout().isIdentity()) + return; + + // Must be able to compute a valid collapse reassociation + SmallVector reassociation; + if (!computeCollapseReassociation( + resultType.getShape(), sourceType.getShape(), reassociation)) + return; + + reshapesToFold.push_back(reshapeOp); + }); + + for (auto reshapeOp : reshapesToFold) { + Value source = reshapeOp.getSource(); + auto allocOp = source.getDefiningOp(); + auto sourceType = cast(source.getType()); + auto resultType = cast(reshapeOp.getResult().getType()); + + SmallVector reassociation; + computeCollapseReassociation( + resultType.getShape(), sourceType.getShape(), reassociation); + + OpBuilder builder(allocOp); + + // Create new alloc with the reshaped (higher-rank) type + auto newAllocType = MemRefType::get( + resultType.getShape(), + resultType.getElementType(), + /*layout=*/nullptr, + sourceType.getMemorySpace()); + auto newAlloc = builder.create( + allocOp.getLoc(), newAllocType, allocOp.getAlignmentAttr()); + + // Create collapse_shape to provide the original shape for existing copy users + auto collapseOp = builder.create( + allocOp.getLoc(), newAlloc.getResult(), reassociation); + + LLVM_DEBUG(llvm::dbgs() << " Folded reshape into alloc: " + << sourceType << " -> " << newAllocType << "\n"); + + // Replace uses: old alloc users → collapse_shape, old reshape users → new alloc + allocOp.getResult().replaceAllUsesWith(collapseOp.getResult()); + reshapeOp.getResult().replaceAllUsesWith(newAlloc.getResult()); + + reshapeOp.erase(); + allocOp.erase(); + } + + // Also fold memref.expand_shape on SBUF allocs. + // expand_shape increases rank (e.g., [1,128] → [1,128,1] from expand_dims). + // Without folding, the alloc is legalized at its original rank but the + // expand_shape result has a different rank, causing downstream mismatches. + SmallVector expandsToFold; + + func.walk([&](memref::ExpandShapeOp expandOp) { + Value source = expandOp.getSrc(); + auto sourceType = cast(source.getType()); + + if (!isSbuf(sourceType.getMemorySpace())) + return; + auto allocOp = source.getDefiningOp(); + if (!allocOp) + return; + if (!sourceType.getLayout().isIdentity()) + return; + + expandsToFold.push_back(expandOp); + }); + + for (auto expandOp : expandsToFold) { + Value source = expandOp.getSrc(); + auto allocOp = source.getDefiningOp(); + auto sourceType = cast(source.getType()); + auto resultType = cast(expandOp.getResult().getType()); + + OpBuilder builder(allocOp); + + // Create new alloc with the expanded (higher-rank) type + auto newAllocType = MemRefType::get( + resultType.getShape(), + resultType.getElementType(), + /*layout=*/nullptr, + sourceType.getMemorySpace()); + auto newAlloc = builder.create( + allocOp.getLoc(), newAllocType, allocOp.getAlignmentAttr()); + + // Create collapse_shape to provide the original shape for existing users + // (e.g., HBM→SBUF copies that wrote into the original 2D alloc) + auto reassociation = expandOp.getReassociationIndices(); + auto collapseOp = builder.create( + allocOp.getLoc(), newAlloc.getResult(), reassociation); + + LLVM_DEBUG(llvm::dbgs() << " Folded expand_shape into alloc: " + << sourceType << " -> " << newAllocType << "\n"); + + // Replace uses: old alloc users → collapse_shape, expand_shape users → new alloc + allocOp.getResult().replaceAllUsesWith(collapseOp.getResult()); + expandOp.getResult().replaceAllUsesWith(newAlloc.getResult()); + + expandOp.erase(); + allocOp.erase(); + } + } + + void runOnOperation() override { + func::FuncOp func = getOperation(); + hasError = false; + + LLVM_DEBUG(llvm::dbgs() << "[LegalizeLayout] Processing function: " + << func.getName() << "\n"); + + // Phase 0: Fold alloc+reshape patterns into single higher-rank allocs + // This handles SBUF allocs created by expand_dims bufferization where the + // alloc has a lower rank than the reshape result. Without folding, the + // alloc's tile sizes can't be determined (they're only visible through + // the reshape's downstream subviews). + foldReshapeIntoAlloc(func); + + // Phase 1: Identify all SBUF tensors needing legalization + SmallVector layoutInfos = findSbufTensorsToLegalize(func); + + if (hasError) { + llvm::errs() << "[LegalizeLayout] FAILED in Phase 1 (find SBUF tensors)\n"; + signalPassFailure(); + return; + } + + if (layoutInfos.empty()) { + llvm::errs() << "[LegalizeLayout] No SBUF tensors need legalization\n"; + // Still need to decompose HBM fills even when no SBUF layout changes + decomposeHbmFills(func); + return; + } + + llvm::errs() << "[LegalizeLayout] Found " << layoutInfos.size() + << " SBUF tensor(s) to legalize:\n"; + for (auto &info : layoutInfos) { + llvm::errs() << " tensor<"; + llvm::interleave(info.origShape, llvm::errs(), "x"); + llvm::errs() << "> tile=["; + llvm::interleave(info.tileSize, llvm::errs(), ","); + llvm::errs() << "] numBlocks=["; + llvm::interleave(info.numBlocks, llvm::errs(), ","); + llvm::errs() << "] -> physical<"; + auto phys = info.getPhysicalShape(); + llvm::interleave(phys, llvm::errs(), "x"); + llvm::errs() << ">\n"; + } + + // Phase 2: Transform allocations and subviews to physical (R+2)-D layout + IRMapping valueMapping; + transformToPhysicalLayout(func, layoutInfos, valueMapping); + + if (hasError) { + llvm::errs() << "[LegalizeLayout] FAILED in Phase 2 (transform to physical layout)\n"; + signalPassFailure(); + return; + } + + // Phase 3a: Tile HBM↔SBUF copies and transposes + // These ops require same-rank inputs/outputs, so we generate tiled loops + tileCopyAndTranspose(func, layoutInfos, valueMapping); + + if (hasError) { + llvm::errs() << "[LegalizeLayout] FAILED in Phase 3a (tile copy/transpose)\n"; + signalPassFailure(); + return; + } + + // Phase 3b: Decompose linalg.fill on HBM + // nisa.memset only supports SBUF/PSUM, so fill on HBM is decomposed into: + // alloc SBUF temp → linalg.fill SBUF → scf.for { memref.copy SBUF → HBM } + decomposeHbmFills(func); + + if (hasError) { + llvm::errs() << "[LegalizeLayout] FAILED in Phase 3b (decompose HBM fills)\n"; + signalPassFailure(); + return; + } + + // Phase 4: Fix rank mismatches + // After Phase 2, some ops have (R+2)-D operands but expect R-D: + // - linalg ops: (R+2)-D operands but indexing maps expect R-D + // - memref.copy: PSUM↔SBUF where PSUM is R-D and SBUF is now (R+2)-D + // Insert collapse_shape to convert (R+2)-D -> R-D where needed. + fixRankMismatches(func); + + if (hasError) { + llvm::errs() << "[LegalizeLayout] FAILED in Phase 4 (fix rank mismatches)\n"; + signalPassFailure(); + return; + } + + llvm::errs() << "[LegalizeLayout] Pass completed successfully\n"; + } + + /// Build a map from 2D values to their layout info for quick lookup + DenseMap buildValueToLayoutMap(SmallVector &layoutInfos) { + DenseMap valueMap; + for (auto &info : layoutInfos) { + valueMap[info.originalValue] = &info; + } + return valueMap; + } + + /// Transform all identified SBUF memrefs to (R+2)-D physical layout + /// + /// Post-bufferization algorithm: + /// 1. Transform memref.alloc to (R+2)-D shape + /// 2. Collect all ops using transformed allocations via BFS + /// 3. Transform memref.subview ops to use (R+2)-D indexing + /// + /// Note: memref.copy and linalg.transpose are handled separately in + /// tileCopyAndTranspose() because they require same-rank src/dst. + void transformToPhysicalLayout(func::FuncOp func, SmallVector &layoutInfos, + IRMapping &valueMapping) { + if (layoutInfos.empty()) + return; + + OpBuilder builder(func.getContext()); + auto valueMap = buildValueToLayoutMap(layoutInfos); + + // Step 1: Transform allocations (memref.alloc) + for (auto &info : layoutInfos) { + transformAllocation(builder, info, valueMapping); + if (hasError) return; + } + + // Step 1.5: Update collapse_shape ops from Phase 0 (foldReshapeIntoAlloc) + // + // Phase 0 creates: alloc(R-D) → collapse_shape → alloc_expanded((R+k)-D) + // After Step 1, alloc_expanded is mapped to a new (R+k+2)-D physical alloc. + // The collapse_shape still references the old alloc_expanded — update it to + // reference the new physical alloc with a composed reassociation. + for (auto &info : layoutInfos) { + Value origAlloc = info.originalValue; + Value legAlloc = valueMapping.lookupOrNull(origAlloc); + if (!legAlloc) continue; + + // Walk through reinterpret_cast to get the raw identity-layout alloc. + // Narrow buffers use reinterpret_cast for strided layout, but + // collapse_shape requires contiguous (identity) layout. + Value rawAlloc = legAlloc; + while (auto reinterpret = rawAlloc.getDefiningOp()) + rawAlloc = reinterpret.getSource(); + + SmallVector collapseOps; + for (auto &use : origAlloc.getUses()) { + if (auto collapseOp = dyn_cast(use.getOwner())) + collapseOps.push_back(collapseOp); + } + + for (auto collapseOp : collapseOps) { + // Map the R-D → targetRank reassociation through the (R+2)-D physical layout. + // Physical layout convention: [t0, nB0, nB1, ..., nB_{R-1}, t_{R-1}] + // So R-dim i maps to physical dims: + // i=0 → {0, 1} (tile0, numBlocks0) + // 0 newReassoc; + for (auto &group : oldReassoc) { + ReassociationIndices newGroup; + for (int64_t rdim : group) { + if (rdim == 0) { + newGroup.push_back(0); + newGroup.push_back(1); + } else if (rdim == R - 1) { + newGroup.push_back(R); + newGroup.push_back(R + 1); + } else { + newGroup.push_back(rdim + 1); + } + } + newReassoc.push_back(newGroup); + } + + OpBuilder b(collapseOp); + auto newCollapse = b.create( + collapseOp.getLoc(), rawAlloc, newReassoc); + collapseOp.getResult().replaceAllUsesWith(newCollapse.getResult()); + collapseOp.erase(); + + LLVM_DEBUG(llvm::dbgs() << " Updated Phase 0 collapse_shape to " + << "reference legalized alloc\n"); + } + } + + // Step 2: Collect which ops need transformation via BFS (into a set) + auto opsToTransform = collectOpsToTransform(layoutInfos, valueMapping); + + // Step 3: Walk the function in program order (topological order) + // and transform subview ops that are in the set. + // Note: memref.copy is handled separately in tileCopyAndTranspose() + // + // IMPORTANT: We collect ops to transform first, then do all transforms, + // then do cleanup. This is because replaceAllUsesWith on a parent subview + // would modify the child subview's source operand before we transform it. + SmallVector> subviewReplacements; + + func.walk([&](Operation *op) { + if (hasError) + return WalkResult::interrupt(); + + // Only process ops that need transformation + if (!opsToTransform.contains(op)) + return WalkResult::advance(); + + if (auto subviewOp = dyn_cast(op)) { + auto newSubview = transformSubview(builder, subviewOp, valueMapping, valueMap); + if (newSubview) { + subviewReplacements.push_back({subviewOp, newSubview}); + } + } + + return WalkResult::advance(); + }); + + // Step 4: Now do subview replacements and cleanup + for (auto &[oldOp, newOp] : subviewReplacements) { + oldOp.getResult().replaceAllUsesWith(newOp.getResult()); + oldOp.erase(); + } + + // Step 5: Redirect linalg ops that still reference the original alloc + // (e.g., fill, generic) to a collapse_shape view of the legalized alloc. + // This handles untiled ops that KnobDrivenTiling left as full-buffer accesses. + redirectDirectLinalgUses(layoutInfos, valueMapping); + } + + /// Redirect linalg ops that directly reference original (stale) alloc values. + /// + /// Two cases based on whether the linalg op mixes legalized/non-legalized operands: + /// - All-SBUF (e.g. fill): redirect to collapse_shape view — interleaving is + /// consistent across all operands so element-wise computation is correct. + /// - Mixed (e.g. generic with SBUF input + non-SBUF output): copy legalized + /// data block-by-block into a sequential temp, then redirect to temp. + void redirectDirectLinalgUses(SmallVector &layoutInfos, + IRMapping &valueMapping) { + DenseMap origToLayout; + for (auto &info : layoutInfos) + origToLayout[info.originalValue] = &info; + + // Classify linalg ops BEFORE any redirects (to avoid stale operand checks) + struct UseRecord { + linalg::LinalgOp op; + unsigned operandIdx; + LayoutInfo *info; + bool mixed; + }; + SmallVector records; + + for (auto &info : layoutInfos) { + Value origAlloc = info.originalValue; + if (!valueMapping.lookupOrNull(origAlloc)) continue; + + for (OpOperand &use : origAlloc.getUses()) { + auto linalgOp = dyn_cast(use.getOwner()); + if (!linalgOp) continue; + + // Skip TransposeOp — handled separately by Phase 3a (tileCopyAndTranspose) + if (isa(use.getOwner())) continue; + + // Mixed = any rank≥2 operand is NOT a legalized orig alloc + bool mixed = false; + for (unsigned i = 0; i < linalgOp->getNumOperands(); ++i) { + auto mt = dyn_cast(linalgOp->getOperand(i).getType()); + if (mt && mt.getRank() >= 2 && !origToLayout.count(linalgOp->getOperand(i))) { + mixed = true; + break; + } + } + records.push_back({linalgOp, use.getOperandNumber(), &info, mixed}); + } + } + + // Cache one collapse_shape per legalized alloc (for all-SBUF ops) + DenseMap collapseCache; + // Ops replaced by tiled versions that need to be erased + llvm::SmallPtrSet opsToErase; + + for (auto &rec : records) { + if (opsToErase.contains(rec.op)) continue; + Value legAlloc = valueMapping.lookup(rec.info->originalValue); + + if (!rec.mixed) { + auto legType = cast(legAlloc.getType()); + + if (!legType.getLayout().isIdentity()) { + // Partition-contiguous strides (e.g. [1,128,128,1] for narrow + // buffers): collapse_shape on the full alloc fails because the + // strides aren't row-major contiguous when numBlocks dims > 1. + // Generate a per-tile loop instead: subview each tile (middle dims + // become 1 → collapse works), then clone the linalg op on each tile. + int64_t R = rec.info->rank(); + OpBuilder b(rec.op); + Location loc = rec.op.getLoc(); + + auto nest = createBlockLoopNest(b, loc, rec.info->numBlocks); + Value tileCollapsed = createTileSubviewAndCollapse( + b, loc, legAlloc, rec.info->tileSize[0], + rec.info->tileSize[R - 1], R, nest.ivs); + + // Clone the linalg op into the loop body, redirecting the + // legalized operand to the tile-sized collapsed view + auto *cloned = b.clone(*rec.op); + cloned->setOperand(rec.operandIdx, tileCollapsed); + // Mark original for erasure (will be erased after loop) + opsToErase.insert(rec.op); + LLVM_DEBUG(llvm::dbgs() << " Tiled " << rec.op->getName() + << " over " << R << " block dims for strided buffer\n"); + } else { + // Default row-major layout: collapse_shape on full alloc works. + Value &collapsed = collapseCache[rec.info->originalValue]; + if (!collapsed) { + OpBuilder b(legAlloc.getContext()); + b.setInsertionPointAfterValue(legAlloc); + collapsed = b.create( + legAlloc.getLoc(), legAlloc, + build2DCollapseFromPhysical(rec.info->physicalRank())); + } + rec.op->setOperand(rec.operandIdx, collapsed); + LLVM_DEBUG(llvm::dbgs() << " Redirected " << rec.op->getName() + << " operand " << rec.operandIdx << " to collapsed view\n"); + } + } else { + // Mixed: deinterleave via block-by-block copy into sequential temp + OpBuilder b(rec.op); + Location loc = rec.op.getLoc(); + int64_t R = rec.info->rank(); + auto elemTy = cast(legAlloc.getType()).getElementType(); + auto temp = b.create( + loc, MemRefType::get(rec.info->origShape, elemTy)); + + auto nest = createBlockLoopNest(b, loc, rec.info->numBlocks); + Value legC = createTileSubviewAndCollapse( + b, loc, legAlloc, rec.info->tileSize[0], + rec.info->tileSize[R - 1], R, nest.ivs); + + // Temp subview [b0*t0, ...][t0, ...] + SmallVector so, ss, sr; + for (int64_t d = 0; d < R; d++) { + if (rec.info->tileSize[d] == 1) { + so.push_back(OpFoldResult(nest.ivs[d])); + } else { + Value ts = b.create(loc, rec.info->tileSize[d]); + so.push_back(OpFoldResult(b.create(loc, nest.ivs[d], ts))); + } + ss.push_back(b.getIndexAttr(rec.info->tileSize[d])); + sr.push_back(b.getIndexAttr(1)); + } + auto seqSV = b.create(loc, temp, so, ss, sr); + b.create(loc, legC, seqSV); + + rec.op->setOperand(rec.operandIdx, temp.getResult()); + LLVM_DEBUG(llvm::dbgs() << " Deinterleaved " << rec.op->getName() + << " operand " << rec.operandIdx << " via temp copy\n"); + } + } + + // Erase original ops that were replaced by tiled versions + for (auto *op : opsToErase) + op->erase(); + } + + /// Convert an R-D offset to a block index by dividing by tile size. + /// + /// NOTE: This is fragile! It assumes offsets after canonicalize-loop-step + /// are either: + /// 1. Static constants (divisible by tileSize) + /// 2. Dynamic values of the form: arith.muli %idx, %tileSize + /// + /// For case 2, we pattern-match and extract %idx directly (avoiding division). + /// If the pattern doesn't match, we emit an error. + std::optional computeBlockIndex(OpBuilder &builder, + OpFoldResult offset2D, + int64_t tileSize, + Location loc) { + // Special case: tile_size == 1, block index equals the offset directly. + // This handles middle dims of R-D tensors where tile=1 and the offset + // is a bare loop induction variable (not wrapped in arith.muli). + if (tileSize == 1) { + return offset2D; + } + + // Case 1: Static constant offset + if (auto attr = dyn_cast(offset2D)) { + if (auto intAttr = dyn_cast(attr)) { + int64_t val = intAttr.getInt(); + if (val % tileSize != 0) { + llvm::errs() << "[LegalizeLayout] Error: static offset " << val + << " not divisible by tile size " << tileSize << "\n"; + return std::nullopt; + } + return builder.getIndexAttr(val / tileSize); + } + } + + // Case 2: Dynamic value - try to pattern match arith.muli %idx, %multiplier + // where multiplier is divisible by tileSize + Value offsetVal = cast(offset2D); + if (auto mulOp = offsetVal.getDefiningOp()) { + // Check if RHS is a constant divisible by tileSize + if (auto constOp = mulOp.getRhs().getDefiningOp()) { + if (auto intAttr = dyn_cast(constOp.getValue())) { + int64_t multiplier = intAttr.getInt(); + if (multiplier % tileSize == 0) { + int64_t scale = multiplier / tileSize; + if (scale == 1) { + // offset = idx * tileSize, so block_index = idx + return OpFoldResult(mulOp.getLhs()); + } else { + // offset = idx * (scale * tileSize), so block_index = idx * scale + Value scaleVal = builder.create(loc, scale); + Value blockIdx = builder.create(loc, mulOp.getLhs(), scaleVal); + return OpFoldResult(blockIdx); + } + } + } + } + // Also check LHS in case constant is on the left + if (auto constOp = mulOp.getLhs().getDefiningOp()) { + if (auto intAttr = dyn_cast(constOp.getValue())) { + int64_t multiplier = intAttr.getInt(); + if (multiplier % tileSize == 0) { + int64_t scale = multiplier / tileSize; + if (scale == 1) { + return OpFoldResult(mulOp.getRhs()); + } else { + Value scaleVal = builder.create(loc, scale); + Value blockIdx = builder.create(loc, mulOp.getRhs(), scaleVal); + return OpFoldResult(blockIdx); + } + } + } + } + } + + // Case 3: Check if it's just 0 (common for first dimension) + if (auto constOp = offsetVal.getDefiningOp()) { + if (auto intAttr = dyn_cast(constOp.getValue())) { + int64_t val = intAttr.getInt(); + if (val % tileSize != 0) { + llvm::errs() << "[LegalizeLayout] Error: constant offset " << val + << " not divisible by tile size " << tileSize << "\n"; + return std::nullopt; + } + return builder.getIndexAttr(val / tileSize); + } + } + + llvm::errs() << "[LegalizeLayout] Error: cannot compute block index from dynamic offset. " + << "Expected pattern: constant (divisible by " << tileSize << ") or " + << "arith.muli %idx, " << tileSize << "\n"; + return std::nullopt; + } + + /// Transform memref.subview to physical-layout subview with correct sizes/offsets + /// + /// For rank-R tensor with tile [t_0, ..., t_{R-1}]: + /// R-D subview: [off_0, ..., off_{R-1}][sz_0, ..., sz_{R-1}] + /// (R+2)-D subview: [0, off_0/t_0, ..., off_{R-1}/t_{R-1}, 0] + /// [t_0, sz_0/t_0, ..., sz_{R-1}/t_{R-1}, t_{R-1}] + /// + /// The first dim (partition tile) and last dim (free tile) always span full tiles. + /// The middle R dims are block indices computed from the original offsets/sizes. + memref::SubViewOp transformSubview(OpBuilder &builder, memref::SubViewOp op, + IRMapping &valueMapping, + DenseMap &valueMap) { + Value source = op.getSource(); + + // Find the physical source value + Value sourcePhys = valueMapping.lookupOrNull(source); + if (!sourcePhys) { + // Source wasn't transformed - skip + return nullptr; + } + + // Find layout info for this source + LayoutInfo *info = findLayoutInfo(source, valueMapping, valueMap); + if (!info) { + LLVM_DEBUG(llvm::dbgs() << " Warning: subview source not in layout map\n"); + return nullptr; + } + + int64_t R = info->rank(); + + // Get the R-D offsets and sizes from original subview + auto mixedOffsets = op.getMixedOffsets(); + auto mixedSizes = op.getMixedSizes(); + if ((int64_t)mixedOffsets.size() != R || (int64_t)mixedSizes.size() != R) { + llvm::errs() << "[LegalizeLayout] Error: expected " << R << "D subview but got " + << mixedOffsets.size() << "D offsets, " + << mixedSizes.size() << "D sizes\n"; + hasError = true; + return nullptr; + } + + builder.setInsertionPoint(op); + Location loc = op.getLoc(); + + // Get static sizes (we require them to be static for now) + auto staticSizes = op.getStaticSizes(); + for (int64_t i = 0; i < R; i++) { + if (staticSizes[i] == ShapedType::kDynamic) { + llvm::errs() << "[LegalizeLayout] Error: dynamic subview sizes not supported\n"; + hasError = true; + return nullptr; + } + } + + // Validate divisibility for all dims + for (int64_t i = 0; i < R; i++) { + if (staticSizes[i] % info->tileSize[i] != 0) { + llvm::errs() << "[LegalizeLayout] Error: subview size " << staticSizes[i] + << " in dim " << i << " not divisible by tile size " << info->tileSize[i] << "\n"; + hasError = true; + return nullptr; + } + } + + // Compute block-index offsets from R-D offsets + SmallVector> blockIndices; + for (int64_t i = 0; i < R; i++) { + auto idx = computeBlockIndex(builder, mixedOffsets[i], info->tileSize[i], loc); + if (!idx) { + hasError = true; + return nullptr; + } + blockIndices.push_back(idx); + } + + // Build (R+2)-dim offsets: [0, blockIdx_0, ..., blockIdx_{R-1}, 0] + SmallVector offsetsPhys; + offsetsPhys.push_back(builder.getIndexAttr(0)); // Partition tile: always 0 + for (int64_t i = 0; i < R; i++) + offsetsPhys.push_back(*blockIndices[i]); + offsetsPhys.push_back(builder.getIndexAttr(0)); // Free tile: always 0 + + // Build (R+2)-dim sizes: [tileSize[0], sz_0/t_0, ..., sz_{R-1}/t_{R-1}, tileSize[R-1]] + SmallVector sizesPhys; + sizesPhys.push_back(builder.getIndexAttr(info->tileSize[0])); + for (int64_t i = 0; i < R; i++) + sizesPhys.push_back(builder.getIndexAttr(staticSizes[i] / info->tileSize[i])); + sizesPhys.push_back(builder.getIndexAttr(info->tileSize[R - 1])); + + // Build (R+2)-dim strides: all 1s + SmallVector stridesPhys(info->physicalRank(), builder.getIndexAttr(1)); + + // Create the new physical subview + // Result type is inferred - rank reduction happens for dims with size=1 + auto newSubview = builder.create( + loc, sourcePhys, offsetsPhys, sizesPhys, stridesPhys); + + // Add mapping from original result to new result + valueMapping.map(op.getResult(), newSubview.getResult()); + + // Add layout info for the new subview result so nested subviews can find it + valueMap[newSubview.getResult()] = info; + + LLVM_DEBUG(llvm::dbgs() << " Transformed subview: " + << op.getSourceType() << " -> " << newSubview.getType() << "\n"); + + return newSubview; + } + + /// Transform a memref allocation to physical layout shape + /// + /// Post-bufferization: transforms memref.alloc from R-D to (R+2)-D shape + /// Input: %alloc = memref.alloc() : memref> + /// Output: %alloc_phys = memref.alloc() : memref> + void transformAllocation(OpBuilder &builder, LayoutInfo &info, IRMapping &valueMapping) { + Value origValue = info.originalValue; + auto physShape = info.getPhysicalShape(); + + // Post-bufferization: must be memref.alloc + auto allocOp = origValue.getDefiningOp(); + if (!allocOp) { + llvm::errs() << "[LegalizeLayout] Error: expected memref.alloc but got "; + if (auto defOp = origValue.getDefiningOp()) + llvm::errs() << defOp->getName() << "\n"; + else + llvm::errs() << "block argument\n"; + hasError = true; + return; + } + + auto origType = allocOp.getType(); + + // Determine layout for the physical memref. + // + // For narrow SBUF buffers (free dim < 128) with multiple tile rows, + // default row-major strides make the partition dimension non-contiguous. + // Example: shape [128, 2, 1, 1] gets strides [2, 1, 1, 1], so accessing + // tile 1 gives stride-2 partition access [[2,128],[1,1]] — hardware rejects + // because the free dim (1 element) doesn't fill a full partition width (128). + // When free dim = 128, the DMA engine handles arbitrary partition strides. + // + // Fix: use a partition-contiguous layout where dim 0 has stride 1 and the + // block dimensions are outer. For [128, 2, 1, 1] this gives strides + // [1, 128, 128, 1], so each tile's partitions are contiguous in memory. + MemRefLayoutAttrInterface layout; + int64_t R = info.rank(); + int64_t tileN = info.tileSize[R - 1]; + int64_t numBlocksM = info.numBlocks[0]; + if (tileN < 128 && numBlocksM > 1) { + int64_t t0 = info.tileSize[0]; + int64_t physRank = physShape.size(); + SmallVector strides(physRank); + // Partition dim (first) is contiguous + strides[0] = 1; + // Free dim (last): stride = t0 so that partition elements are contiguous + // for each free index (avoids aliasing with partition stride in flat memory) + strides[physRank - 1] = t0; + // Block dims from right to left: each block region is t0 * tileN elements + int64_t stride = t0 * tileN; + for (int64_t i = R; i >= 1; --i) { + strides[i] = stride; + stride *= physShape[i]; + } + layout = StridedLayoutAttr::get(builder.getContext(), /*offset=*/0, strides); + LLVM_DEBUG(llvm::dbgs() << " Using partition-contiguous strides for narrow buffer\n"); + } + + builder.setInsertionPoint(allocOp); + Location loc = allocOp.getLoc(); + + // Always allocate with default (identity) layout — LLVM lowering requires + // contiguous allocs. For strided layouts, apply a reinterpret_cast after. + auto defaultType = MemRefType::get( + physShape, + origType.getElementType(), + /*layout=*/nullptr, + origType.getMemorySpace() + ); + + auto newAlloc = builder.create( + loc, + defaultType, + /*dynamicSizes=*/ValueRange{}, + /*symbolOperands=*/ValueRange{}, + allocOp.getAlignmentAttr() + ); + + Value result = newAlloc.getResult(); + + // For narrow buffers: reinterpret_cast to the partition-contiguous layout + if (layout) { + auto stridedType = MemRefType::get( + physShape, + origType.getElementType(), + layout, + origType.getMemorySpace() + ); + SmallVector sizes, strides; + for (auto s : physShape) + sizes.push_back(builder.getIndexAttr(s)); + auto stridedAttr = cast(layout); + for (auto s : stridedAttr.getStrides()) + strides.push_back(builder.getIndexAttr(s)); + auto reinterpret = builder.create( + loc, stridedType, newAlloc.getResult(), + /*offset=*/builder.getIndexAttr(0), sizes, strides); + result = reinterpret.getResult(); + } + + valueMapping.map(origValue, result); + // Also self-map the legalized alloc so lookThroughCastAndResolve can + // identify it when tracing through collapse_shape ops from Step 1.5. + // Without this, the legalized alloc is only a VALUE in the mapping + // (not a KEY), so lookupOrNull would miss it. + valueMapping.map(result, result); + LLVM_DEBUG(llvm::dbgs() << " Transformed memref.alloc to physical layout: " + << result.getType() << "\n"); + } + + /// Collect all memref operations that need transformation via BFS + /// + /// Post-bufferization: memrefs are modified in-place, so we follow: + /// alloc -> copy (src or dest) -> subview -> linalg + /// + /// Returns a set of ops that need transformation + /// (iteration order will be determined by func.walk for topological order) + llvm::SmallPtrSet collectOpsToTransform(SmallVector &layoutInfos, + IRMapping &valueMapping) { + llvm::SmallPtrSet opsToTransform; + llvm::SmallPtrSet visited; + std::queue worklist; + + for (auto &info : layoutInfos) { + worklist.push(info.originalValue); + } + + while (!worklist.empty()) { + Value current = worklist.front(); + worklist.pop(); + + if (visited.contains(current)) + continue; + visited.insert(current); + + // For each use of current value + for (OpOperand &use : current.getUses()) { + Operation *user = use.getOwner(); + + if (opsToTransform.contains(user)) + continue; + + // Skip dealloc and annotate ops + if (isa(user)) { + continue; + } + + // Add to set if it's a copy or subview that needs transformation + if (isa(user)) { + opsToTransform.insert(user); + } + + // Follow through subview result - linalg ops use this + if (auto subviewOp = dyn_cast(user)) { + worklist.push(subviewOp.getResult()); + } + } + } + + LLVM_DEBUG(llvm::dbgs() << " Collected " << opsToTransform.size() + << " ops to transform\n"); + return opsToTransform; + } + + /// Tile HBM↔SBUF copy and transpose operations + /// + /// memref.copy and linalg.transpose require same-rank src/dst. + /// When one operand is R-D HBM and the other is (R+2)-D SBUF, we need to: + /// 1. Generate a tiled R-level loop nest (over block dimensions) + /// 2. Create subviews of both R-D HBM and (R+2)-D SBUF for each tile + /// 3. Collapse SBUF subview to R-D, then copy/transpose tile by tile + /// + /// Cases handled: + /// - HBM→SBUF / SBUF→HBM: R-level loop with collapse + /// - SBUF→SBUF: R-level loop with permuted block indices + /// - PSUM↔HBM, PSUM↔SBUF: No transform needed (already tile-sized) + void tileCopyAndTranspose(func::FuncOp func, SmallVector &layoutInfos, + IRMapping &valueMapping) { + OpBuilder builder(func.getContext()); + auto valueMap = buildValueToLayoutMap(layoutInfos); + + // Collect copy and transpose ops to tile (can't modify while walking) + SmallVector copiesToTile; + SmallVector transposesToTile; + + func.walk([&](Operation *op) { + if (auto copyOp = dyn_cast(op)) { + // Check if this copy needs tiling (HBM↔SBUF transfer) + Value src = copyOp.getSource(); + Value dst = copyOp.getTarget(); + auto srcType = cast(src.getType()); + auto dstType = cast(dst.getType()); + + if (needsTiledTransfer(srcType, dstType)) { + copiesToTile.push_back(copyOp); + } + } else if (auto transposeOp = dyn_cast(op)) { + Value input = transposeOp.getDpsInputs()[0]; + Value output = transposeOp.getDpsInits()[0]; + // Look through casts (SBUF→SBUF transpose input may be a cast result) + Value inputBase = lookThroughCast(input); + auto inputBaseType = cast(inputBase.getType()); + auto outputType = cast(output.getType()); + + // Tile HBM↔SBUF transposes (existing) and SBUF→SBUF transposes (new) + if (needsTiledTransfer(inputBaseType, outputType) || + (isSbuf(inputBaseType.getMemorySpace()) && + isSbuf(outputType.getMemorySpace()))) { + transposesToTile.push_back(transposeOp); + } + } + }); + + LLVM_DEBUG(llvm::dbgs() << " Found " << copiesToTile.size() + << " copies and " << transposesToTile.size() + << " transposes to tile\n"); + + // Process copies + for (auto copyOp : copiesToTile) { + tileMemrefCopy(builder, copyOp, valueMapping, valueMap); + if (hasError) return; + } + + // Process transposes + for (auto transposeOp : transposesToTile) { + tileTranspose(builder, transposeOp, valueMapping, valueMap); + if (hasError) return; + } + } + + /// Tile a memref.copy between R-D HBM and (R+2)-D SBUF + /// + /// Generates an R-level loop nest (one scf.for per block dimension). + /// For each block: creates an R-D HBM subview and an (R+2)-D SBUF subview, + /// collapses the SBUF subview to R-D, then copies. + /// Look through collapse_shape/expand_shape to find a value that has a + /// mapping in valueMapping. This handles the pattern created by Phase 0 + /// (foldReshapeIntoAlloc) where a collapse_shape sits between the legalized + /// alloc and the copy. + Value resolveToMappedValue(Value v, IRMapping &valueMapping) { + // Direct lookup first + if (Value mapped = valueMapping.lookupOrNull(v)) + return mapped; + + // Look through collapse_shape: the source alloc may be mapped + if (auto collapseOp = v.getDefiningOp()) { + if (Value mapped = valueMapping.lookupOrNull(collapseOp.getSrc())) + return mapped; + } + + // Look through expand_shape + if (auto expandOp = v.getDefiningOp()) { + if (Value mapped = valueMapping.lookupOrNull(expandOp.getSrc())) + return mapped; + } + + return v; + } + + void tileMemrefCopy(OpBuilder &builder, memref::CopyOp op, + IRMapping &valueMapping, + DenseMap &valueMap) { + Value src = op.getSource(); + Value dst = op.getTarget(); + + // Look up mapped values (transformations from Phase 2) + // Also look through collapse_shape/expand_shape created by Phase 0 + Value srcMapped = resolveToMappedValue(src, valueMapping); + Value dstMapped = resolveToMappedValue(dst, valueMapping); + + auto srcType = cast(srcMapped.getType()); + auto dstType = cast(dstMapped.getType()); + + // Determine which is HBM (R-D) and which is SBUF ((R+2)-D after transformation) + bool srcIsSbuf = isSbuf(srcType.getMemorySpace()); + bool dstIsSbuf = isSbuf(dstType.getMemorySpace()); + + Value bufHBM = srcIsSbuf ? dst : src; // HBM stays R-D + Value bufSBUF = srcIsSbuf ? srcMapped : dstMapped; // SBUF is now (R+2)-D + auto bufHBMType = cast(bufHBM.getType()); + auto bufSBUFType = cast(bufSBUF.getType()); + + int64_t hbmRank = bufHBMType.getRank(); + + // If SBUF wasn't transformed (partition dim ≤ 128), skip tiling + if (bufSBUFType.getRank() <= hbmRank) { + LLVM_DEBUG(llvm::dbgs() << " Skipping tiling for copy (SBUF not transformed): " + << srcType << " -> " << dstType << "\n"); + return; + } + + // Extract tile info from physical shape: [partTile, nB_0, ..., nB_{R-1}, freeTile] + auto physShape = bufSBUFType.getShape(); + int64_t R = physShape.size() - 2; // logical rank + int64_t partTile = physShape[0]; + int64_t freeTile = physShape[R + 1]; + + // Collect tile sizes for each logical dim from the physical shape + // tileSize[0] = partTile, tileSize[R-1] = freeTile, middle = 1 + SmallVector tileSizePerDim(R); + tileSizePerDim[0] = partTile; + for (int64_t i = 1; i < R - 1; i++) + tileSizePerDim[i] = 1; // middle dims always tile=1 + if (R > 1) + tileSizePerDim[R - 1] = freeTile; + + // Check if the copy's SBUF operand goes through a collapse_shape (Phase 0 + // pattern). When Phase 0 folds a reshape into the alloc, the copy still + // operates at the original (lower) rank via a collapse_shape view. We need + // the reassociation to build the HBM subview at its actual rank. + SmallVector> hbmReassoc; + Value sbufCopyOperand = srcIsSbuf ? src : dst; + if (hbmRank < R) { + if (auto collapseOp = sbufCopyOperand.getDefiningOp()) { + for (auto &indices : collapseOp.getReassociationIndices()) + hbmReassoc.push_back(SmallVector(indices.begin(), indices.end())); + } + if (hbmReassoc.empty()) { + llvm::errs() << "[LegalizeLayout] Error: HBM rank " << hbmRank + << " < SBUF logical rank " << R + << " but no collapse_shape found\n"; + hasError = true; + return; + } + } + + builder.setInsertionPoint(op); + Location loc = op.getLoc(); + + // Generate R-level nested loop nest (one scf.for per block dim) + SmallVector numBlocksVec(physShape.begin() + 1, physShape.begin() + 1 + R); + auto nest = createBlockLoopNest(builder, loc, numBlocksVec); + auto &blockIdxVars = nest.ivs; + Value c0 = builder.create(loc, 0); + + // Build HBM subview. + // When hbmRank == R: straightforward R-D subview. + // When hbmRank < R (Phase 0 pattern): build hbmRank-D subview using the + // collapse_shape reassociation to map R block indices to hbmRank HBM dims. + SmallVector offsetsHBM, sizesHBM, stridesHBM; + if (hbmReassoc.empty()) { + // Normal case: HBM and SBUF have the same logical rank + for (int64_t i = 0; i < R; i++) { + if (tileSizePerDim[i] == 1) { + offsetsHBM.push_back(OpFoldResult(blockIdxVars[i])); + } else { + Value tileSizeVal = builder.create(loc, tileSizePerDim[i]); + Value offset = builder.create(loc, blockIdxVars[i], tileSizeVal); + offsetsHBM.push_back(OpFoldResult(offset)); + } + sizesHBM.push_back(builder.getIndexAttr(tileSizePerDim[i])); + stridesHBM.push_back(builder.getIndexAttr(1)); + } + } else { + // Phase 0 pattern: HBM has fewer dims than SBUF logical rank. + // Each HBM dim corresponds to a group of SBUF logical dims via the + // collapse_shape reassociation. + // For each group, compute: + // size = product of tileSizePerDim[i] for i in group + // offset = linearized index from block indices within the group + for (auto &group : hbmReassoc) { + int64_t combinedSize = 1; + for (int64_t idx : group) + combinedSize *= tileSizePerDim[idx]; + + // Compute linearized offset within this group. + // offset = sum_i( blockIdx[group[i]] * product(tileSizePerDim[group[j]] for j>i) ) + Value offset = nullptr; + int64_t innerProduct = combinedSize; + for (int64_t i = 0; i < (int64_t)group.size(); i++) { + int64_t dimIdx = group[i]; + innerProduct /= tileSizePerDim[dimIdx]; + if (innerProduct == 1 && !offset) { + // Last or only contributing dim: offset += blockIdx * tileSize + if (tileSizePerDim[dimIdx] == 1) { + offset = blockIdxVars[dimIdx]; + } else { + Value ts = builder.create(loc, tileSizePerDim[dimIdx]); + offset = builder.create(loc, blockIdxVars[dimIdx], ts); + } + } else if (innerProduct >= 1) { + int64_t stride = tileSizePerDim[dimIdx] * innerProduct; + Value strideVal = builder.create(loc, stride); + Value term = builder.create(loc, blockIdxVars[dimIdx], strideVal); + offset = offset ? builder.create(loc, offset, term).getResult() + : term; + } + } + if (!offset) + offset = c0; + + offsetsHBM.push_back(OpFoldResult(offset)); + sizesHBM.push_back(builder.getIndexAttr(combinedSize)); + stridesHBM.push_back(builder.getIndexAttr(1)); + } + } + + auto tileHBM = builder.create( + loc, bufHBM, offsetsHBM, sizesHBM, stridesHBM); + + // Build (R+2)-D SBUF subview and collapse to 2D + Value tileSBUFCollapsed = createTileSubviewAndCollapse( + builder, loc, bufSBUF, partTile, freeTile, R, blockIdxVars); + + // For R>2 (or hbmRank>2), collapse the HBM tile to 2D: + // [[0, ..., N-2], [N-1]] where N = actual HBM tile rank + int64_t hbmTileRank = hbmReassoc.empty() ? R : hbmRank; + Value tileHBM2D = tileHBM; + if (hbmTileRank > 2) { + auto reassocHBM = build2DCollapseFromLogical(hbmTileRank); + tileHBM2D = builder.create( + loc, tileHBM, reassocHBM); + } + + // Create copy for this tile (both are now 2D) + if (srcIsSbuf) { + builder.create(loc, tileSBUFCollapsed, tileHBM2D); + } else { + builder.create(loc, tileHBM2D, tileSBUFCollapsed); + } + + LLVM_DEBUG(llvm::dbgs() << " Tiled memref.copy: " << srcType + << " -> " << dstType << "\n"); + + // Erase original copy + op.erase(); + } + + /// Decompose linalg.fill on HBM into SBUF fill + tiled copy to HBM. + /// + /// nisa.memset only supports SBUF/PSUM destinations, so a fill on HBM must + /// be decomposed before linalg-to-nisa. The pattern mirrors tileMemrefCopy: + /// + /// %sbuf = memref.alloc() : memref> + /// linalg.fill ins(%cst) outs(%sbuf) + /// scf.for %i = 0 to numBlocks step 1 { + /// %hbm_tile = memref.subview %hbm[%i*P, 0][P, F][1, 1] + /// memref.copy %sbuf, %hbm_tile + /// } + /// + /// where P = min(partition_dim, 128). + void decomposeHbmFills(func::FuncOp func) { + SmallVector fillsToDecompose; + + func.walk([&](linalg::FillOp fillOp) { + Value output = fillOp.getOutputs()[0]; + auto outputType = dyn_cast(output.getType()); + if (!outputType) + return; + if (!isHbm(outputType.getMemorySpace())) + return; + // Only handle static shapes + if (!outputType.hasStaticShape()) + return; + // Need at least 2D for SBUF (partition + free) + if (outputType.getRank() < 2) + return; + fillsToDecompose.push_back(fillOp); + }); + + if (fillsToDecompose.empty()) + return; + + LLVM_DEBUG(llvm::dbgs() << " Decomposing " << fillsToDecompose.size() + << " linalg.fill on HBM\n"); + + OpBuilder builder(func.getContext()); + + for (auto fillOp : fillsToDecompose) { + builder.setInsertionPoint(fillOp); + Location loc = fillOp.getLoc(); + + Value scalarValue = fillOp.getInputs()[0]; + Value hbmBuf = fillOp.getOutputs()[0]; + auto hbmType = cast(hbmBuf.getType()); + auto hbmShape = hbmType.getShape(); + int64_t rank = hbmType.getRank(); + + // Partition dim capped at MAX_PARTITION_DIM (128) + int64_t partDim = hbmShape[0]; + int64_t partTile = std::min(partDim, MAX_PARTITION_DIM); + int64_t numBlocks = (partDim + partTile - 1) / partTile; + + // Build SBUF shape: [partTile, numBlocks * hbmShape[1], hbmShape[2], ...] + // Fold numBlocks into dim 1 so SBUF holds all the data, + // matching the DMA loading pattern in tileMemrefCopy. + SmallVector sbufShape; + sbufShape.push_back(partTile); + sbufShape.push_back(numBlocks * hbmShape[1]); + for (int64_t i = 2; i < rank; ++i) + sbufShape.push_back(hbmShape[i]); + + // Alloc SBUF temp + auto sbufMemSpace = nkipy::MemSpaceEnumAttr::get( + builder.getContext(), nkipy::MemSpaceEnum::Sbuf); + auto sbufType = MemRefType::get( + sbufShape, hbmType.getElementType(), nullptr, sbufMemSpace); + auto sbufAlloc = builder.create(loc, sbufType); + + // Fill the entire SBUF with the constant + builder.create(loc, scalarValue, sbufAlloc.getResult()); + + if (numBlocks == 1) { + // Fits in one SBUF tile — single copy, no loop needed + builder.create(loc, sbufAlloc.getResult(), hbmBuf); + } else { + // Build scf.for loop over blocks, subviewing both SBUF and HBM + int64_t freeDim = hbmShape[1]; + + Value c0 = builder.create(loc, 0); + Value c1 = builder.create(loc, 1); + Value numBlocksVal = + builder.create(loc, numBlocks); + Value partTileVal = + builder.create(loc, partTile); + Value freeDimVal = + builder.create(loc, freeDim); + + auto loop = builder.create(loc, c0, numBlocksVal, c1); + builder.setInsertionPointToStart(loop.getBody()); + Value iv = loop.getInductionVar(); + + // SBUF subview: offsets=[0, iv*freeDim, 0, ...], + // sizes=[partTile, freeDim, hbmShape[2], ...] + Value sbufDim1Offset = + builder.create(loc, iv, freeDimVal); + SmallVector sbufOffsets, sbufSizes, sbufStrides; + sbufOffsets.push_back(builder.getIndexAttr(0)); + sbufOffsets.push_back(OpFoldResult(sbufDim1Offset)); + sbufSizes.push_back(builder.getIndexAttr(partTile)); + sbufSizes.push_back(builder.getIndexAttr(freeDim)); + sbufStrides.push_back(builder.getIndexAttr(1)); + sbufStrides.push_back(builder.getIndexAttr(1)); + for (int64_t i = 2; i < rank; ++i) { + sbufOffsets.push_back(builder.getIndexAttr(0)); + sbufSizes.push_back(builder.getIndexAttr(hbmShape[i])); + sbufStrides.push_back(builder.getIndexAttr(1)); + } + auto sbufTile = builder.create( + loc, sbufAlloc.getResult(), sbufOffsets, sbufSizes, sbufStrides); + + // HBM subview: offsets=[iv*partTile, 0, ...], + // sizes=[partTile, freeDim, hbmShape[2], ...] + Value hbmPartOffset = + builder.create(loc, iv, partTileVal); + SmallVector hbmOffsets, hbmSizes, hbmStrides; + hbmOffsets.push_back(OpFoldResult(hbmPartOffset)); + hbmSizes.push_back(builder.getIndexAttr(partTile)); + hbmStrides.push_back(builder.getIndexAttr(1)); + for (int64_t i = 1; i < rank; ++i) { + hbmOffsets.push_back(builder.getIndexAttr(0)); + hbmSizes.push_back(builder.getIndexAttr(hbmShape[i])); + hbmStrides.push_back(builder.getIndexAttr(1)); + } + auto hbmTile = builder.create( + loc, hbmBuf, hbmOffsets, hbmSizes, hbmStrides); + + // Copy SBUF tile → HBM tile + builder.create(loc, sbufTile, hbmTile); + + builder.setInsertionPointAfter(loop); + } + + LLVM_DEBUG(llvm::dbgs() << " Decomposed linalg.fill on HBM " + << hbmType << " -> SBUF " << sbufType << "\n"); + + // Erase original fill + fillOp.erase(); + } + } + + /// Tile a linalg.transpose between HBM and SBUF (or SBUF→SBUF) + /// + /// Generalized for rank-R tensors with (R+2)-D physical layout. + /// Generates an R-level loop nest over block indices, applies permutation + /// to map output block indices to source block indices. + void tileTranspose(OpBuilder &builder, linalg::TransposeOp op, + IRMapping &valueMapping, + DenseMap &valueMap) { + Value input = op.getDpsInputs()[0]; + Value output = op.getDpsInits()[0]; + + // Look through casts (and collapse_shape if the base alloc was legalized) + // to find the base alloc for valueMapping lookup. + // This handles chains like: legalized_alloc → collapse_shape → cast → transpose + // where Step 1.5 replaced the Phase 0 collapse with one from the legalized alloc. + Value inputBase = lookThroughCastAndResolve(input, valueMapping); + + // Look up mapped values (transformations from Phase 2) + Value inputMapped = valueMapping.lookupOrDefault(inputBase); + Value outputMapped = valueMapping.lookupOrDefault(output); + + auto inputMappedType = cast(inputMapped.getType()); + auto outputMappedType = cast(outputMapped.getType()); + + // Determine memory spaces + bool inputIsSbuf = isSbuf(inputMappedType.getMemorySpace()); + bool outputIsSbuf = isSbuf(outputMappedType.getMemorySpace()); + + auto permutation = op.getPermutation(); + + // Handle SBUF→SBUF transpose: at least one side is (R+2)-D after legalization + if (inputIsSbuf && outputIsSbuf) { + auto dstPhysShape = outputMappedType.getShape(); + auto srcPhysShape = inputMappedType.getShape(); + + // If NEITHER side was expanded to physical layout (both mapped ranks + // equal the original ranks), leave the transpose for linalg-to-nisa. + // This handles boundary transposes from canonicalize-partition-dim + // which operate on tile-sized buffers (partition dim ≤ 128). + int64_t origInputRank = cast(input.getType()).getRank(); + int64_t origOutputRank = cast(output.getType()).getRank(); + if (inputMappedType.getRank() == origInputRank && + outputMappedType.getRank() == origOutputRank) { + LLVM_DEBUG(llvm::dbgs() << " Skipping SBUF→SBUF transpose " + << "(not expanded to physical layout): " + << inputMappedType << " -> " << outputMappedType << "\n"); + return; + } + + bool inputLegalized = (inputMappedType.getRank() != origInputRank); + bool outputLegalized = (outputMappedType.getRank() != origOutputRank); + + // Determine physical parameters from the legalized side. + // If both legalized: use output. If only one: use that one. + auto physRefShape = outputLegalized ? dstPhysShape : srcPhysShape; + int64_t R = physRefShape.size() - 2; // logical rank + int64_t partTile = physRefShape[0]; + int64_t freeTile = physRefShape[R + 1]; + + builder.setInsertionPoint(op); + Location loc = op.getLoc(); + + // R-level loop nest over block indices (from the legalized side) + SmallVector numBlocksRef(physRefShape.begin() + 1, + physRefShape.begin() + 1 + R); + auto nest = createBlockLoopNest(builder, loc, numBlocksRef); + + // Tile sizes per logical dim (for R-D subview of non-legalized side) + SmallVector tileSizePerDim(R); + tileSizePerDim[0] = partTile; + for (int64_t i = 1; i < R - 1; i++) + tileSizePerDim[i] = 1; + if (R > 1) + tileSizePerDim[R - 1] = freeTile; + + // Helper: create a tile from the legalized (R+2)-D side + auto makeLegalizedTile = [&](Value buf, ArrayRef ivs) { + return createTileSubviewAndCollapse( + builder, loc, buf, partTile, freeTile, R, ivs); + }; + + // Helper: create a tile from the non-legalized R-D side (plain subview) + auto makeNonLegalizedTile = [&](Value buf, ArrayRef ivs) -> Value { + SmallVector offsets, sizes, strides; + for (int64_t i = 0; i < R; i++) { + if (tileSizePerDim[i] == 1) { + offsets.push_back(OpFoldResult(ivs[i])); + } else { + Value ts = builder.create(loc, tileSizePerDim[i]); + offsets.push_back(OpFoldResult( + builder.create(loc, ivs[i], ts))); + } + sizes.push_back(builder.getIndexAttr(tileSizePerDim[i])); + strides.push_back(builder.getIndexAttr(1)); + } + auto tile = builder.create(loc, buf, offsets, sizes, strides); + // Collapse R-D to 2D if needed + Value result = tile; + if (R > 2) { + result = builder.create( + loc, tile, build2DCollapseFromLogical(R)); + } + return result; + }; + + // Dest tile: straight block indices + SmallVector permutedIVs; + for (int64_t i = 0; i < R; i++) + permutedIVs.push_back(nest.ivs[permutation[i]]); + + Value dstTile = outputLegalized + ? makeLegalizedTile(outputMapped, nest.ivs) + : makeNonLegalizedTile(outputMapped, nest.ivs); + + // Source tile: apply permutation to output block indices + Value srcTile = inputLegalized + ? makeLegalizedTile(inputMapped, permutedIVs) + : makeNonLegalizedTile(inputMapped, permutedIVs); + + // Compute 2D permutation from R-D permutation. + SmallVector invPerm(R); + for (int64_t i = 0; i < R; i++) + invPerm[permutation[i]] = i; + SmallVector perm2D = (invPerm[0] < invPerm[R - 1]) + ? SmallVector{0, 1} // partition stays first + : SmallVector{1, 0}; // partition and free swap + + // If perm2D is identity [0,1], the transpose is a no-op (e.g. swapping + // dims where one is size 1). Emit memref.copy instead of a transpose + // so linalg-to-nisa can convert it to nisa.tensor_copy. + if (perm2D[0] == 0 && perm2D[1] == 1) + builder.create(loc, srcTile, dstTile); + else + builder.create(loc, srcTile, dstTile, perm2D); + + LLVM_DEBUG(llvm::dbgs() << " Tiled SBUF→SBUF linalg.transpose: " + << inputMappedType << " -> " << outputMappedType << "\n"); + + op.erase(); + return; + } + + // HBM↔SBUF case + Value bufSBUF = inputIsSbuf ? inputMapped : outputMapped; + Value bufHBM = inputIsSbuf ? output : input; + auto bufSBUFType = cast(bufSBUF.getType()); + auto bufHBMType = cast(bufHBM.getType()); + int64_t hbmRank = bufHBMType.getRank(); + + // If SBUF wasn't transformed, skip tiling + if (bufSBUFType.getRank() <= hbmRank) { + LLVM_DEBUG(llvm::dbgs() << " Skipping tiling for transpose (SBUF not transformed): " + << inputMappedType << " -> " << outputMappedType << "\n"); + return; + } + + // Extract tile info from SBUF physical shape + auto physShape = bufSBUFType.getShape(); + int64_t R = physShape.size() - 2; + int64_t partTile = physShape[0]; + int64_t freeTile = physShape[R + 1]; + + // Collect tile sizes per logical dim + SmallVector tileSizePerDim(R); + tileSizePerDim[0] = partTile; + for (int64_t i = 1; i < R - 1; i++) + tileSizePerDim[i] = 1; + if (R > 1) + tileSizePerDim[R - 1] = freeTile; + + builder.setInsertionPoint(op); + Location loc = op.getLoc(); + + // R-level loop nest over SBUF block indices + SmallVector numBlocksVec(physShape.begin() + 1, + physShape.begin() + 1 + R); + auto nest = createBlockLoopNest(builder, loc, numBlocksVec); + auto &blockIdxVars = nest.ivs; + + // Build SBUF (R+2)-D subview and collapse to 2D + Value tileSBUFCollapsed = createTileSubviewAndCollapse( + builder, loc, bufSBUF, partTile, freeTile, R, blockIdxVars); + + // Build HBM R-D subview with permuted offsets + // For transpose: HBM dim j corresponds to SBUF dim invPerm[j] + // So HBM offset at dim j = blockIdxVars[invPerm[j]] * tileSizePerDim[invPerm[j]] + // HBM size at dim j = tileSizePerDim[invPerm[j]] + SmallVector invPerm(R); + for (int64_t i = 0; i < R; i++) + invPerm[permutation[i]] = i; + + SmallVector offsetsHBM, sizesHBM, stridesHBM; + for (int64_t j = 0; j < R; j++) { + int64_t srcDim = invPerm[j]; + int64_t tileForDim = tileSizePerDim[srcDim]; + if (tileForDim == 1) { + offsetsHBM.push_back(OpFoldResult(blockIdxVars[srcDim])); + } else { + Value tileSizeVal = builder.create(loc, tileForDim); + Value offset = builder.create(loc, blockIdxVars[srcDim], tileSizeVal); + offsetsHBM.push_back(OpFoldResult(offset)); + } + sizesHBM.push_back(builder.getIndexAttr(tileForDim)); + stridesHBM.push_back(builder.getIndexAttr(1)); + } + + auto tileHBM = builder.create( + loc, bufHBM, offsetsHBM, sizesHBM, stridesHBM); + + // For R>2, collapse HBM R-D to 2D. + // HBM tile has permuted sizes: non-unit dims at positions permutation[0] + // (partition) and permutation[R-1] (free). Split into 2 contiguous groups + // at the first non-unit dim boundary. + Value tileHBM2D = tileHBM; + if (R > 2) { + int64_t partPosInHBM = permutation[0]; + int64_t freePosInHBM = permutation[R - 1]; + int64_t splitAt = std::min(partPosInHBM, freePosInHBM) + 1; + SmallVector reassocHBM; + ReassociationIndices g0, g1; + for (int64_t i = 0; i < splitAt; i++) g0.push_back(i); + for (int64_t i = splitAt; i < R; i++) g1.push_back(i); + reassocHBM = {g0, g1}; + tileHBM2D = builder.create(loc, tileHBM, reassocHBM); + } + + // Compute 2D permutation: check relative order of partition and free dims + SmallVector perm2D = (invPerm[0] < invPerm[R - 1]) + ? SmallVector{0, 1} + : SmallVector{1, 0}; + + // Create tiled 2D transpose with correct src/dst based on direction + if (outputIsSbuf) { + // HBM input → SBUF output + builder.create(loc, tileHBM2D, tileSBUFCollapsed, perm2D); + } else { + // SBUF input → HBM output + builder.create(loc, tileSBUFCollapsed, tileHBM2D, perm2D); + } + + LLVM_DEBUG(llvm::dbgs() << " Tiled linalg.transpose: " << inputMappedType + << " -> " << outputMappedType << "\n"); + + // Erase original transpose + op.erase(); + } + + /// Fix rank mismatches by collapsing all operands directly to 2D. + /// + /// After transforming subviews to (R+2)-D, and given that middle dims always + /// have tile=1, we collapse everything directly to 2D [partTile, freeTile]: + /// + /// 1. linalg ops: collapse (R+2)-D SBUF and R-D non-SBUF operands to 2D. + /// For R>2, also reconstruct the linalg op with 2D identity maps since + /// the original maps have R dims. + /// 2. memref.copy: collapse the higher-rank operand to 2D to match the other. + void fixRankMismatches(func::FuncOp func) { + OpBuilder builder(func.getContext()); + + // --- Phase 1: Fix linalg ops --- + // Only collect ops where an operand's rank doesn't match its indexing map, + // which signals the layout transformation changed the operand's rank. + // Ops where rank matches the map (e.g. 3D SBUF fills from + // decomposeHbmFills) are already consistent and must not be touched. + SmallVector linalgOpsToFix; + func.walk([&](linalg::LinalgOp op) { + auto maps = op.getIndexingMapsArray(); + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + auto memrefType = dyn_cast(op->getOperand(i).getType()); + if (!memrefType) + continue; + if (memrefType.getRank() != (int64_t)maps[i].getNumResults()) { + linalgOpsToFix.push_back(op); + return; + } + } + }); + + for (auto linalgOp : linalgOpsToFix) { + builder.setInsertionPoint(linalgOp); + Location loc = linalgOp.getLoc(); + auto indexingMaps = linalgOp.getIndexingMapsArray(); + + // --- Special case: TransposeOp with one expanded operand --- + // + // When canonicalize-partition-dim inserts boundary transposes and one + // operand comes from a legalized (expanded) alloc while the other is a + // tile-sized alloc (not legalized), the generic 2D collapse path would + // create an identity copy where the tile-sized alloc's partition dim + // doesn't match the tile (e.g., base dim0=1 but tile wants 128). + // + // Fix: replace the tile-sized alloc with a 2D alloc whose dim0 matches + // the partition count, then emit a 2D copy or transpose. + if (auto transposeOp = dyn_cast(linalgOp.getOperation())) { + Value input = transposeOp.getInput(); + Value output = transposeOp.getInit(); + auto inputType = cast(input.getType()); + auto outputType = cast(output.getType()); + int64_t inputRank = inputType.getRank(); + int64_t outputRank = outputType.getRank(); + int64_t logicalRank = (int64_t)transposeOp.getPermutation().size(); + + bool inputExpanded = (inputRank > logicalRank); + bool outputExpanded = (outputRank > logicalRank); + + // Only handle the asymmetric case: one side expanded, other not + if (inputExpanded != outputExpanded) { + Value expandedVal = inputExpanded ? input : output; + Value nonExpandedVal = inputExpanded ? output : input; + auto nonExpandedType = cast(nonExpandedVal.getType()); + + // Non-expanded operand must be a direct AllocOp for replacement + auto nonExpandedAlloc = nonExpandedVal.getDefiningOp(); + if (nonExpandedAlloc) { + // Step 1: Collapse expanded operand to 2D + auto expandedRank = cast(expandedVal.getType()).getRank(); + auto reassocExpanded = build2DCollapseFromPhysical(expandedRank); + auto collapsed2D = builder.create( + loc, expandedVal, reassocExpanded); + auto collapsed2DType = cast(collapsed2D.getType()); + + // Step 2: Compute 2D shape for the non-expanded alloc + // Use same logic as simplify-linalg's computeCollapse: merge dims + // up to (and including) the first non-unit dim into dim 0, rest + // into dim 1. + auto neShape = nonExpandedType.getShape(); + unsigned firstNonUnit = 0; + for (unsigned i = 0; i < neShape.size(); i++) + if (neShape[i] != 1) { firstNonUnit = i; break; } + + int64_t d0 = 1, d1 = 1; + for (unsigned i = 0; i <= firstNonUnit; i++) d0 *= neShape[i]; + for (unsigned i = firstNonUnit + 1; + i < (unsigned)nonExpandedType.getRank(); i++) + d1 *= neShape[i]; + + SmallVector allocReassoc; + { + ReassociationIndices g0, g1; + for (int64_t i = 0; i <= (int64_t)firstNonUnit; i++) + g0.push_back(i); + for (int64_t i = firstNonUnit + 1; + i < nonExpandedType.getRank(); i++) + g1.push_back(i); + allocReassoc.push_back(g0); + allocReassoc.push_back(g1); + } + + auto new2DType = MemRefType::get( + {d0, d1}, nonExpandedType.getElementType(), + /*layout=*/nullptr, nonExpandedType.getMemorySpace()); + + // Collect downstream copy users of the old alloc BEFORE erasing + SmallVector downstreamCopies; + for (auto *user : nonExpandedAlloc.getResult().getUsers()) { + if (user == transposeOp.getOperation()) + continue; // skip the transpose itself + if (auto copyOp = dyn_cast(user)) { + if (copyOp.getSource() == nonExpandedAlloc.getResult() || + copyOp.getTarget() == nonExpandedAlloc.getResult()) + downstreamCopies.push_back(copyOp); + } + } + + // Create 2D alloc at the old alloc's position + OpBuilder allocBuilder(nonExpandedAlloc); + auto new2DAlloc = allocBuilder.create( + nonExpandedAlloc.getLoc(), new2DType, + nonExpandedAlloc.getAlignmentAttr()); + + // Step 3: Create 2D copy or transpose replacing the original + auto cShape = collapsed2DType.getShape(); + bool shapesMatch = (cShape[0] == d0 && cShape[1] == d1); + + if (shapesMatch) { + // Unit-dim-only movement: just copy + if (inputExpanded) + builder.create( + loc, collapsed2D.getResult(), new2DAlloc.getResult()); + else + builder.create( + loc, new2DAlloc.getResult(), collapsed2D.getResult()); + } else { + // Real transpose between 2D operands + if (inputExpanded) + builder.create( + loc, collapsed2D.getResult(), new2DAlloc.getResult(), + ArrayRef{1, 0}); + else + builder.create( + loc, new2DAlloc.getResult(), collapsed2D.getResult(), + ArrayRef{1, 0}); + } + + // Step 4: Redirect downstream copies to use 2D paths + // Instead of copying through expand_shape (which loses indexing + // info in getBaseAndOffsets), collapse the destination to 2D + // and copy directly from the 2D alloc. + for (auto copyOp : downstreamCopies) { + bool isSource = + (copyOp.getSource() == nonExpandedAlloc.getResult()); + Value otherOperand = + isSource ? copyOp.getTarget() : copyOp.getSource(); + auto otherType = cast(otherOperand.getType()); + + OpBuilder copyBuilder(copyOp); + if (otherType.getRank() > 2) { + // Collapse the other operand from R-D to 2D + auto otherReassoc = + build2DCollapseFromLogical(otherType.getRank()); + auto collapsedOther = + copyBuilder.create( + loc, otherOperand, otherReassoc); + if (isSource) + copyBuilder.create( + loc, new2DAlloc.getResult(), + collapsedOther.getResult()); + else + copyBuilder.create( + loc, collapsedOther.getResult(), + new2DAlloc.getResult()); + } else { + if (isSource) + copyBuilder.create( + loc, new2DAlloc.getResult(), otherOperand); + else + copyBuilder.create( + loc, otherOperand, new2DAlloc.getResult()); + } + copyOp.erase(); + } + + LLVM_DEBUG(llvm::dbgs() + << " Handled TransposeOp with expanded operand: " + << inputType << " -> " << outputType + << " → 2D " << collapsed2DType << " / " << new2DType << "\n"); + + transposeOp->erase(); + nonExpandedAlloc.erase(); + continue; + } + } + } + + bool needsReconstruction = false; + + // Collapse operands whose rank doesn't match the indexing map to 2D + for (unsigned i = 0; i < linalgOp->getNumOperands(); ++i) { + auto memrefType = dyn_cast(linalgOp->getOperand(i).getType()); + if (!memrefType) + continue; + + int64_t operandRank = memrefType.getRank(); + int64_t expectedRank = indexingMaps[i].getNumResults(); + + // Skip operands that already match their indexing map + if (operandRank == expectedRank) + continue; + + // Choose collapse based on whether operand was transformed + SmallVector reassoc; + if (operandRank > expectedRank) { + // SBUF operand transformed to (R+2)-D → collapse to 2D + reassoc = build2DCollapseFromPhysical(operandRank); + } else { + // Non-transformed operand (e.g., PSUM R-D where R>2) → collapse to 2D + reassoc = build2DCollapseFromLogical(operandRank); + } + + auto collapsed = builder.create( + loc, linalgOp->getOperand(i), reassoc); + linalgOp->setOperand(i, collapsed.getResult()); + + LLVM_DEBUG(llvm::dbgs() << " Collapsed operand " << i + << " of " << linalgOp->getName() << ": " + << memrefType << " -> " << collapsed.getType() << "\n"); + + if (expectedRank > 2) + needsReconstruction = true; + } + + // For R>2: the linalg op's indexing maps still have R dims but operands + // are now 2D. Reconstruct with 2D identity maps and parallel iterators. + // This only applies to elementwise ops (identity maps, all parallel). + if (needsReconstruction) { + // Before gathering operands, collapse any remaining rank>2 operands + // to 2D. These are non-legalized operands (rank == expectedRank) that + // still need collapsing because the op is being reconstructed as 2D. + // E.g. memref<128x1x128> (tile-sized 3D SBUF, partition ≤ 128) → 2D. + for (unsigned i = 0; i < linalgOp->getNumOperands(); ++i) { + auto mt = dyn_cast(linalgOp->getOperand(i).getType()); + if (!mt || mt.getRank() <= 2) + continue; + auto reassoc = build2DCollapseFromLogical(mt.getRank()); + auto collapsed = builder.create( + loc, linalgOp->getOperand(i), reassoc); + linalgOp->setOperand(i, collapsed.getResult()); + LLVM_DEBUG(llvm::dbgs() << " Collapsed remaining operand " << i + << " of " << linalgOp->getName() << ": " + << mt << " -> " << collapsed.getType() << "\n"); + } + + // Gather collapsed operands (now all 2D) + SmallVector inputs(linalgOp.getDpsInputs()); + SmallVector outputs(linalgOp.getDpsInits()); + + // Try to create a fresh named op of the same type. + // Named ops (AddOp, SubOp, etc.) infer correct 2D indexing maps + // automatically from the operand types. + Operation *newOp = nullptr; + if (isa(linalgOp.getOperation())) { + newOp = builder.create(loc, inputs, outputs); + } else if (isa(linalgOp.getOperation())) { + newOp = builder.create(loc, inputs, outputs); + } else if (isa(linalgOp.getOperation())) { + newOp = builder.create(loc, inputs, outputs); + } else if (isa(linalgOp.getOperation())) { + newOp = builder.create(loc, inputs, outputs); + } else if (isa(linalgOp.getOperation())) { + newOp = builder.create(loc, inputs, outputs); + } else if (isa(linalgOp.getOperation())) { + newOp = builder.create(loc, inputs, outputs); + } else if (isa(linalgOp.getOperation())) { + newOp = builder.create(loc, inputs, outputs); + } else if (isa(linalgOp.getOperation())) { + newOp = builder.create(loc, inputs, outputs); + } else { + // Fall back to linalg.generic for unrecognized ops. + // Build 2D indexing maps per-operand: if an operand's dim is 1 + // where other operands have a larger dim, use a constant-0 + // (broadcast/reduction) map for that dimension. + auto ctx = linalgOp->getContext(); + + // Determine the "full" 2D shape (max across all operands per dim). + int64_t fullShape[2] = {1, 1}; + SmallVector allOperands; + allOperands.append(inputs.begin(), inputs.end()); + allOperands.append(outputs.begin(), outputs.end()); + for (Value v : allOperands) { + auto mt = dyn_cast(v.getType()); + if (!mt || mt.getRank() != 2) + continue; + for (int d = 0; d < 2; d++) + fullShape[d] = std::max(fullShape[d], mt.getShape()[d]); + } + + // Check if the original op has reduction iterators → dim 1 is + // a reduction dimension in the collapsed 2D space. + auto origIterTypes = linalgOp.getIteratorTypesArray(); + bool hasReduction = llvm::any_of(origIterTypes, + [](utils::IteratorType t) { + return t == utils::IteratorType::reduction; + }); + + SmallVector newIterTypes = { + utils::IteratorType::parallel, + hasReduction ? utils::IteratorType::reduction + : utils::IteratorType::parallel}; + + // Build per-operand map: identity if shape matches full, + // constant-0 for dims that are broadcast/reduced (size 1). + auto d0 = getAffineDimExpr(0, ctx); + auto d1 = getAffineDimExpr(1, ctx); + auto c0 = getAffineConstantExpr(0, ctx); + + SmallVector newMaps; + for (Value v : allOperands) { + auto mt = cast(v.getType()); + AffineExpr e0 = (mt.getShape()[0] < fullShape[0]) ? c0 : d0; + AffineExpr e1 = (mt.getShape()[1] < fullShape[1]) ? c0 : d1; + newMaps.push_back(AffineMap::get(2, 0, {e0, e1}, ctx)); + } + + auto genericOp = builder.create( + loc, /*resultTypes=*/TypeRange{}, inputs, outputs, + newMaps, newIterTypes); + genericOp.getRegion().takeBody(linalgOp->getRegion(0)); + newOp = genericOp; + } + + // Copy relevant attributes (e.g., nkipy.op_id) + for (auto attr : linalgOp->getAttrs()) { + if (attr.getName().strref().starts_with("nkipy.")) + newOp->setAttr(attr.getName(), attr.getValue()); + } + + LLVM_DEBUG(llvm::dbgs() << " Reconstructed " << linalgOp->getName() + << " as " << newOp->getName() << " with 2D operands\n"); + + linalgOp->erase(); + } + } + + // --- Phase 2: Fix memref.copy rank mismatches --- + // Only fix copies involving at least one SBUF operand (affected by layout + // transformation). HBM↔HBM copies at rank > 2 are perfectly valid and + // must not be collapsed. + SmallVector copiesToFix; + func.walk([&](memref::CopyOp copyOp) { + auto srcType = cast(copyOp.getSource().getType()); + auto dstType = cast(copyOp.getTarget().getType()); + int64_t srcRank = srcType.getRank(); + int64_t dstRank = dstType.getRank(); + + // Only collect copies with an actual rank mismatch (one side was + // transformed to (R+2)-D by the layout pass while the other stayed R-D). + // Same-rank copies are valid regardless of rank and must not be touched + // (e.g. 3D SBUF temps created by decomposeHbmFills). + if (srcRank != dstRank) + copiesToFix.push_back(copyOp); + }); + + for (auto copyOp : copiesToFix) { + Value src = copyOp.getSource(); + Value dst = copyOp.getTarget(); + auto srcType = cast(src.getType()); + auto dstType = cast(dst.getType()); + int64_t srcRank = srcType.getRank(); + int64_t dstRank = dstType.getRank(); + + builder.setInsertionPoint(copyOp); + + // Collapse each operand to 2D if rank > 2. + // - Higher-rank side was transformed to physical layout → use physical collapse + // - Equal/lower-rank side is untransformed → use logical collapse + // Safety: middle dims must be 1 (design constraint) for collapse to produce + // matching 2D shapes [partTile, freeTile] from both sides. + if (srcRank > 2) { + // Verify middle dims are 1 before collapsing + auto srcShape = srcType.getShape(); + bool middleDimsUnit = true; + for (int64_t i = 1; i < srcRank - 1; i++) { + if (srcShape[i] != 1) { middleDimsUnit = false; break; } + } + if (!middleDimsUnit) { + llvm::errs() << "[LegalizeLayout] Error: copy src has non-unit middle dims, " + << "cannot safely collapse to 2D: " << srcType << "\n"; + hasError = true; + return; + } + auto reassoc = (srcRank > dstRank) + ? build2DCollapseFromPhysical(srcRank) + : build2DCollapseFromLogical(srcRank); + auto collapsed = builder.create( + copyOp.getLoc(), src, reassoc); + copyOp->setOperand(0, collapsed.getResult()); + LLVM_DEBUG(llvm::dbgs() << " Collapsed copy src: " + << srcType << " -> " << collapsed.getType() << "\n"); + } + if (dstRank > 2) { + auto dstShape = dstType.getShape(); + bool middleDimsUnit = true; + for (int64_t i = 1; i < dstRank - 1; i++) { + if (dstShape[i] != 1) { middleDimsUnit = false; break; } + } + if (!middleDimsUnit) { + llvm::errs() << "[LegalizeLayout] Error: copy dst has non-unit middle dims, " + << "cannot safely collapse to 2D: " << dstType << "\n"; + hasError = true; + return; + } + auto reassoc = (dstRank > srcRank) + ? build2DCollapseFromPhysical(dstRank) + : build2DCollapseFromLogical(dstRank); + auto collapsed = builder.create( + copyOp.getLoc(), dst, reassoc); + copyOp->setOperand(1, collapsed.getResult()); + LLVM_DEBUG(llvm::dbgs() << " Collapsed copy dst: " + << dstType << " -> " << collapsed.getType() << "\n"); + } + } + + LLVM_DEBUG(llvm::dbgs() << " Fixed " << linalgOpsToFix.size() + << " linalg ops and " << copiesToFix.size() + << " copies for rank mismatches\n"); + } + + /// Find LayoutInfo for a value by checking all possible mappings + LayoutInfo* findLayoutInfo(Value val, IRMapping &valueMapping, + DenseMap &valueMap) { + // Direct lookup + if (valueMap.count(val)) + return valueMap[val]; + + // Check if val is mapped and lookup the mapped value + if (Value mapped = valueMapping.lookupOrNull(val)) { + if (valueMap.count(mapped)) + return valueMap[mapped]; + } + + // Reverse lookup: find original value that maps to val + for (auto &[origVal, layoutPtr] : valueMap) { + if (valueMapping.lookupOrNull(origVal) == val) { + return layoutPtr; + } + } + + return nullptr; + } + +private: + /// Find all SBUF memrefs that need layout legalization + /// + /// Algorithm: + /// 1. Walk entire function, collect all SBUF memref.alloc ops (rank >= 2) + /// 2. For each alloc, trace uses through subview chains to ALL linalg ops + /// 3. Collect operand shapes as candidate tile sizes + /// 4. Filter out invalid tile sizes (dim0 > 128) + /// 5. Verify all valid tile sizes are the same + /// 6. Create LayoutInfo with the validated tile size + SmallVector findSbufTensorsToLegalize(func::FuncOp func) { + SmallVector results; + + // Step 1: Collect all SBUF memref.alloc ops (rank >= 2) in the entire function + SmallVector sbufAllocs; + func.walk([&](memref::AllocOp allocOp) { + auto memrefType = allocOp.getType(); + if (!isSbuf(memrefType.getMemorySpace())) + return; // Not SBUF + if (memrefType.getRank() < 2) + return; // Scalars or 1D not supported + + sbufAllocs.push_back(allocOp); + LLVM_DEBUG(llvm::dbgs() << " Found SBUF alloc: " << memrefType << "\n"); + }); + + // Step 2-6: For each SBUF alloc, determine tile sizes and create LayoutInfo + for (auto allocOp : sbufAllocs) { + auto memrefType = allocOp.getType(); + auto origShape = memrefType.getShape(); + int64_t R = memrefType.getRank(); + + llvm::errs() << "[LegalizeLayout] Processing SBUF alloc: " << memrefType + << " at " << allocOp.getLoc() << "\n"; + + // Step 2-3: Trace uses to linalg ops, collect operand shapes as tile sizes + SmallVector>> linalgUses; + traceToLinalgOperands(allocOp.getResult(), linalgUses); + + // Extract just the tile sizes from the results + SmallVector> tileSizes; + for (auto &[op, idx, tileShape] : linalgUses) { + tileSizes.push_back(tileShape); + } + + if (tileSizes.empty()) { + llvm::errs() << " -> Skipping (no linalg uses)\n"; + continue; + } + + // Step 4: Filter out tile sizes that match the alloc's full shape + // (these are full-buffer writes like linalg.transpose, not tile accesses) + // and tile sizes with dim0 > 128 (invalid partition dim). + SmallVector> validTileSizes; + SmallVector origShapeVec(origShape.begin(), origShape.end()); + for (auto &tile : tileSizes) { + if (SmallVector(tile) == origShapeVec) { + LLVM_DEBUG({ + llvm::dbgs() << " Filtered out full-buffer tile ["; + llvm::interleave(tile, llvm::dbgs(), "x"); + llvm::dbgs() << "] (matches alloc shape)\n"; + }); + } else if (tile[0] > MAX_PARTITION_DIM) { + LLVM_DEBUG({ + llvm::dbgs() << " Filtered out invalid tile ["; + llvm::interleave(tile, llvm::dbgs(), "x"); + llvm::dbgs() << "] (dim0 > 128)\n"; + }); + } else { + validTileSizes.push_back(tile); + } + } + + if (validTileSizes.empty()) { + llvm::errs() << " -> Skipping (already at tile size)\n"; + continue; + } + + // Log all discovered tiles for debugging + llvm::errs() << " All tiles found (" << tileSizes.size() << " total, " + << validTileSizes.size() << " valid):\n"; + for (auto &tile : tileSizes) { + llvm::errs() << " ["; + llvm::interleave(tile, llvm::errs(), ","); + bool isFiltered = (SmallVector(tile) == origShapeVec) || + (!tile.empty() && tile[0] > MAX_PARTITION_DIM); + llvm::errs() << "]" << (isFiltered ? " (filtered)" : "") << "\n"; + } + + // Step 5: Verify all valid tile sizes are the same + auto &refTile = validTileSizes[0]; + for (size_t i = 1; i < validTileSizes.size(); ++i) { + if (validTileSizes[i] != refTile) { + llvm::errs() << "[LegalizeLayout] Error: SBUF alloc " << memrefType + << " has inconsistent tile sizes: ["; + llvm::interleave(refTile, llvm::errs(), ","); + llvm::errs() << "] vs ["; + llvm::interleave(validTileSizes[i], llvm::errs(), ","); + llvm::errs() << "]\n"; + hasError = true; + return results; + } + } + + // Validate rank match and divisibility for all dims + if ((int64_t)refTile.size() != R) { + llvm::errs() << "[LegalizeLayout] Error: SBUF alloc " << memrefType + << " tile rank " << refTile.size() + << " does not match alloc rank " << R << ". Tile: ["; + llvm::interleave(refTile, llvm::errs(), ","); + llvm::errs() << "]\n"; + hasError = true; + return results; + } + for (int64_t i = 0; i < R; i++) { + if (origShape[i] % refTile[i] != 0) { + llvm::errs() << "[LegalizeLayout] Error: SBUF alloc " << memrefType + << " dim " << i << " size " << origShape[i] + << " not divisible by tile size " << refTile[i] + << ". Full tile: ["; + llvm::interleave(refTile, llvm::errs(), ","); + llvm::errs() << "]\n"; + hasError = true; + return results; + } + } + + // Step 6: Compute numBlocks; skip if all 1 (already tile-sized) + SmallVector numBlocks; + for (int64_t i = 0; i < R; i++) + numBlocks.push_back(origShape[i] / refTile[i]); + + if (llvm::all_of(numBlocks, [](int64_t n) { return n == 1; })) { + llvm::errs() << " -> Skipping (numBlocks all 1, already tile-sized)\n"; + continue; + } + + // Validate: middle tile sizes (dims 1..R-2) must all be 1. + // The physical shape formula [tileSize[0], numBlocks..., tileSize[R-1]] + // only stores the first and last tile sizes. Non-unit middle tiles + // (e.g. tile=[1,128,64] for a 4x128x64 alloc from multi-head attention) + // cannot be represented by the physical layout. These allocs have their + // partition dim NOT at dim 0 (it's a batch dim) and don't need + // legalization — they're accessed tile-by-tile via subviews in loops. + if (R > 2) { + bool middleTilesUnit = true; + for (int64_t i = 1; i < R - 1; i++) { + if (refTile[i] != 1) { + middleTilesUnit = false; + break; + } + } + if (!middleTilesUnit) { + if (numBlocks[0] > 1) { + // Multi-partition SBUF alloc with non-unit middle tiles: + // the vector engine cannot selectively address individual + // partitions, and the physical layout format can't represent + // non-unit middle tiles. Convert to SharedHbm so DMA can + // handle the per-partition access. + llvm::errs() << " -> Converting " << memrefType + << " to SharedHbm (non-unit middle tile, " + << "numBlocks[0]=" << numBlocks[0] << ")\n"; + auto hbmMemSpace = nkipy::MemSpaceEnumAttr::get( + allocOp.getContext(), nkipy::MemSpaceEnum::SharedHbm); + auto newType = MemRefType::get( + memrefType.getShape(), memrefType.getElementType(), + memrefType.getLayout(), hbmMemSpace); + OpBuilder builder(allocOp); + builder.setInsertionPointAfter(allocOp); + auto newAlloc = builder.create( + allocOp.getLoc(), newType, allocOp.getAlignmentAttr()); + + // Recreate subview users with updated memspace + SmallVector subviews; + for (auto *user : allocOp.getResult().getUsers()) + if (auto sv = dyn_cast(user)) + subviews.push_back(sv); + for (auto sv : subviews) { + OpBuilder svBuilder(sv); + svBuilder.setInsertionPointAfter(sv); + auto newSv = svBuilder.create( + sv.getLoc(), newAlloc.getResult(), + sv.getMixedOffsets(), sv.getMixedSizes(), + sv.getMixedStrides()); + sv.replaceAllUsesWith(newSv.getResult()); + sv.erase(); + } + allocOp.replaceAllUsesWith(newAlloc.getResult()); + allocOp.erase(); + } else { + llvm::errs() << " -> Skipping " << memrefType + << " (non-unit middle tile ["; + llvm::interleave(refTile, llvm::errs(), ","); + llvm::errs() << "], not supported by physical layout)\n"; + } + continue; + } + } + + LayoutInfo info; + info.originalValue = allocOp.getResult(); + info.origShape = SmallVector(origShape.begin(), origShape.end()); + info.tileSize = refTile; + info.numBlocks = numBlocks; + results.push_back(info); + } + + return results; + } +}; + +} // namespace + +std::unique_ptr> createLegalizeLayoutPass() { + return std::make_unique(); +} + +} // namespace nkipy +} // namespace mlir diff --git a/kernelgen/mlir/lib/Transforms/OpClassification.cpp b/kernelgen/mlir/lib/Transforms/OpClassification.cpp new file mode 100644 index 0000000..23db2af --- /dev/null +++ b/kernelgen/mlir/lib/Transforms/OpClassification.cpp @@ -0,0 +1,69 @@ +//===- OpClassification.cpp - Shared op classification helpers -------------===// + +#include "nkipy/Transforms/OpClassification.h" + +using namespace mlir; + +namespace mlir { +namespace nkipy { + +bool isNamedUnaryElementwiseOp(StringRef opName) { + return opName == "linalg.exp" || opName == "linalg.log" || + opName == "linalg.tanh" || opName == "linalg.negf" || + opName == "linalg.abs" || opName == "linalg.ceil" || + opName == "linalg.floor" || opName == "linalg.sqrt" || + opName == "linalg.reciprocal" || opName == "linalg.square" || + opName == "linalg.copy"; +} + +bool isNamedBinaryElementwiseOp(StringRef opName) { + return opName == "linalg.add" || opName == "linalg.sub" || + opName == "linalg.mul" || opName == "linalg.div" || + opName == "linalg.max" || opName == "linalg.min"; +} + +bool isNamedElementwiseOp(StringRef opName) { + return isNamedUnaryElementwiseOp(opName) || + isNamedBinaryElementwiseOp(opName); +} + +bool isElementwiseGeneric(linalg::LinalgOp linalgOp) { + auto genericOp = dyn_cast(linalgOp.getOperation()); + if (!genericOp) + return false; + return llvm::all_of(genericOp.getIteratorTypesArray(), + [](utils::IteratorType t) { + return t == utils::IteratorType::parallel; + }); +} + +bool isElementwiseOp(linalg::LinalgOp linalgOp) { + return isNamedElementwiseOp(linalgOp->getName().getStringRef()) || + isElementwiseGeneric(linalgOp); +} + +bool isReductionGeneric(linalg::LinalgOp linalgOp) { + auto genericOp = dyn_cast(linalgOp.getOperation()); + if (!genericOp) + return false; + return llvm::any_of(genericOp.getIteratorTypesArray(), + [](utils::IteratorType t) { + return t == utils::IteratorType::reduction; + }); +} + +bool isMatmulOp(StringRef opName) { + return opName == "linalg.matmul" || opName == "linalg.batch_matmul"; +} + +bool isMatmulOp(linalg::LinalgOp linalgOp) { + return isMatmulOp(linalgOp->getName().getStringRef()); +} + +bool isAnnotatableOp(linalg::LinalgOp linalgOp) { + return isElementwiseOp(linalgOp) || isReductionGeneric(linalgOp) || + isMatmulOp(linalgOp); +} + +} // namespace nkipy +} // namespace mlir diff --git a/kernelgen/mlir/lib/Transforms/PassGen.h b/kernelgen/mlir/lib/Transforms/PassGen.h new file mode 100644 index 0000000..6229724 --- /dev/null +++ b/kernelgen/mlir/lib/Transforms/PassGen.h @@ -0,0 +1,19 @@ + + +#ifndef NKIPY_MLIR_PASSDETAIL_H +#define NKIPY_MLIR_PASSDETAIL_H + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" + +namespace mlir { +namespace nkipy { + +#define GEN_PASS_CLASSES +#include "nkipy/Transforms/Passes.h.inc" + +} // namespace nkipy +} // end namespace mlir + +#endif // Allo_MLIR_PASSDETAIL_H diff --git a/kernelgen/mlir/lib/Transforms/Passes.cpp b/kernelgen/mlir/lib/Transforms/Passes.cpp new file mode 100644 index 0000000..1fa46cd --- /dev/null +++ b/kernelgen/mlir/lib/Transforms/Passes.cpp @@ -0,0 +1,15 @@ +#include "nkipy/Transforms/Passes.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +namespace { +#define GEN_PASS_REGISTRATION +#include "nkipy/Transforms/Passes.h.inc" +} // end namespace + +void mlir::nkipy::registerNkipyPasses() { ::registerPasses(); } diff --git a/kernelgen/mlir/lib/Transforms/PrepareArithmetic.cpp b/kernelgen/mlir/lib/Transforms/PrepareArithmetic.cpp new file mode 100644 index 0000000..f998c4b --- /dev/null +++ b/kernelgen/mlir/lib/Transforms/PrepareArithmetic.cpp @@ -0,0 +1,301 @@ +//===- PrepareArithmetic.cpp - Prepare arithmetic ops for NISA lowering ===// +// +// This pass prepares arithmetic operations for NISA lowering by transforming +// operations that don't have direct NISA equivalents. +// +// Transformations: +// - linalg.div(A, B) -> linalg.mul(A, linalg.reciprocal(B)) +// NISA's tensor_tensor_arith doesn't support DIVIDE, so we convert division +// to multiplication by reciprocal. +// - linalg.generic { arith.divf(%a, %b) } with broadcast indexing maps +// -> linalg.reciprocal(B) + linalg.generic { arith.mulf(%a, %recip) } +// Handles broadcast tensor-tensor division (e.g. tensor / tensor). +// +// This pass runs before tiling so that the generated reciprocal operations +// get tiled and bufferized normally. +// +//===----------------------------------------------------------------------===// + +#include "nkipy/Transforms/Passes.h" +#include "nkipy/Dialect/NkipyOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +namespace { + +//===----------------------------------------------------------------------===// +// Helper: Clone nkipy.annotate ops from one value to another +//===----------------------------------------------------------------------===// + +static void cloneAnnotations(Value oldValue, Value newValue, + PatternRewriter &rewriter) { + for (Operation *user : oldValue.getUsers()) { + if (auto annotateOp = dyn_cast(user)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfterValue(newValue); + rewriter.create( + annotateOp.getLoc(), newValue, + annotateOp.getMemSpaceAttr(), annotateOp.getPartitionDimAttr(), + annotateOp.getTileSizeAttr(), annotateOp.getReductionTileAttr()); + } + } +} + +//===----------------------------------------------------------------------===// +// Pattern: Any linalg op with division -> reciprocal + multiply +//===----------------------------------------------------------------------===// + +/// Convert any linalg op containing division to use reciprocal+multiply. +/// Matches both the named linalg.div and linalg.generic with arith.divf. +struct ConvertDivToReciprocal + : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, + PatternRewriter &rewriter) const override { + if (auto divOp = dyn_cast(linalgOp.getOperation())) + return handleNamedDiv(divOp, rewriter); + + auto genericOp = dyn_cast(linalgOp.getOperation()); + if (!genericOp || genericOp.getNumDpsInits() != 1 || + !isAllParallel(genericOp)) + return failure(); + + arith::DivFOp divFOp = findUniqueDivF(genericOp); + if (!divFOp) + return failure(); + + unsigned numInputs = genericOp.getNumDpsInputs(); + if (numInputs == 1) + return handleScalarDiv(genericOp, divFOp, rewriter); + if (numInputs == 2) + return handleBroadcastDiv(genericOp, divFOp, rewriter); + return failure(); + } + +private: + // --- Helpers --- + + static arith::DivFOp findUniqueDivF(linalg::GenericOp op) { + arith::DivFOp found = nullptr; + for (Operation &bodyOp : op.getRegion().front().without_terminator()) { + if (auto d = dyn_cast(&bodyOp)) { + if (found) + return nullptr; + found = d; + } + } + return found; + } + + static linalg::ReciprocalOp createReciprocal(PatternRewriter &rewriter, + Location loc, Value input) { + auto inputType = cast(input.getType()); + auto recipOut = rewriter.create( + loc, inputType.getShape(), inputType.getElementType()); + return rewriter.create( + loc, TypeRange{inputType}, ValueRange{input}, + ValueRange{recipOut.getResult()}); + } + + static bool isAllParallel(linalg::GenericOp op) { + return llvm::all_of(op.getIteratorTypesArray(), + [](utils::IteratorType t) { + return t == utils::IteratorType::parallel; + }); + } + + /// Create reciprocal(divisor) and mul(numerator, recip), replacing origOp. + void replaceWithReciprocalMul(Operation *origOp, Value numerator, + Value divisor, Value output, + PatternRewriter &rewriter) const { + Location loc = origOp->getLoc(); + auto outType = cast(output.getType()); + + auto recipOp = createReciprocal(rewriter, loc, divisor); + cloneAnnotations(origOp->getResult(0), recipOp.getResult(0), rewriter); + + auto mulOp = rewriter.create( + loc, TypeRange{outType}, + ValueRange{numerator, recipOp.getResult(0)}, ValueRange{output}); + rewriter.replaceOp(origOp, mulOp.getResults()); + } + + // --- Case handlers --- + + /// linalg.div(A, B) → mul(A, reciprocal(B)) + LogicalResult handleNamedDiv(linalg::DivOp op, + PatternRewriter &rewriter) const { + llvm::errs() << "[PrepareArithmetic] Converting linalg.div to " + "mul+reciprocal\n"; + replaceWithReciprocalMul(op, op.getInputs()[0], op.getInputs()[1], + op.getOutputs()[0], rewriter); + return success(); + } + + /// 1-input generic with divf involving a constant. + LogicalResult handleScalarDiv(linalg::GenericOp op, arith::DivFOp divOp, + PatternRewriter &rewriter) const { + Value divLhs = divOp.getLhs(); + Value divRhs = divOp.getRhs(); + auto lhsConst = divLhs.getDefiningOp(); + auto rhsConst = divRhs.getDefiningOp(); + + if ((!lhsConst && !rhsConst) || (lhsConst && rhsConst)) + return failure(); + + // divf(tensor, scalar) → mulf(tensor, 1/scalar) in body + if (rhsConst) { + auto floatAttr = dyn_cast(rhsConst.getValue()); + if (!floatAttr || floatAttr.getValueAsDouble() == 0.0) + return failure(); + double recipVal = 1.0 / floatAttr.getValueAsDouble(); + llvm::errs() << "[PrepareArithmetic] Converting divf(tensor, " + << floatAttr.getValueAsDouble() << ") to mulf(tensor, " + << recipVal << ")\n"; + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(divOp); + auto recipConst = rewriter.create( + divOp.getLoc(), + rewriter.getFloatAttr(floatAttr.getType(), recipVal)); + rewriter.replaceOpWithNewOp(divOp, divLhs, + recipConst.getResult()); + return success(); + } + + // divf(scalar, tensor) → reciprocal(tensor) [* scalar] + auto floatAttr = dyn_cast(lhsConst.getValue()); + if (!floatAttr) + return failure(); + double scalarVal = floatAttr.getValueAsDouble(); + llvm::errs() << "[PrepareArithmetic] Converting divf(" << scalarVal + << ", tensor) to reciprocal\n"; + + Location loc = op.getLoc(); + Value input = op.getDpsInputs()[0]; + Value output = op.getDpsInits()[0]; + auto outputType = cast(output.getType()); + + if (scalarVal == 1.0) { + auto recipOp = createReciprocal(rewriter, loc, input); + cloneAnnotations(op.getResult(0), recipOp.getResult(0), rewriter); + rewriter.replaceOp(op, recipOp.getResults()); + } else { + // Create fill(scalar) as the numerator, then use shared helper + auto fillEmpty = rewriter.create( + loc, outputType.getShape(), outputType.getElementType()); + auto scalarCst = rewriter.create( + loc, rewriter.getFloatAttr( + cast(outputType.getElementType()), scalarVal)); + auto fillOp = rewriter.create( + loc, TypeRange{outputType}, ValueRange{scalarCst.getResult()}, + ValueRange{fillEmpty.getResult()}); + replaceWithReciprocalMul(op, fillOp.getResult(0), input, output, + rewriter); + } + return success(); + } + + /// 2-input generic with divf between two block arguments. + LogicalResult handleBroadcastDiv(linalg::GenericOp op, arith::DivFOp divOp, + PatternRewriter &rewriter) const { + auto rhsArg = dyn_cast(divOp.getRhs()); + if (!rhsArg || !isa(divOp.getLhs())) + return failure(); + + unsigned rhsArgNum = rhsArg.getArgNumber(); + if (rhsArgNum >= op.getNumDpsInputs()) + return failure(); + + llvm::errs() << "[PrepareArithmetic] Converting broadcast divf to " + "reciprocal+mulf\n"; + + Location loc = op.getLoc(); + Value rhsInput = op.getDpsInputs()[rhsArgNum]; + auto recipOp = createReciprocal(rewriter, loc, rhsInput); + + // Clone generic with reciprocal replacing the rhs input, divf→mulf in body + SmallVector newInputs(op.getDpsInputs()); + newInputs[rhsArgNum] = recipOp.getResult(0); + + auto newGeneric = rewriter.create( + loc, op.getResultTypes(), newInputs, op.getDpsInits(), + op.getIndexingMapsArray(), op.getIteratorTypesArray()); + + rewriter.cloneRegionBefore(op.getRegion(), newGeneric.getRegion(), + newGeneric.getRegion().end()); + + for (Operation &bodyOp : newGeneric.getRegion().front().without_terminator()) { + if (auto d = dyn_cast(&bodyOp)) { + rewriter.setInsertionPoint(d); + rewriter.replaceOpWithNewOp(d, d.getLhs(), d.getRhs()); + break; + } + } + + rewriter.replaceOp(op, newGeneric.getResults()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +struct PrepareArithmeticPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrepareArithmeticPass) + + StringRef getArgument() const final { return "prepare-arithmetic"; } + + StringRef getDescription() const final { + return "Prepare arithmetic operations for NISA lowering"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + MLIRContext *ctx = &getContext(); + + RewritePatternSet patterns(ctx); + + // Add arithmetic preparation patterns + patterns.add(ctx); + + if (failed(applyPatternsGreedily(module, std::move(patterns)))) { + llvm::errs() << "[PrepareArithmetic] Pattern application failed\n"; + signalPassFailure(); + return; + } + + llvm::errs() << "[PrepareArithmetic] Pass completed successfully\n"; + } +}; + +} // namespace + +namespace mlir { +namespace nkipy { + +std::unique_ptr> createPrepareArithmeticPass() { + return std::make_unique(); +} + +} // namespace nkipy +} // namespace mlir diff --git a/kernelgen/mlir/lib/Transforms/RemoveRedundantZeroFill.cpp b/kernelgen/mlir/lib/Transforms/RemoveRedundantZeroFill.cpp new file mode 100644 index 0000000..d1e4800 --- /dev/null +++ b/kernelgen/mlir/lib/Transforms/RemoveRedundantZeroFill.cpp @@ -0,0 +1,138 @@ +//===- RemoveRedundantZeroFill.cpp - Remove zero fills before matmul ------===// +// +// This pass removes linalg.fill operations that fill with zero and are only +// consumed by matmul-like operations. NISA matmul hardware initializes PSUM +// accumulators to zero automatically (psum_zero_region), so the zero fill is +// redundant. +// +// This runs on tensor IR (before tiling/bufferization) so the fill is removed +// early, before it becomes a memref.copy chain that is harder to optimize. +// +// Pattern: +// %cst = arith.constant 0.0 : f32 +// %empty = tensor.empty() : tensor<...> +// %filled = linalg.fill ins(%cst) outs(%empty) -> tensor<...> +// %result = linalg.matmul ins(%a, %b) outs(%filled) -> tensor<...> +// +// After: +// %empty = tensor.empty() : tensor<...> +// %result = linalg.matmul ins(%a, %b) outs(%empty) -> tensor<...> +// +//===----------------------------------------------------------------------===// + +#include "nkipy/Transforms/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +namespace { + +static bool isMatmulLikeOp(Operation *op) { + return isa(op); +} + +/// Check if a value is defined by arith.constant with a zero value. +static bool isZeroConstant(Value value) { + auto constOp = value.getDefiningOp(); + if (!constOp) + return false; + + if (auto intAttr = dyn_cast(constOp.getValue())) + return intAttr.getValue().isZero(); + if (auto fpAttr = dyn_cast(constOp.getValue())) + return fpAttr.getValue().isZero(); + return false; +} + +/// Remove linalg.fill(zero) when all users of the fill result are matmul-like. +struct RemoveZeroFillBeforeMatmul : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::FillOp fillOp, + PatternRewriter &rewriter) const override { + // Check fill value is zero + if (!isZeroConstant(fillOp.getInputs()[0])) + return failure(); + + // The fill must have exactly one result (the filled tensor) + if (fillOp.getNumResults() != 1) + return failure(); + + Value fillResult = fillOp.getResult(0); + + // All users must be matmul-like ops + for (Operation *user : fillResult.getUsers()) { + if (!isMatmulLikeOp(user)) + return failure(); + } + + // Replace fill result with the unfilled output tensor + Value outputTensor = fillOp.getOutputs()[0]; + llvm::errs() << "[RemoveRedundantZeroFill] Removing zero fill before " + "matmul: " + << *fillOp << "\n"; + rewriter.replaceOp(fillOp, outputTensor); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +struct RemoveRedundantZeroFillPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RemoveRedundantZeroFillPass) + + StringRef getArgument() const final { return "remove-redundant-zero-fill"; } + + StringRef getDescription() const final { + return "Remove linalg.fill ops with zero values when only used by " + "matmul-like ops (NISA matmul auto-zeros PSUM)"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + MLIRContext *ctx = &getContext(); + + RewritePatternSet patterns(ctx); + patterns.add(ctx); + + if (failed(applyPatternsGreedily(module, std::move(patterns)))) { + llvm::errs() + << "[RemoveRedundantZeroFill] Pattern application failed\n"; + signalPassFailure(); + return; + } + + llvm::errs() << "[RemoveRedundantZeroFill] Pass completed successfully\n"; + } +}; + +} // namespace + +namespace mlir { +namespace nkipy { + +std::unique_ptr> createRemoveRedundantZeroFillPass() { + return std::make_unique(); +} + +} // namespace nkipy +} // namespace mlir diff --git a/kernelgen/mlir/lib/Transforms/SimplifyLinalg.cpp b/kernelgen/mlir/lib/Transforms/SimplifyLinalg.cpp new file mode 100644 index 0000000..91580c4 --- /dev/null +++ b/kernelgen/mlir/lib/Transforms/SimplifyLinalg.cpp @@ -0,0 +1,778 @@ +//===- SimplifyLinalg.cpp - Simplify linalg ops for NISA lowering ---------===// +// +// Pre-processing pass that simplifies linalg operations before linalg-to-nisa. +// +// Transformations: +// 1. Rewrites >2D SBUF linalg.transpose with unit dims to 2D. +// NISA dma_transpose only supports [1,0] (2D) or [2,1,0] (3D full reverse). +// Collapses SBUF allocs to 2D + expand_shape views, rewrites transpose or +// emits copy (when non-unit dims keep order, i.e. just a reshape). +// +// 2. Converts trivial-broadcast linalg.generic ops to named linalg ops. +// After tiling, broadcasts become same-shape operations. This converts +// them to named ops (linalg.mul, etc.) so LinalgToNisa patterns can match. +// +//===----------------------------------------------------------------------===// + +#include "nkipy/Transforms/Passes.h" +#include "nkipy/Transforms/IRHelpers.h" +#include "nkipy/Dialect/NkipyOps.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { + +//===----------------------------------------------------------------------===// +// Helper functions +//===----------------------------------------------------------------------===// + +using nkipy::isHbm; +using nkipy::isSbuf; + +/// Elementwise arith kinds recognized in linalg.generic bodies. Kept local +/// to this file so we do not depend on nki::nisa::ArithOpKind. +enum class LocalArithKind { + ADD, SUBTRACT, MULTIPLY, DIVIDE, MOD, MODINT, + ISEQ, ISGT, ISGE, ISLE, ISLT, ISNE, +}; + +/// Map arith dialect binary op to a local arith kind used to classify the +/// body of linalg.generic ops. +static std::optional +getArithOpKindFromBodyOp(Operation *op) { + auto kind = llvm::TypeSwitch>(op) + .Case( + [](auto) { return LocalArithKind::ADD; }) + .Case( + [](auto) { return LocalArithKind::SUBTRACT; }) + .Case( + [](auto) { return LocalArithKind::MULTIPLY; }) + .Case( + [](auto) { return LocalArithKind::DIVIDE; }) + .Case([](auto) { return LocalArithKind::MOD; }) + .Case([](auto) { return LocalArithKind::MODINT; }) + .Default([](Operation *) { return std::nullopt; }); + if (kind) + return kind; + + // Comparison pattern: arith.uitofp(arith.cmpf(...)) + if (auto uitofp = dyn_cast(op)) { + if (auto cmpf = uitofp.getIn().getDefiningOp()) { + switch (cmpf.getPredicate()) { + case arith::CmpFPredicate::OEQ: return LocalArithKind::ISEQ; + case arith::CmpFPredicate::OGT: return LocalArithKind::ISGT; + case arith::CmpFPredicate::OGE: return LocalArithKind::ISGE; + case arith::CmpFPredicate::OLE: return LocalArithKind::ISLE; + case arith::CmpFPredicate::OLT: return LocalArithKind::ISLT; + case arith::CmpFPredicate::ONE: return LocalArithKind::ISNE; + default: return std::nullopt; + } + } + } + return std::nullopt; +} + +//===----------------------------------------------------------------------===// +// Preprocessing: Decompose high-rank transposes to loops of 2D transposes +//===----------------------------------------------------------------------===// + +/// Decompose N-D linalg.transpose where exactly 2 dimensions are swapped +/// (and the rest are identity) into a loop nest over the identity dims +/// with a 2D transpose (or copy) on the swapped pair. +/// +/// Example: [0, 2, 1, 3] on memref<2x128x2x128> (multi-head attention reshape) +/// → scf.for batch = 0..2: +/// scf.for head = 0..2: +/// linalg.transpose [1, 0] on memref<128x128> tiles +/// +/// This handles transposes across any memory space (HBM, SharedHBM, SBUF). +static void decomposeHighRankTranspose(func::FuncOp func) { + SmallVector worklist; + func.walk([&](linalg::TransposeOp op) { worklist.push_back(op); }); + + for (auto transposeOp : worklist) { + auto srcType = dyn_cast(transposeOp.getInput().getType()); + auto dstType = dyn_cast(transposeOp.getInit().getType()); + if (!srcType || !dstType) + continue; + + int64_t rank = srcType.getRank(); + // Only handle rank > 3. Rank-2 is already native NISA. + // Rank-3 with unit dims is handled by rewriteSbufTransposeTo2D. + if (rank <= 3) + continue; + + // Find which dimensions are swapped vs identity. + auto perm = transposeOp.getPermutation(); + SmallVector swappedDims; // dims where perm[i] != i + SmallVector identityDims; // dims where perm[i] == i + for (int64_t i = 0; i < rank; i++) { + if (perm[i] != i) + swappedDims.push_back(i); + else + identityDims.push_back(i); + } + + // Only handle the case where exactly 2 dims are swapped. + if (swappedDims.size() != 2) + continue; + + int64_t d0 = swappedDims[0]; // first swapped dim + int64_t d1 = swappedDims[1]; // second swapped dim + // Verify they are actually swapped with each other + if (perm[d0] != d1 || perm[d1] != d0) + continue; + + auto srcShape = srcType.getShape(); + auto dstShape = dstType.getShape(); + + OpBuilder b(transposeOp); + Location loc = transposeOp.getLoc(); + + // Helper: collapse a subview slice to 2D by grouping unit dims with + // their neighbors. dimA is the position of the first non-unit dim; + // everything up to and including dimA goes in group 0. + auto collapse2D = [&](Value slice, int64_t dimA, int64_t dimB, + MemRefType origType) -> Value { + auto sliceType = cast(slice.getType()); + if (sliceType.getRank() == 2) + return slice; + + SmallVector reassoc; + ReassociationIndices group0, group1; + bool seenFirst = false; + for (int64_t i = 0; i < (int64_t)sliceType.getRank(); i++) { + if (!seenFirst || i <= dimA) { + group0.push_back(i); + if (sliceType.getShape()[i] > 1) seenFirst = true; + } else { + group1.push_back(i); + } + } + if (group1.empty()) + return slice; + reassoc = {group0, group1}; + return b.create(loc, slice, reassoc); + }; + + // ---------------------------------------------------------------- + // Optimization: when one swapped dim is small (e.g. n_heads=2), + // iterate over it in the loop nest and emit a 2D *copy* instead of + // a 2D transpose. This avoids collapse_shape tiles whose non-unit + // strides are lost by getBaseAndOffsets in linalg-to-nisa, causing + // wrong column interleaving in NISA DMA access patterns. + // + // Example: perm [0,2,1,3] on (2,128,2,128) + // Before: loop batch(2) × hd(128) = 256 iters, inner (128,2) transpose + // After: loop batch(2) × head(2) = 4 iters, inner (128,128) copy + // ---------------------------------------------------------------- + constexpr int64_t kSmallSwapDimThreshold = 16; + int64_t smallSwapDim = -1, largeSwapDim = -1; + { + int64_t s0 = srcShape[d0], s1 = srcShape[d1]; + if (s0 <= kSmallSwapDimThreshold && s0 < s1) { + smallSwapDim = d0; largeSwapDim = d1; + } else if (s1 <= kSmallSwapDimThreshold && s1 < s0) { + smallSwapDim = d1; largeSwapDim = d0; + } + } + + // Find the largest identity dim — kept in the inner tile for efficiency. + int64_t keepIdDim = -1; + int64_t keepIdSize = 0; + for (int64_t idim : identityDims) { + if (srcShape[idim] > keepIdSize) { + keepIdSize = srcShape[idim]; + keepIdDim = idim; + } + } + + // Use the optimized copy path when a small swapped dim exists and + // there is a large identity dim to pair with the large swapped dim + // for the 2D inner tile. + bool useSmallSwapOpt = (smallSwapDim >= 0 && keepIdDim >= 0 && + keepIdSize > 1); + + Value c0 = b.create(loc, 0); + Value c1 = b.create(loc, 1); + + if (useSmallSwapOpt) { + // Inverse permutation: invPerm[srcDim] = dstDim + SmallVector invPerm(rank); + for (int64_t j = 0; j < rank; j++) + invPerm[perm[j]] = j; + + // Build loop nest over: + // - identity dims (except keepIdDim) + // - small swapped dim (offset goes at different positions in src vs dst) + struct LoopInfo { int64_t srcDim; int64_t dstDim; }; + SmallVector loopInfos; + SmallVector loopIVs; + auto addLoop = [&](int64_t srcDim, int64_t dstDim, int64_t size) { + Value ub = b.create(loc, size); + auto loop = b.create(loc, c0, ub, c1); + b.setInsertionPointToStart(loop.getBody()); + loopInfos.push_back({srcDim, dstDim}); + loopIVs.push_back(loop.getInductionVar()); + }; + + for (int64_t idim : identityDims) { + if (idim == keepIdDim) continue; + addLoop(idim, idim, srcShape[idim]); + } + addLoop(smallSwapDim, invPerm[smallSwapDim], srcShape[smallSwapDim]); + + // Build src subview + SmallVector srcOffsets(rank, b.getIndexAttr(0)); + SmallVector srcSizes; + SmallVector srcStrides(rank, b.getIndexAttr(1)); + for (int64_t i = 0; i < rank; i++) { + bool inLoop = false; + for (size_t li = 0; li < loopInfos.size(); li++) { + if (loopInfos[li].srcDim == i) { + srcOffsets[i] = OpFoldResult(loopIVs[li]); + srcSizes.push_back(b.getIndexAttr(1)); + inLoop = true; break; + } + } + if (!inLoop) + srcSizes.push_back(b.getIndexAttr(srcShape[i])); + } + auto srcSlice = b.create( + loc, transposeOp.getInput(), srcOffsets, srcSizes, srcStrides); + + // Build dst subview (loop IVs placed at dstDim positions) + SmallVector dstOffsets(rank, b.getIndexAttr(0)); + SmallVector dstSizes; + SmallVector dstStrides(rank, b.getIndexAttr(1)); + for (int64_t i = 0; i < rank; i++) { + bool inLoop = false; + for (size_t li = 0; li < loopInfos.size(); li++) { + if (loopInfos[li].dstDim == i) { + dstOffsets[i] = OpFoldResult(loopIVs[li]); + dstSizes.push_back(b.getIndexAttr(1)); + inLoop = true; break; + } + } + if (!inLoop) + dstSizes.push_back(b.getIndexAttr(dstShape[i])); + } + auto dstSlice = b.create( + loc, transposeOp.getInit(), dstOffsets, dstSizes, dstStrides); + + // Collapse to 2D: the inner tile has two non-unit dims + // (largeSwapDim and keepIdDim in src; their permuted positions in dst). + int64_t srcDimA = std::min(largeSwapDim, keepIdDim); + int64_t dstDimA = std::min(invPerm[largeSwapDim], keepIdDim); + Value src2d = collapse2D(srcSlice, srcDimA, -1, srcType); + Value dst2d = collapse2D(dstSlice, dstDimA, -1, dstType); + + // Emit copy — no transpose needed since the two inner dims are in + // the same relative order in src and dst. + auto src2dType = cast(src2d.getType()); + auto dst2dType = cast(dst2d.getType()); + if (isSbuf(src2dType) || isSbuf(dst2dType)) { + b.create(loc, src2d, dst2d); + } else { + // HBM→HBM: route through SBUF temp. + auto sbufMemSpace = nkipy::MemSpaceEnumAttr::get( + b.getContext(), nkipy::MemSpaceEnum::Sbuf); + auto shape2d = src2dType.getShape(); + auto sbufType = MemRefType::get( + shape2d, src2dType.getElementType(), nullptr, sbufMemSpace); + auto sbufTemp = b.create(loc, sbufType); + b.create(loc, src2d, sbufTemp); + b.create(loc, sbufTemp, dst2d); + } + + transposeOp.erase(); + LLVM_DEBUG(llvm::dbgs() << "[SimplifyLinalg] Decomposed " << rank + << "D transpose to loop of 2D copies" + " (small-swap-dim opt)\n"); + continue; + } + + // ---------------------------------------------------------------- + // Original path: loop over all identity dims, transpose inner 2D. + // ---------------------------------------------------------------- + SmallVector ivs; + for (int64_t idim : identityDims) { + Value ub = b.create(loc, srcShape[idim]); + auto loop = b.create(loc, c0, ub, c1); + b.setInsertionPointToStart(loop.getBody()); + ivs.push_back(loop.getInductionVar()); + } + + // Build subview for src: take a 2D slice at the swapped dims + SmallVector srcOffsets(rank, b.getIndexAttr(0)); + SmallVector srcSizes; + SmallVector srcStrides(rank, b.getIndexAttr(1)); + unsigned ivIdx = 0; + for (int64_t i = 0; i < rank; i++) { + if (perm[i] == i) { + srcOffsets[i] = OpFoldResult(ivs[ivIdx++]); + srcSizes.push_back(b.getIndexAttr(1)); + } else { + srcSizes.push_back(b.getIndexAttr(srcShape[i])); + } + } + auto srcSlice = b.create( + loc, transposeOp.getInput(), srcOffsets, srcSizes, srcStrides); + + // Build subview for dst: same loop IVs but at the permuted positions + SmallVector dstOffsets(rank, b.getIndexAttr(0)); + SmallVector dstSizes; + SmallVector dstStrides(rank, b.getIndexAttr(1)); + ivIdx = 0; + for (int64_t i = 0; i < rank; i++) { + if (perm[i] == i) { + dstOffsets[i] = OpFoldResult(ivs[ivIdx++]); + dstSizes.push_back(b.getIndexAttr(1)); + } else { + dstSizes.push_back(b.getIndexAttr(dstShape[i])); + } + } + auto dstSlice = b.create( + loc, transposeOp.getInit(), dstOffsets, dstSizes, dstStrides); + + Value src2d = collapse2D(srcSlice, d0, d1, srcType); + Value dst2d = collapse2D(dstSlice, d0, d1, dstType); + + // The 2D transpose is [1, 0] since the two swapped dims are now the + // only dims. + // + // For SBUF-involved transposes: emit linalg.transpose directly. + // For HBM-only transposes: route through an SBUF temp since there's no + // HBM→HBM transpose in hardware. Pattern: load → transpose in SBUF → store. + auto src2dType = cast(src2d.getType()); + auto dst2dType = cast(dst2d.getType()); + if (isSbuf(src2dType) || isSbuf(dst2dType)) { + auto newOp = b.create( + loc, src2d, dst2d, ArrayRef{1, 0}); + if (auto id = transposeOp->getAttr("nkipy.op_id")) + newOp->setAttr("nkipy.op_id", id); + } else { + // HBM→HBM: allocate SBUF temp, load src, transpose in SBUF, store to dst. + auto sbufMemSpace = nkipy::MemSpaceEnumAttr::get( + b.getContext(), nkipy::MemSpaceEnum::Sbuf); + auto srcShape2d = src2dType.getShape(); + auto dstShape2d = dst2dType.getShape(); + auto sbufSrcType = MemRefType::get( + srcShape2d, src2dType.getElementType(), nullptr, sbufMemSpace); + auto sbufDstType = MemRefType::get( + dstShape2d, dst2dType.getElementType(), nullptr, sbufMemSpace); + auto sbufSrc = b.create(loc, sbufSrcType); + auto sbufDst = b.create(loc, sbufDstType); + b.create(loc, src2d, sbufSrc); // HBM → SBUF + auto newOp = b.create( // transpose in SBUF + loc, sbufSrc, sbufDst, ArrayRef{1, 0}); + if (auto id = transposeOp->getAttr("nkipy.op_id")) + newOp->setAttr("nkipy.op_id", id); + b.create(loc, sbufDst, dst2d); // SBUF → HBM + } + + transposeOp.erase(); + LLVM_DEBUG(llvm::dbgs() << "[SimplifyLinalg] Decomposed " << rank + << "D transpose to loop of 2D transposes\n"); + } +} + +//===----------------------------------------------------------------------===// +// Preprocessing: Collapse >2D SBUF transpose to 2D +//===----------------------------------------------------------------------===// + +/// Replace a >2D SBUF memref.alloc (with unit dims, effectively 2D) with a +/// true 2D alloc + expand_shape view. Returns the 2D alloc Value, or nullptr +/// if the operand is not a collapsible SBUF alloc. +static Value collapseSbufAllocTo2D(Value operand, int64_t dim0, int64_t dim1, + SmallVector &reassoc) { + auto allocOp = operand.getDefiningOp(); + if (!allocOp) + return nullptr; + auto oldType = allocOp.getType(); + auto newType = MemRefType::get( + {dim0, dim1}, oldType.getElementType(), + /*layout=*/nullptr, oldType.getMemorySpace()); + + OpBuilder b(allocOp); + Location loc = allocOp.getLoc(); + auto newAlloc = b.create(loc, newType, + allocOp.getAlignmentAttr()); + auto expandOp = b.create( + loc, oldType, newAlloc, reassoc); + + allocOp.replaceAllUsesWith(expandOp.getResult()); + allocOp.erase(); + return newAlloc.getResult(); +} + +/// Rewrite >2D SBUF linalg.transpose (with unit dims) to a true 2D transpose. +static void rewriteSbufTransposeTo2D(func::FuncOp func) { + SmallVector worklist; + func.walk([&](linalg::TransposeOp op) { worklist.push_back(op); }); + + for (auto transposeOp : worklist) { + auto srcType = dyn_cast(transposeOp.getInput().getType()); + auto dstType = dyn_cast(transposeOp.getInit().getType()); + if (!srcType || !dstType || srcType.getRank() <= 2) + continue; + if (!isSbuf(srcType) || !isSbuf(dstType)) + continue; + + // Must have exactly 2 non-unit dims to collapse to 2D. + auto shape = srcType.getShape(); + unsigned nonUnitCount = 0; + SmallVector nonUnitDims; + for (unsigned i = 0; i < shape.size(); i++) { + if (shape[i] != 1) { nonUnitCount++; nonUnitDims.push_back(i); } + } + if (nonUnitCount != 2) + continue; + + // Only rewrite when the non-unit dims are actually transposed. + auto perm = transposeOp.getPermutation(); + unsigned dstPos0 = 0, dstPos1 = 0; + for (unsigned d = 0; d < perm.size(); d++) { + if (perm[d] == (int64_t)nonUnitDims[0]) dstPos0 = d; + if (perm[d] == (int64_t)nonUnitDims[1]) dstPos1 = d; + } + bool needsTranspose = (dstPos0 > dstPos1); + + // Helper: compute 2D collapse params for a given shape. + auto computeCollapse = [](ArrayRef sh, + int64_t &d0, int64_t &d1, + SmallVector &ra) { + int64_t rank = sh.size(); + unsigned firstNonUnit = 0; + for (unsigned i = 0; i < sh.size(); i++) + if (sh[i] != 1) { firstNonUnit = i; break; } + d0 = 1; + for (unsigned i = 0; i <= firstNonUnit; i++) d0 *= sh[i]; + d1 = 1; + for (unsigned i = firstNonUnit + 1; i < (unsigned)rank; i++) d1 *= sh[i]; + ra.clear(); + ReassociationIndices g0, g1; + for (int64_t i = 0; i <= (int64_t)firstNonUnit; i++) g0.push_back(i); + for (int64_t i = firstNonUnit + 1; i < rank; i++) g1.push_back(i); + ra.push_back(g0); + ra.push_back(g1); + }; + + int64_t srcDim0, srcDim1, dstDim0, dstDim1; + SmallVector srcReassoc, dstReassoc; + computeCollapse(srcType.getShape(), srcDim0, srcDim1, srcReassoc); + computeCollapse(dstType.getShape(), dstDim0, dstDim1, dstReassoc); + + // Replace both allocs with 2D + expand_shape. + Value src2d = collapseSbufAllocTo2D(transposeOp.getInput(), + srcDim0, srcDim1, srcReassoc); + Value dst2d = collapseSbufAllocTo2D(transposeOp.getInit(), + dstDim0, dstDim1, dstReassoc); + if (!src2d || !dst2d) + continue; + + // Redirect dealloc and copy ops from expand_shape (3D) to 2D alloc. + for (auto *val : {&src2d, &dst2d}) { + for (auto *user : val->getUsers()) { + auto expand = dyn_cast(user); + if (!expand) continue; + for (auto *eu : llvm::make_early_inc_range(expand->getUsers())) { + if (auto dealloc = dyn_cast(eu)) { + dealloc.getMemrefMutable().assign(*val); + } else if (auto copyOp = dyn_cast(eu)) { + OpBuilder cb(copyOp); + Location cl = copyOp.getLoc(); + bool sbufIsDst = (copyOp.getTarget() == expand.getResult()); + Value hbmOperand = sbufIsDst ? copyOp.getSource() + : copyOp.getTarget(); + auto hbmType = cast(hbmOperand.getType()); + int64_t hbmRank = hbmType.getRank(); + + // Build a rank-reducing subview: [1,128,128] -> [128,128] + SmallVector offsets(hbmRank, + cb.getI64IntegerAttr(0)); + SmallVector sizes; + for (int64_t s : hbmType.getShape()) + sizes.push_back(cb.getI64IntegerAttr(s)); + SmallVector strides(hbmRank, + cb.getI64IntegerAttr(1)); + + auto sbufType = cast(val->getType()); + auto hbm2dType = memref::SubViewOp::inferRankReducedResultType( + sbufType.getShape(), hbmType, offsets, sizes, strides); + auto hbm2d = cb.create( + cl, cast(hbm2dType), hbmOperand, + offsets, sizes, strides); + + if (sbufIsDst) { + cb.create(cl, hbm2d, *val); + } else { + cb.create(cl, *val, hbm2d); + } + copyOp.erase(); + } + } + } + } + + OpBuilder b(transposeOp); + Location loc = transposeOp.getLoc(); + + if (needsTranspose) { + // Real transpose: create 2D linalg.transpose with [1, 0]. + auto newOp = b.create( + loc, src2d, dst2d, ArrayRef{1, 0}); + if (auto id = transposeOp->getAttr("nkipy.op_id")) + newOp->setAttr("nkipy.op_id", id); + } else { + // Just moving unit dim -- same data layout. Emit a copy. + b.create(loc, src2d, dst2d); + } + + transposeOp.erase(); + } +} + +//===----------------------------------------------------------------------===// +// Preprocessing: Canonicalize trivial-broadcast generics to named ops +//===----------------------------------------------------------------------===// + +/// Convert linalg.generic with trivial broadcast to named ops. +/// After tiling, a broadcast like (128x4x64) * (128x1x64) becomes +/// (128x1x64) * (128x1x64) -- the broadcast dim is now size 1 on both sides. +/// The generic still carries broadcast indexing maps but is effectively +/// a same-shape elementwise op. Convert it to a named op (linalg.mul, etc.) +/// so the existing LinalgElementwiseToNisaPattern can handle it. +static void canonicalizeTrivialBroadcastGenerics(func::FuncOp func) { + SmallVector toConvert; + func.walk([&](linalg::GenericOp op) { + // 2 inputs, 1 output, all parallel + if (op.getNumDpsInputs() != 2 || op.getNumDpsInits() != 1) + return; + if (!llvm::all_of(op.getIteratorTypesArray(), + [](utils::IteratorType t) { + return t == utils::IteratorType::parallel; + })) + return; + + // Single binary arith op in body + Operation *binaryOp = nullptr; + for (Operation &bodyOp : op.getRegion().front().without_terminator()) { + if (getArithOpKindFromBodyOp(&bodyOp)) { + if (binaryOp) return; // multiple ops + binaryOp = &bodyOp; + } + } + if (!binaryOp) return; + + // Must be a direct binary op (not a wrapped pattern like uitofp(andi(...))) + if (binaryOp->getNumOperands() != 2) return; + + // Both operands must be block args (not constants) + if (binaryOp->getOperand(0).getDefiningOp() || + binaryOp->getOperand(1).getDefiningOp()) + return; + + // Check all indexing maps are identity or trivial broadcast + auto maps = op.getIndexingMapsArray(); + auto outType = dyn_cast(op.getDpsInits()[0].getType()); + if (!outType) return; + + for (auto &map : maps) { + if (map.isIdentity()) continue; + for (unsigned i = 0; i < map.getNumResults(); ++i) { + auto expr = map.getResult(i); + if (auto constExpr = dyn_cast(expr)) { + if (constExpr.getValue() == 0 && outType.getDimSize(i) == 1) + continue; // trivial broadcast + return; // non-trivial + } + if (!isa(expr)) return; + } + } + + toConvert.push_back(op); + }); + + for (auto op : toConvert) { + Operation *binaryOp = nullptr; + for (Operation &bodyOp : op.getRegion().front().without_terminator()) { + if (getArithOpKindFromBodyOp(&bodyOp)) { + binaryOp = &bodyOp; + break; + } + } + + // Figure out operand order: body may swap block args + Block &body = op.getRegion().front(); + Value lhs = op.getDpsInputs()[0]; + Value rhs = op.getDpsInputs()[1]; + if (binaryOp->getOperand(0) == body.getArgument(1) && + binaryOp->getOperand(1) == body.getArgument(0)) + std::swap(lhs, rhs); + + OpBuilder builder(op); + Value output = op.getDpsInits()[0]; + + Operation *namedOp = nullptr; + auto kind = *getArithOpKindFromBodyOp(binaryOp); + switch (kind) { + case LocalArithKind::ADD: + namedOp = builder.create( + op.getLoc(), ValueRange{lhs, rhs}, ValueRange{output}); + break; + case LocalArithKind::SUBTRACT: + namedOp = builder.create( + op.getLoc(), ValueRange{lhs, rhs}, ValueRange{output}); + break; + case LocalArithKind::MULTIPLY: + namedOp = builder.create( + op.getLoc(), ValueRange{lhs, rhs}, ValueRange{output}); + break; + case LocalArithKind::DIVIDE: + namedOp = builder.create( + op.getLoc(), ValueRange{lhs, rhs}, ValueRange{output}); + break; + default: + continue; // skip unsupported ops + } + + // Copy over any relevant attrs (like nkipy.op_id) + if (auto opId = op->getAttr("nkipy.op_id")) + namedOp->setAttr("nkipy.op_id", opId); + + op.replaceAllUsesWith(namedOp->getResults()); + op.erase(); + } +} + +//===----------------------------------------------------------------------===// +// Preprocessing: Replace SBUF gather operands with HBM originals +//===----------------------------------------------------------------------===// + +/// For nkipy.gather ops, replace SBUF source/indices with their HBM origins. +/// nisa.dma_copy_indirect requires the source table in HBM. The annotation +/// pass may have copied source/indices to SBUF — undo that so linalg-to-nisa +/// sees HBM operands and can emit dma_copy_indirect directly. +static void prepareGatherForNisaLowering(func::FuncOp func) { + SmallVector gatherOps; + func.walk([&](nkipy::GatherOp op) { gatherOps.push_back(op); }); + + for (auto gatherOp : gatherOps) { + // Check source (operand 0) and indices (operand 1). + for (unsigned idx : {0u, 1u}) { + Value operand = gatherOp->getOperand(idx); + auto memrefType = dyn_cast(operand.getType()); + if (!memrefType || !isSbuf(memrefType)) + continue; + + // Find the memref.copy that writes HBM data into this SBUF alloc. + Value hbmSource = nullptr; + memref::CopyOp deadCopy = nullptr; + for (auto *user : operand.getUsers()) { + auto copyOp = dyn_cast(user); + if (!copyOp || copyOp.getTarget() != operand) + continue; + Value base = nkipy::getBaseMemRef(copyOp.getSource()); + if (isHbm(cast(base.getType()))) { + hbmSource = copyOp.getSource(); + deadCopy = copyOp; + break; + } + } + if (!hbmSource) + continue; + + // Replace the gather operand with the HBM source. + gatherOp->setOperand(idx, hbmSource); + + // Erase the dead copy. + deadCopy->erase(); + + // If the SBUF alloc has no remaining readers, erase it + dealloc. + if (auto allocOp = operand.getDefiningOp()) { + SmallVector toErase; + bool canErase = true; + for (auto *user : allocOp->getResult(0).getUsers()) { + if (isa(user)) + toErase.push_back(user); + else { + canErase = false; + break; + } + } + if (canErase) { + for (auto *op : toErase) + op->erase(); + allocOp->erase(); + } + } + } + } +} + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +struct SimplifyLinalgPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SimplifyLinalgPass) + + StringRef getArgument() const final { return "simplify-linalg"; } + + StringRef getDescription() const final { + return "Prepare linalg operations for NISA lowering"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + func::FuncOp func = getOperation(); + + // Decompose high-rank transposes (e.g., 4D [0,2,1,3] from multi-head + // attention reshaping) into loops of 2D transposes. This handles any + // memory space (HBM, SharedHBM, SBUF) and any rank where exactly 2 + // dimensions are swapped. Must run before rewriteSbufTransposeTo2D. + decomposeHighRankTranspose(func); + + // Rewrite >2D SBUF transpose to 2D. + // NISA dma_transpose only supports [1,0] (2D) or [2,1,0] (3D full reverse). + // A 3D transpose [0,2,1] on [1,128,128] must use 2D [1,0] on [128,128]. + // Replace the >2D SBUF allocs with 2D allocs + expand_shape views. + // After pattern rewriting, getBaseAndOffsets traces through expand_shape + // so NISA ops use the 2D base. The expand_shape becomes dead -> DCE'd. + rewriteSbufTransposeTo2D(func); + + // Convert trivial-broadcast linalg.generic to named ops + canonicalizeTrivialBroadcastGenerics(func); + + // Replace SBUF gather operands with HBM originals for dma_copy_indirect + prepareGatherForNisaLowering(func); + } +}; + +} // namespace + +namespace mlir { +namespace nkipy { + +std::unique_ptr> createSimplifyLinalgPass() { + return std::make_unique(); +} + +} // namespace nkipy +} // namespace mlir diff --git a/kernelgen/mlir/tools/CMakeLists.txt b/kernelgen/mlir/tools/CMakeLists.txt new file mode 100644 index 0000000..8a73c1f --- /dev/null +++ b/kernelgen/mlir/tools/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(nkipy-opt) diff --git a/kernelgen/mlir/tools/nkipy-opt/CMakeLists.txt b/kernelgen/mlir/tools/nkipy-opt/CMakeLists.txt new file mode 100644 index 0000000..07b3f14 --- /dev/null +++ b/kernelgen/mlir/tools/nkipy-opt/CMakeLists.txt @@ -0,0 +1,20 @@ +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) + +set(LIBS + ${dialect_libs} + ${conversion_libs} + ${extension_libs} + MLIROptLib + MLIRNkipy + MLIRNkipyPasses + MLIRNkipyTransformOps + ) + +add_llvm_executable(nkipy-opt nkipy-opt.cpp) + +llvm_update_compile_flags(nkipy-opt) +target_link_libraries(nkipy-opt PRIVATE ${LIBS}) + +mlir_check_all_link_libraries(nkipy-opt) diff --git a/kernelgen/mlir/tools/nkipy-opt/README.md b/kernelgen/mlir/tools/nkipy-opt/README.md new file mode 100644 index 0000000..b2a0149 --- /dev/null +++ b/kernelgen/mlir/tools/nkipy-opt/README.md @@ -0,0 +1,148 @@ +# nkipy-opt - MLIR Optimizer Driver + +`nkipy-opt` is a command-line tool for testing and running MLIR passes on nkipy dialect code. It is analogous to LLVM's `mlir-opt` tool but specifically configured for the nkipy project. + +## Building + +To build `nkipy-opt`, you need to first set up your environment and then build the mlir project. + +### Prerequisites + +1. Source the environment setup script to configure LLVM/MLIR paths: +```bash +source setup.sh +``` + +This will set up the required environment variables: +- `LLVM_DIR` - Path to LLVM CMake config +- `MLIR_DIR` - Path to MLIR CMake config +- `PATH` - Includes LLVM binaries + +2. Ensure you have a compatible CMake version (3.20.0 or higher) + +### Build Steps + +```bash +cd mlir +rm -rf build # Clean any existing build +mkdir build +cd build + +# Configure with CMake +cmake .. -G Ninja \ + -DCMAKE_BUILD_TYPE=RelWithDebInfo \ + -DLLVM_DIR=${LLVM_DIR} \ + -DMLIR_DIR=${MLIR_DIR} + +# Build nkipy-opt +ninja nkipy-opt +``` + +The resulting binary will be placed in `mlir/build/bin/nkipy-opt`. + +Alternatively, if Ninja is not available, you can use Make: +```bash +cmake .. -DLLVM_DIR=${LLVM_DIR} -DMLIR_DIR=${MLIR_DIR} +make nkipy-opt -j$(nproc) +``` + +## Usage + +### Basic Usage + +```bash +nkipy-opt [options] +``` + +### Common Options + +- `--help`: Display available options and passes +- `--show-dialects`: Show all registered dialects +- `--print-ir-after-all`: Print IR after each pass +- `--mlir-print-ir-after-change`: Only print IR after a pass if it changed +- `--mlir-timing`: Display timing information for passes + +### Running Passes + +To run a specific pass: + +```bash +nkipy-opt --memref-dce input.mlir +``` + +To chain multiple passes: + +```bash +nkipy-opt --pass-pipeline='builtin.module(memref-dce,canonicalize)' input.mlir +``` + +### Available Nkipy Passes + +- `--memref-dce`: Remove MemRefs that are never loaded from + +### Example Workflows + +#### 1. View Available Passes + +```bash +nkipy-opt --help +``` + +#### 2. Run Dead Code Elimination + +```bash +nkipy-opt --memref-dce example.mlir -o output.mlir +``` + +#### 3. Run with Timing Information + +```bash +nkipy-opt --memref-dce --mlir-timing example.mlir +``` + +#### 4. Print IR After Each Pass + +```bash +nkipy-opt --memref-dce --print-ir-after-all example.mlir +``` + +#### 5. Run Standard MLIR Passes + +Since `nkipy-opt` registers all standard MLIR passes, you can also use standard passes: + +```bash +nkipy-opt --canonicalize --cse example.mlir +``` + +## Input File Format + +Input files should be valid MLIR text format. Example: + +```mlir +module { + func.func @example(%arg0: memref<10xf32>) -> f32 { + %0 = memref.alloc() : memref<10xf32> + %c0 = arith.constant 0 : index + %1 = memref.load %arg0[%c0] : memref<10xf32> + return %1 : f32 + } +} +``` + +## Integration with Build System + +The tool is automatically built as part of the nkipy-kg project when you build the mlir subdirectory. + +## Troubleshooting + +If you encounter build errors: + +1. Ensure MLIR and LLVM are properly installed and found by CMake +2. Check that all required dialects are registered +3. Verify the include paths are correct in CMakeLists.txt + +For runtime errors: + +1. Use `--help` to see all available options +2. Use `--show-dialects` to verify the nkipy dialect is registered +3. Enable verbose output with `--mlir-print-ir-after-all` diff --git a/kernelgen/mlir/tools/nkipy-opt/nkipy-opt.cpp b/kernelgen/mlir/tools/nkipy-opt/nkipy-opt.cpp new file mode 100644 index 0000000..6752984 --- /dev/null +++ b/kernelgen/mlir/tools/nkipy-opt/nkipy-opt.cpp @@ -0,0 +1,51 @@ +//===- nkipy-opt.cpp - MLIR Optimizer Driver ------------------------------===// +// +// This file implements the 'nkipy-opt' tool, which is the nkipy analog of +// mlir-opt, used to drive compiler passes, e.g. for testing. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllPasses.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/ToolOutputFile.h" + +#include "nkipy/Dialect/NkipyDialect.h" +#include "nkipy/Transforms/Passes.h" +#include "nkipy/TransformOps/NkipyTransformOps.h" + +// Include Transform dialect for knob-driven-tiling pass +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/Transforms/Passes.h" +#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" +#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" +#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" + +int main(int argc, char **argv) { + mlir::registerAllPasses(); + mlir::nkipy::registerNkipyPasses(); + + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + registry.insert(); + registry.insert(); // Required for knob-driven-tiling + + // Note: bufferization dialect is included in registerAllDialects(), but + // we explicitly register the transform extension below + + mlir::scf::registerTransformDialectExtension(registry); + mlir::linalg::registerTransformDialectExtension(registry); + mlir::bufferization::registerTransformDialectExtension(registry); + mlir::nkipy::registerTransformDialectExtension(registry); + + return mlir::asMainReturnCode( + mlir::MlirOptMain(argc, argv, "Nkipy optimizer driver\n", registry)); +} diff --git a/kernelgen/mlir/tools/nkipy-opt/test_example.mlir b/kernelgen/mlir/tools/nkipy-opt/test_example.mlir new file mode 100644 index 0000000..c0847d5 --- /dev/null +++ b/kernelgen/mlir/tools/nkipy-opt/test_example.mlir @@ -0,0 +1,27 @@ +// Test example for nkipy-opt +// This file demonstrates a simple MLIR module with memref operations +// that can be optimized using the memref-dce pass + +module { + func.func @test_dce(%arg0: memref<10xf32>) -> f32 { + // This allocation is never loaded from - should be removed by memref-dce + %dead_alloc = memref.alloc() : memref<10xf32> + + // This allocation is used - should be kept + %used_alloc = memref.alloc() : memref<5xf32> + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // Load from input argument + %val = memref.load %arg0[%c0] : memref<10xf32> + + // Store to used allocation + memref.store %val, %used_alloc[%c0] : memref<5xf32> + + // Load from used allocation + %result = memref.load %used_alloc[%c1] : memref<5xf32> + + return %result : f32 + } +} diff --git a/kernelgen/nkipy_kernelgen/__init__.py b/kernelgen/nkipy_kernelgen/__init__.py new file mode 100644 index 0000000..4b328c9 --- /dev/null +++ b/kernelgen/nkipy_kernelgen/__init__.py @@ -0,0 +1,27 @@ +""" +NKIPyKernelGen - Lowering from NumPy to NKI compiler + +This package provides tools to trace Python functions with NumPy operations +and convert them to MLIR for compilation with neuronxcc. +""" + +from .trace import trace +from .traced_array import TracedArray +from .custom_op import CustomOp +from .execution import verify_against_numpy +from .pass_manager import apply_passes +from . import apis +from . import transforms + +__version__ = "0.1.0" +__author__ = "Your Name" + +__all__ = [ + "trace", + "TracedArray", + "CustomOp", + "verify_against_numpy", + "apply_passes", + "apis", + "transforms", +] diff --git a/kernelgen/nkipy_kernelgen/apis.py b/kernelgen/nkipy_kernelgen/apis.py new file mode 100644 index 0000000..4de9b22 --- /dev/null +++ b/kernelgen/nkipy_kernelgen/apis.py @@ -0,0 +1,11 @@ +""" +APIs for use inside traced functions. + +This module contains APIs that are meant to be used within functions +decorated with @trace, such as control flow constructs and optimization hints. +""" + +from .knob import knob +from .control_flow import fori_loop + +__all__ = ["knob", "fori_loop"] diff --git a/kernelgen/nkipy_kernelgen/builder.py b/kernelgen/nkipy_kernelgen/builder.py new file mode 100644 index 0000000..0860148 --- /dev/null +++ b/kernelgen/nkipy_kernelgen/builder.py @@ -0,0 +1,1802 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Builder API for nkipy kernelgen backend. + +Provides an opaque IR construction interface so that nkipy never imports +``mlir`` directly. All MLIR types (``ir.Value``, ``ir.Type``, dialect ops) +are encapsulated behind :class:`TensorHandle`, :class:`LoopIndexHandle`, +and :class:`IRBuilder`. +""" + +from __future__ import annotations + +import math +from typing import Callable, Optional, Union + +import numpy as np +from mlir import ir, passmanager +from mlir.dialects import arith, func, linalg, scf, tensor +from mlir.dialects import math as mlir_math + +from nkipy_kernelgen._mlir.dialects import nkipy as nkipy_d +from nkipy_kernelgen.mlir_utils import ( + const_scalar, + make_empty, + make_filled, + make_zeros, + ranked_tensor_of, + to_mlir_type, +) + +Scalar = Union[int, float] + +_MEM_SPACE_CONSTANT = 4 + +# --------------------------------------------------------------------------- +# Type helpers +# --------------------------------------------------------------------------- + +_MLIR_TO_NP: dict[str, np.dtype] = { + "f16": np.dtype("float16"), + "bf16": np.dtype("bfloat16"), + "f32": np.dtype("float32"), + "f64": np.dtype("float64"), + "i32": np.dtype("int32"), + "i64": np.dtype("int64"), +} + + +def _mlir_type_to_np(mlir_ty: ir.Type) -> np.dtype: + s = str(mlir_ty) + if s in _MLIR_TO_NP: + return _MLIR_TO_NP[s] + raise KeyError(f"Cannot convert MLIR type {mlir_ty} to numpy dtype") + + +def _np_to_mlir(dtype) -> ir.Type: + if isinstance(dtype, ir.Type): + return dtype + return to_mlir_type(dtype) + + +def _is_float(elem_ty: ir.Type) -> bool: + return isinstance(elem_ty, ir.FloatType) + + +def _is_int(elem_ty: ir.Type) -> bool: + return isinstance(elem_ty, ir.IntegerType) + + +# --------------------------------------------------------------------------- +# TensorHandle — opaque tensor value +# --------------------------------------------------------------------------- + + +class TensorHandle: + """Opaque handle to a traced tensor value. + + Users see ``.shape`` and ``.dtype`` only. The internal ``_value`` + (an ``ir.Value``) is never accessed from nkipy. + """ + + __slots__ = ("_value", "shape", "dtype", "_elem_ty") + + def __init__( + self, value: ir.Value, shape: tuple, dtype: np.dtype, elem_ty: ir.Type + ): + self._value = value + self.shape = tuple(shape) + self.dtype = np.dtype(dtype) if not isinstance(dtype, np.dtype) else dtype + self._elem_ty = elem_ty + + +def _make_handle(value, shape, elem_ty) -> TensorHandle: + return TensorHandle(value, shape, _mlir_type_to_np(elem_ty), elem_ty) + + +def _loc() -> ir.Location: + return ir.Location.unknown() + + +# --------------------------------------------------------------------------- +# LoopIndexHandle — opaque loop induction variable +# --------------------------------------------------------------------------- + + +class LoopIndexHandle: + """Opaque handle to a loop induction variable with factor/offset tracking.""" + + __slots__ = ("_value", "mul_factor", "add_offset") + + def __init__(self, value: ir.Value, mul_factor: int = 1, add_offset: int = 0): + self._value = value + self.mul_factor = mul_factor + self.add_offset = add_offset + + +# --------------------------------------------------------------------------- +# IRBuilder — lifecycle management +# --------------------------------------------------------------------------- + + +class IRBuilder: + """Manages MLIR context, module, and function lifecycle.""" + + def __init__(self, source_file: str = "nkipy_kernel"): + self._ctx = ir.Context() + nkipy_d.register_dialect(self._ctx) + self._ctx.__enter__() + + self._file_loc = ir.Location.file(source_file, 0, 0, context=self._ctx) + self._file_loc.__enter__() + + self._module = ir.Module.create() + self._parameters: list[TensorHandle] = [] + self._func_op = None + self._entry_block = None + self._ip = None + + @property + def context(self): + return self._ctx + + @property + def module(self): + return self._module + + def begin_function( + self, + name: str, + arg_shapes: list[tuple], + arg_dtypes: list, + ) -> list[TensorHandle]: + arg_types = [ + ranked_tensor_of(shape, _np_to_mlir(dtype)) + for shape, dtype in zip(arg_shapes, arg_dtypes) + ] + fn_type = ir.FunctionType.get(arg_types, [arg_types[0]]) + with ir.InsertionPoint(self._module.body): + self._func_op = func.FuncOp(name=name, type=fn_type, loc=self._file_loc) + self._entry_block = self._func_op.add_entry_block() + self._ip = ir.InsertionPoint(self._entry_block) + self._ip.__enter__() + + handles: list[TensorHandle] = [] + for arg, (shape, dtype) in zip( + self._entry_block.arguments, zip(arg_shapes, arg_dtypes) + ): + elem_ty = _np_to_mlir(dtype) + h = TensorHandle(arg, shape, np.dtype(dtype) if not isinstance(dtype, str) else _mlir_type_to_np(elem_ty), elem_ty) + self._parameters.append(h) + handles.append(h) + return handles + + def finish_function(self, result_handles: list[TensorHandle]): + values = [h._value for h in result_handles] + func.ReturnOp(values, loc=self._file_loc) + self._ip.__exit__(None, None, None) + self._ip = None + + arg_types = [ranked_tensor_of(p.shape, p._elem_ty) for p in self._parameters] + res_types = [v.type for v in values] + self._func_op.attributes["function_type"] = ir.TypeAttr.get( + ir.FunctionType.get(arg_types, res_types) + ) + + def emit_custom_op_declarations(self, custom_ops: list): + """Emit func.func private declarations and stash NISA bodies for custom ops.""" + if not custom_ops: + return + with ir.InsertionPoint(self._module.body): + for custom in custom_ops: + input_types = [ + ranked_tensor_of(s, _np_to_mlir(d)) + for s, d in zip(custom.input_shapes, custom.input_dtypes) + ] + result_types = [ + ranked_tensor_of(s, _np_to_mlir(d)) + for s, d in zip(custom.output_shapes, custom.output_dtypes) + ] + fn_type = ir.FunctionType.get(input_types, result_types) + fn = func.FuncOp(name=custom.func_name, type=fn_type) + fn.attributes["sym_visibility"] = ir.StringAttr.get("private") + fn.attributes["nkipy.custom_op"] = ir.UnitAttr.get() + + bodies = {custom.func_name: custom.nisa_mlir for custom in custom_ops} + self._module.operation.attributes["nkipy.custom_op_bodies"] = ( + ir.DictAttr.get({k: ir.StringAttr.get(v) for k, v in bodies.items()}) + ) + + def run_canonicalize(self): + pm = passmanager.PassManager.parse("builtin.module(func.func(canonicalize))") + pm.run(self._module.operation) + + def get_ir_text(self) -> str: + return str(self._module) + + def cleanup(self): + if self._ip is not None: + self._ip.__exit__(None, None, None) + self._ip = None + self._file_loc.__exit__(None, None, None) + self._ctx.__exit__(None, None, None) + + +# --------------------------------------------------------------------------- +# Broadcasting helpers (internal) +# --------------------------------------------------------------------------- + + +def _broadcast_shape(sa: tuple, sb: tuple) -> tuple: + mr = max(len(sa), len(sb)) + pa = (1,) * (mr - len(sa)) + tuple(sa) + pb = (1,) * (mr - len(sb)) + tuple(sb) + result = [] + for da, db in zip(pa, pb): + if da == db: + result.append(da) + elif da == 1: + result.append(db) + elif db == 1: + result.append(da) + else: + raise ValueError(f"Incompatible shapes for broadcasting: {sa} vs {sb}") + return tuple(result) + + +def _broadcast_indexing_map(in_shape: tuple, out_shape: tuple) -> ir.AffineMap: + rd = len(out_shape) - len(in_shape) + exprs = [] + for oi in range(len(out_shape)): + ii = oi - rd + if ii < 0: + continue + if in_shape[ii] == 1 and out_shape[oi] > 1: + exprs.append(ir.AffineConstantExpr.get(0)) + else: + exprs.append(ir.AffineDimExpr.get(oi)) + return ir.AffineMap.get(len(out_shape), 0, exprs) + + +# --------------------------------------------------------------------------- +# Cast helpers (internal) +# --------------------------------------------------------------------------- + + +def _cast_to_float(val, shape, elem_ty, loc): + out_elem = ir.F32Type.get() + result_type = ranked_tensor_of(shape, out_elem) + out = make_empty(loc, shape, out_elem) + nd = len(shape) + imap = ir.AffineMap.get_identity(nd) + g = linalg.GenericOp( + [result_type], + [val], + [out], + ir.ArrayAttr.get([ir.AffineMapAttr.get(imap)] * 2), + ir.ArrayAttr.get([ir.Attribute.parse("#linalg.iterator_type")] * nd), + loc=loc, + ) + blk = g.regions[0].blocks.append(elem_ty, out_elem) + with ir.InsertionPoint(blk): + r = arith.SIToFPOp(out_elem, blk.arguments[0], loc=loc).result + linalg.YieldOp([r], loc=loc) + return g.results[0], shape, out_elem + + +# --------------------------------------------------------------------------- +# Generic element-wise helpers (internal) +# --------------------------------------------------------------------------- + + +def _scalar_binary( + tensor_val, + tensor_shape, + tensor_elem, + scalar_val, + arith_fn, + loc, + scalar_is_lhs=False, +): + rt = ranked_tensor_of(tensor_shape, tensor_elem) + out = make_empty(loc, tensor_shape, tensor_elem) + nd = len(tensor_shape) + imap = ir.AffineMap.get_identity(nd) + g = linalg.GenericOp( + [rt], + [tensor_val], + [out], + ir.ArrayAttr.get([ir.AffineMapAttr.get(imap)] * 2), + ir.ArrayAttr.get([ir.Attribute.parse("#linalg.iterator_type")] * nd), + loc=loc, + ) + blk = g.regions[0].blocks.append(tensor_elem, tensor_elem) + with ir.InsertionPoint(blk): + cst = const_scalar(scalar_val, tensor_elem, loc) + if scalar_is_lhs: + r = arith_fn(cst, blk.arguments[0]) + else: + r = arith_fn(blk.arguments[0], cst) + linalg.YieldOp([r], loc=loc) + return g.results[0], tensor_shape, tensor_elem + + +def _broadcast_binary(a_val, a_shape, a_elem, b_val, b_shape, b_elem, body_fn, loc): + if str(a_elem) != str(b_elem): + raise TypeError(f"Element type mismatch: {a_elem} vs {b_elem}") + elem = a_elem + out_shape = _broadcast_shape(a_shape, b_shape) + rt = ranked_tensor_of(out_shape, elem) + out = make_empty(loc, out_shape, elem) + ma = _broadcast_indexing_map(a_shape, out_shape) + mb = _broadcast_indexing_map(b_shape, out_shape) + mo = ir.AffineMap.get_identity(len(out_shape)) + g = linalg.GenericOp( + [rt], + [a_val, b_val], + [out], + ir.ArrayAttr.get( + [ + ir.AffineMapAttr.get(ma), + ir.AffineMapAttr.get(mb), + ir.AffineMapAttr.get(mo), + ] + ), + ir.ArrayAttr.get( + [ir.Attribute.parse("#linalg.iterator_type")] * len(out_shape) + ), + loc=loc, + ) + blk = g.regions[0].blocks.append(elem, elem, elem) + with ir.InsertionPoint(blk): + r = body_fn(blk.arguments[0], blk.arguments[1]) + linalg.YieldOp([r], loc=loc) + return g.results[0], out_shape, elem + + +# --------------------------------------------------------------------------- +# Internal: unary generic +# --------------------------------------------------------------------------- + + +def _unary_named(x: TensorHandle, named_cls, body_fn, loc) -> TensorHandle: + val, shape, elem = x._value, x.shape, x._elem_ty + rt = ranked_tensor_of(shape, elem) + out = make_empty(loc, shape, elem) + op = named_cls([rt], [val], [out], loc=loc) + blk = op.regions[0].blocks.append(elem, elem) + with ir.InsertionPoint(blk): + r = body_fn(blk.arguments[0], elem, loc) + linalg.YieldOp([r], loc=loc) + return _make_handle(op.results[0], shape, elem) + + +def _unary_generic(x: TensorHandle, body_fn, loc) -> TensorHandle: + val, shape, elem = x._value, x.shape, x._elem_ty + rt = ranked_tensor_of(shape, elem) + out = make_empty(loc, shape, elem) + nd = len(shape) + imap = ir.AffineMap.get_identity(nd) + g = linalg.GenericOp( + [rt], + [val], + [out], + ir.ArrayAttr.get([ir.AffineMapAttr.get(imap)] * 2), + ir.ArrayAttr.get([ir.Attribute.parse("#linalg.iterator_type")] * nd), + loc=loc, + ) + blk = g.regions[0].blocks.append(elem, elem) + with ir.InsertionPoint(blk): + r = body_fn(blk.arguments[0], elem, loc) + linalg.YieldOp([r], loc=loc) + return _make_handle(g.results[0], shape, elem) + + +# --------------------------------------------------------------------------- +# Internal: binary dispatch (tensor-tensor, tensor-scalar, scalar-tensor) +# --------------------------------------------------------------------------- + + +def _binary_dispatch(x, y, float_cls, int_cls, named_cls, loc): + x_is_t = isinstance(x, TensorHandle) + y_is_t = isinstance(y, TensorHandle) + + if not x_is_t and not y_is_t: + raise TypeError("At least one operand must be a TensorHandle") + + float_op = lambda a, b: float_cls(a, b, loc=loc).result + int_op = lambda a, b: int_cls(a, b, loc=loc).result + + if not x_is_t or not y_is_t: + if x_is_t: + tv, ts, te = x._value, x.shape, x._elem_ty + sv, slhs = y, False + else: + tv, ts, te = y._value, y.shape, y._elem_ty + sv, slhs = x, True + fn = float_op if _is_float(te) else int_op + rv, rs, re = _scalar_binary(tv, ts, te, sv, fn, loc, slhs) + return _make_handle(rv, rs, re) + + xv, xs, xe = x._value, x.shape, x._elem_ty + yv, ys, ye = y._value, y.shape, y._elem_ty + if _is_float(xe) and not _is_float(ye): + yv, ys, ye = _cast_to_float(yv, ys, ye, loc) + elif _is_float(ye) and not _is_float(xe): + xv, xs, xe = _cast_to_float(xv, xs, xe, loc) + + if xs != ys or named_cls is None: + fn = float_op if _is_float(xe) else int_op + rv, rs, re = _broadcast_binary(xv, xs, xe, yv, ys, ye, fn, loc) + return _make_handle(rv, rs, re) + + elem = xe + rt = ranked_tensor_of(xs, elem) + out = make_empty(loc, xs, elem) + nop = named_cls([rt], [xv, yv], [out], loc=loc) + blk = nop.regions[0].blocks.append(elem, elem, elem) + with ir.InsertionPoint(blk): + fn = float_op if _is_float(elem) else int_op + r = fn(blk.arguments[0], blk.arguments[1]) + linalg.YieldOp([r], loc=loc) + return _make_handle(nop.results[0], xs, elem) + + +# --------------------------------------------------------------------------- +# Internal: reshape +# --------------------------------------------------------------------------- + + +def _emit_reshape(loc, value, old_shape, new_shape, elem_ty): + dst_ty = ranked_tensor_of(tuple(new_shape), elem_ty) + if tuple(old_shape) == tuple(new_shape): + return value + + def _reassoc(from_shape, to_shape): + reassoc = [] + to_idx = 0 + to_rank = len(to_shape) + for from_dim in from_shape: + group = [] + product = 1 + while to_idx < to_rank and product < from_dim: + product *= to_shape[to_idx] + group.append(to_idx) + to_idx += 1 + if from_dim == 1 and not group: + if to_idx < to_rank and to_shape[to_idx] == 1: + group.append(to_idx) + to_idx += 1 + product = 1 + else: + return None + if product != from_dim or not group: + return None + reassoc.append(group) + while to_idx < to_rank and to_shape[to_idx] == 1: + reassoc[-1].append(to_idx) + to_idx += 1 + return reassoc if to_idx == to_rank else None + + if len(old_shape) >= len(new_shape): + r = _reassoc(new_shape, old_shape) + if r is not None: + return tensor.CollapseShapeOp(dst_ty, value, r, loc=loc).result + if len(old_shape) <= len(new_shape): + r = _reassoc(old_shape, new_shape) + if r is not None: + return tensor.ExpandShapeOp( + dst_ty, + value, + r, + output_shape=[], + static_output_shape=list(new_shape), + loc=loc, + ).result + + idx_ty = ir.IndexType.get() + shape_ty = ir.RankedTensorType.get([len(new_shape)], idx_ty) + shape_vals = [ + arith.ConstantOp(idx_ty, ir.IntegerAttr.get(idx_ty, d), loc=loc).result + for d in new_shape + ] + fe = tensor.FromElementsOp(shape_ty, shape_vals, loc=loc) + return tensor.ReshapeOp(dst_ty, value, fe.result, loc=loc).result + + +# --------------------------------------------------------------------------- +# Internal: matmul body helper +# --------------------------------------------------------------------------- + + +def _matmul_body(op, elem_ty, loc): + blk = op.regions[0].blocks.append(elem_ty, elem_ty, elem_ty) + with ir.InsertionPoint(blk): + a, b, c = blk.arguments + if _is_float(elem_ty): + p = arith.MulFOp(a, b, loc=loc).result + r = arith.AddFOp(c, p, loc=loc).result + else: + p = arith.MulIOp(a, b, loc=loc).result + r = arith.AddIOp(c, p, loc=loc).result + linalg.YieldOp([r], loc=loc) + + +# =================================================================== +# PUBLIC OP API +# =================================================================== + +# --------------------------------------------------------------------------- +# Binary arithmetic +# --------------------------------------------------------------------------- + + +def add(x: Union[TensorHandle, Scalar], y: Union[TensorHandle, Scalar], loc=None) -> TensorHandle: + return _binary_dispatch(x, y, arith.AddFOp, arith.AddIOp, linalg.AddOp, loc or _loc()) + + +def subtract(x, y, loc=None) -> TensorHandle: + return _binary_dispatch(x, y, arith.SubFOp, arith.SubIOp, linalg.SubOp, loc or _loc()) + + +def multiply(x, y, loc=None) -> TensorHandle: + return _binary_dispatch(x, y, arith.MulFOp, arith.MulIOp, linalg.MulOp, loc or _loc()) + + +def divide(x, y, loc=None) -> TensorHandle: + return _binary_dispatch(x, y, arith.DivFOp, arith.DivSIOp, linalg.DivOp, loc or _loc()) + + +def maximum(x, y, loc=None) -> TensorHandle: + return _binary_dispatch(x, y, arith.MaximumFOp, arith.MaxSIOp, linalg.MaxOp, loc or _loc()) + + +def minimum(x, y, loc=None) -> TensorHandle: + return _binary_dispatch(x, y, arith.MinimumFOp, arith.MinSIOp, linalg.MinOp, loc or _loc()) + + +def power(x, y, loc=None) -> TensorHandle: + x_is_t = isinstance(x, TensorHandle) + y_is_t = isinstance(y, TensorHandle) + + if x_is_t and not y_is_t: + if isinstance(y, (int, float)): + if y == 2: + return multiply(x, x, loc=loc) + if y == 0.5: + return sqrt(x, loc=loc) + log_x = log(x, loc=loc) + scaled = multiply(log_x, float(y), loc=loc) + return exp(scaled, loc=loc) + elif not x_is_t and y_is_t: + log_scalar = math.log(float(x)) + scaled = multiply(y, log_scalar, loc=loc) + return exp(scaled, loc=loc) + else: + loc = loc or _loc() + xv, xs, xe = x._value, x.shape, x._elem_ty + yv, ys, ye = y._value, y.shape, y._elem_ty + pow_fn = lambda a, b: mlir_math.PowFOp(a, b, loc=loc).result + rv, rs, re = _broadcast_binary(xv, xs, xe, yv, ys, ye, pow_fn, loc) + return _make_handle(rv, rs, re) + + +# --------------------------------------------------------------------------- +# Unary math +# --------------------------------------------------------------------------- + + +def exp(x: TensorHandle, loc=None) -> TensorHandle: + return _unary_named( + x, linalg.ExpOp, lambda v, _, l: mlir_math.ExpOp(v, loc=l).result, loc or _loc() + ) + + +def log(x: TensorHandle, loc=None) -> TensorHandle: + return _unary_named( + x, linalg.LogOp, lambda v, _, l: mlir_math.LogOp(v, loc=l).result, loc or _loc() + ) + + +def sqrt(x: TensorHandle, loc=None) -> TensorHandle: + return _unary_named( + x, linalg.SqrtOp, lambda v, _, l: mlir_math.SqrtOp(v, loc=l).result, loc or _loc() + ) + + +def tanh(x: TensorHandle, loc=None) -> TensorHandle: + return _unary_named( + x, linalg.TanhOp, lambda v, _, l: mlir_math.TanhOp(v, loc=l).result, loc or _loc() + ) + + +def abs_(x: TensorHandle, loc=None) -> TensorHandle: + def _body(v, elem, l): + if _is_float(elem): + return mlir_math.AbsFOp(v, loc=l).result + return mlir_math.AbsIOp(v, loc=l).result + + return _unary_named(x, linalg.AbsOp, _body, loc or _loc()) + + +def ceil_(x: TensorHandle, loc=None) -> TensorHandle: + return _unary_named( + x, linalg.CeilOp, lambda v, _, l: mlir_math.CeilOp(v, loc=l).result, loc or _loc() + ) + + +def floor_(x: TensorHandle, loc=None) -> TensorHandle: + return _unary_named( + x, linalg.FloorOp, lambda v, _, l: mlir_math.FloorOp(v, loc=l).result, loc or _loc() + ) + + +def sin(x: TensorHandle, loc=None) -> TensorHandle: + return _unary_generic( + x, lambda v, _, l: mlir_math.SinOp(v, loc=l).result, loc or _loc() + ) + + +def cos(x: TensorHandle, loc=None) -> TensorHandle: + shifted = add(x, math.pi / 2, loc=loc) + return sin(shifted, loc=loc) + + +def sign(x: TensorHandle, loc=None) -> TensorHandle: + def _body(v, elem, l): + one = arith.ConstantOp(elem, ir.FloatAttr.get(elem, 1.0), loc=l).result + return mlir_math.CopySignOp(one, v, loc=l).result + + return _unary_generic(x, _body, loc or _loc()) + + +def square(x: TensorHandle, loc=None) -> TensorHandle: + def _body(v, elem, l): + if _is_float(elem): + return arith.MulFOp(v, v, loc=l).result + return arith.MulIOp(v, v, loc=l).result + + return _unary_named(x, linalg.SquareOp, _body, loc or _loc()) + + +def reciprocal(x: TensorHandle, loc=None) -> TensorHandle: + def _body(v, elem, l): + one = arith.ConstantOp(elem, ir.FloatAttr.get(elem, 1.0), loc=l).result + return arith.DivFOp(one, v, loc=l).result + + return _unary_named(x, linalg.ReciprocalOp, _body, loc or _loc()) + + +def negative(x: TensorHandle, loc=None) -> TensorHandle: + return multiply(x, -1.0, loc=loc) + + +def copy_(x: TensorHandle, loc=None) -> TensorHandle: + return _unary_named(x, linalg.CopyOp, lambda v, _, __: v, loc or _loc()) + + +# --------------------------------------------------------------------------- +# Comparison +# --------------------------------------------------------------------------- + + +def _comparison(x, y, pred_name, loc) -> TensorHandle: + loc = loc or _loc() + x_is_t = isinstance(x, TensorHandle) + y_is_t = isinstance(y, TensorHandle) + + pred = ir.IntegerAttr.get( + ir.IntegerType.get_signless(64), + arith.CmpFPredicate.__members__[pred_name].value, + ).value + + def cmp_fn(lhs, rhs): + c = arith.CmpFOp(pred, lhs, rhs, loc=loc).result + return arith.UIToFPOp(lhs.type, c, loc=loc).result + + if not x_is_t or not y_is_t: + if x_is_t: + tv, ts, te = x._value, x.shape, x._elem_ty + sv, slhs = y, False + else: + tv, ts, te = y._value, y.shape, y._elem_ty + sv, slhs = x, True + if _is_int(te): + tv, ts, te = _cast_to_float(tv, ts, te, loc) + rv, rs, re = _scalar_binary(tv, ts, te, sv, cmp_fn, loc, slhs) + return _make_handle(rv, rs, re) + + xv, xs, xe = x._value, x.shape, x._elem_ty + yv, ys, ye = y._value, y.shape, y._elem_ty + if _is_int(xe): + xv, xs, xe = _cast_to_float(xv, xs, xe, loc) + if _is_int(ye): + yv, ys, ye = _cast_to_float(yv, ys, ye, loc) + + rv, rs, re = _broadcast_binary(xv, xs, xe, yv, ys, ye, cmp_fn, loc) + return _make_handle(rv, rs, re) + + +def equal(x, y, loc=None) -> TensorHandle: + return _comparison(x, y, "OEQ", loc) + + +def not_equal(x, y, loc=None) -> TensorHandle: + return _comparison(x, y, "UNE", loc) + + +def greater(x, y, loc=None) -> TensorHandle: + return _comparison(x, y, "OGT", loc) + + +def greater_equal(x, y, loc=None) -> TensorHandle: + return _comparison(x, y, "OGE", loc) + + +def less(x, y, loc=None) -> TensorHandle: + return _comparison(x, y, "OLT", loc) + + +def less_equal(x, y, loc=None) -> TensorHandle: + return _comparison(x, y, "OLE", loc) + + +# --------------------------------------------------------------------------- +# Bitwise / logical +# --------------------------------------------------------------------------- + + +def _logical_binary(x, y, int_cls, loc) -> TensorHandle: + loc = loc or _loc() + x_is_t = isinstance(x, TensorHandle) + y_is_t = isinstance(y, TensorHandle) + + ref = x if x_is_t else y + elem = ref._elem_ty + is_fp = _is_float(elem) + + if is_fp: + i1 = ir.IntegerType.get_signless(1) + + def body_fn(lhs, rhs): + l1 = arith.FPToUIOp(i1, lhs, loc=loc).result + r1 = arith.FPToUIOp(i1, rhs, loc=loc).result + ri = int_cls(l1, r1, loc=loc).result + return arith.UIToFPOp(lhs.type, ri, loc=loc).result + else: + + def body_fn(lhs, rhs): + return int_cls(lhs, rhs, loc=loc).result + + if not x_is_t or not y_is_t: + if x_is_t: + bt = x + sv, slhs = y, False + else: + bt = y + sv, slhs = x, True + rv, rs, re = _scalar_binary( + bt._value, bt.shape, bt._elem_ty, sv, body_fn, loc, slhs + ) + return _make_handle(rv, rs, re) + + rv, rs, re = _broadcast_binary( + x._value, x.shape, x._elem_ty, + y._value, y.shape, y._elem_ty, + body_fn, loc, + ) + return _make_handle(rv, rs, re) + + +def bitwise_and(x, y, loc=None) -> TensorHandle: + return _logical_binary(x, y, arith.AndIOp, loc) + + +def bitwise_or(x, y, loc=None) -> TensorHandle: + return _logical_binary(x, y, arith.OrIOp, loc) + + +def bitwise_xor(x, y, loc=None) -> TensorHandle: + return _logical_binary(x, y, arith.XOrIOp, loc) + + +def logical_not(x, loc=None) -> TensorHandle: + return subtract(1, x, loc=loc) + + +def mod(x, y, loc=None) -> TensorHandle: + return _binary_dispatch(x, y, arith.RemFOp, arith.RemSIOp, None, loc or _loc()) + + +# --------------------------------------------------------------------------- +# Matmul +# --------------------------------------------------------------------------- + + +def matmul(x: TensorHandle, y: TensorHandle, loc=None) -> TensorHandle: + loc = loc or _loc() + xv, xs, xe = x._value, x.shape, x._elem_ty + yv, ys, ye = y._value, y.shape, y._elem_ty + + m, k = xs[-2], xs[-1] + k2, n = ys[-2], ys[-1] + if k != k2: + raise ValueError(f"Incompatible shapes for matmul: {xs} @ {ys}") + + if len(xs) == 2 and len(ys) == 2: + out_shape = (m, n) + out_val = make_zeros(loc, out_shape, xe) + rt = ranked_tensor_of(out_shape, xe) + mm = linalg.MatmulOp([rt], [xv, yv], [out_val], loc=loc) + _matmul_body(mm, xe, loc) + return _make_handle(mm.results[0], out_shape, xe) + + ba = xs[:-2] + bb = ys[:-2] + ml = max(len(ba), len(bb)) + pa = (1,) * (ml - len(ba)) + ba + pb = (1,) * (ml - len(bb)) + bb + bs = tuple(max(da, db) for da, db in zip(pa, pb)) + + bsz = 1 + for d in bs: + bsz *= d + + a3 = (bsz, m, k) + b3 = (bsz, k, n) + o3 = (bsz, m, n) + + av = xv if xs == a3 else _emit_reshape(loc, xv, xs, a3, xe) + bv = yv if ys == b3 else _emit_reshape(loc, yv, ys, b3, ye) + + out_val = make_zeros(loc, o3, xe) + rt = ranked_tensor_of(o3, xe) + mm = linalg.BatchMatmulOp([rt], [av, bv], [out_val], loc=loc) + _matmul_body(mm, xe, loc) + + final = bs + (m, n) + if final == o3: + rv = mm.results[0] + else: + rv = _emit_reshape(loc, mm.results[0], o3, final, xe) + return _make_handle(rv, final, xe) + + +# --------------------------------------------------------------------------- +# Transpose +# --------------------------------------------------------------------------- + + +def transpose(x: TensorHandle, axes=None, loc=None) -> TensorHandle: + loc = loc or _loc() + val, shape, elem = x._value, x.shape, x._elem_ty + rank = len(shape) + if axes is None: + perm = list(range(rank - 1, -1, -1)) + else: + perm = [ax if ax >= 0 else ax + rank for ax in axes] + new_shape = tuple(shape[p] for p in perm) + out = make_empty(loc, new_shape, elem) + rt = ranked_tensor_of(new_shape, elem) + top = linalg.TransposeOp([rt], val, out, perm, loc=loc) + blk = top.regions[0].blocks.append(elem, elem) + with ir.InsertionPoint(blk): + linalg.YieldOp([blk.arguments[0]], loc=loc) + return _make_handle(top.results[0], new_shape, elem) + + +# --------------------------------------------------------------------------- +# Reshape / expand_dims +# --------------------------------------------------------------------------- + + +def reshape(x: TensorHandle, newshape, loc=None) -> TensorHandle: + loc = loc or _loc() + val, shape, elem = x._value, x.shape, x._elem_ty + newshape = list(newshape) + if -1 in newshape: + if newshape.count(-1) > 1: + raise ValueError("can only specify one unknown dimension (-1)") + total = 1 + for d in shape: + total *= d + known = 1 + neg = -1 + for i, d in enumerate(newshape): + if d == -1: + neg = i + else: + known *= d + newshape[neg] = total // known + ns = tuple(newshape) + rv = _emit_reshape(loc, val, shape, ns, elem) + return _make_handle(rv, ns, elem) + + +def expand_dims(x: TensorHandle, axis, loc=None) -> TensorHandle: + shape = x.shape + if isinstance(axis, int): + axis = (axis,) + ns = list(shape) + for ax in sorted(axis): + if ax < 0: + ax = len(ns) + ax + 1 + ns.insert(ax, 1) + return reshape(x, tuple(ns), loc=loc) + + +# --------------------------------------------------------------------------- +# Reductions +# --------------------------------------------------------------------------- + + +def _normalize_axis(axis, rank): + if axis is None: + return None + if isinstance(axis, int): + axis = [axis] + return sorted([ax % rank for ax in axis]) + + +def _reduce(x: TensorHandle, axis, keepdims, init_fn, body_fn, loc=None) -> TensorHandle: + loc = loc or _loc() + val, shape, elem = x._value, x.shape, x._elem_ty + rank = len(shape) + if axis is None: + axis = list(range(rank)) + + if keepdims: + out_shape = tuple(1 if i in axis else shape[i] for i in range(rank)) + out_exprs = [ + ir.AffineConstantExpr.get(0) if i in axis else ir.AffineDimExpr.get(i) + for i in range(rank) + ] + else: + out_shape = tuple(shape[i] for i in range(rank) if i not in axis) + out_exprs = [ir.AffineDimExpr.get(i) for i in range(rank) if i not in axis] + + rt = ranked_tensor_of(out_shape, elem) + init = init_fn(loc, out_shape, elem) + imap = ir.AffineMap.get_identity(rank) + omap = ir.AffineMap.get(rank, 0, out_exprs) + + g = linalg.GenericOp( + [rt], + [val], + [init], + ir.ArrayAttr.get([ir.AffineMapAttr.get(imap), ir.AffineMapAttr.get(omap)]), + ir.ArrayAttr.get( + [ + ir.Attribute.parse("#linalg.iterator_type") + if i in axis + else ir.Attribute.parse("#linalg.iterator_type") + for i in range(rank) + ] + ), + loc=loc, + ) + blk = g.regions[0].blocks.append(elem, elem) + with ir.InsertionPoint(blk): + r = body_fn(blk.arguments[0], blk.arguments[1], elem, loc) + linalg.YieldOp([r], loc=loc) + return _make_handle(g.results[0], out_shape, elem) + + +def reduce_sum(x: TensorHandle, axis=None, keepdims=False, loc=None) -> TensorHandle: + na = _normalize_axis(axis, len(x.shape)) + + def body(inp, acc, elem, l): + if _is_float(elem): + return arith.AddFOp(acc, inp, loc=l).result + return arith.AddIOp(acc, inp, loc=l).result + + return _reduce(x, na, keepdims, make_zeros, body, loc) + + +def reduce_prod(x: TensorHandle, axis=None, keepdims=False, loc=None) -> TensorHandle: + na = _normalize_axis(axis, len(x.shape)) + init_fn = lambda loc, shape, elem: make_filled(loc, shape, elem, 1.0) + + def body(inp, acc, elem, l): + if _is_float(elem): + return arith.MulFOp(acc, inp, loc=l).result + return arith.MulIOp(acc, inp, loc=l).result + + return _reduce(x, na, keepdims, init_fn, body, loc) + + +def reduce_max(x: TensorHandle, axis=None, keepdims=False, loc=None) -> TensorHandle: + na = _normalize_axis(axis, len(x.shape)) + init_fn = lambda loc, shape, elem: make_filled(loc, shape, elem, float("-inf")) + + def body(inp, acc, elem, l): + if _is_float(elem): + return arith.MaximumFOp(acc, inp, loc=l).result + return arith.MaxSIOp(acc, inp, loc=l).result + + return _reduce(x, na, keepdims, init_fn, body, loc) + + +def reduce_min(x: TensorHandle, axis=None, keepdims=False, loc=None) -> TensorHandle: + na = _normalize_axis(axis, len(x.shape)) + init_fn = lambda loc, shape, elem: make_filled(loc, shape, elem, float("inf")) + + def body(inp, acc, elem, l): + if _is_float(elem): + return arith.MinimumFOp(acc, inp, loc=l).result + return arith.MinSIOp(acc, inp, loc=l).result + + return _reduce(x, na, keepdims, init_fn, body, loc) + + +def reduce_mean(x: TensorHandle, axis=None, keepdims=False, loc=None) -> TensorHandle: + shape = x.shape + s = reduce_sum(x, axis=axis, keepdims=keepdims, loc=loc) + if axis is None: + count = int(np.prod(shape)) + else: + axes = [axis] if isinstance(axis, int) else list(axis) + count = int(np.prod([shape[i] for i in axes])) + return divide(s, float(count), loc=loc) + + +def reduce_std(x: TensorHandle, axis=None, keepdims=False, loc=None) -> TensorHandle: + mean_val = reduce_mean(x, axis=axis, keepdims=True, loc=loc) + diff = subtract(x, mean_val, loc=loc) + sq = multiply(diff, diff, loc=loc) + variance = reduce_mean(sq, axis=axis, keepdims=keepdims, loc=loc) + return sqrt(variance, loc=loc) + + +def reduce_var(x: TensorHandle, axis=None, keepdims=False, loc=None) -> TensorHandle: + mean_val = reduce_mean(x, axis=axis, keepdims=True, loc=loc) + diff = subtract(x, mean_val, loc=loc) + sq = multiply(diff, diff, loc=loc) + return reduce_mean(sq, axis=axis, keepdims=keepdims, loc=loc) + + +# --------------------------------------------------------------------------- +# Creation +# --------------------------------------------------------------------------- + + +def zeros(shape: tuple, dtype, loc=None) -> TensorHandle: + loc = loc or _loc() + elem = _np_to_mlir(dtype) + v = make_filled(loc, tuple(shape), elem, 0.0) + return _make_handle(v, tuple(shape), elem) + + +def full(shape: tuple, fill_value, dtype, loc=None) -> TensorHandle: + loc = loc or _loc() + elem = _np_to_mlir(dtype) + v = make_filled(loc, tuple(shape), elem, fill_value) + return _make_handle(v, tuple(shape), elem) + + +def empty(shape: tuple, dtype, loc=None) -> TensorHandle: + loc = loc or _loc() + elem = _np_to_mlir(dtype) + v = make_empty(loc, tuple(shape), elem) + return _make_handle(v, tuple(shape), elem) + + +def constant_tensor(val: Scalar, shape: tuple, elem_ty, loc=None) -> TensorHandle: + """Create a filled tensor annotated with CONSTANT memory space.""" + loc = loc or _loc() + if not isinstance(elem_ty, ir.Type): + elem_ty = _np_to_mlir(elem_ty) + v = make_filled(loc, tuple(shape), elem_ty, val) + ms_attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), _MEM_SPACE_CONSTANT) + nkipy_d.AnnotateOp( + target=v, + mem_space=ms_attr, + partition_dim=None, + tile_size=None, + reduction_tile=None, + loc=loc, + ) + return _make_handle(v, tuple(shape), elem_ty) + + +# --------------------------------------------------------------------------- +# Concatenate +# --------------------------------------------------------------------------- + + +def concatenate(arrays: list[TensorHandle], axis: int = 0, loc=None) -> TensorHandle: + loc = loc or _loc() + if not arrays: + raise ValueError("need at least one array to concatenate") + if len(arrays) == 1: + return arrays[0] + + first = arrays[0] + elem = first._elem_ty + out_shape = list(first.shape) + out_shape[axis] = sum(a.shape[axis] for a in arrays) + out_shape = tuple(out_shape) + + output = make_empty(loc, out_shape, elem) + offset = 0 + for a in arrays: + offsets = [0] * len(out_shape) + offsets[axis] = offset + sizes = list(a.shape) + strides = [1] * len(out_shape) + output = tensor.InsertSliceOp( + a._value, + output, + [], + [], + [], + offsets, + sizes, + strides, + loc=loc, + ).result + offset += a.shape[axis] + + return _make_handle(output, out_shape, elem) + + +# --------------------------------------------------------------------------- +# Broadcast +# --------------------------------------------------------------------------- + + +def broadcast_to(x: TensorHandle, shape: tuple, loc=None) -> TensorHandle: + loc = loc or _loc() + val, xs, elem = x._value, x.shape, x._elem_ty + ts = tuple(shape) + if xs == ts: + return x + rt = ranked_tensor_of(ts, elem) + out = make_empty(loc, ts, elem) + im = _broadcast_indexing_map(xs, ts) + om = ir.AffineMap.get_identity(len(ts)) + g = linalg.GenericOp( + [rt], + [val], + [out], + ir.ArrayAttr.get([ir.AffineMapAttr.get(im), ir.AffineMapAttr.get(om)]), + ir.ArrayAttr.get( + [ir.Attribute.parse("#linalg.iterator_type")] * len(ts) + ), + loc=loc, + ) + blk = g.regions[0].blocks.append(elem, elem) + with ir.InsertionPoint(blk): + linalg.YieldOp([blk.arguments[0]], loc=loc) + return _make_handle(g.results[0], ts, elem) + + +# --------------------------------------------------------------------------- +# Where (conditional select) +# --------------------------------------------------------------------------- + + +def where(condition, x, y, loc=None) -> TensorHandle: + cond_is_t = isinstance(condition, TensorHandle) + x_is_t = isinstance(x, TensorHandle) + y_is_t = isinstance(y, TensorHandle) + + ref = x if x_is_t else y if y_is_t else None + if ref is None: + raise TypeError("where requires at least one tensor for x or y") + out_elem = ref._elem_ty + + needs_cast = cond_is_t and _is_int(condition._elem_ty) and _is_float(out_elem) + + # cond * x + (1 - cond) * y — works when condition values are 0/1. + if not needs_cast: + cx = multiply(condition, x, loc=loc) + inv = subtract(1, condition, loc=loc) + iy = multiply(inv, y, loc=loc) + return add(cx, iy, loc=loc) + + # Integer condition with float output: fuse cast + select into one generic. + loc = loc or _loc() + + shapes = [condition.shape] + if x_is_t: + shapes.append(x.shape) + if y_is_t: + shapes.append(y.shape) + out_shape = shapes[0] + for s in shapes[1:]: + out_shape = _broadcast_shape(out_shape, s) + + rt = ranked_tensor_of(out_shape, out_elem) + out = make_empty(loc, out_shape, out_elem) + + inputs = [condition._value] + in_maps = [_broadcast_indexing_map(condition.shape, out_shape)] + in_elem_tys = [condition._elem_ty] + + if x_is_t: + inputs.append(x._value) + in_maps.append(_broadcast_indexing_map(x.shape, out_shape)) + in_elem_tys.append(x._elem_ty) + if y_is_t: + inputs.append(y._value) + in_maps.append(_broadcast_indexing_map(y.shape, out_shape)) + in_elem_tys.append(y._elem_ty) + + om = ir.AffineMap.get_identity(len(out_shape)) + all_maps = [ir.AffineMapAttr.get(m) for m in in_maps] + all_maps.append(ir.AffineMapAttr.get(om)) + + g = linalg.GenericOp( + [rt], + inputs, + [out], + ir.ArrayAttr.get(all_maps), + ir.ArrayAttr.get( + [ir.Attribute.parse("#linalg.iterator_type")] * len(out_shape) + ), + loc=loc, + ) + + blk_types = in_elem_tys + [out_elem] + blk = g.regions[0].blocks.append(*blk_types) + with ir.InsertionPoint(blk): + tensor_args = iter(blk.arguments[1:]) + ci = blk.arguments[0] + cf = arith.SIToFPOp(out_elem, ci, loc=loc).result + xv = next(tensor_args) if x_is_t else const_scalar(x, out_elem, loc) + yv = next(tensor_args) if y_is_t else const_scalar(y, out_elem, loc) + one = const_scalar(1.0, out_elem, loc) + inv = arith.SubFOp(one, cf, loc=loc).result + t1 = arith.MulFOp(cf, xv, loc=loc).result + t2 = arith.MulFOp(inv, yv, loc=loc).result + r = arith.AddFOp(t1, t2, loc=loc).result + linalg.YieldOp([r], loc=loc) + + return _make_handle(g.results[0], tuple(out_shape), out_elem) + + +# --------------------------------------------------------------------------- +# Take (gather) +# --------------------------------------------------------------------------- + + +def take(a: TensorHandle, indices: TensorHandle, axis: int = 0, loc=None) -> TensorHandle: + loc = loc or _loc() + av, a_shape, a_elem = a._value, a.shape, a._elem_ty + iv, i_shape, i_elem = indices._value, indices.shape, indices._elem_ty + + if axis != 0: + raise NotImplementedError("Only axis=0 gather is currently supported") + + out_shape = i_shape + a_shape[1:] + rt = ranked_tensor_of(out_shape, a_elem) + output = make_empty(loc, out_shape, a_elem) + + gather = nkipy_d.GatherOp(rt, av, iv, output, loc=loc) + + src_type = av.type + idx_type = iv.type + # Linalg fallback for LLVM JIT; NISA path lowers GatherOp directly to DMA. + region = gather.reference_impl + blk = region.blocks.append(src_type, idx_type) + with ir.InsertionPoint(blk): + src_arg, idx_arg = blk.arguments + rank = len(out_shape) + irank = len(i_shape) + out2 = make_empty(loc, out_shape, a_elem) + + ie = [ir.AffineDimExpr.get(i) for i in range(irank)] + im = ir.AffineMap.get(rank, 0, ie) + om = ir.AffineMap.get_identity(rank) + + g = linalg.GenericOp( + [rt], + [idx_arg], + [out2], + ir.ArrayAttr.get([ir.AffineMapAttr.get(im), ir.AffineMapAttr.get(om)]), + ir.ArrayAttr.get( + [ir.Attribute.parse("#linalg.iterator_type")] * rank + ), + loc=loc, + ) + with ir.InsertionPoint(g.regions[0].blocks.append(i_elem, a_elem)): + index_val = g.regions[0].blocks[0].arguments[0] + if str(i_elem) != "index": + index_val = arith.IndexCastOp( + ir.IndexType.get(), index_val, loc=loc + ).result + ext_idx = [index_val] + for di in range(1, len(a_shape)): + ext_idx.append(linalg.IndexOp(irank + di - 1, loc=loc).result) + extracted = tensor.ExtractOp(src_arg, ext_idx, loc=loc).result + linalg.YieldOp([extracted], loc=loc) + + nkipy_d.YieldOp(values=[g.results[0]], loc=loc) + + return _make_handle(gather.result, tuple(out_shape), a_elem) + + +# --------------------------------------------------------------------------- +# Astype (cast) +# --------------------------------------------------------------------------- + + +def astype(x: TensorHandle, dtype, loc=None) -> TensorHandle: + loc = loc or _loc() + val, shape, src_elem = x._value, x.shape, x._elem_ty + dst_elem = _np_to_mlir(dtype) + + if str(src_elem) == str(dst_elem): + return x + + rt = ranked_tensor_of(shape, dst_elem) + out = make_empty(loc, shape, dst_elem) + nd = len(shape) + imap = ir.AffineMap.get_identity(nd) + g = linalg.GenericOp( + [rt], + [val], + [out], + ir.ArrayAttr.get([ir.AffineMapAttr.get(imap)] * 2), + ir.ArrayAttr.get([ir.Attribute.parse("#linalg.iterator_type")] * nd), + loc=loc, + ) + blk = g.regions[0].blocks.append(src_elem, dst_elem) + with ir.InsertionPoint(blk): + in_e = blk.arguments[0] + sf = _is_float(src_elem) + df = _is_float(dst_elem) + if sf and df: + if dst_elem.width > src_elem.width: + oe = arith.ExtFOp(dst_elem, in_e, loc=loc).result + elif dst_elem.width < src_elem.width: + oe = arith.TruncFOp(dst_elem, in_e, loc=loc).result + else: + oe = in_e + elif sf and not df: + oe = arith.FPToSIOp(dst_elem, in_e, loc=loc).result + elif not sf and df: + oe = arith.SIToFPOp(dst_elem, in_e, loc=loc).result + else: + if dst_elem.width > src_elem.width: + oe = arith.ExtSIOp(dst_elem, in_e, loc=loc).result + elif dst_elem.width < src_elem.width: + oe = arith.TruncIOp(dst_elem, in_e, loc=loc).result + else: + oe = in_e + linalg.YieldOp([oe], loc=loc) + return _make_handle(g.results[0], shape, dst_elem) + + +# --------------------------------------------------------------------------- +# Static / dynamic slicing +# --------------------------------------------------------------------------- + + +def static_slice( + x: TensorHandle, + start_indices, + limit_indices, + strides, + squeeze_dims, + loc=None, +) -> TensorHandle: + loc = loc or _loc() + val, shape, elem = x._value, x.shape, x._elem_ty + + slice_shape = [] + for s, l, st in zip(start_indices, limit_indices, strides): + slice_shape.append((l - s + st - 1) // st) + + rt = ranked_tensor_of(tuple(slice_shape), elem) + sliced = tensor.ExtractSliceOp( + rt, + val, + [], + [], + [], + start_indices, + slice_shape, + strides, + loc=loc, + ).result + + if squeeze_dims: + out_shape = tuple(s for i, s in enumerate(slice_shape) if i not in squeeze_dims) + if out_shape != tuple(slice_shape): + sliced = _emit_reshape(loc, sliced, tuple(slice_shape), out_shape, elem) + slice_shape = list(out_shape) + + return _make_handle(sliced, tuple(slice_shape), elem) + + +def _parse_dynamic_indices(indices, shape, loc): + """Parse a tuple of indices (LoopIndexHandle, slice, int) into static/dynamic components. + + Returns (static_offsets, static_sizes, static_strides, dynamic_offsets, full_indices). + """ + if not isinstance(indices, tuple): + indices = (indices,) + + full_indices = list(indices) + [slice(None)] * (len(shape) - len(indices)) + + static_offsets = [] + static_sizes = [] + static_strides = [] + dynamic_offsets = [] + + DYNAMIC = ir.ShapedType.get_dynamic_size() + + for idx, dim_size in zip(full_indices, shape): + if isinstance(idx, LoopIndexHandle): + ov = arith.IndexCastOp(ir.IndexType.get(), idx._value, loc=loc).result + dynamic_offsets.append(ov) + static_offsets.append(DYNAMIC) + static_sizes.append(1) + static_strides.append(1) + elif isinstance(idx, slice): + start = idx.start if idx.start is not None else 0 + stop = idx.stop if idx.stop is not None else dim_size + step = idx.step if idx.step is not None else 1 + + start_is_li = isinstance(start, LoopIndexHandle) + stop_is_li = isinstance(stop, LoopIndexHandle) + + if start_is_li: + ov = arith.IndexCastOp(ir.IndexType.get(), start._value, loc=loc).result + dynamic_offsets.append(ov) + static_offsets.append(DYNAMIC) + else: + static_offsets.append(int(start)) + + if start_is_li and stop_is_li and start.mul_factor == stop.mul_factor: + static_sizes.append(stop.add_offset - start.add_offset) + elif not start_is_li and not stop_is_li: + static_sizes.append(int(stop) - int(start)) + else: + raise NotImplementedError( + "Mixed static/dynamic slice sizes not yet supported" + ) + + static_strides.append(int(step)) + elif isinstance(idx, int): + static_offsets.append(int(idx)) + static_sizes.append(1) + static_strides.append(1) + else: + raise TypeError(f"Unsupported index type: {type(idx)}") + + return static_offsets, static_sizes, static_strides, dynamic_offsets, full_indices + + +def dynamic_slice(x: TensorHandle, indices, loc=None) -> TensorHandle: + loc = loc or _loc() + val, shape, elem = x._value, x.shape, x._elem_ty + + static_offsets, static_sizes, static_strides, dynamic_offsets, full_indices = \ + _parse_dynamic_indices(indices, shape, loc) + + DYNAMIC = ir.ShapedType.get_dynamic_size() + result_shape = [] + for idx, size, stride in zip(full_indices, static_sizes, static_strides): + if isinstance(idx, LoopIndexHandle) or isinstance(idx, int): + continue + if size != DYNAMIC: + result_shape.append(size // stride if stride > 1 else size) + + rt = ranked_tensor_of(tuple(result_shape), elem) + extract = tensor.ExtractSliceOp( + rt, + val, + dynamic_offsets, + [], + [], + static_offsets, + static_sizes, + static_strides, + loc=loc, + ) + return _make_handle(extract.result, tuple(result_shape), elem) + + +# --------------------------------------------------------------------------- +# Insert slice (setitem) +# --------------------------------------------------------------------------- + + +def static_insert_slice( + dest: TensorHandle, + src: TensorHandle, + offsets: list[int], + sizes: list[int], + strides: list[int], + loc=None, +) -> TensorHandle: + loc = loc or _loc() + new_tensor = tensor.InsertSliceOp( + src._value, + dest._value, + [], + [], + [], + offsets, + sizes, + strides, + loc=loc, + ).result + return _make_handle(new_tensor, dest.shape, dest._elem_ty) + + +def dynamic_insert_slice( + dest: TensorHandle, + src: TensorHandle, + indices, + loc=None, +) -> TensorHandle: + loc = loc or _loc() + + static_offsets, static_sizes, static_strides, dynamic_offsets, _ = \ + _parse_dynamic_indices(indices, dest.shape, loc) + + new_tensor = tensor.InsertSliceOp( + src._value, + dest._value, + dynamic_offsets, + [], + [], + static_offsets, + static_sizes, + static_strides, + loc=loc, + ).result + return _make_handle(new_tensor, dest.shape, dest._elem_ty) + + +# --------------------------------------------------------------------------- +# Split +# --------------------------------------------------------------------------- + + +def split(x: TensorHandle, sections: int, axis: int = 0, loc=None) -> list[TensorHandle]: + shape = x.shape + size = shape[axis] + if size % sections != 0: + raise ValueError("array split does not result in an equal division") + section_size = size // sections + results = [] + for i in range(sections): + start_indices = [0] * len(shape) + start_indices[axis] = i * section_size + limit_indices = list(shape) + limit_indices[axis] = (i + 1) * section_size + strides = [1] * len(shape) + results.append(static_slice(x, start_indices, limit_indices, strides, [], loc=loc)) + return results + + +# --------------------------------------------------------------------------- +# Annotations (knob) +# --------------------------------------------------------------------------- + + +def annotate( + x: TensorHandle, + *, + partition_dim: Optional[int] = None, + mem_space: Optional[str] = None, + tile_size: Optional[list[int]] = None, + reduction_tile: Optional[list[int]] = None, +) -> TensorHandle: + value = x._value + defining_op = value.owner + if defining_op is None: + return x + + loc = _loc() + + if isinstance(tile_size, int): + tile_size = [tile_size] + if isinstance(reduction_tile, int): + reduction_tile = [reduction_tile] + + valid = {"Hbm", "Psum", "Sbuf", "SharedHbm"} + if mem_space is not None and mem_space not in valid: + raise ValueError(f"Invalid mem_space '{mem_space}'. Must be one of: {valid}") + + ms_attr = None + if mem_space is not None: + ms_map = {"Hbm": 1, "Psum": 2, "Sbuf": 3, "SharedHbm": 4} + ms_attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), ms_map[mem_space]) + + pd_attr = None + if partition_dim is not None: + pd_attr = ir.IntegerAttr.get(ir.IntegerType.get_unsigned(32), partition_dim) + + ts_attr = None + if tile_size is not None: + ts_attr = ir.DenseI64ArrayAttr.get(tile_size) + + rt_attr = None + if reduction_tile is not None: + rt_attr = ir.DenseI64ArrayAttr.get(reduction_tile) + + nkipy_d.AnnotateOp( + target=value, + mem_space=ms_attr, + partition_dim=pd_attr, + tile_size=ts_attr, + reduction_tile=rt_attr, + loc=loc, + ) + return x + + +# --------------------------------------------------------------------------- +# Control flow: fori_loop +# --------------------------------------------------------------------------- + + +def fori_loop( + lower: int, + upper: int, + body_fn: Callable, + init_handles: list[TensorHandle], +) -> list[TensorHandle]: + """Build an ``scf.for`` loop with loop-carried tensor accumulators. + + Args: + lower: inclusive lower bound + upper: exclusive upper bound + body_fn: ``(LoopIndexHandle, list[TensorHandle]) -> list[TensorHandle]`` + init_handles: initial accumulator values + + Returns: + Final accumulator values after all iterations. + """ + loc = _loc() + i32 = ir.IntegerType.get_signless(32) + lb = arith.ConstantOp(i32, lower, loc=loc) + ub = arith.ConstantOp(i32, upper, loc=loc) + step = arith.ConstantOp(i32, 1, loc=loc) + + loop_op = scf.ForOp( + lb.result, + ub.result, + step.result, + [h._value for h in init_handles], + loc=loc, + ) + + loop_block = loop_op.body + loop_idx_value = loop_block.arguments[0] + loop_acc_values = loop_block.arguments[1:] + + loop_idx = LoopIndexHandle(loop_idx_value) + + acc_handles = [ + TensorHandle(av, ih.shape, ih.dtype, ih._elem_ty) + for av, ih in zip(loop_acc_values, init_handles) + ] + + with ir.InsertionPoint(loop_block): + results = body_fn(loop_idx, acc_handles) + result_values = [r._value for r in results] + + # Rewire linalg ops to use loop accumulators as their output operand. + # Without this, bufferization can't see the loop-carried dependence + # and may allocate a fresh buffer instead of updating in place. + for rv, ia in zip(result_values, loop_acc_values): + producer = rv.owner + if not producer.name.startswith("linalg."): + continue + if len(list(producer.results)) != 1: + continue + operands = list(producer.operands) + if not operands or operands[-1] == ia: + continue + if len(producer.regions) == 1 and len(producer.regions[0].blocks) == 1: + bb = producer.regions[0].blocks[0] + ba = list(bb.arguments) + if ba: + oe = ba[-1] + used = any( + op_arg == oe for op in bb.operations for op_arg in op.operands + ) + if used: + continue + producer.operands[-1] = ia + + scf.YieldOp(result_values, loc=loc) + + return [ + TensorHandle( + loop_op.results[i], + init_handles[i].shape, + init_handles[i].dtype, + init_handles[i]._elem_ty, + ) + for i in range(len(init_handles)) + ] + + +def lift_scalar_to_tensor(val, dtype_hint: str = "float") -> TensorHandle: + """Lift a Python scalar to a 0-d tensor handle.""" + loc = _loc() + if dtype_hint == "float" or isinstance(val, float): + elem = ir.F32Type.get() + else: + elem = ir.IntegerType.get_signless(32) + v = make_filled(loc, (), elem, val) + return _make_handle(v, (), elem) + + +# --------------------------------------------------------------------------- +# LoopIndex arithmetic +# --------------------------------------------------------------------------- + + +def loop_index_mul(idx: LoopIndexHandle, factor: int) -> LoopIndexHandle: + loc = _loc() + i32 = ir.IntegerType.get_signless(32) + cst = arith.ConstantOp(i32, factor, loc=loc).result + rv = arith.MulIOp(idx._value, cst, loc=loc).result + return LoopIndexHandle(rv, idx.mul_factor * factor, idx.add_offset * factor) + + +def loop_index_add(idx: LoopIndexHandle, offset: int) -> LoopIndexHandle: + loc = _loc() + i32 = ir.IntegerType.get_signless(32) + cst = arith.ConstantOp(i32, offset, loc=loc).result + rv = arith.AddIOp(idx._value, cst, loc=loc).result + return LoopIndexHandle(rv, idx.mul_factor, idx.add_offset + offset) + + +def loop_index_add_loop_index( + a: LoopIndexHandle, b: LoopIndexHandle +) -> LoopIndexHandle: + loc = _loc() + rv = arith.AddIOp(a._value, b._value, loc=loc).result + return LoopIndexHandle(rv, a.mul_factor + b.mul_factor, a.add_offset + b.add_offset) + + +# --------------------------------------------------------------------------- +# Custom ops +# --------------------------------------------------------------------------- + + +def apply_custom_op(kernel_builder, reference_fn, input_specs, output_specs, args): + """Compile a kernel_builder function and call it during tracing. + + Handles the nki.compiler.kernel_builder spec translation that was + previously in nkipy's KernelGenTraceContext. + + Args: + kernel_builder: NKI kernel_builder function. + reference_fn: NumPy reference (for fallback). + input_specs: List of (shape, dtype_str) tuples. + output_specs: List of (shape, dtype_str) tuples. + args: Traced tensor arguments to pass to the custom op. + + Returns: + Result from the custom op call. + """ + import nki.compiler.kernel_builder as nb + from nkipy_kernelgen.custom_op import CustomOp + + _dtype_map = {"f32": nb.float32, "f16": nb.float16, "bf16": nb.bfloat16} + nb_input_specs = { + f"input_{i}": nb.Tensor(shape, _dtype_map[dtype], nb.shared_hbm) + for i, (shape, dtype) in enumerate(input_specs) + } + nb_output_specs = { + f"output_{i}": nb.Tensor(shape, _dtype_map[dtype], nb.shared_hbm) + for i, (shape, dtype) in enumerate(output_specs) + } + internal = CustomOp.from_kernel_builder( + kernel_func=kernel_builder, + input_specs=nb_input_specs, + output_specs=nb_output_specs, + reference_fn=reference_fn, + ) + return internal(*args) diff --git a/kernelgen/nkipy_kernelgen/compile.py b/kernelgen/nkipy_kernelgen/compile.py new file mode 100644 index 0000000..8784bae --- /dev/null +++ b/kernelgen/nkipy_kernelgen/compile.py @@ -0,0 +1,90 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Compile MLIR to NEFF via the nkipy-opt pass pipeline and nki.compiler. + +This module encapsulates all NKI compiler internals so that callers only +need to provide MLIR text, tensor metadata, and a target string. +""" + +from __future__ import annotations + +from typing import List, Tuple + +import numpy as np + + +def compile_to_neff( + mlir_text: str, + func_name: str, + input_specs: List[Tuple[str, Tuple[int, ...], np.dtype]], + output_specs: List[Tuple[str, Tuple[int, ...], np.dtype]], + *, + target: str, + output_path: str, + artifacts_dir: str | None = None, + neuronx_cc_args: Tuple[str, ...] = (), +) -> None: + """Compile MLIR text to a NEFF file. + + Two-stage pipeline: + 1. Run the nkipy-opt MLIR pass pipeline to produce NISA IR. + 2. Compile NISA IR to NEFF via ``nki.compiler``. + + Args: + mlir_text: MLIR module text (linalg-on-tensor level). + func_name: Entry-point function name in the MLIR module. + input_specs: ``[(name, shape, dtype), ...]`` for each input tensor. + output_specs: ``[(name, shape, dtype), ...]`` for each output tensor. + target: Hardware target string (e.g. ``"trn2"``). + output_path: Where to write the NEFF file. + artifacts_dir: Optional directory for intermediate compilation artifacts. + neuronx_cc_args: Extra arguments forwarded to ``neuronx-cc``. + """ + from nkipy_kernelgen.transforms.nkipy_opt import apply_complete_knob_pipeline + + dump_dir = f"{artifacts_dir}/mlir_passes" if artifacts_dir else None + nisa_mlir = apply_complete_knob_pipeline( + mlir_text, + target=target, + dump_dir=dump_dir, + ) + + from nki.compiler._internal import ir as nki_ir + from nki.compiler._internal import register_all_dialects + from nki.compiler.ncc_driver import CompileOptions, compile_mlir_to_neff + + nki_ctx = nki_ir.Context() + register_all_dialects(nki_ctx) + + with nki_ctx: + mlir_module = nki_ir.Module.parse(nisa_mlir, nki_ctx) + + all_arrays = [ + np.zeros(shape, dtype=dtype) for _, shape, dtype in input_specs + ] + [ + np.zeros(shape, dtype=dtype) for _, shape, dtype in output_specs + ] + + argument_names = [ + name for name, _, _ in input_specs + ] + [ + name for name, _, _ in output_specs + ] + output_arg_names = [name for name, _, _ in output_specs] + + opts = CompileOptions( + target=target, + verbose=False, + output_path=output_path, + neuronx_cc_args=neuronx_cc_args, + artifacts_dir=artifacts_dir, + ) + + compile_mlir_to_neff( + mlir_module, + func_name, + all_arrays, + argument_names, + output_arg_names, + opts, + ) diff --git a/kernelgen/nkipy_kernelgen/control_flow.py b/kernelgen/nkipy_kernelgen/control_flow.py new file mode 100644 index 0000000..647822c --- /dev/null +++ b/kernelgen/nkipy_kernelgen/control_flow.py @@ -0,0 +1,163 @@ +""" +Control flow operations for traced execution. +""" + +from typing import Callable, Union + +from . import builder +from .builder import LoopIndexHandle +from .traced_array import TracedArray + + +class LoopIndex: + """Wrapper for loop index that supports arithmetic operations and tracks constants.""" + + def __init__(self, value_or_handle, mul_factor: int = 1, add_offset: int = 0): + if isinstance(value_or_handle, LoopIndexHandle): + self._handle = value_or_handle + else: + self._handle = LoopIndexHandle(value_or_handle, mul_factor, add_offset) + + @property + def value(self): + return self._handle._value + + @property + def mul_factor(self): + return self._handle.mul_factor + + @property + def add_offset(self): + return self._handle.add_offset + + def __mul__(self, other): + if isinstance(other, int): + return LoopIndex(builder.loop_index_mul(self._handle, other)) + raise TypeError( + f"LoopIndex * {type(other).__name__} is not supported; only integer constants are allowed" + ) + + def __add__(self, other): + if isinstance(other, int): + return LoopIndex(builder.loop_index_add(self._handle, other)) + elif isinstance(other, LoopIndex): + return LoopIndex(builder.loop_index_add_loop_index(self._handle, other._handle)) + raise TypeError( + f"LoopIndex + {type(other).__name__} is not supported; use int or LoopIndex" + ) + + def __sub__(self, other): + if isinstance(other, int): + return LoopIndex(builder.loop_index_add(self._handle, -other)) + raise TypeError( + f"LoopIndex - {type(other).__name__} is not supported; only integer constants are allowed" + ) + + def __rsub__(self, other): + if isinstance(other, int): + neg = LoopIndex(builder.loop_index_mul(self._handle, -1)) + return LoopIndex(builder.loop_index_add(neg._handle, other)) + raise TypeError( + f"{type(other).__name__} - LoopIndex is not supported; only integer constants are allowed" + ) + + def __rmul__(self, other): + return self.__mul__(other) + + def __radd__(self, other): + return self.__add__(other) + + +def _to_handle(val): + """Convert a TracedArray, scalar, or value to a TensorHandle for builder calls.""" + from .op_vtable import _to_handle as vtable_to_handle + + if isinstance(val, TracedArray): + return vtable_to_handle(val) + if isinstance(val, (int, float)): + return builder.lift_scalar_to_tensor(val) + raise TypeError(f"init_val must be TracedArray, int, or float, got {type(val)}") + + +def _from_handle(h, source_file): + """Convert a TensorHandle back to a TracedArray.""" + from .op_vtable import _from_handle as vtable_from_handle + + return vtable_from_handle(h, source_file) + + +def fori_loop( + lower_bound: Union[int, TracedArray], + upper_bound: Union[int, TracedArray], + body_fn: Callable, + init_val: Union[TracedArray, float, int, tuple], +) -> Union[TracedArray, tuple]: + is_tuple = isinstance(init_val, tuple) + if is_tuple: + is_tracing = any(isinstance(v, TracedArray) for v in init_val) + else: + is_tracing = isinstance(init_val, TracedArray) + + if not is_tracing: + acc = init_val + for i in range(lower_bound, upper_bound): + acc = body_fn(i, acc) + return acc + + if is_tuple: + init_handles = [_to_handle(v) for v in init_val] + source_files = [ + v.source_file if isinstance(v, TracedArray) else "unknown" + for v in init_val + ] + else: + init_handles = [_to_handle(init_val)] + source_files = [ + init_val.source_file if isinstance(init_val, TracedArray) else "unknown" + ] + + if not isinstance(lower_bound, int): + raise TypeError("Dynamic lower bounds not yet supported") + if not isinstance(upper_bound, int): + raise TypeError("Dynamic upper bounds not yet supported") + + from .op_vtable import _to_handle as vtable_to_handle + + def wrapped_body(loop_idx_handle, acc_handles): + loop_idx = LoopIndex(loop_idx_handle) + + if is_tuple: + accs = tuple( + _from_handle(ah, sf) for ah, sf in zip(acc_handles, source_files) + ) + result = body_fn(loop_idx, accs) + if not isinstance(result, tuple): + raise TypeError( + f"body_fn must return a tuple of {len(init_handles)} elements, " + f"got {type(result).__name__}" + ) + if len(result) != len(init_handles): + raise TypeError( + f"body_fn must return a tuple of {len(init_handles)} elements, " + f"got tuple of {len(result)}" + ) + return [vtable_to_handle(r) for r in result] + else: + acc = _from_handle(acc_handles[0], source_files[0]) + result = body_fn(loop_idx, acc) + if not isinstance(result, TracedArray): + raise TypeError( + f"body_fn must return a TracedArray, got {type(result).__name__}" + ) + return [vtable_to_handle(result)] + + result_handles = builder.fori_loop( + lower_bound, upper_bound, wrapped_body, init_handles + ) + + if is_tuple: + return tuple( + _from_handle(rh, sf) for rh, sf in zip(result_handles, source_files) + ) + else: + return _from_handle(result_handles[0], source_files[0]) diff --git a/kernelgen/nkipy_kernelgen/custom_op.py b/kernelgen/nkipy_kernelgen/custom_op.py new file mode 100644 index 0000000..6fa0e33 --- /dev/null +++ b/kernelgen/nkipy_kernelgen/custom_op.py @@ -0,0 +1,228 @@ +""" +CustomOp: Wrap kernel_builder NISA functions for use in @trace-decorated kernels. + +A CustomOp represents a pre-compiled NISA function (from kernel_builder) that can +be called during tracing. It emits a func.call in the traced IR, and the NISA body +is stashed as a string attribute for late resolution by the resolve-custom-ops pass. +""" + +from typing import Optional, Callable, List, Tuple +from mlir import ir +from mlir.dialects import func + +from .traced_array import TracedArray +from .mlir_utils import to_mlir_type, ranked_tensor_of + +# Module-level registry for custom ops used during tracing. +# No thread safety needed -- tracing is always single-threaded. +_custom_op_registry: list = [] + + +def _get_registry() -> list: + return _custom_op_registry + + +def _clear_registry(): + _custom_op_registry.clear() + + +def _nb_dtype_to_str(nb_dtype) -> str: + """Convert kernel_builder dtype to string (e.g., nb.float32 -> 'f32').""" + import nki.compiler.kernel_builder as nb + mapping = { + nb.float32: "f32", + nb.float16: "f16", + nb.bfloat16: "bf16", + } + if nb_dtype in mapping: + return mapping[nb_dtype] + raise ValueError(f"Unsupported kernel_builder dtype: {nb_dtype}") + + +class CustomOp: + """A kernel_builder function compiled to NISA, usable during KernelGen tracing. + + Wraps the output of nb.build_kernel() and provides a callable interface + that emits func.call during tracing. Falls back to reference_fn for + NumPy execution (testing). + """ + + def __init__( + self, + nisa_mlir: str, + func_name: str, + input_names: List[str], + output_names: List[str], + input_shapes: List[Tuple[int, ...]], + output_shapes: List[Tuple[int, ...]], + input_dtypes: List[str], + output_dtypes: List[str], + reference_fn: Optional[Callable] = None, + ): + self.nisa_mlir = nisa_mlir + self.func_name = f"__custom_op__{func_name}" + self.input_names = input_names + self.output_names = output_names + self.input_shapes = input_shapes + self.output_shapes = output_shapes + self.input_dtypes = input_dtypes + self.output_dtypes = output_dtypes + self.reference_fn = reference_fn + + @classmethod + def from_kernel_builder( + cls, + kernel_func: Callable, + input_specs: dict, + output_specs: dict, + reference_fn: Optional[Callable] = None, + **hyperparams, + ) -> "CustomOp": + """Compile a kernel_builder function to a CustomOp. + + Args: + kernel_func: Function using nki.compiler.kernel_builder APIs. + input_specs: Dict of name -> nb.Tensor specs for inputs. + output_specs: Dict of name -> nb.Tensor specs for outputs. + reference_fn: NumPy reference implementation for testing. + **hyperparams: Compile-time constants passed to kernel_func. + + Returns: + CustomOp ready to use inside @trace-decorated functions. + """ + import nki.compiler.kernel_builder as nb + + module = nb.build_kernel( + kernel_func, + input_specs=input_specs, + output_specs=output_specs, + **hyperparams, + ) + # Use generic MLIR form -- NISA custom assembly isn't round-trippable + nisa_mlir = module.operation.get_asm(print_generic_op_form=True) + + input_names = list(input_specs.keys()) + output_names = list(output_specs.keys()) + input_shapes = [spec.shape for spec in input_specs.values()] + output_shapes = [spec.shape for spec in output_specs.values()] + input_dtypes = [_nb_dtype_to_str(spec.dtype) for spec in input_specs.values()] + output_dtypes = [_nb_dtype_to_str(spec.dtype) for spec in output_specs.values()] + + # Include shape signature and hyperparams in func_name to deduplicate + # same function compiled with different shapes or hyperparams + shape_sig = "_".join( + "x".join(str(d) for d in s) + for s in list(input_shapes) + list(output_shapes) + ) + if hyperparams: + import hashlib + hp_hash = hashlib.md5( + repr(sorted(hyperparams.items())).encode() + ).hexdigest()[:8] + func_name = f"{kernel_func.__name__}_{shape_sig}_{hp_hash}" + else: + func_name = f"{kernel_func.__name__}_{shape_sig}" + + return cls( + nisa_mlir=nisa_mlir, + func_name=func_name, + input_names=input_names, + output_names=output_names, + input_shapes=input_shapes, + output_shapes=output_shapes, + input_dtypes=input_dtypes, + output_dtypes=output_dtypes, + reference_fn=reference_fn, + ) + + def __call__(self, *args): + """Call during tracing or NumPy execution. + + During tracing: emits func.call and returns TracedArray(s). + Outside tracing: calls reference_fn with numpy inputs. + """ + # Not tracing -- use NumPy reference + if not any(isinstance(a, TracedArray) for a in args): + if self.reference_fn is None: + raise RuntimeError( + f"CustomOp '{self.func_name}' called outside tracing " + f"without a reference_fn." + ) + return self.reference_fn(*args) + + # --- Tracing mode --- + if len(args) != len(self.input_shapes): + raise ValueError( + f"CustomOp '{self.func_name}' expects {len(self.input_shapes)} " + f"inputs, got {len(args)}." + ) + + for i, (arg, expected_shape) in enumerate(zip(args, self.input_shapes)): + if not isinstance(arg, TracedArray): + raise TypeError( + f"Argument {i} ('{self.input_names[i]}') must be TracedArray, " + f"got {type(arg)}" + ) + if tuple(arg.shape) != tuple(expected_shape): + raise ValueError( + f"Shape mismatch for '{self.input_names[i]}': " + f"expected {expected_shape}, got {arg.shape}" + ) + + # Register for module-level emission (deduplicate by func_name) + registry = _get_registry() + if not any(op.func_name == self.func_name for op in registry): + registry.append(self) + + loc = args[0]._get_caller_location() + input_values = [a.value for a in args] + + # Build result types: return-value style (tensor types) + result_types = [ + ranked_tensor_of(shape, to_mlir_type(dtype)) + for shape, dtype in zip(self.output_shapes, self.output_dtypes) + ] + + # Emit: %result = func.call @name(%in0, %in1) -> tensor<...> + call_op = func.CallOp(result_types, self.func_name, input_values, loc=loc) + + # Return output TracedArray(s) + if len(self.output_shapes) == 1: + return TracedArray( + call_op.results[0], + self.output_shapes[0], + to_mlir_type(self.output_dtypes[0]), + source_file=args[0].source_file, + ) + return tuple( + TracedArray( + call_op.results[i], shape, + to_mlir_type(dtype), + source_file=args[0].source_file, + ) + for i, (shape, dtype) in enumerate( + zip(self.output_shapes, self.output_dtypes) + ) + ) + + +def emit_custom_op_declaration(custom: CustomOp): + """Emit func.func private @name(tensor<...>) -> tensor<...> + attributes {nkipy.custom_op} + + Return-value style: inputs only as arguments, outputs as return values. + resolve-custom-ops will later convert to output-as-argument. + """ + input_types = [ + ranked_tensor_of(s, to_mlir_type(d)) + for s, d in zip(custom.input_shapes, custom.input_dtypes) + ] + result_types = [ + ranked_tensor_of(s, to_mlir_type(d)) + for s, d in zip(custom.output_shapes, custom.output_dtypes) + ] + fn_type = ir.FunctionType.get(input_types, result_types) + fn = func.FuncOp(name=custom.func_name, type=fn_type) + fn.attributes["sym_visibility"] = ir.StringAttr.get("private") + fn.attributes["nkipy.custom_op"] = ir.UnitAttr.get() + return fn diff --git a/kernelgen/nkipy_kernelgen/execution.py b/kernelgen/nkipy_kernelgen/execution.py new file mode 100644 index 0000000..e8a2976 --- /dev/null +++ b/kernelgen/nkipy_kernelgen/execution.py @@ -0,0 +1,27 @@ +""" +Execution engine for running MLIR on CPU and verifying against NumPy. +""" + +import numpy as np +from typing import Callable, List, Tuple, Any +from nkipy_kernelgen.llvm import LLVMModule + + +def verify_against_numpy( + traced_func: Callable, numpy_func: Callable, input_arrays: List[np.ndarray], + rtol: float = 1e-6, + atol: float = 1e-6, +) -> Tuple[bool, Any, Any]: + + # Get NumPy result + numpy_result = numpy_func(*input_arrays) + + # Get MLIR module + module = traced_func.to_mlir() + + runner = LLVMModule(module, traced_func.__name__) + mlir_result = runner(*input_arrays) + + matches = np.allclose(mlir_result, numpy_result, rtol=rtol, atol=atol) + # Return results without raising an error - let the test decide what to do + return matches, mlir_result, numpy_result \ No newline at end of file diff --git a/kernelgen/nkipy_kernelgen/knob.py b/kernelgen/nkipy_kernelgen/knob.py new file mode 100644 index 0000000..c05aa27 --- /dev/null +++ b/kernelgen/nkipy_kernelgen/knob.py @@ -0,0 +1,222 @@ +""" +Knob API for annotating tensors with transformation hints. + +This module provides a knob() function similar to OpenMP pragmas in C++, +allowing users to annotate tensors with transformation directives that +get recorded as MLIR attributes on operations or injected as custom ops. +""" + +from typing import Union, Any, Optional, List +from mlir import ir +from .traced_array import TracedArray +from nkipy_kernelgen._mlir.dialects import nkipy as nkipy_d + + +def knob( + tensor: Union[TracedArray, Any], + partition_dim: Optional[int] = None, + mem_space: Optional[str] = None, + tile_size: Optional[List[int]] = None, + reduction_tile: Optional[List[int]] = None, +) -> Union[TracedArray, Any]: + """ + Annotate a tensor with transformation hints. + + This function acts like OpenMP pragmas in C++, marking tensors with + transformation directives that get recorded as MLIR nkipy.annotate operations. + + Args: + tensor: The tensor to annotate (TracedArray or regular array) + partition_dim: Dimension to partition (int, must be 0 for NISA compatibility) + mem_space: Memory space placement (must be "Hbm", "Psum", "Sbuf", or "SharedHbm", optional) + tile_size: Tile sizes for each dimension (list of ints, optional). + Must have exactly the same number of elements as the tensor rank. + E.g., for a 3D tensor [16, 128, 512], use tile_size=[1, 128, 128]. + reduction_tile: Tile sizes for reduction dimensions (list of ints, optional). + Used for contraction ops like matmul where the iteration space has more + dimensions than the output tensor. For matmul C=A@B, this is the K tile. + E.g., tile_size=[128, 128] for output dims + reduction_tile=[128] for K. + + Returns: + The same tensor (pass-through), but with annotate op injected if parameters specified + + Raises: + ValueError: If partition_dim is >= tensor rank, mem_space is invalid, + tile_size doesn't match tensor rank, or reduction_tile has negative values. + + Examples: + # Specify only memory space + tensor = knob(tensor, mem_space="Hbm") + + # Specify only partition dimension + tensor = knob(tensor, partition_dim=1) + + # Specify memory space and partition dimension + tensor = knob(tensor, partition_dim=1, mem_space="Sbuf") + + # Specify tile size for a 2D tensor [256, 256] + tensor = knob(tensor, tile_size=[128, 128]) + + # Specify tile size for a 3D tensor [16, 128, 512] + tensor = knob(tensor, tile_size=[1, 128, 128]) + + # Matmul with separate reduction tile: C[M,N] = A[M,K] @ B[K,N] + output = knob(output, mem_space="Psum", tile_size=[128, 128], reduction_tile=[128]) + """ + # If not a TracedArray, just return as-is (for regular NumPy execution) + if not isinstance(tensor, TracedArray): + return tensor + + # If no parameters are specified, just return (no-op) + if ( + mem_space is None + and partition_dim is None + and tile_size is None + and reduction_tile is None + ): + return tensor + + # Get the MLIR value from the traced array + value = tensor.value + + # Get the operation that defines this value + defining_op = value.owner + + if defining_op is None: + # Value is a block argument, cannot annotate + return tensor + + # Normalize scalar tile_size/reduction_tile to lists + if isinstance(tile_size, int): + tile_size = [tile_size] + if isinstance(reduction_tile, int): + reduction_tile = [reduction_tile] + + # Validate mem_space + if mem_space is not None: + valid_mem_spaces = {"Hbm", "Psum", "Sbuf", "SharedHbm"} + if mem_space not in valid_mem_spaces: + raise ValueError( + f"Invalid mem_space '{mem_space}'. Must be one of: {valid_mem_spaces}" + ) + + # Validate partition_dim against tensor rank + if partition_dim is not None: + # Get the tensor type from the MLIR value + tensor_type = value.type + if hasattr(tensor_type, "shape"): + rank = len(tensor_type.shape) + if partition_dim >= rank: + raise ValueError( + f"partition_dim {partition_dim} must be less than tensor rank {rank}" + ) + # Also validate it's non-negative + if partition_dim < 0: + raise ValueError(f"partition_dim must be non-negative, got {partition_dim}") + + # Validate tile_size against tensor rank + if tile_size is not None: + tensor_type = value.type + if hasattr(tensor_type, "shape"): + rank = len(tensor_type.shape) + # tile_size must either match the full tensor rank, or when + # reduction_tile is also provided, tile_size + reduction_tile + # together must cover all dimensions (e.g. for reductions with + # keepdims=True where tile_size covers non-reduction dims and + # reduction_tile covers reduction dims). + n_reduction = len(reduction_tile) if reduction_tile is not None else 0 + if len(tile_size) != rank and len(tile_size) + n_reduction != rank: + raise ValueError( + f"tile_size has {len(tile_size)} elements but tensor has rank {rank}; " + f"tile_size must have exactly one element per dimension" + ) + # Validate non-negative + if any(t <= 0 for t in tile_size): + raise ValueError(f"tile_size values must be positive, got {tile_size}") + + # Validate reduction_tile + if reduction_tile is not None: + if any(t <= 0 for t in reduction_tile): + raise ValueError( + f"reduction_tile values must be positive, got {reduction_tile}" + ) + + # Inject nkipy.annotate op when at least one parameter is specified + _inject_annotate_op(value, mem_space, partition_dim, tile_size, reduction_tile) + + # Return the tensor unchanged (knob is just an annotation) + return tensor + + +def _inject_annotate_op( + value: ir.Value, + mem_space: Optional[str], + partition_dim: Optional[int], + tile_size: Optional[List[int]], + reduction_tile: Optional[List[int]] = None, +) -> None: + """ + Inject a nkipy.annotate op into the IR after the tensor's value. + + Args: + value: The MLIR value to annotate + mem_space: Memory space (Hbm, Psum, Sbuf, or SharedHbm), optional + partition_dim: Dimension to partition, optional + tile_size: Tile sizes for each dimension, optional + reduction_tile: Tile sizes for reduction dimensions, optional + """ + # Get the defining operation for location info + defining_op = value.owner + if defining_op is None: + return + + loc = defining_op.location + + # Build attributes dict with only specified parameters + if mem_space is not None: + # Map mem_space string to enum value. + # Values MUST match mlir/include/nkipy/Dialect/NkipyAttrs.td. + # Zero is intentionally reserved: MemRefType::get drops an + # IntegerAttr(0) memorySpace, so zero-valued enum cases cannot be + # attached to a memref. See the comment in NkipyAttrs.td. + mem_space_map = { + "Hbm": 1, + "Psum": 2, + "Sbuf": 3, + "SharedHbm": 4, + } + mem_space_value = mem_space_map[mem_space] + mem_space_attr = ir.IntegerAttr.get( + ir.IntegerType.get_signless(32), mem_space_value + ) + mem_space = mem_space_attr + + if partition_dim is not None: + partition_dim_attr = ir.IntegerAttr.get( + ir.IntegerType.get_unsigned(32), partition_dim + ) + partition_dim = partition_dim_attr + + if tile_size is not None: + # Convert to DenseI64ArrayAttr as defined in NkipyOps.td + tile_size_attr = ir.DenseI64ArrayAttr.get(tile_size) + tile_size = tile_size_attr + + if reduction_tile is not None: + # Convert to DenseI64ArrayAttr as defined in NkipyOps.td + reduction_tile_attr = ir.DenseI64ArrayAttr.get(reduction_tile) + reduction_tile = reduction_tile_attr + + # Simply create the operation at the current insertion point + # Since we're being called during tracing, the insertion point is already + # set correctly to insert after the current operation + # Note: The context has allow_unregistered_dialects=True, so we can create + # nkipy.annotate ops directly without needing to register the dialect + nkipy_d.AnnotateOp( + target=value, + mem_space=mem_space, + partition_dim=partition_dim, + tile_size=tile_size, + reduction_tile=reduction_tile, + loc=loc, + ) diff --git a/kernelgen/nkipy_kernelgen/llvm.py b/kernelgen/nkipy_kernelgen/llvm.py new file mode 100644 index 0000000..ce94cf8 --- /dev/null +++ b/kernelgen/nkipy_kernelgen/llvm.py @@ -0,0 +1,642 @@ +import os +import ctypes +import ml_dtypes +import numpy as np + +from mlir.ir import ( + Context, + Location, + Module, + UnitAttr, + InsertionPoint, + FloatAttr, + IntegerAttr, + F32Type, + IntegerType, + MemRefType, + FunctionType, + TypeAttr, +) +from mlir.dialects import tensor, arith, linalg + +from mlir.passmanager import PassManager +from mlir.execution_engine import ExecutionEngine +from mlir.runtime import ( + get_ranked_memref_descriptor, + make_nd_memref_descriptor, +) + +from nkipy_kernelgen.utils import ( + get_func_inputs_outputs, + find_func_in_module, + get_bitwidth_from_type, + ctype_map, + np_supported_types, + np_type_to_str, + get_np_struct_type, + create_output_struct, + extract_out_np_arrays_from_out_struct, + ranked_memref_to_numpy, +) + +# Import nkipy dialect for registration +from nkipy_kernelgen._mlir.dialects import nkipy as nkipy_d + + +def extract_and_clean_func_from_module(mlir_module_str: str): + """ + Extract func.func operation from MLIR module and remove nkipy annotations. + + This utility is particularly useful for MLIR modules that have been through + knob-driven-tiling or other passes that add transform dialect IR and nkipy + annotations that need to be stripped before LLVM execution. + + Steps performed: + 1. Strip #nisa.mem<...> memory spaces from memref type syntax (text pre-processing) + 2. Parse the MLIR module + 3. Find the func.func operation (ignoring transform.named_sequence) + 4. Remove nkipy.annotate operations + 5. Remove nkipy.op_id and memory_space operation attributes + 6. Strip any remaining memory spaces from memref types programmatically + 7. Create a clean MLIR module with just the function + + Args: + mlir_module_str: MLIR module as a string, potentially containing + transform.named_sequence and nkipy annotations + + Returns: + tuple: (clean_mlir_str, function_name) + - clean_mlir_str: MLIR module string with only the cleaned func.func + - function_name: The name of the extracted function + + Raises: + ValueError: If no func.func operation is found in the module + + Example: + clean_mlir, func_name = extract_and_clean_func_from_module(tiled_mlir) + runner = LLVMModule(clean_mlir, func_name) + """ + import re + # Strip memory-space attributes from memref type syntax before parsing. + # We accept two forms and drop both: + # 1. #nisa.mem<...> — post-annotate-memory-space NISA dialect form + # 2. N : i32 — raw IntegerAttr form that + # nkipy's MemSpaceEnumAttr prints as + # + # The MLIR parser can't handle the NISA-dialect attribute without the + # dialect registered. More importantly, a non-zero integer memspace + # becomes `!llvm.ptr` during memref-to-llvm lowering, and the + # runtime helpers (`@free` etc.) only accept `!llvm.ptr` in address + # space 0 — leaving the memspace on would trip the LLVM verifier + # with `operand type mismatch … '!llvm.ptr' != '!llvm.ptr'` when + # lowering `memref.dealloc` on an HBM/SBUF buffer. + mlir_module_str = re.sub(r',\s*#nisa\.mem<[^>]+>', '', mlir_module_str) + mlir_module_str = re.sub(r',\s*\d+\s*:\s*i32>', '>', mlir_module_str) + + with Context() as ctx: + # Register nkipy dialect to handle nkipy operations + nkipy_d.register_dialect(ctx) + + # Allow unregistered dialects temporarily to parse the module + ctx.allow_unregistered_dialects = True + + # Parse the MLIR module and clean it in-place (preserving all + # module-level declarations like memref.global that the function + # may reference). + new_module = Module.parse(mlir_module_str, ctx) + + # Inline reference_impl regions from nkipy ops (e.g. nkipy.gather) + # so the LLVM JIT only sees standard linalg/tensor ops. + # The inline pass also folds tensor.extract(to_tensor(memref)) → + # memref.load(memref) patterns left after inlining into post- + # bufferization IR. Canonicalize then folds remaining + # to_buffer(to_tensor(x)) chains. + from nkipy_kernelgen.transforms.nkipy_opt import run_nkipy_opt_passes + inlined_str = run_nkipy_opt_passes( + str(new_module), + ["inline-nkipy-reference", "canonicalize"], + ) + new_module = Module.parse(inlined_str, ctx) + + # Strip the transform.with_named_sequence module attribute if present + module_op = new_module.operation + if "transform.with_named_sequence" in module_op.attributes: + del module_op.attributes["transform.with_named_sequence"] + + # Walk module body: find the func, and collect ops to erase + # (transform sequences, nkipy.annotate, etc.) + actual_func_name = None + ops_to_erase = [] + + def walk_and_mark(op): + if op.name == "nkipy.annotate": + ops_to_erase.append(op) + if "nkipy.op_id" in op.attributes: + del op.attributes["nkipy.op_id"] + if "memory_space" in op.attributes: + del op.attributes["memory_space"] + for region in op.regions: + for block in region: + for nested_op in block: + walk_and_mark(nested_op) + + for op in new_module.body.operations: + op_name = op.operation.name + if op_name == "func.func" and actual_func_name is None: + actual_func_name = str(op.attributes["sym_name"]).strip('"') + walk_and_mark(op.operation) + elif op_name == "transform.named_sequence": + ops_to_erase.append(op.operation) + + if actual_func_name is None: + raise ValueError("Could not find func.func operation in MLIR module") + + # Erase marked operations + for op in ops_to_erase: + op.erase() + + # For CPU simulation: Zero-fill all tensor.empty operations + _zero_fill_empty_tensors_ir(new_module) + + # For CPU simulation: Zero-fill all memref.alloc operations + # (needed for post-bufferization IR, e.g. --stop=5+) + _zero_fill_alloc_memrefs_ir(new_module) + + # Get clean MLIR string + clean_mlir = str(new_module) + + return clean_mlir, actual_func_name + + +def _zero_fill_empty_tensors_ir(module: Module): + """ + Walk the IR and replace tensor.empty operations with zero-filled tensors. + + For each tensor.empty: + 1. Create a zero constant of the appropriate element type + 2. Create a linalg.fill operation to fill the tensor with zero + 3. Replace all uses of tensor.empty with the filled tensor + + This ensures CPU simulation matches target ASIC behavior (empty tensors are zero). + """ + # Collect all tensor.empty operations to process + empty_ops = [] + + def collect_empty_ops(op): + if op.name == "tensor.empty": + empty_ops.append(op) + for region in op.regions: + for block in region: + for nested_op in block: + collect_empty_ops(nested_op) + + # Walk the module to collect all tensor.empty ops + for op in module.body.operations: + collect_empty_ops(op.operation) + + # Process each tensor.empty operation + for empty_op in empty_ops: + # Get the result type (should be a tensor type) + result = empty_op.results[0] + tensor_type = result.type + + # Extract element type from the tensor type + try: + elem_type = tensor_type.element_type + except: + # If we can't get element type, skip this operation + continue + + # Get location from the empty op + loc = empty_op.location + + # Create zero constant based on element type + with loc, InsertionPoint.at_block_begin(empty_op.operation.block): + # Determine zero value based on type + if str(elem_type).startswith('f'): + # Float type - create 0.0 + zero_const = arith.ConstantOp(elem_type, FloatAttr.get(elem_type, 0.0)) + elif str(elem_type).startswith('i'): + # Integer type - create 0 + zero_const = arith.ConstantOp(elem_type, IntegerAttr.get(elem_type, 0)) + else: + # Unknown type, skip + continue + + # Move the insertion point right after tensor.empty + with loc, InsertionPoint(empty_op): + new_empty = tensor.EmptyOp(list(tensor_type.shape), tensor_type.element_type, loc=loc) + + fill_op = linalg.FillOp([tensor_type], [zero_const.result], [new_empty.result], loc=loc) + + region = fill_op.regions[0] + if len(region.blocks) == 0: + block = region.blocks.append(elem_type, elem_type) + with InsertionPoint(block): + linalg.YieldOp([block.arguments[0]], loc=loc) + + # Replace all uses of the original tensor.empty with linalg.fill result + result.replace_all_uses_with(fill_op.results[0]) + + +def _zero_fill_alloc_memrefs_ir(module: Module): + """ + Walk the IR and zero-fill all memref.alloc operations. + + After bufferization (e.g., --stop=5+), tensor.empty becomes memref.alloc. + Unlike tensor.empty (semantically undefined), memref.alloc produces truly + uninitialized memory on CPU which may contain garbage/NaN. This function + inserts linalg.fill operations right after each memref.alloc to + zero-initialize the buffer for correct CPU simulation. + """ + # Collect all memref.alloc operations + alloc_ops = [] + + def collect_alloc_ops(op): + if op.name == "memref.alloc": + alloc_ops.append(op) + for region in op.regions: + for block in region: + for nested_op in block: + collect_alloc_ops(nested_op) + + for op in module.body.operations: + collect_alloc_ops(op.operation) + + # Process each memref.alloc: insert linalg.fill right after it + for alloc_op in alloc_ops: + result = alloc_op.results[0] + memref_type = result.type + + try: + elem_type = memref_type.element_type + except: + continue + + loc = alloc_op.location + + # Find the next operation after this alloc in its parent block. + # We insert the fill before the next op (effectively after the alloc). + parent_block = alloc_op.operation.block + found_alloc = False + next_op = None + for block_op in parent_block: + if found_alloc: + next_op = block_op + break + if block_op.operation == alloc_op.operation: + found_alloc = True + + if next_op is None: + continue + + # Insert zero constant + linalg.fill before next_op (= after alloc) + with loc, InsertionPoint(next_op): + # Create zero constant based on element type + if str(elem_type).startswith('f'): + zero_const = arith.ConstantOp( + elem_type, FloatAttr.get(elem_type, 0.0) + ) + elif str(elem_type).startswith('i'): + zero_const = arith.ConstantOp( + elem_type, IntegerAttr.get(elem_type, 0) + ) + else: + continue + + fill_op = linalg.FillOp( + [], [zero_const.result], [result], loc=loc + ) + + region = fill_op.regions[0] + if len(region.blocks) == 0: + fill_block = region.blocks.append(elem_type, elem_type) + with InsertionPoint(fill_block): + linalg.YieldOp([fill_block.arguments[0]], loc=loc) + + +class LLVMModule: + def __init__(self, mod, top_func_name, ext_libs=None): + # Copy the module to avoid modifying the original one + with Context() as ctx: + # Register the nkipy dialect to handle nkipy.annotate ops + nkipy_d.register_dialect(ctx) + + self.module = Module.parse(str(mod), ctx) + self.top_func_name = top_func_name + func = find_func_in_module(self.module, top_func_name) + ext_libs = [] if ext_libs is None else ext_libs + + # Get input/output types + self.in_types, self.out_types = get_func_inputs_outputs(func) + + # Run through lowering passes + pm = PassManager.parse( + # "builtin.module(" + # # used for lowering tensor.empty + # "empty-tensor-to-alloc-tensor," + # # translate tensor dialect (virtual) to memref dialect (physical) + # "one-shot-bufferize{bufferize-function-boundaries}," + # # used for lowering memref.subview + # "expand-strided-metadata," + # # common lowering passes + # "func.func(convert-linalg-to-affine-loops),lower-affine" + # ")" + "builtin.module(" + "one-shot-bufferize{bufferize-function-boundaries=1}," + "func.func(convert-linalg-to-loops)," + "func.func(lower-affine)" + ")" + ) + + pm.run(self.module.operation) + # self.intermediate_module = self.module.operation.clone() + + # Attach necessary attributes + func = find_func_in_module(self.module, top_func_name) + if func is None: + raise RuntimeError( + "No top-level function found in the built MLIR module" + ) + func.attributes["llvm.emit_c_interface"] = UnitAttr.get() + func.attributes["top"] = UnitAttr.get() + + # https://github.com/llvm/llvm-project/issues/52945 + # Final lowering + pm = PassManager.parse( + "builtin.module(" + "func.func(convert-scf-to-cf)," + "func.func(arith-expand)," + "expand-strided-metadata," + "lower-affine," + "convert-math-to-llvm," + "convert-arith-to-llvm," + "finalize-memref-to-llvm," + "convert-func-to-llvm," + "convert-cf-to-llvm," + "reconcile-unrealized-casts" + ")" + ) + pm.run(self.module.operation) + + # Add shared library for MLIR runner utils (provides memrefCopy etc.) + # Resolve the LLVM lib directory: LLVM_INST env var takes priority, + # then llvm-config --libdir, then no shared libs. + llvm_lib_dir = None + if os.getenv("LLVM_INST"): + llvm_lib_dir = os.path.join(os.getenv("LLVM_INST"), "lib") + else: + try: + import subprocess + result = subprocess.run( + ["llvm-config", "--libdir"], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + llvm_lib_dir = result.stdout.strip() + except (FileNotFoundError, subprocess.TimeoutExpired): + pass + if llvm_lib_dir is not None: + shared_libs = [ + os.path.join(llvm_lib_dir, "libmlir_runner_utils.so"), + os.path.join(llvm_lib_dir, "libmlir_c_runner_utils.so"), + ] + else: + shared_libs = [] + + self.execution_engine = ExecutionEngine( + self.module, opt_level=2, shared_libs=shared_libs + ) + + # pylint: disable=too-many-branches + def __call__(self, *args): + """ + Reference: + * https://github.com/llvm/llvm-project/blob/llvmorg-15.0.0/mlir/test/python/execution_engine.py + * https://github.com/llvm/llvm-project/blob/llvmorg-15.0.0/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py + """ + input_types = self.in_types + arg_ptrs = [] + new_args = [] + assert len(args) == len( + input_types + ), f"# of input arguments mismatch, got {len(args)} but expected {len(input_types)}" + + # 1. Construct argument pointers + for arg, (target_in_type, shape, is_memref) in zip(args, input_types): + if not is_memref: # scalar + if isinstance(arg, int): + if target_in_type != "i32": + raise RuntimeError( + f"Input type mismatch: {target_in_type} vs i32. Please use NumPy array" + " to wrap the data to avoid possible result mismatch" + ) + bitwidth = get_bitwidth_from_type(target_in_type) + signed = "i" if target_in_type.startswith("i") else "ui" + dtype = ctype_map[f"{signed}{bitwidth}"] + c_int_p = dtype * 1 + arg_ptrs.append(c_int_p(arg)) + + elif isinstance(arg, float): + if target_in_type != "f32": + raise Warning( + f"Input type mismatch: {target_in_type} vs f32. Please use NumPy array" + " to wrap the data to avoid possible result mismatch" + ).warn() + if target_in_type == "f16": + c_float_p = ctypes.c_int16 * 1 + arg = np.float16(arg).view(np.int16) + elif target_in_type == "bf16": + c_float_p = ctypes.c_int16 * 1 + arg = ml_dtypes.bfloat16(arg).view(np.int16) + elif target_in_type == "f32": + c_float_p = ctypes.c_float * 1 + else: # f64 + c_float_p = ctypes.c_double * 1 + arg_ptrs.append(c_float_p(arg)) + + else: + raise RuntimeError( + "Unsupported input type. Please use NumPy array to wrap the data if other" + " data types are needed as inputs." + ) + + else: # memref + if not arg.flags["C_CONTIGUOUS"]: + raise RuntimeError( + "The input data is not contiguous. Please use np.ascontiguousarray to change the layout first." + ) + if not isinstance(arg.dtype, np.dtypes.VoidDType): + np_type = np_type_to_str(arg.dtype) + if np_type != target_in_type: + import warnings + warnings.warn( + f"Input type mismatch: {np_type} vs {target_in_type}", + RuntimeWarning, + ) + + if target_in_type in np_supported_types: + target_np_type = np_supported_types[target_in_type] + if arg.dtype != target_np_type: + # avoid changing the address of the original array + arg = arg.astype(target_np_type) + else: + raise RuntimeError( + f"Unsupported input type: {target_in_type}, " + f"please use a supported type or wrap the scalar as an array" + ) + arg_ptrs.append( + ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg))) + ) + new_args.append(arg) + + # 2. Construct return pointers + # Need to verify the return variable is not the same as the input + result_types = self.out_types + # Returns as arguments: no return value from the top function + if len(result_types) == 0: + self.execution_engine.invoke(self.top_func_name, *arg_ptrs) + for arg, new_arg, (target_in_type, shape, is_memref) in zip( + args, new_args, input_types + ): + if is_memref: + arg[:] = new_arg + return + # Return inner variables: return one or more values allocated inside kernel + # For two or more return values, llvm.emit_c_interface will return a struct + # Therefore, for functions that return values, we need to separate two cases: + # 1. return one value: no need to create a struct + # 2. return two or more values: need to create a struct + # In any case, we prepare a pointer of pointer to the return object + # which is ready to be passed to the invoke function. + if len(result_types) == 1: # exactly one return value + result_type, shape, is_memref = result_types[0] + if is_memref: + # After bufferization, tensors (including rank-0 tensors like tensor) + # become memrefs (including rank-0 memrefs like memref) + if result_type in ctype_map: + dtype = ctype_map[result_type] + elif result_type.startswith("i") or result_type.startswith("ui"): + bitwidth = get_bitwidth_from_type(result_type) + dtype = np.ctypeslib.as_ctypes_type(get_np_struct_type(bitwidth)) + else: + raise RuntimeError("Unsupported return type") + # Create an empty memref descriptor (rank-0 for scalars, rank-N for tensors) + return_desc = make_nd_memref_descriptor(len(shape), dtype)() + return_ptr = ctypes.pointer(ctypes.pointer(return_desc)) + else: # bare scalar + if result_type in ctype_map: + dtype = ctype_map[result_type] + else: + signed = "i" if result_type.startswith("i") else "ui" + bitwidth = get_bitwidth_from_type(result_type) + dtype = ctype_map[f"{signed}{bitwidth}"] + + dtype_p = dtype * 1 + # -1/-1.0 is a placeholder + return_ptr = dtype_p(-1 if not result_type in {"f32", "f64"} else 1.0) + + else: # multiple return values + # we assume all return values are memrefs + out_memref_descs = [] + for elt_res_type, elt_shape, is_memref in result_types: + if not is_memref: + raise RuntimeError( + "When returning multiple values, we only support all tensors/memrefs." + ) + if elt_res_type in ctype_map: + dtype = ctype_map[elt_res_type] + elif elt_res_type.startswith("i") or elt_res_type.startswith("ui"): + bitwidth = get_bitwidth_from_type(elt_res_type) + dtype = np.ctypeslib.as_ctypes_type(get_np_struct_type(bitwidth)) + else: + raise RuntimeError("Unsupported return type") + # Create an empty tensor + return_desc = make_nd_memref_descriptor(len(elt_shape), dtype)() + out_memref_descs.append(return_desc) + # Create a struct + out_struct = create_output_struct(out_memref_descs) + return_ptr = ctypes.pointer(ctypes.pointer(out_struct)) + + # 3. Invoke the function and return the result + if len(result_types) == 1: + result_type, shape, is_memref = result_types[0] + if is_memref: + # INVOKE - memref return (including rank-0 memrefs) + self.execution_engine.invoke(self.top_func_name, return_ptr, *arg_ptrs) + ret = ranked_memref_to_numpy(return_ptr[0][0]) + if result_type == "f16": + ret = np.array(ret, dtype=np.int16).view(np.float16) + elif result_type == "bf16": + ret = np.array(ret, dtype=np.int16).view(ml_dtypes.bfloat16) + + # For rank-0 tensors, extract the scalar value + if len(shape) == 0: + ret = ret.item() + else: + # INVOKE - bare scalar return + self.execution_engine.invoke(self.top_func_name, *arg_ptrs, return_ptr) + ret = return_ptr[0] + if result_type == "f16": + ret = np.int16(ret).view(np.float16) + elif result_type == "bf16": + ret = np.int16(ret).view(ml_dtypes.bfloat16) + else: # multiple returns, assume all memref + # INVOKE + self.execution_engine.invoke(self.top_func_name, return_ptr, *arg_ptrs) + ret_raw_np = extract_out_np_arrays_from_out_struct( + return_ptr, len(result_types) + ) + # pylint: disable=redefined-variable-type + ret = [] + for np_arr, (res_type, _, _) in zip(ret_raw_np, result_types): + if res_type == "f16": + ret_i = np.array(np_arr, dtype=np.int16).view(np.float16) + elif res_type == "bf16": + ret_i = np.array(np_arr, dtype=np.int16).view(ml_dtypes.bfloat16) + else: + ret_i = np_arr + ret.append(ret_i) + return ret + + +if __name__ == "__main__": + # Minimal example: parse MLIR from string, JIT-run it, and compare with NumPy. + mlir_src = r""" + module { + func.func @top(%A: tensor<4x4xf32>, %B: tensor<4x4xf32>) -> tensor<4x4xf32> { + %init = tensor.empty() : tensor<4x4xf32> + %C = linalg.generic + { indexing_maps = [ affine_map<(i,j)->(i,j)>, affine_map<(i,j)->(i,j)>, affine_map<(i,j)->(i,j)> ], + iterator_types = ["parallel", "parallel"] } + ins(%A, %B : tensor<4x4xf32>, tensor<4x4xf32>) + outs(%init : tensor<4x4xf32>) { + ^bb0(%a: f32, %b: f32, %acc: f32): + %sum = arith.addf %a, %b : f32 + linalg.yield %sum : f32 + } -> tensor<4x4xf32> + return %C : tensor<4x4xf32> + } + } + """ + + # Build module and engine + with Context(): + mod = Module.parse(mlir_src) + runner = LLVMModule(mod, "top") + + # Inputs and NumPy reference + A = np.random.rand(4, 4).astype(np.float32) + B = np.random.rand(4, 4).astype(np.float32) + ref = A + B + + # Run and compare + out = runner(A.copy(), B.copy()) + ok = np.allclose(out, ref, rtol=1e-5, atol=1e-6) + + print("Match with NumPy:", ok) + if not ok: + print("MLIR:", out) + print("NumPy:", ref) diff --git a/kernelgen/nkipy_kernelgen/mlir_utils.py b/kernelgen/nkipy_kernelgen/mlir_utils.py new file mode 100644 index 0000000..d5392c4 --- /dev/null +++ b/kernelgen/nkipy_kernelgen/mlir_utils.py @@ -0,0 +1,122 @@ +""" +MLIR utility functions for building IR constructs. +""" + +from typing import Callable, Dict, Tuple, Union +from mlir import ir +from mlir.dialects import arith, tensor, linalg +import ml_dtypes +import numpy as np + +# ---------- Type Mapping ---------- +_BASE_TYPE_MAP: Dict[str, Callable[[], ir.Type]] = { + "float16": lambda: ir.F16Type.get(), + "bfloat16": lambda: ir.BF16Type.get(), + "float32": lambda: ir.F32Type.get(), + "float64": lambda: ir.F64Type.get(), + "int8": lambda: ir.IntegerType.get_signless(8), + "int16": lambda: ir.IntegerType.get_signless(16), + "int32": lambda: ir.IntegerType.get_signless(32), + "int64": lambda: ir.IntegerType.get_signless(64), + "uint8": lambda: ir.IntegerType.get_signless(8), + "uint16": lambda: ir.IntegerType.get_signless(16), +} + +_MLIR_ALIASES: Dict[str, str] = { + "f16": "float16", + "bf16": "bfloat16", + "f32": "float32", + "f64": "float64", + "i8": "int8", + "i16": "int16", + "i32": "int32", + "i64": "int64", +} + +NP_TO_MLIR_TYPE_MAP: Dict[str, Callable[[], ir.Type]] = { + **_BASE_TYPE_MAP, + **{alias: _BASE_TYPE_MAP[canonical] for alias, canonical in _MLIR_ALIASES.items()}, +} + +NUMPY_TO_STR_MAP = { + np.float16: "float16", + ml_dtypes.bfloat16: "bfloat16", + np.float32: "float32", + np.float64: "float64", + np.int8: "int8", + np.int16: "int16", + np.int32: "int32", + np.int64: "int64", + np.uint8: "uint8", + np.uint16: "uint16", +} + +def to_mlir_type(dtype: Union[str, np.dtype, type]) -> ir.Type: + """ + Convert either: + - a string (e.g. "float32", "i32") + - a numpy dtype object (np.float32, np.dtype("float32"), etc.) + into the corresponding MLIR type. + """ + # Case 1: string + if isinstance(dtype, str): + key = dtype.lower() + if key not in NP_TO_MLIR_TYPE_MAP: + raise KeyError(f"Unsupported dtype string: {dtype}") + return NP_TO_MLIR_TYPE_MAP[key]() + + # Case 2: NumPy dtype + if isinstance(dtype, (np.dtype, type)): + # Normalize to a numpy.dtype + dtype = np.dtype(dtype) + if dtype.type not in NUMPY_TO_STR_MAP: + raise KeyError(f"Unsupported numpy dtype: {dtype}") + key = NUMPY_TO_STR_MAP[dtype.type] + return NP_TO_MLIR_TYPE_MAP[key]() + + raise TypeError(f"Unsupported dtype input type: {type(dtype)}") + + +def ranked_tensor_of(shape: Tuple[int, ...], elem_ty: ir.Type) -> ir.RankedTensorType: + """Create a ranked tensor type.""" + return ir.RankedTensorType.get(shape, elem_ty) + + +def make_empty(loc: ir.Location, shape: Tuple[int, ...], elem_ty: ir.Type) -> ir.Value: + """Create an empty tensor with given shape and element type (uninitialized).""" + return tensor.EmptyOp(list(shape), elem_ty, loc=loc).result + + +def make_filled(loc: ir.Location, shape: Tuple[int, ...], elem_ty: ir.Type, + fill_value: Union[int, float]) -> ir.Value: + """Create a tensor with given shape and element type, filled with a scalar.""" + result_type = ranked_tensor_of(shape, elem_ty) + empty_tensor = make_empty(loc, shape, elem_ty) + cst = const_scalar(fill_value, elem_ty, loc) + + filled = linalg.FillOp([result_type], [cst], [empty_tensor], loc=loc) + + region = filled.regions[0] + if len(region.blocks) == 0: + block = region.blocks.append(elem_ty, elem_ty) + with ir.InsertionPoint(block): + linalg.YieldOp([block.arguments[0]], loc=loc) + + return filled.results[0] + + +def make_zeros(loc: ir.Location, shape: Tuple[int, ...], elem_ty: ir.Type) -> ir.Value: + """Create a tensor with given shape and element type, initialized to zero.""" + return make_filled(loc, shape, elem_ty, 0.0) + + +def const_scalar(val: Union[int, float], elem_ty: ir.Type, loc: ir.Location) -> ir.Value: + """Create a scalar constant.""" + if isinstance(elem_ty, ir.FloatType): + attr = ir.FloatAttr.get(elem_ty, float(val)) + return arith.ConstantOp(elem_ty, attr, loc=loc).result + elif isinstance(elem_ty, ir.IntegerType): + attr = ir.IntegerAttr.get(elem_ty, int(val)) + return arith.ConstantOp(elem_ty, attr, loc=loc).result + else: + raise TypeError(f"Unsupported element type: {elem_ty}") diff --git a/kernelgen/nkipy_kernelgen/op_vtable.py b/kernelgen/nkipy_kernelgen/op_vtable.py new file mode 100644 index 0000000..11c4384 --- /dev/null +++ b/kernelgen/nkipy_kernelgen/op_vtable.py @@ -0,0 +1,411 @@ +""" +VTable (Virtual Table) for mapping NumPy operations to MLIR operations. + +This module contains the mapping between NumPy ufuncs/functions and their +corresponding MLIR implementations. It delegates to builder.py for IR +construction via TracedArray↔TensorHandle conversion. +""" + +from collections.abc import Callable + +import numpy as np +from mlir import ir + +from . import builder +from .builder import TensorHandle + + +# --------------------------------------------------------------------------- +# TracedArray ↔ TensorHandle bridge +# --------------------------------------------------------------------------- + + +def _to_handle(a): + """Convert a TracedArray to a TensorHandle for builder.py calls. + + Scalars pass through unchanged — builder's binary dispatch handles them. + """ + from .traced_array import TracedArray + + if isinstance(a, TracedArray): + return TensorHandle(a.value, a.shape, a.dtype, a.elem_ty) + return a + + +def _from_handle(h: TensorHandle, source_file: str = "unknown"): + """Convert a TensorHandle back to a TracedArray.""" + from .traced_array import TracedArray + + return TracedArray(h._value, h.shape, h._elem_ty, source_file) + + +def _source(a, b=None) -> str: + """Extract source_file from TracedArray operands.""" + from .traced_array import TracedArray + + if isinstance(a, TracedArray): + return a.source_file + if b is not None and isinstance(b, TracedArray): + return b.source_file + return "unknown" + + +# --------------------------------------------------------------------------- +# Ufunc operations (unary and binary) +# --------------------------------------------------------------------------- + + +def _add_op(a, b, loc): + return _from_handle(builder.add(_to_handle(a), _to_handle(b), loc=loc), _source(a, b)) + + +def _subtract_op(a, b, loc): + return _from_handle(builder.subtract(_to_handle(a), _to_handle(b), loc=loc), _source(a, b)) + + +def _multiply_op(a, b, loc): + return _from_handle(builder.multiply(_to_handle(a), _to_handle(b), loc=loc), _source(a, b)) + + +def _divide_op(a, b, loc): + return _from_handle(builder.divide(_to_handle(a), _to_handle(b), loc=loc), _source(a, b)) + + +def _maximum_op(a, b, loc): + return _from_handle(builder.maximum(_to_handle(a), _to_handle(b), loc=loc), _source(a, b)) + + +def _minimum_op(a, b, loc): + return _from_handle(builder.minimum(_to_handle(a), _to_handle(b), loc=loc), _source(a, b)) + + +def _power_op(a, b, loc): + return _from_handle(builder.power(_to_handle(a), _to_handle(b), loc=loc), _source(a, b)) + + +def _mod_op(a, b, loc): + return _from_handle(builder.mod(_to_handle(a), _to_handle(b), loc=loc), _source(a, b)) + + +def _square_op(a, loc): + return _from_handle(builder.square(_to_handle(a), loc=loc), _source(a)) + + +def _sqrt_op(a, loc): + return _from_handle(builder.sqrt(_to_handle(a), loc=loc), _source(a)) + + +def _exp_op(a, loc): + return _from_handle(builder.exp(_to_handle(a), loc=loc), _source(a)) + + +def _log_op(a, loc): + return _from_handle(builder.log(_to_handle(a), loc=loc), _source(a)) + + +def _tanh_op(a, loc): + return _from_handle(builder.tanh(_to_handle(a), loc=loc), _source(a)) + + +def _ceil_op(a, loc): + return _from_handle(builder.ceil_(_to_handle(a), loc=loc), _source(a)) + + +def _floor_op(a, loc): + return _from_handle(builder.floor_(_to_handle(a), loc=loc), _source(a)) + + +def _sin_op(a, loc): + return _from_handle(builder.sin(_to_handle(a), loc=loc), _source(a)) + + +def _cos_op(a, loc): + return _from_handle(builder.cos(_to_handle(a), loc=loc), _source(a)) + + +def _sign_op(a, loc): + return _from_handle(builder.sign(_to_handle(a), loc=loc), _source(a)) + + +def _abs_op(a, loc): + return _from_handle(builder.abs_(_to_handle(a), loc=loc), _source(a)) + + +def _negative_op(a, loc): + return _from_handle(builder.negative(_to_handle(a), loc=loc), _source(a)) + + +def _reciprocal_op(a, loc): + return _from_handle(builder.reciprocal(_to_handle(a), loc=loc), _source(a)) + + +def _greater_equal_op(a, b, loc): + return _from_handle(builder.greater_equal(_to_handle(a), _to_handle(b), loc=loc), _source(a, b)) + + +def _less_op(a, b, loc): + return _from_handle(builder.less(_to_handle(a), _to_handle(b), loc=loc), _source(a, b)) + + +def _equal_op(a, b, loc): + return _from_handle(builder.equal(_to_handle(a), _to_handle(b), loc=loc), _source(a, b)) + + +def _bitwise_and_op(a, b, loc): + return _from_handle(builder.bitwise_and(_to_handle(a), _to_handle(b), loc=loc), _source(a, b)) + + +def _bitwise_or_op(a, b, loc): + return _from_handle(builder.bitwise_or(_to_handle(a), _to_handle(b), loc=loc), _source(a, b)) + + +def _logical_not_op(a, loc): + return _from_handle(builder.logical_not(_to_handle(a), loc=loc), _source(a)) + + +# VTable for NumPy ufuncs +NUMPY_UFUNC_VTABLE: dict[str, Callable] = { + "add": _add_op, + "subtract": _subtract_op, + "multiply": _multiply_op, + "divide": _divide_op, + "square": _square_op, + "sqrt": _sqrt_op, + "exp": _exp_op, + "negative": _negative_op, + "reciprocal": _reciprocal_op, + "absolute": _abs_op, + "abs": _abs_op, + "log": _log_op, + "sin": _sin_op, + "cos": _cos_op, + "tanh": _tanh_op, + "ceil": _ceil_op, + "floor": _floor_op, + "sign": _sign_op, + "maximum": _maximum_op, + "minimum": _minimum_op, + "power": _power_op, + "remainder": _mod_op, + "mod": _mod_op, + "greater_equal": _greater_equal_op, + "less": _less_op, + "equal": _equal_op, + "bitwise_and": _bitwise_and_op, + "bitwise_or": _bitwise_or_op, + "logical_not": _logical_not_op, + "matmul": lambda a, b, loc: _matmul_op((a, b), {}, loc), +} + + +# --------------------------------------------------------------------------- +# NumPy function vtable operations +# --------------------------------------------------------------------------- + + +def _matmul_op(args: tuple, kwargs: dict, loc): + from .traced_array import TracedArray + + A, B = args + if not isinstance(A, TracedArray) or not isinstance(B, TracedArray): + raise TypeError("matmul requires TracedArray inputs") + h = builder.matmul(_to_handle(A), _to_handle(B), loc=loc) + return _from_handle(h, A.source_file) + + +def _transpose_op(args: tuple, kwargs: dict, loc): + from .traced_array import TracedArray + + A = args[0] + if not isinstance(A, TracedArray): + raise TypeError("transpose requires TracedArray input") + axes = args[1] if len(args) > 1 else kwargs.get("axes", None) + h = builder.transpose(_to_handle(A), axes=axes, loc=loc) + return _from_handle(h, A.source_file) + + +def _reshape_op(args: tuple, kwargs: dict, loc): + from .traced_array import TracedArray + + A, newshape = args + if not isinstance(A, TracedArray): + raise TypeError("reshape requires TracedArray input") + h = builder.reshape(_to_handle(A), newshape, loc=loc) + return _from_handle(h, A.source_file) + + +def _sum_op(args: tuple, kwargs: dict, loc): + from .traced_array import TracedArray + + A = args[0] + if not isinstance(A, TracedArray): + raise TypeError("sum requires TracedArray input") + axis = kwargs.get("axis", None) + keepdims = kwargs.get("keepdims", False) + h = builder.reduce_sum(_to_handle(A), axis=axis, keepdims=keepdims, loc=loc) + return _from_handle(h, A.source_file) + + +def _prod_op(args: tuple, kwargs: dict, loc): + from .traced_array import TracedArray + + A = args[0] + if not isinstance(A, TracedArray): + raise TypeError("prod requires TracedArray input") + axis = kwargs.get("axis", None) + keepdims = kwargs.get("keepdims", False) + h = builder.reduce_prod(_to_handle(A), axis=axis, keepdims=keepdims, loc=loc) + return _from_handle(h, A.source_file) + + +def _max_op(args: tuple, kwargs: dict, loc): + from .traced_array import TracedArray + + A = args[0] + if not isinstance(A, TracedArray): + raise TypeError("max requires TracedArray input") + axis = kwargs.get("axis", None) + keepdims = kwargs.get("keepdims", False) + h = builder.reduce_max(_to_handle(A), axis=axis, keepdims=keepdims, loc=loc) + return _from_handle(h, A.source_file) + + +def _min_op(args: tuple, kwargs: dict, loc): + from .traced_array import TracedArray + + A = args[0] + if not isinstance(A, TracedArray): + raise TypeError("min requires TracedArray input") + axis = kwargs.get("axis", None) + keepdims = kwargs.get("keepdims", False) + h = builder.reduce_min(_to_handle(A), axis=axis, keepdims=keepdims, loc=loc) + return _from_handle(h, A.source_file) + + +def _mean_op(args: tuple, kwargs: dict, loc): + from .traced_array import TracedArray + + A = args[0] + if not isinstance(A, TracedArray): + raise TypeError("mean requires TracedArray input") + axis = kwargs.get("axis", None) + keepdims = kwargs.get("keepdims", False) + h = builder.reduce_mean(_to_handle(A), axis=axis, keepdims=keepdims, loc=loc) + return _from_handle(h, A.source_file) + + +def _std_op(args: tuple, kwargs: dict, loc): + from .traced_array import TracedArray + + A = args[0] + if not isinstance(A, TracedArray): + raise TypeError("std requires TracedArray input") + axis = kwargs.get("axis", None) + keepdims = kwargs.get("keepdims", False) + h = builder.reduce_std(_to_handle(A), axis=axis, keepdims=keepdims, loc=loc) + return _from_handle(h, A.source_file) + + +def _concatenate_op(args: tuple, kwargs: dict, loc): + from .traced_array import TracedArray + + arrays = args[0] + axis = kwargs.get("axis", 0) + if not all(isinstance(a, TracedArray) for a in arrays): + raise TypeError("concatenate requires all inputs to be TracedArray") + handles = [_to_handle(a) for a in arrays] + h = builder.concatenate(handles, axis=axis, loc=loc) + return _from_handle(h, arrays[0].source_file) + + +def _split_op(args: tuple, kwargs: dict, loc): + from .traced_array import TracedArray + + A = args[0] + indices_or_sections = args[1] + axis = kwargs.get("axis", args[2] if len(args) > 2 else 0) + if not isinstance(A, TracedArray): + raise TypeError("split requires TracedArray input") + + if isinstance(indices_or_sections, int): + handles = builder.split(_to_handle(A), indices_or_sections, axis=axis, loc=loc) + return tuple(_from_handle(h, A.source_file) for h in handles) + else: + raise NotImplementedError("split with explicit indices not yet implemented") + + +def _expand_dims_op(args: tuple, kwargs: dict, loc): + from .traced_array import TracedArray + + A = args[0] + axis = args[1] if len(args) > 1 else kwargs.get("axis") + if not isinstance(A, TracedArray): + raise TypeError("expand_dims requires TracedArray input") + h = builder.expand_dims(_to_handle(A), axis, loc=loc) + return _from_handle(h, A.source_file) + + +def _broadcast_to_op(args: tuple, kwargs: dict, loc): + from .traced_array import TracedArray + + A = args[0] + target_shape = tuple(args[1]) + if not isinstance(A, TracedArray): + raise TypeError("broadcast_to requires TracedArray input") + h = builder.broadcast_to(_to_handle(A), target_shape, loc=loc) + return _from_handle(h, A.source_file) + + +def _copy_op(args: tuple, kwargs: dict, loc): + from .traced_array import TracedArray + + A = args[0] + if not isinstance(A, TracedArray): + raise TypeError("copy requires TracedArray input") + h = builder.copy_(_to_handle(A), loc=loc) + return _from_handle(h, A.source_file) + + +def _take_op(args: tuple, kwargs: dict, loc): + from .traced_array import TracedArray + + a = args[0] + indices = args[1] + axis = kwargs.get("axis", args[2] if len(args) > 2 else 0) + if isinstance(indices, TracedArray) and isinstance(indices.elem_ty, ir.FloatType): + indices_h = builder.astype(_to_handle(indices), np.int32, loc=loc) + indices = _from_handle(indices_h, indices.source_file) + h = builder.take(_to_handle(a), _to_handle(indices), axis=axis, loc=loc) + return _from_handle(h, a.source_file) + + +def _where_op(args: tuple, kwargs: dict, loc): + from .traced_array import TracedArray + + condition = args[0] + x = args[1] + y = args[2] + h = builder.where(_to_handle(condition), _to_handle(x), _to_handle(y), loc=loc) + sf = _source(condition, x if isinstance(x, TracedArray) else y) + return _from_handle(h, sf) + + +# VTable for NumPy array functions +NUMPY_FUNCTION_VTABLE: dict[Callable, Callable] = { + np.matmul: _matmul_op, + np.transpose: _transpose_op, + np.reshape: _reshape_op, + np.sum: _sum_op, + np.prod: _prod_op, + np.max: _max_op, + np.min: _min_op, + np.mean: _mean_op, + np.std: _std_op, + np.concatenate: _concatenate_op, + np.split: _split_op, + np.expand_dims: _expand_dims_op, + np.broadcast_to: _broadcast_to_op, + np.copy: _copy_op, + np.take: _take_op, + np.where: _where_op, +} diff --git a/kernelgen/nkipy_kernelgen/pass_manager.py b/kernelgen/nkipy_kernelgen/pass_manager.py new file mode 100644 index 0000000..4b3bf32 --- /dev/null +++ b/kernelgen/nkipy_kernelgen/pass_manager.py @@ -0,0 +1,119 @@ +""" +MLIR pass management utilities. + +This module provides utilities for applying MLIR transformation passes to modules. +""" + + +def apply_passes(mlir_module, passes): + """ + Apply MLIR passes and/or custom Python transformations to a module and return the transformed module. + + This function allows you to apply a sequence of transformation passes to an MLIR module. + You can specify passes in multiple formats and mix them together: + + Args: + mlir_module: The MLIR module to transform + passes: Can be: + - A string: Complete pass pipeline (e.g., "builtin.module(func.func(...))") + - A list containing: + * Pass name strings (e.g., "linalg-generalize-named-ops") + * Callable functions that take and return a module + - A single callable function + + Returns: + The transformed MLIR module + + Examples: + >>> from nkipy_kernelgen import apply_passes + >>> + >>> # Method 1: Using a list of pass names + >>> transformed = apply_passes(module, ["linalg-generalize-named-ops", "linalg-fuse-elementwise-ops"]) + >>> + >>> # Method 2: Using a complete pipeline string + >>> transformed = apply_passes(module, "builtin.module(func.func(linalg-generalize-named-ops))") + """ + from mlir import passmanager + + # Module-level passes that should not be nested in func.func + MODULE_LEVEL_PASSES = { + "one-shot-bufferize", + "canonicalize", + "cse", + "symbol-dce", + "inline", + } + + # Note: convert-linalg-to-loops is a function-level pass + + def is_module_level_pass(pass_name): + """Check if a pass should run at module level rather than function level.""" + # Extract the base pass name (without options) + base_name = pass_name.split("{")[0].strip() + return base_name in MODULE_LEVEL_PASSES + + def apply_accumulated_passes(module, func_passes, module_passes): + """Apply accumulated function-level and module-level passes.""" + # Apply module-level passes FIRST (at module level) + # This is critical for passes like one-shot-bufferize that need to run + # before subsequent function-level passes like convert-linalg-to-loops + if module_passes: + pass_list_str = ",".join(module_passes) + pass_pipeline = f"builtin.module({pass_list_str})" + with module.context: + pm = passmanager.PassManager.parse(pass_pipeline) + pm.run(module.operation) + + # Apply function-level passes SECOND (nested in func.func) + if func_passes: + pass_list_str = ",".join(func_passes) + pass_pipeline = f"builtin.module(func.func({pass_list_str}))" + with module.context: + pm = passmanager.PassManager.parse(pass_pipeline) + pm.run(module.operation) + + return module + + # Handle single callable function + if callable(passes): + return passes(mlir_module) + + # Handle string (complete pipeline) + if isinstance(passes, str): + with mlir_module.context: + pm = passmanager.PassManager.parse(passes) + pm.run(mlir_module.operation) + return mlir_module + + # Handle list of passes (mix of strings and callables) + if isinstance(passes, list): + # Separate passes into function-level and module-level + func_passes = [] + module_passes = [] + + for item in passes: + if callable(item): + # If we have accumulated MLIR passes, apply them first + if func_passes or module_passes: + mlir_module = apply_accumulated_passes(mlir_module, func_passes, module_passes) + func_passes = [] + module_passes = [] + + # Apply the Python transformation + mlir_module = item(mlir_module) + elif isinstance(item, str): + # Categorize as module-level or function-level pass + if is_module_level_pass(item): + module_passes.append(item) + else: + func_passes.append(item) + else: + raise TypeError(f"Pass must be a string or callable, got {type(item)}") + + # Apply any remaining MLIR passes + if func_passes or module_passes: + mlir_module = apply_accumulated_passes(mlir_module, func_passes, module_passes) + + return mlir_module + + raise TypeError(f"Passes must be a string, callable, or list, got {type(passes)}") diff --git a/kernelgen/nkipy_kernelgen/trace.py b/kernelgen/nkipy_kernelgen/trace.py new file mode 100644 index 0000000..a562688 --- /dev/null +++ b/kernelgen/nkipy_kernelgen/trace.py @@ -0,0 +1,164 @@ +""" +Trace decorator for converting Python functions with NumPy operations to MLIR. +""" + +import contextlib +import functools +import inspect +from typing import Callable, Optional + +import numpy as np +from mlir import ir + +from .builder import IRBuilder +from . import builder +from .traced_array import TracedArray +from .op_vtable import _to_handle, _from_handle +from .custom_op import _get_registry, _clear_registry + + +def _normalize_shape(shape): + if isinstance(shape, TracedArray): + return tuple(shape.shape) + if isinstance(shape, (list, tuple)): + return tuple(int(d) for d in shape) + if isinstance(shape, int): + return (shape,) + return tuple(shape) + + +@contextlib.contextmanager +def _numpy_constructor_patch(source_file: str): + """Patch NumPy constructors during tracing so they emit MLIR and return TracedArray.""" + originals = {} + + def _patch(name, fn): + originals[name] = getattr(np, name) + setattr(np, name, fn) + + def _make_loc(): + return ir.Location.file(source_file, 0, 0, context=ir.Context.current) + + def ones(shape, dtype=None, **kw): + shp = _normalize_shape(shape) + h = builder.full(shp, 1.0, dtype or np.float32, loc=_make_loc()) + return _from_handle(h, source_file) + + def zeros(shape, dtype=None, **kw): + shp = _normalize_shape(shape) + h = builder.zeros(shp, dtype or np.float32, loc=_make_loc()) + return _from_handle(h, source_file) + + def full(shape, fill_value, dtype=None, **kw): + shp = _normalize_shape(shape) + h = builder.full(shp, fill_value, dtype or np.float32, loc=_make_loc()) + return _from_handle(h, source_file) + + def empty(shape, dtype=None, **kw): + shp = _normalize_shape(shape) + h = builder.empty(shp, dtype or np.float32, loc=_make_loc()) + return _from_handle(h, source_file) + + def array(obj, dtype=None, **kw): + if isinstance(obj, TracedArray): + return obj + return originals["array"](obj, dtype=dtype, **kw) + + def asarray(obj, dtype=None, **kw): + return array(obj, dtype=dtype, **kw) + + try: + for name, fn in ( + ("ones", ones), + ("zeros", zeros), + ("empty", empty), + ("full", full), + ("array", array), + ("asarray", asarray), + ): + _patch(name, fn) + yield + finally: + for k, v in originals.items(): + setattr(np, k, v) + + +def trace( + func_to_trace: Optional[Callable] = None, + *, + input_specs: Optional[list] = None, + name: Optional[str] = None, +) -> Callable: + """Decorator to trace a Python function with NumPy APIs into MLIR.""" + + def decorator(f: Callable) -> Callable: + func_name = name or f.__name__ + source_file = inspect.getsourcefile(f) or "unknown" + + @functools.wraps(f) + def wrapper(*args, **kwargs): + return f(*args, **kwargs) + + def to_mlir(specs: Optional[list] = None, debug: bool = False): + """Generate MLIR module from the traced function.""" + nonlocal input_specs + specs = specs or input_specs + if not specs: + raise ValueError( + "input_specs must be provided either in decorator or to_mlir()" + ) + + _clear_registry() + + b = IRBuilder(source_file=source_file) + arg_shapes = [s for s, _ in specs] + arg_dtypes = [d for _, d in specs] + handles = b.begin_function(func_name, arg_shapes, arg_dtypes) + + traced_args = [ + TracedArray(h._value, h.shape, h._elem_ty, source_file=source_file) + for h in handles + ] + + try: + with _numpy_constructor_patch(source_file): + result = f(*traced_args) + + if isinstance(result, tuple): + results = list(result) + elif isinstance(result, TracedArray): + results = [result] + else: + raise TypeError( + f"Result must be a TracedArray or tuple of TracedArrays, got {type(result)}" + ) + + for i, r in enumerate(results): + if not isinstance(r, TracedArray): + raise TypeError( + f"Result element {i} must be a TracedArray, got {type(r)}" + ) + + result_handles = [_to_handle(r) for r in results] + b.finish_function(result_handles) + + custom_ops = _get_registry() + b.emit_custom_op_declarations(custom_ops) + + b.run_canonicalize() + module = b.module + return module + finally: + _clear_registry() + b.cleanup() + + wrapper.to_mlir = to_mlir + wrapper.__traced__ = True + wrapper.input_specs = input_specs + + return wrapper + + if func_to_trace is None: + return decorator + else: + return decorator(func_to_trace) diff --git a/kernelgen/nkipy_kernelgen/traced_array.py b/kernelgen/nkipy_kernelgen/traced_array.py new file mode 100644 index 0000000..127a388 --- /dev/null +++ b/kernelgen/nkipy_kernelgen/traced_array.py @@ -0,0 +1,350 @@ +""" +TracedArray class for intercepting NumPy operations. +""" + +import inspect +import os +from typing import Tuple, Optional +import numpy as np +from mlir import ir + +from .op_vtable import NUMPY_UFUNC_VTABLE, NUMPY_FUNCTION_VTABLE + + +class TracedArray: + """Represents a traced array that builds MLIR operations.""" + + # Enable NumPy's __array_function__ protocol + __array_priority__ = 1000 + + def __init__(self, value: ir.Value, shape: Tuple[int, ...], elem_ty: ir.Type, source_file: Optional[str] = None): + self.value = value + self.shape = tuple(shape) + self.elem_ty = elem_ty + self.source_file = source_file or "unknown" + + @property + def dtype(self): + """Return the numpy-compatible dtype of the TracedArray.""" + # Map MLIR type to numpy dtype + if isinstance(self.elem_ty, ir.FloatType): + if self.elem_ty.width == 32: + return np.float32 + elif self.elem_ty.width == 64: + return np.float64 + elif self.elem_ty.width == 16: + return np.float16 + elif isinstance(self.elem_ty, ir.IntegerType): + if self.elem_ty.width == 32: + return np.int32 if self.elem_ty.is_signed else np.uint32 + elif self.elem_ty.width == 64: + return np.int64 if self.elem_ty.is_signed else np.uint64 + elif self.elem_ty.width == 16: + return np.int16 if self.elem_ty.is_signed else np.uint16 + elif self.elem_ty.width == 8: + return np.int8 if self.elem_ty.is_signed else np.uint8 + return self.elem_ty # Fallback to MLIR type + + @property + def ndim(self): + """Return the number of dimensions.""" + return len(self.shape) + + def __repr__(self) -> str: + """Return a detailed string representation of the TracedArray.""" + return f"TracedArray(shape={self.shape}, dtype={self.elem_ty}, source={self.source_file})" + + def __str__(self) -> str: + """Return a user-friendly string representation of the TracedArray.""" + return f"TracedArray{self.shape} of type {self.elem_ty}" + + def _get_caller_location(self) -> ir.Location: + """Get MLIR location from Python call stack.""" + # Walk up the stack to find the user's code (skip internal frames) + frame = inspect.currentframe() + try: + # Walk through all frames and find the first one that matches our source file + while frame is not None: + frame = frame.f_back + if frame is None: + break + + filename = frame.f_code.co_filename + lineno = frame.f_lineno + + if self.source_file != "unknown": + source_basename = os.path.basename(self.source_file) + frame_basename = os.path.basename(filename) + + # If the basenames match, or if the source_file is contained in filename + if source_basename == frame_basename or self.source_file in filename: + ctx = ir.Context.current + return ir.Location.file(self.source_file, lineno, 0, context=ctx) + finally: + del frame # Avoid reference cycles + + # Fallback to the value's location + return self.value.location + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + """Intercept NumPy ufuncs like np.add, np.multiply, np.square, etc.""" + if method != "__call__": + raise NotImplementedError(f"Method {method} not supported") + + # Get location from Python call stack + loc = self._get_caller_location() + + ufunc_name = ufunc.__name__ + + # Check if this ufunc is supported + if ufunc_name not in NUMPY_UFUNC_VTABLE: + raise NotImplementedError(f"ufunc {ufunc_name} not supported") + + # Check if this is a unary operation + if len(inputs) == 1: + A = inputs[0] + if not isinstance(A, TracedArray): + A = lift_constant(A, self.elem_ty, self.shape, loc, self.source_file) + return NUMPY_UFUNC_VTABLE[ufunc_name](A, loc) + + # Check if this is a binary operation + elif len(inputs) == 2: + # Pass inputs directly — scalar constants are handled by each op + A, B = inputs[0], inputs[1] + return NUMPY_UFUNC_VTABLE[ufunc_name](A, B, loc) + + else: + raise NotImplementedError(f"ufunc {ufunc_name} with {len(inputs)} inputs not supported") + + def __array_function__(self, func, types, args, kwargs): + """Intercept NumPy array functions like matmul, transpose, etc.""" + if func not in NUMPY_FUNCTION_VTABLE: + return NotImplemented + + # Get location from Python call stack + loc = self._get_caller_location() + return NUMPY_FUNCTION_VTABLE[func](args, kwargs, loc) + + def __getitem__(self, key): + """Support indexing/slicing using tensor.extract_slice or gather.""" + from .control_flow import LoopIndex + from .op_vtable import _to_handle, _from_handle + from . import builder + + if not isinstance(key, tuple): + key = (key,) + + # Gather path: first dim is a TracedArray + if len(key) > 0 and isinstance(key[0], TracedArray): + for i, slice_spec in enumerate(key[1:], start=1): + if not isinstance(slice_spec, slice): + raise TypeError("Gather with TracedArray indices only supports full slices on remaining dimensions") + if slice_spec != slice(None, None, None): + raise NotImplementedError("Partial slicing with gather not yet supported") + result_h = builder.take(_to_handle(self), _to_handle(key[0]), axis=0) + return _from_handle(result_h, self.source_file) + + # Check for dynamic indices + has_dynamic = any( + isinstance(s, LoopIndex) or ( + isinstance(s, slice) and ( + isinstance(s.start, LoopIndex) or isinstance(s.stop, LoopIndex) + ) + ) + for s in key + ) + + self_h = _to_handle(self) + if has_dynamic: + indices = self._convert_key_to_builder_indices(key, LoopIndex) + result_h = builder.dynamic_slice(self_h, indices) + else: + start_indices, limit_indices, strides_list, squeeze_dims = \ + self._convert_key_to_static(key) + result_h = builder.static_slice( + self_h, start_indices, limit_indices, strides_list, squeeze_dims + ) + return _from_handle(result_h, self.source_file) + + def _convert_key_to_builder_indices(self, key, LoopIndex): + """Convert __getitem__ key to builder.dynamic_slice indices tuple.""" + from .builder import LoopIndexHandle + indices = [] + for slice_spec in key: + if isinstance(slice_spec, LoopIndex): + indices.append(LoopIndexHandle( + slice_spec.value, slice_spec.mul_factor, slice_spec.add_offset + )) + elif isinstance(slice_spec, slice): + start = slice_spec.start + stop = slice_spec.stop + step = slice_spec.step + new_start = start + new_stop = stop + if isinstance(start, LoopIndex): + new_start = LoopIndexHandle( + start.value, start.mul_factor, start.add_offset + ) + if isinstance(stop, LoopIndex): + new_stop = LoopIndexHandle( + stop.value, stop.mul_factor, stop.add_offset + ) + indices.append(slice(new_start, new_stop, step)) + else: + indices.append(slice_spec) + return tuple(indices) + + def _convert_key_to_static(self, key): + """Convert __getitem__ key to static_slice parameters.""" + start_indices = [] + limit_indices = [] + strides_list = [] + squeeze_dims = [] + + for dim_idx, (slice_spec, dim_size) in enumerate(zip(key, self.shape)): + if isinstance(slice_spec, slice): + start = slice_spec.start if slice_spec.start is not None else 0 + stop = slice_spec.stop if slice_spec.stop is not None else dim_size + step = slice_spec.step if slice_spec.step is not None else 1 + start_indices.append(int(start)) + limit_indices.append(int(stop)) + strides_list.append(int(step)) + elif isinstance(slice_spec, int): + start_indices.append(int(slice_spec)) + limit_indices.append(int(slice_spec) + 1) + strides_list.append(1) + squeeze_dims.append(dim_idx) + else: + raise TypeError(f"Unsupported index type: {type(slice_spec)}") + + for dim_idx in range(len(key), len(self.shape)): + start_indices.append(0) + limit_indices.append(self.shape[dim_idx]) + strides_list.append(1) + + return start_indices, limit_indices, strides_list, squeeze_dims + + def __setitem__(self, key, value): + """Support item assignment using tensor.insert_slice. + + Since MLIR tensors are SSA values (immutable), this updates self.value + to a new tensor with the slice inserted. + """ + from .control_flow import LoopIndex + from .op_vtable import _to_handle, _from_handle + from . import builder + + if not isinstance(value, TracedArray): + raise TypeError(f"Can only assign TracedArray values, got {type(value)}") + + if not isinstance(key, tuple): + key = (key,) + + # Expand value for any single-index dims that were collapsed. + insert_val = value + for dim_idx, slice_spec in enumerate(key): + if isinstance(slice_spec, (int, LoopIndex)): + insert_val = np.expand_dims(insert_val, axis=dim_idx) + + has_dynamic = any( + isinstance(s, LoopIndex) or ( + isinstance(s, slice) and ( + isinstance(s.start, LoopIndex) or isinstance(s.stop, LoopIndex) + ) + ) + for s in key + ) + + self_h = _to_handle(self) + insert_h = _to_handle(insert_val) + + if has_dynamic: + indices = self._convert_key_to_builder_indices(key, LoopIndex) + result_h = builder.dynamic_insert_slice(self_h, insert_h, indices) + else: + offsets, sizes, strides_list = [], [], [] + + for slice_spec, dim_size in zip(key, self.shape): + if isinstance(slice_spec, slice): + start = slice_spec.start if slice_spec.start is not None else 0 + stop = slice_spec.stop if slice_spec.stop is not None else dim_size + step = slice_spec.step if slice_spec.step is not None else 1 + offsets.append(int(start)) + sizes.append(int(stop - start)) + strides_list.append(int(step)) + elif isinstance(slice_spec, int): + offsets.append(int(slice_spec)) + sizes.append(1) + strides_list.append(1) + else: + raise TypeError(f"Unsupported index type: {type(slice_spec)}") + + for dim_idx in range(len(key), len(self.shape)): + offsets.append(0) + sizes.append(self.shape[dim_idx]) + strides_list.append(1) + + result_h = builder.static_insert_slice( + self_h, insert_h, offsets, sizes, strides_list + ) + + self.value = result_h._value + + def astype(self, dtype, **kwargs): + """Cast array to specified dtype.""" + from .op_vtable import _to_handle, _from_handle + from . import builder + + result_h = builder.astype(_to_handle(self), dtype) + return _from_handle(result_h, self.source_file) + + def reshape(self, *shape): + """Reshape the array using np.reshape.""" + # Flatten shape if it's passed as a tuple + if len(shape) == 1 and isinstance(shape[0], (tuple, list)): + shape = shape[0] + # Handle -1 in shape (infer dimension) + if -1 in shape: + # Compute the inferred dimension + known_size = 1 + unknown_idx = None + for i, dim in enumerate(shape): + if dim == -1: + if unknown_idx is not None: + raise ValueError("can only specify one unknown dimension") + unknown_idx = i + else: + known_size *= dim + total_size = np.prod(self.shape) + inferred_dim = total_size // known_size + shape = list(shape) + shape[unknown_idx] = inferred_dim + shape = tuple(shape) + return np.reshape(self, shape) + + # Python operator overloads to trigger NumPy ufuncs + def __add__(self, other): return np.add(self, other) + def __sub__(self, other): return np.subtract(self, other) + def __mul__(self, other): return np.multiply(self, other) + def __truediv__(self, other): return np.divide(self, other) + def __radd__(self, other): return np.add(other, self) + def __rsub__(self, other): return np.subtract(other, self) + def __rmul__(self, other): return np.multiply(other, self) + def __rtruediv__(self, other): return np.divide(other, self) + def __neg__(self): return np.negative(self) + def __matmul__(self, other): return np.matmul(self, other) + def __rmatmul__(self, other): return np.matmul(other, self) + + +def lift_constant(val: float, elem_ty: ir.Type, shape: Tuple[int, ...], loc: ir.Location, source_file: str = "unknown") -> TracedArray: + """Convert a scalar constant to a TracedArray by filling a tensor. + + The result is annotated with CONSTANT memory space (mem_space=5) to indicate + that this tensor represents a broadcasted scalar constant. This allows later + passes to optimize tensor-scalar operations (e.g., use nisa.tensor_scalar_arith + instead of nisa.tensor_tensor_arith). + """ + from . import builder + + h = builder.constant_tensor(val, shape, elem_ty) + return TracedArray(h._value, shape, elem_ty, source_file) diff --git a/kernelgen/nkipy_kernelgen/transforms/__init__.py b/kernelgen/nkipy_kernelgen/transforms/__init__.py new file mode 100644 index 0000000..b6d00ba --- /dev/null +++ b/kernelgen/nkipy_kernelgen/transforms/__init__.py @@ -0,0 +1,36 @@ +""" +MLIR Transformation Passes for NKIPyKernelGen + +This package provides Python implementations of MLIR transformation passes +for lowering traced NumPy functions to Neuron hardware. +""" + +__all__ = [] + +# Import pass management +from ..pass_manager import apply_passes + +__all__.extend([ + "apply_passes", +]) + +# Import nkipy-opt wrapper for passes that can't run from Python +try: + from .nkipy_opt import ( + get_nkipy_opt_path, + run_nkipy_opt_passes, + apply_complete_knob_pipeline, + ) + + __all__.extend([ + "get_nkipy_opt_path", + "run_nkipy_opt_passes", + "apply_complete_knob_pipeline", + ]) +except ImportError as e: + # nkipy-opt wrapper not available + import warnings + warnings.warn( + f"nkipy-opt wrapper not available: {e}", + ImportWarning + ) diff --git a/kernelgen/nkipy_kernelgen/transforms/linalg_to_nisa_py.py b/kernelgen/nkipy_kernelgen/transforms/linalg_to_nisa_py.py new file mode 100644 index 0000000..eb3e461 --- /dev/null +++ b/kernelgen/nkipy_kernelgen/transforms/linalg_to_nisa_py.py @@ -0,0 +1,2720 @@ +""" +Python replacement for the (deleted) C++ `linalg-to-nisa` pass. + +Reads post-Phase-4 MLIR (linalg+memref+scf+arith+func, with integer-encoded +nkipy memory spaces) and emits NISA MLIR using the `nki` wheel's Python +bindings. Everything downstream (resolve-custom-ops, prepare-for-nki, +`nki-opt-pipeline`) then consumes standard NISA IR. + +Design +------ + +This mirrors the (pre-open-source) C++ `LinalgToNisa.cpp` architecture: + +1. **Parse in upstream ctx, re-parse in NKI ctx.** Memref/scf/linalg exist + only in upstream MLIR; NISA exists only in the NKI wheel's context. We + print the module as generic IR in the upstream context, rewrite integer + memspace markers to `#nisa.mem<...>`, and re-parse in the NKI context + (with `allow_unregistered_dialects`) so unrecognised ops survive as + opaque ops we can still walk. + +2. **Walk + rewrite.** For each supported op we compute a ``MemRefAccess`` + per operand by tracing back through subview/collapse_shape/expand_shape + to the base memref (materialising `arith.constant` / `arith.addi` / + `arith.muli` for the offset math along the way). We then build a plain + ``AffineMap`` per operand following ``createStandardNisaMap`` (d0 at the + first kept dim, d1 at the last kept dim, symbols everywhere else), flatten + it via ``nisa.flatten_affine_map``, and hand the base+indices+map tuple + straight to the Python ``nisa.(...)`` builder. No ``prepare_operand`` + layer — that path inside ``_nki_irbuilder`` runs a linearisation that can + merge per-dim expressions into a single multi-symbol flat_affine_expr, + which the NISA verifier rejects. + +3. **Post-pass cleanup.** DCE the dead view ops, fold `reinterpret_cast` + on fresh allocs, fold HBM reshape chains into the alloc type. + +Validation +---------- + +Per the plan's 2026-04-20 decision: we do not byte-diff against pre-refactor +golden MLIR. Each e2e test is expected to simulate correctly through BIRSim +(and/or run on HW). The pass is a success when every fixture's numerical +output matches NumPy. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from typing import Callable + +from mlir import ir as up_ir # type: ignore[import-not-found] + +from nki.compiler._internal import ir as nk_ir # type: ignore[import-not-found] +from nki.compiler._internal._mlir_libs import _nki # type: ignore[import-not-found] +from nki.compiler._internal._mlir_libs import _nki_irbuilder # type: ignore[import-not-found] +from nki.compiler._internal.dialects import nisa # type: ignore[import-not-found] + + +# NKIPy emits memref memory-space annotations as `N : i32` integers (matching +# the `MemSpaceEnum` in `NkipyAttrs.td`). NKI's parser needs them as +# `#nisa.mem<...>` attribute syntax. The enum values start at 1 (not 0) — see +# the comment in NkipyAttrs.td for why 0 cannot be used. +_NKIPY_TO_NISA_MEMSPACE = { + 1: "hbm", + 2: "psum", + 3: "sbuf", + 4: "shared_hbm", +} + +_INT_MEMSPACE_RE = re.compile(r", (\d+) : i32>") + + +def _rewrite_memspace_text(generic: str) -> str: + def repl(m: re.Match[str]) -> str: + n = int(m.group(1)) + name = _NKIPY_TO_NISA_MEMSPACE.get(n) + if name is None: + return m.group(0) + return f", #nisa.mem<{name}>>" + + return _INT_MEMSPACE_RE.sub(repl, generic) + + +def _to_nki_module(src: str) -> tuple[nk_ir.Context, nk_ir.Module]: + # Re-serialize through nkipy-opt with `--mlir-print-op-generic` so any + # nkipy-dialect ops that survive into this phase (currently just + # `nkipy.gather`, which we lower below) arrive in generic form + # `"nkipy.gather"(...)`. The upstream MLIR Python bindings don't know + # about the nkipy dialect; pretty-form `nkipy.gather(...)` would fail + # to parse (`allow_unregistered_dialects` only covers generic form). + from .nkipy_opt import run_nkipy_opt_passes # avoid circular import + src = run_nkipy_opt_passes(src, passes=[], print_generic=True) + + up_ctx = up_ir.Context() + up_ctx.load_all_available_dialects() + up_ctx.allow_unregistered_dialects = True + with up_ctx: + up_mod = up_ir.Module.parse(src) + generic = up_mod.operation.get_asm( + print_generic_op_form=True, assume_verified=True + ) + generic = _rewrite_memspace_text(generic) + + nk_ctx = nk_ir.Context() + _nki.register_all_dialects(nk_ctx) + nk_ctx.allow_unregistered_dialects = True + with nk_ctx: + nk_mod = nk_ir.Module.parse(generic) + return nk_ctx, nk_mod + + +# --------------------------------------------------------------------------- +# MemRef access trace +# --------------------------------------------------------------------------- + + +DYN_SENTINEL = -(1 << 63) + + +@dataclass +class _Access: + """Mirror of C++ ``MemRefAccess``. + + ``indices`` contains one ``arith`` SSA value per base-rank dim (zero for + dims with no subview offset). ``dropped_dims`` marks rank-reducing or + unit-collapsed dims that should carry only a symbol offset in the NISA + affine map (no iteration dim). + + The flat-affine map is built from this triple via ``_build_nisa_map``. + """ + + base: nk_ir.Value + indices: list[nk_ir.Value] + base_type: nk_ir.MemRefType + dropped_dims: list[bool] # len == base_rank + + @property + def base_rank(self) -> int: + return self.base_type.rank # type: ignore[attr-defined] + + +def _reassoc_groups(attr: nk_ir.Attribute) -> list[list[int]]: + outer = nk_ir.ArrayAttr(attr) + groups: list[list[int]] = [] + for g in outer: + inner = nk_ir.ArrayAttr(g) + groups.append([int(nk_ir.IntegerAttr(x).value) for x in inner]) + return groups + + +def _const_int(v: nk_ir.Value) -> int | None: + owner = getattr(v, "owner", None) + if owner is None: + return None + op = owner.opview if hasattr(owner, "opview") else owner + if getattr(op, "name", None) != "arith.constant": + return None + try: + attr = op.attributes["value"] + except KeyError: + return None + try: + return nk_ir.IntegerAttr(attr).value + except Exception: + return None + + +def _emit_const_index(ctx: nk_ir.Context, value: int, loc: nk_ir.Location) -> nk_ir.Value: + idx_ty = nk_ir.IndexType.get(ctx) + attr = nk_ir.IntegerAttr.get(idx_ty, value) + op = nk_ir.Operation.create( + "arith.constant", results=[idx_ty], attributes={"value": attr}, loc=loc + ) + return op.result + + +def _emit_addi(a: nk_ir.Value, b: nk_ir.Value, ctx: nk_ir.Context, + loc: nk_ir.Location) -> nk_ir.Value: + idx_ty = nk_ir.IndexType.get(ctx) + op = nk_ir.Operation.create( + "arith.addi", results=[idx_ty], operands=[a, b], loc=loc + ) + return op.result + + +def _emit_muli(a: nk_ir.Value, b: nk_ir.Value, ctx: nk_ir.Context, + loc: nk_ir.Location) -> nk_ir.Value: + idx_ty = nk_ir.IndexType.get(ctx) + op = nk_ir.Operation.create( + "arith.muli", results=[idx_ty], operands=[a, b], loc=loc + ) + return op.result + + +def _emit_divui(a: nk_ir.Value, b: nk_ir.Value, ctx: nk_ir.Context, + loc: nk_ir.Location) -> nk_ir.Value: + idx_ty = nk_ir.IndexType.get(ctx) + op = nk_ir.Operation.create( + "arith.divui", results=[idx_ty], operands=[a, b], loc=loc + ) + return op.result + + +def _get_base_and_offsets(ctx: nk_ir.Context, operand: nk_ir.Value, + loc: nk_ir.Location) -> _Access: + """Port of C++ ``getBaseAndOffsets``. Walks subview/collapse/expand chains + back to the base alloc or block arg, materialising arith ops as needed. + + The current insertion point must be positioned where new arith ops can be + safely emitted (typically the op being rewritten). + """ + base = operand + base_type = operand.type + indices: list[nk_ir.Value] = [] + dropped_dims: list[bool] = [] + + changed = True + while changed: + changed = False + owner = getattr(base, "owner", None) + if owner is None: + break + op = owner.opview if hasattr(owner, "opview") else owner + name = getattr(op, "name", None) + + if name == "memref.subview": + source = op.operands[0] + source_ty = source.type + if not isinstance(source_ty, nk_ir.MemRefType): + break + src_rank = source_ty.rank + try: + static_offsets = [int(x) for x in op.attributes["static_offsets"]] + except (KeyError, ValueError): + break + if len(static_offsets) != src_rank: + break + dyn_ops = list(op.operands)[1:] + dyn_idx = 0 + subview_offsets: list[nk_ir.Value] = [] + for i in range(src_rank): + if static_offsets[i] == DYN_SENTINEL: + subview_offsets.append(dyn_ops[dyn_idx]) + dyn_idx += 1 + else: + subview_offsets.append( + _emit_const_index(ctx, static_offsets[i], loc) + ) + + # Rank-reducing detection via static_sizes==1 vs result shape. + try: + static_sizes = [int(x) for x in op.attributes["static_sizes"]] + except (KeyError, ValueError): + break + result_shape = list(getattr(op.results[0].type, "shape", ())) + # Determine dropped dims: rank-reducing means result_rank < src_rank. + # Heuristic: dims with static_sizes[i]==1 that are NOT present in + # the result shape are dropped. We align by order: walk source dims, + # match to result in order, preferring size equality. + dropped = [False] * src_rank + if len(result_shape) < src_rank: + ri = 0 + for si in range(src_rank): + if ri < len(result_shape) and static_sizes[si] == result_shape[ri]: + ri += 1 + else: + # If size == 1 and we haven't matched yet, drop it. + if static_sizes[si] == 1: + dropped[si] = True + elif ri < len(result_shape) and static_sizes[si] == 1: + dropped[si] = True + else: + # fall back: mark as dropped if no remaining result dim + # matches + dropped[si] = True + # If we ended up dropping too many or too few, bail + non_dropped = sum(1 for d in dropped if not d) + if non_dropped != len(result_shape): + break + + # Accumulate offsets: + if not indices: + # No carried indices yet. Usually this is the first subview in + # the chain, but it can also follow a collapse_shape that + # reduced a multi-dim HBM operand to a lower-rank view — in + # which case `dropped_dims` is already populated and must be + # preserved. The subview is same-rank here (src_rank == + # result_rank) for a collapse→subview chain, so just seed + # indices with the subview offsets and merge `dropped` with + # any preserved `dropped_dims` element-wise. + indices = subview_offsets + if dropped_dims and len(dropped_dims) == src_rank: + dropped_dims = [a or b for a, b in zip(dropped_dims, dropped)] + else: + dropped_dims = dropped + else: + # Nested subview. Expand current indices (in result-rank space + # after prior ops) to source rank. + if any(dropped): + # Rank-reducing: dropped dims get pure subview offset; + # kept dims get accumulated + subview offset. + expanded: list[nk_ir.Value] = [] + kept_idx = 0 + for si in range(src_rank): + if dropped[si]: + expanded.append(subview_offsets[si]) + else: + assert kept_idx < len(indices) + expanded.append( + _emit_addi(indices[kept_idx], subview_offsets[si], + ctx, loc) + ) + kept_idx += 1 + indices = expanded + merged_dropped = [False] * src_rank + kept_idx = 0 + for si in range(src_rank): + if dropped[si]: + merged_dropped[si] = True + else: + if kept_idx < len(dropped_dims) and dropped_dims[kept_idx]: + merged_dropped[si] = True + kept_idx += 1 + dropped_dims = merged_dropped + else: + # Same-rank subview: add element-wise. + if len(indices) != src_rank: + break + indices = [ + _emit_addi(indices[i], subview_offsets[i], ctx, loc) + for i in range(src_rank) + ] + + base = source + base_type = source_ty + changed = True + continue + + if name == "memref.collapse_shape": + source = op.operands[0] + source_ty = source.type + if not isinstance(source_ty, nk_ir.MemRefType): + break + try: + groups = _reassoc_groups(op.attributes["reassociation"]) + except KeyError: + break + src_shape = list(getattr(source_ty, "shape", ())) + if any(s < 0 for s in src_shape): + break + src_rank = len(src_shape) + + # Determine whether any group has multiple non-unit dims. + has_multi_non_unit = any( + sum(1 for d in grp if src_shape[d] != 1) > 1 for grp in groups + ) + + ms = getattr(source_ty, "memory_space", None) + is_hbm = ms is not None and ( + "" in str(ms) or "" in str(ms) + ) + + if has_multi_non_unit and is_hbm: + # Stop tracing — HBM collapse is fine as-is; NCC handles it. + break + + # Find primary dim (largest size) per group — for multi-non-unit. + def primary_dim(group: list[int]) -> int: + best_idx, best_size = -1, 0 + for i, d in enumerate(group): + if src_shape[d] > best_size: + best_size = src_shape[d] + best_idx = i + return best_idx + + # Expand droppedDims from collapsed to source rank. + expanded_dropped = [False] * src_rank + if is_hbm: + for gi, grp in enumerate(groups): + collapsed_dropped = ( + gi < len(dropped_dims) and dropped_dims[gi] if dropped_dims + else False + ) + if collapsed_dropped: + for d in grp: + expanded_dropped[d] = True + elif len(grp) > 1: + for d in grp: + if src_shape[d] == 1: + expanded_dropped[d] = True + else: + for gi, grp in enumerate(groups): + collapsed_dropped = ( + gi < len(dropped_dims) and dropped_dims[gi] if dropped_dims + else False + ) + if collapsed_dropped: + for d in grp: + expanded_dropped[d] = True + elif len(grp) > 1 and has_multi_non_unit: + p = primary_dim(grp) + for i, d in enumerate(grp): + if i != p: + expanded_dropped[d] = True + elif len(grp) > 1: + # Drop only the unit dims, but keep one per group so + # the group still has a home for its iteration dim. + # For single-non-unit groups that dim is obviously + # the keeper; for all-unit groups we keep the first + # position so downstream affine-map construction has + # somewhere to place d_gi. + non_unit = [d for d in grp if src_shape[d] != 1] + keeper = non_unit[0] if non_unit else grp[0] + for d in grp: + if d != keeper: + expanded_dropped[d] = True + + # Expand indices from collapsed rank to source rank. + if indices: + expanded_indices: list[nk_ir.Value] = [] + for gi, grp in enumerate(groups): + if len(grp) == 1: + expanded_indices.append(indices[gi]) + else: + non_unit_count = sum(1 for d in grp if src_shape[d] != 1) + if non_unit_count <= 1: + zero = _emit_const_index(ctx, 0, loc) + for d in grp: + if src_shape[d] != 1: + expanded_indices.append(indices[gi]) + else: + expanded_indices.append(zero) + else: + p = primary_dim(grp) + primary_size = src_shape[grp[p]] + zero = _emit_const_index(ctx, 0, loc) + size_val = _emit_const_index(ctx, primary_size, loc) + batch = _emit_divui(indices[gi], size_val, ctx, loc) + for i, d in enumerate(grp): + if i == p: + expanded_indices.append(zero) + elif src_shape[d] == 1: + expanded_indices.append(zero) + else: + expanded_indices.append(batch) + indices = expanded_indices + + dropped_dims = expanded_dropped + base = source + base_type = source_ty + changed = True + continue + + if name == "memref.expand_shape": + source = op.operands[0] + source_ty = source.type + if not isinstance(source_ty, nk_ir.MemRefType): + break + try: + groups = _reassoc_groups(op.attributes["reassociation"]) + except KeyError: + break + dst_shape = list(getattr(op.results[0].type, "shape", ())) + src_rank = source_ty.rank + + if indices: + src_indices: list[nk_ir.Value] = [] + for grp in groups: + combined = indices[grp[0]] + for k in range(1, len(grp)): + inner_size = dst_shape[grp[k]] + scale = _emit_const_index(ctx, inner_size, loc) + combined = _emit_muli(combined, scale, ctx, loc) + combined = _emit_addi(combined, indices[grp[k]], ctx, loc) + src_indices.append(combined) + indices = src_indices + + if dropped_dims: + new_dropped = [False] * src_rank + for gi, grp in enumerate(groups): + all_dropped = all( + d < len(dropped_dims) and dropped_dims[d] for d in grp + ) + if all_dropped: + new_dropped[gi] = True + dropped_dims = new_dropped + + base = source + base_type = source_ty + changed = True + continue + + if name == "memref.reinterpret_cast": + # Pass-through when it preserves rank and has zero static offsets. + source = op.operands[0] + source_ty = source.type + if not isinstance(source_ty, nk_ir.MemRefType): + break + src_shape = list(getattr(source_ty, "shape", ())) + dst_shape = list(getattr(op.results[0].type, "shape", ())) + if src_shape != dst_shape: + break + try: + st_off = [int(x) for x in op.attributes["static_offsets"]] + except (KeyError, ValueError): + st_off = [] + if any(x != 0 for x in st_off): + break + base = source + base_type = source_ty + changed = True + continue + + break + + # Ensure we have rank-many indices (fresh allocs / block args need zeros). + if not indices: + rank = base_type.rank # type: ignore[attr-defined] + indices = [_emit_const_index(ctx, 0, loc) for _ in range(rank)] + + # Normalise dropped_dims length to base rank. + rank = base_type.rank # type: ignore[attr-defined] + if len(dropped_dims) < rank: + dropped_dims = dropped_dims + [False] * (rank - len(dropped_dims)) + elif len(dropped_dims) > rank: + dropped_dims = dropped_dims[:rank] + + return _Access( + base=base, + indices=indices, + base_type=base_type, # type: ignore[arg-type] + dropped_dims=dropped_dims, + ) + + +# --------------------------------------------------------------------------- +# Affine map construction +# --------------------------------------------------------------------------- + + +def _create_standard_nisa_map( + ctx: nk_ir.Context, + num_iter_dims: int, + num_symbols: int, + num_results: int, + dropped_dims: list[bool], +) -> nk_ir.AffineMap: + """Python mirror of C++ ``createStandardNisaMap``. + + d0 -> first kept position, d(N-1) -> last kept position, middle iter dims + map to middle kept positions in order. Dropped dims get pure symbol or + constant expressions; any remaining result has symbol or 0. + """ + kept_positions = [ + i for i in range(num_results) + if not (dropped_dims and i < len(dropped_dims) and dropped_dims[i]) + ] + + position_to_dim: dict[int, int] = {} + if kept_positions and num_iter_dims > 0: + position_to_dim[kept_positions[0]] = 0 + if num_iter_dims > 1 and len(kept_positions) > 1: + position_to_dim[kept_positions[-1]] = num_iter_dims - 1 + mid_dim = 1 + for k in range(1, len(kept_positions) - 1): + if mid_dim + 1 >= num_iter_dims: + break + position_to_dim[kept_positions[k]] = mid_dim + mid_dim += 1 + + exprs: list[nk_ir.AffineExpr] = [] + for i in range(num_results): + is_dropped = bool(dropped_dims) and i < len(dropped_dims) and dropped_dims[i] + if is_dropped: + if i < num_symbols: + exprs.append(nk_ir.AffineSymbolExpr.get(i)) + else: + exprs.append(nk_ir.AffineConstantExpr.get(0)) + elif i in position_to_dim: + dim_expr = nk_ir.AffineDimExpr.get(position_to_dim[i]) + if i < num_symbols: + exprs.append(dim_expr + nk_ir.AffineSymbolExpr.get(i)) + else: + exprs.append(dim_expr) + else: + if i < num_symbols: + exprs.append(nk_ir.AffineSymbolExpr.get(i)) + else: + exprs.append(nk_ir.AffineConstantExpr.get(0)) + return nk_ir.AffineMap.get(num_iter_dims, num_symbols, exprs) + + +def _build_nisa_map( + ctx: nk_ir.Context, num_iter_dims: int, access: _Access +) -> nk_ir.Attribute: + amap = _create_standard_nisa_map( + ctx, + num_iter_dims, + num_symbols=len(access.indices), + num_results=access.base_rank, + dropped_dims=access.dropped_dims, + ) + return nisa.flatten_affine_map(amap, ctx) + + +def _operand_kwargs( + prefix: str, + access: _Access, + flat_map: nk_ir.Attribute, + tile_shape: list[int], + tile_par_dims: int = 1, +) -> dict: + return { + f"{prefix}_memloc": access.base, + f"{prefix}_indices": access.indices, + f"{prefix}_ap": flat_map, + f"{prefix}_static_tile_shape": list(tile_shape), + f"{prefix}_tile_par_dims": tile_par_dims, + } + + +def _empty_operand_kwargs(prefix: str) -> dict: + """Kwargs for an 'omitted' optional operand (ap=None, memloc=None).""" + return { + f"{prefix}_memloc": None, + f"{prefix}_indices": [], + f"{prefix}_ap": None, + f"{prefix}_static_tile_shape": [], + f"{prefix}_tile_par_dims": 0, + } + + +def _scalar_operand_kwargs(prefix: str, scalar: nk_ir.Value) -> dict: + return { + f"{prefix}_memloc": scalar, + f"{prefix}_indices": [], + f"{prefix}_ap": None, + f"{prefix}_static_tile_shape": [], + f"{prefix}_tile_par_dims": 0, + } + + +# --------------------------------------------------------------------------- +# Rewrite context +# --------------------------------------------------------------------------- + + +_Pattern = Callable[["_RewriteContext", nk_ir.OpView], None] +_PATTERNS: dict[str, _Pattern] = {} + + +def pattern(*op_names: str) -> Callable[[_Pattern], _Pattern]: + def register(fn: _Pattern) -> _Pattern: + for name in op_names: + _PATTERNS[name] = fn + return fn + + return register + + +_LINALG_TO_ARITH_OP = { + "linalg.add": nisa.ArithOp.Add, + "linalg.sub": nisa.ArithOp.Subtract, + "linalg.mul": nisa.ArithOp.Multiply, + "linalg.max": nisa.ArithOp.Max, + "linalg.min": nisa.ArithOp.Min, +} + +_REDUCE_BODY_OP_TO_ARITH = { + "arith.addf": nisa.ArithOp.Add, + "arith.addi": nisa.ArithOp.Add, + "arith.mulf": nisa.ArithOp.Multiply, + "arith.muli": nisa.ArithOp.Multiply, + "arith.maximumf": nisa.ArithOp.Max, + "arith.minimumf": nisa.ArithOp.Min, +} + +_ARITH_TO_CROSS_LANE = { + nisa.ArithOp.Add: nisa.CrossLaneReduceArithOp.Add, + nisa.ArithOp.Max: nisa.CrossLaneReduceArithOp.Max, +} + + +class _RewriteContext: + def __init__(self, ctx: nk_ir.Context, module: nk_ir.Module): + self.ctx = ctx + self.module = module + self.loc = nk_ir.Location.unknown(ctx) + self._f32_const_cache: dict[tuple[int, float], nk_ir.Value] = {} + + def f32_const(self, block: nk_ir.Block, value: float) -> nk_ir.Value: + key = (id(block), value) + cached = self._f32_const_cache.get(key) + if cached is not None: + return cached + with self.loc: + f32 = nk_ir.F32Type.get(self.ctx) + attr = nk_ir.FloatAttr.get(f32, value) + with nk_ir.InsertionPoint.at_block_begin(block): + const_op = nk_ir.Operation.create( + "arith.constant", + results=[f32], + attributes={"value": attr}, + loc=self.loc, + ) + self._f32_const_cache[key] = const_op.result + return const_op.result + + +def _enclosing_block(op: nk_ir.OpView) -> nk_ir.Block: + return op.operation.block # type: ignore[attr-defined] + + +def _is_memspace(ty: nk_ir.Type, name: str) -> bool: + ms = getattr(ty, "memory_space", None) + if ms is None: + return False + return f"<{name}>" in str(ms) + + +def _is_hbm(ty: nk_ir.Type) -> bool: + return _is_memspace(ty, "hbm") or _is_memspace(ty, "shared_hbm") + + +def _is_sbuf(ty: nk_ir.Type) -> bool: + return _is_memspace(ty, "sbuf") + + +def _is_psum(ty: nk_ir.Type) -> bool: + return _is_memspace(ty, "psum") + + +def _static_shape(ty: nk_ir.Type) -> list[int] | None: + shape = list(getattr(ty, "shape", ())) + if any(s < 0 for s in shape): + return None + return shape + + +# --------------------------------------------------------------------------- +# Elementwise: linalg.add/sub/mul/max/min -> nisa.tensor_tensor_arith +# --------------------------------------------------------------------------- + + +@pattern("linalg.add", "linalg.sub", "linalg.mul", "linalg.max", "linalg.min") +def _rewrite_elementwise(rctx: _RewriteContext, op: nk_ir.OpView) -> None: + arith_kind = _LINALG_TO_ARITH_OP[op.operation.name] + operands = list(op.operation.operands) + if len(operands) < 3: + return + lhs, rhs, dst = operands[0], operands[1], operands[2] + + shape = _static_shape(dst.type) + if shape is None: + return + + with nk_ir.InsertionPoint(op), rctx.loc: + dst_acc = _get_base_and_offsets(rctx.ctx, dst, rctx.loc) + lhs_acc = _get_base_and_offsets(rctx.ctx, lhs, rctx.loc) + rhs_acc = _get_base_and_offsets(rctx.ctx, rhs, rctx.loc) + rank = len(shape) + dst_map = _build_nisa_map(rctx.ctx, rank, dst_acc) + lhs_map = _build_nisa_map(rctx.ctx, rank, lhs_acc) + rhs_map = _build_nisa_map(rctx.ctx, rank, rhs_acc) + + kwargs: dict = {} + kwargs.update(_operand_kwargs("dst", dst_acc, dst_map, shape)) + kwargs.update(_operand_kwargs("lhs", lhs_acc, lhs_map, shape)) + kwargs.update(_operand_kwargs("rhs", rhs_acc, rhs_map, shape)) + nisa.tensor_tensor_arith( + op=arith_kind, engine=nisa.Engine.Vector, **kwargs + ) + op.operation.erase() + + +# --------------------------------------------------------------------------- +# memref.copy + linalg.copy +# --------------------------------------------------------------------------- + + +def _pad_shape_to_2d(shape: list[int]) -> list[int]: + if len(shape) < 2: + return shape + [1] * (2 - len(shape)) + return shape + + +@pattern("memref.copy") +def _rewrite_memref_copy(rctx: _RewriteContext, op: nk_ir.OpView) -> None: + src = op.operation.operands[0] + dst = op.operation.operands[1] + src_ty = src.type + dst_ty = dst.type + + src_hbm, dst_hbm = _is_hbm(src_ty), _is_hbm(dst_ty) + src_sbuf, dst_sbuf = _is_sbuf(src_ty), _is_sbuf(dst_ty) + src_psum, dst_psum = _is_psum(src_ty), _is_psum(dst_ty) + + needs_dma = src_hbm or dst_hbm + on_tpb = ( + (src_sbuf and dst_sbuf) or (src_sbuf and dst_psum) or (src_psum and dst_sbuf) + ) + if not (needs_dma or on_tpb): + return + + shape = _static_shape(dst_ty) + if shape is None: + return + shape = _pad_shape_to_2d(shape) + + with nk_ir.InsertionPoint(op), rctx.loc: + src_acc = _get_base_and_offsets(rctx.ctx, src, rctx.loc) + dst_acc = _get_base_and_offsets(rctx.ctx, dst, rctx.loc) + + num_iter = len(shape) + src_map = _build_nisa_map(rctx.ctx, num_iter, src_acc) + dst_map = _build_nisa_map(rctx.ctx, num_iter, dst_acc) + + needs_psum_hop = needs_dma and ( + (src_hbm and dst_psum) or (src_psum and dst_hbm) + ) + + if needs_psum_hop: + sbuf_attr = nk_ir.Attribute.parse("#nisa.mem") + inter_ty = nk_ir.MemRefType.get( + shape, dst_ty.element_type, memory_space=sbuf_attr # type: ignore[attr-defined] + ) + inter_val = nisa.alloc(memref_type=inter_ty, alignment=64) + inter_acc = _get_base_and_offsets(rctx.ctx, inter_val, rctx.loc) + inter_map = _build_nisa_map(rctx.ctx, num_iter, inter_acc) + if src_hbm and dst_psum: + nisa.dma_copy( + **_operand_kwargs("dst", inter_acc, inter_map, shape), + **_operand_kwargs("src", src_acc, src_map, shape), + ) + nisa.tensor_copy( + **_operand_kwargs("dst", dst_acc, dst_map, shape), + **_operand_kwargs("src", inter_acc, inter_map, shape), + engine=nisa.Engine.Vector, + ) + else: + nisa.tensor_copy( + **_operand_kwargs("dst", inter_acc, inter_map, shape), + **_operand_kwargs("src", src_acc, src_map, shape), + engine=nisa.Engine.Vector, + ) + nisa.dma_copy( + **_operand_kwargs("dst", dst_acc, dst_map, shape), + **_operand_kwargs("src", inter_acc, inter_map, shape), + ) + elif needs_dma: + nisa.dma_copy( + **_operand_kwargs("dst", dst_acc, dst_map, shape), + **_operand_kwargs("src", src_acc, src_map, shape), + ) + else: + nisa.tensor_copy( + **_operand_kwargs("dst", dst_acc, dst_map, shape), + **_operand_kwargs("src", src_acc, src_map, shape), + engine=nisa.Engine.Vector, + ) + op.operation.erase() + + +@pattern("linalg.copy") +def _rewrite_linalg_copy(rctx: _RewriteContext, op: nk_ir.OpView) -> None: + operands = list(op.operation.operands) + if len(operands) < 2: + return + src, dst = operands[0], operands[1] + src_ty, dst_ty = src.type, dst.type + src_sbuf, dst_sbuf = _is_sbuf(src_ty), _is_sbuf(dst_ty) + src_psum, dst_psum = _is_psum(src_ty), _is_psum(dst_ty) + if not ((src_sbuf and dst_sbuf) or (src_sbuf and dst_psum) or (src_psum and dst_sbuf)): + return + shape = _static_shape(dst_ty) + if shape is None: + return + with nk_ir.InsertionPoint(op), rctx.loc: + src_acc = _get_base_and_offsets(rctx.ctx, src, rctx.loc) + dst_acc = _get_base_and_offsets(rctx.ctx, dst, rctx.loc) + rank = len(shape) + src_map = _build_nisa_map(rctx.ctx, rank, src_acc) + dst_map = _build_nisa_map(rctx.ctx, rank, dst_acc) + nisa.tensor_copy( + **_operand_kwargs("dst", dst_acc, dst_map, shape), + **_operand_kwargs("src", src_acc, src_map, shape), + engine=nisa.Engine.Vector, + ) + op.operation.erase() + + +# --------------------------------------------------------------------------- +# memref.alloc / memref.dealloc +# --------------------------------------------------------------------------- + + +@pattern("memref.alloc") +def _rewrite_memref_alloc(rctx: _RewriteContext, op: nk_ir.OpView) -> None: + result_ty = op.operation.results[0].type + if not (_is_sbuf(result_ty) or _is_psum(result_ty) or _is_hbm(result_ty)): + return + + alignment = 0 + attrs = op.operation.attributes + if "alignment" in attrs: + alignment = nk_ir.IntegerAttr(attrs["alignment"]).value + + with nk_ir.InsertionPoint(op), rctx.loc: + new_val = nisa.alloc(memref_type=result_ty, alignment=alignment) + + op.operation.results[0].replace_all_uses_with(new_val) + op.operation.erase() + + +def _fold_reinterpret_casts(rctx: _RewriteContext) -> None: + casts: list[nk_ir.OpView] = [] + + def visit(op_handle: nk_ir.Operation) -> nk_ir.WalkResult: + if op_handle.name == "memref.reinterpret_cast": + casts.append(op_handle.opview) + return nk_ir.WalkResult.ADVANCE + + rctx.module.operation.walk(visit) + + for cast_op in casts: + src = cast_op.operation.operands[0] + src_owner = getattr(src, "owner", None) + if src_owner is None: + continue + src_op = src_owner.opview if hasattr(src_owner, "opview") else src_owner + if getattr(src_op, "name", None) != "nisa.alloc": + continue + try: + st_off = [int(x) for x in cast_op.operation.attributes["static_offsets"]] + if any(x != 0 for x in st_off): + continue + except (KeyError, ValueError): + continue + + new_ty = cast_op.operation.results[0].type + + alignment = 0 + if "alignment" in src_op.attributes: + alignment = nk_ir.IntegerAttr(src_op.attributes["alignment"]).value + + with nk_ir.InsertionPoint(src_op), rctx.loc: + new_alloc = nisa.alloc(memref_type=new_ty, alignment=alignment) + + cast_op.operation.results[0].replace_all_uses_with(new_alloc) + cast_op.operation.erase() + if list(src.uses): + src.replace_all_uses_with(new_alloc) + src_op.erase() + + +@pattern("memref.dealloc") +def _rewrite_memref_dealloc(rctx: _RewriteContext, op: nk_ir.OpView) -> None: + target = op.operation.operands[0] + target_ty = target.type + if not (_is_sbuf(target_ty) or _is_psum(target_ty)): + return + with nk_ir.InsertionPoint(op), rctx.loc: + nisa.release(memref=target) + op.operation.erase() + + +# --------------------------------------------------------------------------- +# linalg.transpose +# --------------------------------------------------------------------------- + + +def _non_unit_dims(shape: list[int]) -> list[int]: + return [i for i, s in enumerate(shape) if s != 1] + + +@pattern("linalg.transpose") +def _rewrite_linalg_transpose(rctx: _RewriteContext, op: nk_ir.OpView) -> None: + operands = list(op.operation.operands) + if len(operands) < 2: + return + src, dst = operands[0], operands[1] + src_ty, dst_ty = src.type, dst.type + src_shape = _static_shape(src_ty) + dst_shape = _static_shape(dst_ty) + if src_shape is None or dst_shape is None: + return + + attrs = op.operation.attributes + if "permutation" not in attrs: + return + perm_str = str(attrs["permutation"]) + try: + inside = perm_str.split(":", 1)[1].rstrip(">").strip() + perm = [int(x.strip()) for x in inside.split(",") if x.strip()] + except (IndexError, ValueError): + return + + non_unit_src = _non_unit_dims(src_shape) + if len(non_unit_src) > 2: + return + + needs_transpose = False + if len(non_unit_src) == 2: + s0, s1 = non_unit_src[0], non_unit_src[1] + d0 = perm.index(s0) + d1 = perm.index(s1) + needs_transpose = d0 > d1 + + src_hbm = _is_hbm(src_ty) + src_sbuf = _is_sbuf(src_ty) + dst_sbuf = _is_sbuf(dst_ty) + dst_hbm = _is_hbm(dst_ty) + + # 2D tile shapes derived from non-unit dims. + src_tile = [src_shape[d] for d in non_unit_src] + dst_tile = [dst_shape[d] for d in _non_unit_dims(dst_shape)] + while len(src_tile) < 2: + src_tile.append(1) + while len(dst_tile) < 2: + dst_tile.append(1) + + with nk_ir.InsertionPoint(op), rctx.loc: + src_acc = _get_base_and_offsets(rctx.ctx, src, rctx.loc) + dst_acc = _get_base_and_offsets(rctx.ctx, dst, rctx.loc) + num_iter = 2 + src_map = _build_nisa_map(rctx.ctx, num_iter, src_acc) + dst_map = _build_nisa_map(rctx.ctx, num_iter, dst_acc) + + if needs_transpose: + if not ((src_hbm or src_sbuf) and dst_sbuf): + return + nisa.dma_transpose( + **_operand_kwargs("dst", dst_acc, dst_map, dst_tile), + **_operand_kwargs("src", src_acc, src_map, src_tile), + permutation=[1, 0], + dge_mode=nisa.DGEType.NoDGE, + oob_is_err=True, + engine=nisa.Engine.DMA, + ) + op.operation.erase() + return + + cross = (src_hbm and dst_sbuf) or (src_sbuf and dst_hbm) + if cross: + nisa.dma_copy( + **_operand_kwargs("dst", dst_acc, dst_map, dst_tile), + **_operand_kwargs("src", src_acc, src_map, src_tile), + ) + op.operation.erase() + return + + +# --------------------------------------------------------------------------- +# linalg.matmul_transpose_a +# --------------------------------------------------------------------------- + + +def _index_const(rctx: _RewriteContext, value: int) -> nk_ir.Value: + with rctx.loc: + idx_ty = nk_ir.IndexType.get(rctx.ctx) + attr = nk_ir.IntegerAttr.get(idx_ty, value) + const_op = nk_ir.Operation.create( + "arith.constant", + results=[idx_ty], + attributes={"value": attr}, + loc=rctx.loc, + ) + return const_op.result + + +@pattern("linalg.matmul_transpose_a") +def _rewrite_matmul_transpose_a(rctx: _RewriteContext, op: nk_ir.OpView) -> None: + operands = list(op.operation.operands) + if len(operands) != 3: + return + mat_a, mat_b, mat_c = operands + a_ty, b_ty, c_ty = mat_a.type, mat_b.type, mat_c.type + + a_shape = _static_shape(a_ty) + b_shape = _static_shape(b_ty) + c_shape = _static_shape(c_ty) + if a_shape is None or b_shape is None or c_shape is None: + return + if len(a_shape) != 2 or len(b_shape) != 2 or len(c_shape) != 2: + return + K, M = a_shape + if b_shape[0] != K or c_shape[0] != M or c_shape[1] != b_shape[1]: + return + + if not (_is_sbuf(a_ty) and _is_sbuf(b_ty) and _is_psum(c_ty)): + return + + N = b_shape[1] + + with nk_ir.InsertionPoint(op), rctx.loc: + a_acc = _get_base_and_offsets(rctx.ctx, mat_a, rctx.loc) + b_acc = _get_base_and_offsets(rctx.ctx, mat_b, rctx.loc) + c_acc = _get_base_and_offsets(rctx.ctx, mat_c, rctx.loc) + a_map = _build_nisa_map(rctx.ctx, 2, a_acc) + b_map = _build_nisa_map(rctx.ctx, 2, b_acc) + c_map = _build_nisa_map(rctx.ctx, 2, c_acc) + + row_pos = _index_const(rctx, 0) + col_pos = _index_const(rctx, 0) + nisa.matmul( + **_operand_kwargs("dst", c_acc, c_map, [M, N]), + **_operand_kwargs("stationary", a_acc, a_map, [K, M]), + **_operand_kwargs("moving", b_acc, b_map, [K, N]), + row_pos=row_pos, + col_pos=col_pos, + psum_accumulate_flags=None, + is_transpose=False, + perf_opt=nisa.PerfOptMode.None_, + psum_zero_region=nisa.MatmulZeroRegion.Size2048, + engine=nisa.Engine.Tensor, + ) + op.operation.erase() + + +# --------------------------------------------------------------------------- +# linalg.reciprocal +# --------------------------------------------------------------------------- + + +@pattern("linalg.reciprocal") +def _rewrite_reciprocal(rctx: _RewriteContext, op: nk_ir.OpView) -> None: + operands = list(op.operation.operands) + if len(operands) < 2: + return + src, dst = operands[0], operands[1] + if not ( + (_is_sbuf(src.type) or _is_psum(src.type)) + and (_is_sbuf(dst.type) or _is_psum(dst.type)) + ): + return + shape = _static_shape(dst.type) + if shape is None: + return + with nk_ir.InsertionPoint(op), rctx.loc: + src_acc = _get_base_and_offsets(rctx.ctx, src, rctx.loc) + dst_acc = _get_base_and_offsets(rctx.ctx, dst, rctx.loc) + rank = len(shape) + src_map = _build_nisa_map(rctx.ctx, rank, src_acc) + dst_map = _build_nisa_map(rctx.ctx, rank, dst_acc) + nisa.reciprocal( + **_operand_kwargs("dst", dst_acc, dst_map, shape), + **_operand_kwargs("src", src_acc, src_map, shape), + engine=nisa.Engine.Vector, + ) + op.operation.erase() + + +# --------------------------------------------------------------------------- +# linalg.{exp, square, sqrt, abs, log, tanh} -> nisa.activation +# --------------------------------------------------------------------------- + + +_LINALG_TO_ACTIVATION = { + "linalg.exp": nisa.ActivationFunction.exp, + "linalg.square": nisa.ActivationFunction.square, + "linalg.sqrt": nisa.ActivationFunction.sqrt, + "linalg.abs": nisa.ActivationFunction.abs, + "linalg.log": nisa.ActivationFunction.log, + "linalg.tanh": nisa.ActivationFunction.tanh, +} + + +def _emit_activation( + rctx: _RewriteContext, + op: nk_ir.OpView, + src: nk_ir.Value, + dst: nk_ir.Value, + act_kind, +) -> bool: + if not (_is_sbuf(src.type) and _is_sbuf(dst.type)): + return False + shape = _static_shape(dst.type) + if shape is None: + return False + + block = _enclosing_block(op) + bias = rctx.f32_const(block, 0.0) + scale = rctx.f32_const(block, 1.0) + + with nk_ir.InsertionPoint(op), rctx.loc: + src_acc = _get_base_and_offsets(rctx.ctx, src, rctx.loc) + dst_acc = _get_base_and_offsets(rctx.ctx, dst, rctx.loc) + rank = len(shape) + src_map = _build_nisa_map(rctx.ctx, rank, src_acc) + dst_map = _build_nisa_map(rctx.ctx, rank, dst_acc) + nisa.activation( + **_operand_kwargs("dst", dst_acc, dst_map, shape), + **_empty_operand_kwargs("reduce_res"), + **_operand_kwargs("src", src_acc, src_map, shape), + **_scalar_operand_kwargs("bias", bias), + **_scalar_operand_kwargs("scale", scale), + **_empty_operand_kwargs("alpha"), + op=act_kind, + engine=nisa.Engine.Scalar, + ) + return True + + +@pattern(*_LINALG_TO_ACTIVATION.keys()) +def _rewrite_linalg_activation(rctx: _RewriteContext, op: nk_ir.OpView) -> None: + act_kind = _LINALG_TO_ACTIVATION[op.operation.name] + operands = list(op.operation.operands) + if len(operands) < 2: + return + if _emit_activation(rctx, op, operands[0], operands[1], act_kind): + op.operation.erase() + + +# --------------------------------------------------------------------------- +# linalg.fill -> nisa.memset +# --------------------------------------------------------------------------- + + +@pattern("linalg.fill") +def _rewrite_linalg_fill(rctx: _RewriteContext, op: nk_ir.OpView) -> None: + operands = list(op.operation.operands) + if len(operands) < 2: + return + scalar, dst = operands[0], operands[1] + dst_ty = dst.type + if not (_is_sbuf(dst_ty) or _is_psum(dst_ty)): + return + shape = _static_shape(dst_ty) + if shape is None: + return + with nk_ir.InsertionPoint(op), rctx.loc: + dst_acc = _get_base_and_offsets(rctx.ctx, dst, rctx.loc) + rank = len(shape) + dst_map = _build_nisa_map(rctx.ctx, rank, dst_acc) + nisa.memset( + **_operand_kwargs("dst", dst_acc, dst_map, shape), + value=scalar, + engine=nisa.Engine.Vector, + ) + op.operation.erase() + + +# --------------------------------------------------------------------------- +# linalg.generic (scalar / broadcast / same-shape / type cast / reduction / +# unary math / powf) +# --------------------------------------------------------------------------- + + +_BODY_OP_TO_ARITH = { + "arith.addf": nisa.ArithOp.Add, + "arith.addi": nisa.ArithOp.Add, + "arith.subf": nisa.ArithOp.Subtract, + "arith.subi": nisa.ArithOp.Subtract, + "arith.mulf": nisa.ArithOp.Multiply, + "arith.muli": nisa.ArithOp.Multiply, + "arith.divf": nisa.ArithOp.Divide, + "arith.divsi": nisa.ArithOp.DivideInt, + "arith.divui": nisa.ArithOp.DivideInt, + "arith.remf": nisa.ArithOp.Mod, + "arith.remsi": nisa.ArithOp.ModInt, +} + +_CMPF_PRED_TO_ARITH = { + 1: nisa.ArithOp.IsEQ, + 2: nisa.ArithOp.IsGT, + 3: nisa.ArithOp.IsGE, + 4: nisa.ArithOp.IsLT, + 5: nisa.ArithOp.IsLE, + 6: nisa.ArithOp.IsNE, +} + +_CMPI_PRED_TO_ARITH = { + 0: nisa.ArithOp.IsEQ, + 1: nisa.ArithOp.IsNE, + 2: nisa.ArithOp.IsLT, + 3: nisa.ArithOp.IsLE, + 4: nisa.ArithOp.IsGT, + 5: nisa.ArithOp.IsGE, +} + +_BODY_MATH_TO_ACTIVATION = { + "math.sin": nisa.ActivationFunction.sin, + "math.copysign": nisa.ActivationFunction.sign, +} + + +def _defining_op(v: nk_ir.Value): + owner = getattr(v, "owner", None) + if owner is None: + return None + return owner.opview if hasattr(owner, "opview") else owner + + +def _predicate_int(op): + attrs = op.attributes + if "predicate" not in attrs: + return None + return nk_ir.IntegerAttr(attrs["predicate"]).value + + +def _is_constant_value(v: nk_ir.Value) -> bool: + owner = getattr(v, "owner", None) + if owner is None: + return False + op = owner.opview if hasattr(owner, "opview") else owner + return getattr(op, "name", None) == "arith.constant" + + +def _shape_match_broadcast(in_shape, out_shape): + if len(in_shape) != len(out_shape): + return False + had_broadcast = False + for i, o in zip(in_shape, out_shape): + if i == 1 and o > 1: + had_broadcast = True + elif i != o: + return False + return had_broadcast + + +def _analyze_generic_body(op: nk_ir.OpView): + region = op.regions[0] + block = region.blocks[0] + ops = list(block.operations) + if not ops: + return None + yield_op = ops[-1] + if yield_op.name != "linalg.yield": + return None + yielded = list(yield_op.operands) + if len(yielded) != 1: + return None + root = _defining_op(yielded[0]) + if root is None or not hasattr(root, "name"): + return None + name = root.name + kind = _BODY_OP_TO_ARITH.get(name) + if kind is not None: + return kind, root.operands[0], root.operands[1] + if name == "arith.uitofp": + inner = _defining_op(root.operands[0]) + if inner is None: + return None + if inner.name == "arith.cmpf": + pred = _predicate_int(inner) + k = _CMPF_PRED_TO_ARITH.get(pred) if pred is not None else None + if k is not None: + return k, inner.operands[0], inner.operands[1] + if inner.name in ("arith.andi", "arith.ori"): + logical_kind = ( + nisa.ArithOp.LogicalAnd if inner.name == "arith.andi" + else nisa.ArithOp.LogicalOr + ) + lhs_cast = _defining_op(inner.operands[0]) + rhs_cast = _defining_op(inner.operands[1]) + if (lhs_cast and rhs_cast and + lhs_cast.name == "arith.fptoui" and + rhs_cast.name == "arith.fptoui"): + return logical_kind, lhs_cast.operands[0], rhs_cast.operands[0] + return None + if name == "arith.extui": + inner = _defining_op(root.operands[0]) + if inner is None or inner.name != "arith.cmpi": + return None + pred = _predicate_int(inner) + k = _CMPI_PRED_TO_ARITH.get(pred) if pred is not None else None + if k is not None: + return k, inner.operands[0], inner.operands[1] + return None + + +def _match_generic_powf(op: nk_ir.OpView): + region = op.regions[0] + block = region.blocks[0] + ops = list(block.operations) + if len(ops) != 2: + return None + inner, yield_op = ops[0], ops[1] + if inner.name != "math.powf" or yield_op.name != "linalg.yield": + return None + if list(yield_op.operands) != [inner.results[0]]: + return None + return inner.operands[0], inner.operands[1] + + +def _match_reduction_body(op: nk_ir.OpView): + region = op.regions[0] + block = region.blocks[0] + ops = list(block.operations) + if not ops or ops[-1].name != "linalg.yield": + return None + body_ops = ops[:-1] + found = None + for bo in body_ops: + k = _REDUCE_BODY_OP_TO_ARITH.get(bo.name) + if k is not None: + if found is not None: + return None + found = (k, bo) + return found + + +def _classify_reduction(op: nk_ir.OpView): + attrs = op.operation.attributes + it_str = str(attrs["iterator_types"]) + kinds = [] + for token in it_str.split("#linalg.iterator_type<"): + close = token.find(">") + if close < 0: + continue + k = token[:close] + if k in ("parallel", "reduction"): + kinds.append(k) + if not kinds: + return None + num_red = sum(1 for k in kinds if k == "reduction") + if num_red == 0: + return None + is_left = all(kinds[i] == "reduction" for i in range(num_red)) + is_right = all(kinds[-1 - i] == "reduction" for i in range(num_red)) + if not (is_left or is_right): + return None + return num_red, is_left, is_right + + +def _rewrite_linalg_generic_reduction( + rctx: _RewriteContext, op: nk_ir.OpView, num_ins: int +) -> bool: + if num_ins != 1: + return False + operands = list(op.operation.operands) + src = operands[0] + dst = operands[1] + + match = _match_reduction_body(op) + if match is None: + return False + arith_kind, inner = match + + block = op.regions[0].blocks[0] + block_args = list(block.arguments) + out_block_arg = block_args[-1] + in0 = inner.operands[0] + in1 = inner.operands[1] + if not (str(in0) == str(out_block_arg) or str(in1) == str(out_block_arg)): + return False + + classified = _classify_reduction(op) + if classified is None: + return False + num_red_dims, is_left, is_right = classified + + src_ty = src.type + dst_ty = dst.type + if not _is_sbuf(src_ty): + return False + if not (_is_sbuf(dst_ty) or _is_psum(dst_ty)): + return False + dst_shape = _static_shape(dst_ty) + src_shape = _static_shape(src_ty) + if dst_shape is None or src_shape is None: + return False + + with nk_ir.InsertionPoint(op), rctx.loc: + src_acc = _get_base_and_offsets(rctx.ctx, src, rctx.loc) + dst_acc = _get_base_and_offsets(rctx.ctx, dst, rctx.loc) + + if is_left: + cross_op = _ARITH_TO_CROSS_LANE.get(arith_kind) + if cross_op is None: + return False + # cross_lane_reduce_arith uses each operand's data shape as its + # iteration/tile domain. src spans the full input (parallel + + # partition reduction dim), dst spans the output. Using dst_shape + # for src tile (as before) made the hardware reduce a 1-wide + # slice, returning 0 for axis=0 sums. + src_map = _build_nisa_map(rctx.ctx, len(src_shape), src_acc) + dst_map = _build_nisa_map(rctx.ctx, len(dst_shape), dst_acc) + nisa.cross_lane_reduce_arith( + **_operand_kwargs("dst", dst_acc, dst_map, dst_shape), + **_operand_kwargs("src", src_acc, src_map, src_shape), + reduce_op=cross_op, + num_r_dim=0, + engine=nisa.Engine.Gpsimd, + ) + op.operation.erase() + return True + + # Rightmost reduction: tensor_reduce_arith into temp, then + # tensor_tensor_arith to accumulate into dst. + # + # tensor_reduce_arith's iteration domain spans the SOURCE shape + # (parallel dims + reduction dims), so affine maps for both src and + # temp_dst have `numSrcIterDims = len(src_shape)` dimensions. Tile + # shapes reflect each operand's own shape: src_shape for src, + # dst_shape for the dst/temp. Using dst_shape for src (as before) + # made the hardware only reduce a 1-wide slice, giving wrong sums. + num_src_iter_dims = len(src_shape) + + src_map = _build_nisa_map(rctx.ctx, num_src_iter_dims, src_acc) + dst_reduce_map = _build_nisa_map(rctx.ctx, num_src_iter_dims, dst_acc) + + temp_ty = nk_ir.MemRefType.get( + dst_shape, + dst_ty.element_type, # type: ignore[attr-defined] + memory_space=dst_ty.memory_space, # type: ignore[attr-defined] + ) + temp_val = nisa.alloc(memref_type=temp_ty, alignment=0) + temp_acc = _get_base_and_offsets(rctx.ctx, temp_val, rctx.loc) + temp_reduce_map = _build_nisa_map( + rctx.ctx, num_src_iter_dims, temp_acc, + ) + + nisa.tensor_reduce_arith( + **_operand_kwargs("dst", temp_acc, temp_reduce_map, dst_shape), + **_operand_kwargs("src", src_acc, src_map, src_shape), + op=arith_kind, + negated=False, + num_r_dim=num_red_dims, + engine=nisa.Engine.Vector, + ) + + # Accumulation uses the dst iteration domain only (parallel dims). + num_dst_iter_dims = len(dst_shape) + dst_accum_map = _build_nisa_map(rctx.ctx, num_dst_iter_dims, dst_acc) + temp_accum_map = _build_nisa_map( + rctx.ctx, num_dst_iter_dims, temp_acc, + ) + nisa.tensor_tensor_arith( + **_operand_kwargs("dst", dst_acc, dst_accum_map, dst_shape), + **_operand_kwargs("lhs", dst_acc, dst_accum_map, dst_shape), + **_operand_kwargs("rhs", temp_acc, temp_accum_map, dst_shape), + op=arith_kind, + engine=nisa.Engine.Vector, + ) + nisa.release(memref=temp_val) + + op.operation.erase() + return True + + +def _match_generic_unary_activation(op: nk_ir.OpView): + region = op.regions[0] + block = region.blocks[0] + ops = list(block.operations) + if len(ops) != 2: + return None + inner, yield_op = ops[0], ops[1] + if yield_op.name != "linalg.yield": + return None + if list(yield_op.operands) != [inner.results[0]]: + return None + return _BODY_MATH_TO_ACTIVATION.get(inner.name) + + +def _match_generic_identity_body(op: nk_ir.OpView) -> bool: + """True if the generic's body yields the first block argument directly. + + Mirrors the C++ LinalgGenericIdentityCopyPattern. Uses `walk()` to find + the yield op — touching `block.operations` (by iteration or index) + corrupts the NKI Python binding's iterator state and breaks later + matchers on the same generic. + """ + block = op.regions[0].blocks[0] + yield_op: list[nk_ir.Operation] = [] + + def visit(o: nk_ir.Operation) -> nk_ir.WalkResult: + if o.name == "linalg.yield": + yield_op.append(o) + return nk_ir.WalkResult.INTERRUPT + return nk_ir.WalkResult.ADVANCE + + op.operation.walk(visit) + if not yield_op: + return False + terminator = yield_op[0] + operands = list(terminator.operands) + if len(operands) != 1: + return False + args = list(block.arguments) + if not args: + return False + return operands[0] == args[0] + + +def _emit_copy_from_identity_generic( + rctx: _RewriteContext, + op: nk_ir.OpView, + src: nk_ir.Value, + dst: nk_ir.Value, +) -> None: + """Lower an identity-body linalg.generic to nisa.tensor_copy / dma_copy. + + Same lowering strategy as _rewrite_memref_copy / _rewrite_linalg_copy. + """ + src_ty, dst_ty = src.type, dst.type + src_hbm, dst_hbm = _is_hbm(src_ty), _is_hbm(dst_ty) + src_sbuf, dst_sbuf = _is_sbuf(src_ty), _is_sbuf(dst_ty) + src_psum, dst_psum = _is_psum(src_ty), _is_psum(dst_ty) + + needs_dma = src_hbm or dst_hbm + on_tpb = ( + (src_sbuf and dst_sbuf) or (src_sbuf and dst_psum) or (src_psum and dst_sbuf) + ) + if not (needs_dma or on_tpb): + return + + shape = _static_shape(dst_ty) + if shape is None: + return + + with nk_ir.InsertionPoint(op), rctx.loc: + src_acc = _get_base_and_offsets(rctx.ctx, src, rctx.loc) + dst_acc = _get_base_and_offsets(rctx.ctx, dst, rctx.loc) + rank = len(shape) + src_map = _build_nisa_map(rctx.ctx, rank, src_acc) + dst_map = _build_nisa_map(rctx.ctx, rank, dst_acc) + if needs_dma: + nisa.dma_copy( + **_operand_kwargs("dst", dst_acc, dst_map, shape), + **_operand_kwargs("src", src_acc, src_map, shape), + ) + else: + nisa.tensor_copy( + **_operand_kwargs("dst", dst_acc, dst_map, shape), + **_operand_kwargs("src", src_acc, src_map, shape), + engine=nisa.Engine.Vector, + ) + op.operation.erase() + + +def _match_generic_type_cast(op: nk_ir.OpView) -> bool: + region = op.regions[0] + block = region.blocks[0] + ops = list(block.operations) + if len(ops) != 2: + return False + inner, yield_op = ops[0], ops[1] + if inner.name not in ("arith.sitofp", "arith.fptosi"): + return False + if yield_op.name != "linalg.yield": + return False + if list(yield_op.operands) != [inner.results[0]]: + return False + return True + + +@pattern("linalg.generic") +def _rewrite_linalg_generic(rctx: _RewriteContext, op: nk_ir.OpView) -> None: + attrs = op.operation.attributes + if "operandSegmentSizes" not in attrs: + return + seg_str = str(attrs["operandSegmentSizes"]) + try: + nums = [int(x.strip()) for x in seg_str.split(":")[1].strip(" >").split(",")] + num_ins, num_outs = nums[0], nums[1] + except (ValueError, IndexError): + return + + if num_outs != 1: + return + + if "iterator_types" not in attrs: + return + if "reduction" in str(attrs["iterator_types"]): + _rewrite_linalg_generic_reduction(rctx, op, num_ins) + return + + operands = list(op.operation.operands) + inputs = operands[:num_ins] + output = operands[num_ins] + out_ty = output.type + out_shape = _static_shape(out_ty) + + # Identity-body generic (body is just `linalg.yield %arg0`) — lower to + # a copy. Ported from the pre-open-source LinalgGenericIdentityCopyPattern + # in LinalgToNisa.cpp. Arises from broadcast_to in the tracer, and from + # trivial transposes reconstructed by legalize-layout. + if (num_ins == 1 + and "reduction" not in str(attrs["iterator_types"]) + and _match_generic_identity_body(op)): + _emit_copy_from_identity_generic(rctx, op, inputs[0], output) + return + + if num_ins == 1: + act_kind = _match_generic_unary_activation(op) + if act_kind is not None: + if _emit_activation(rctx, op, inputs[0], output, act_kind): + op.operation.erase() + return + + if num_ins == 1 and _match_generic_type_cast(op): + if ( + "parallel" in str(attrs["iterator_types"]) + and "reduction" not in str(attrs["iterator_types"]) + and out_shape is not None + and _is_sbuf(inputs[0].type) + and _is_sbuf(out_ty) + ): + zero = rctx.f32_const(_enclosing_block(op), 0.0) + with nk_ir.InsertionPoint(op), rctx.loc: + src_acc = _get_base_and_offsets(rctx.ctx, inputs[0], rctx.loc) + dst_acc = _get_base_and_offsets(rctx.ctx, output, rctx.loc) + rank = len(out_shape) + src_map = _build_nisa_map(rctx.ctx, rank, src_acc) + dst_map = _build_nisa_map(rctx.ctx, rank, dst_acc) + nisa.tensor_scalar_arith( + **_operand_kwargs("dst", dst_acc, dst_map, out_shape), + **_operand_kwargs("src", src_acc, src_map, out_shape), + **_scalar_operand_kwargs("operand0", zero), + **_empty_operand_kwargs("operand1"), + op0=nisa.ArithOp.Add, + op1=None, + reverse_operands=nisa.TensScalarRevOps.None_, + engine=nisa.Engine.Vector, + ) + op.operation.erase() + return + + pow_match = _match_generic_powf(op) if num_ins == 2 else None + if pow_match is not None and out_shape is not None: + base, exp_v = pow_match + block_arg0 = op.regions[0].blocks[0].arguments[0] + swap = str(base) != str(block_arg0) + base_v = inputs[1] if swap else inputs[0] + exp_v = inputs[0] if swap else inputs[1] + if (_is_sbuf(base_v.type) and _is_sbuf(exp_v.type) + and _is_sbuf(out_ty)): + with nk_ir.InsertionPoint(op), rctx.loc: + base_acc = _get_base_and_offsets(rctx.ctx, base_v, rctx.loc) + exp_acc = _get_base_and_offsets(rctx.ctx, exp_v, rctx.loc) + dst_acc = _get_base_and_offsets(rctx.ctx, output, rctx.loc) + rank = len(out_shape) + nisa.tensor_tensor_power( + **_operand_kwargs( + "dst", dst_acc, _build_nisa_map(rctx.ctx, rank, dst_acc), + out_shape, + ), + **_operand_kwargs( + "lhs", base_acc, _build_nisa_map(rctx.ctx, rank, base_acc), + out_shape, + ), + **_operand_kwargs( + "rhs", exp_acc, _build_nisa_map(rctx.ctx, rank, exp_acc), + out_shape, + ), + engine=nisa.Engine.Gpsimd, + ) + op.operation.erase() + return + + analysis = _analyze_generic_body(op) + if analysis is None or out_shape is None: + return + arith_kind, body_lhs, body_rhs = analysis + + if not _is_sbuf(out_ty): + return + + if num_ins == 1: + input_v = inputs[0] + if not _is_sbuf(input_v.type): + return + lhs_const = _is_constant_value(body_lhs) + rhs_const = _is_constant_value(body_rhs) + if lhs_const == rhs_const: + return + scalar_v = body_lhs if lhs_const else body_rhs + reverse = ( + nisa.TensScalarRevOps.First if lhs_const else nisa.TensScalarRevOps.None_ + ) + with nk_ir.InsertionPoint(op), rctx.loc: + src_acc = _get_base_and_offsets(rctx.ctx, input_v, rctx.loc) + dst_acc = _get_base_and_offsets(rctx.ctx, output, rctx.loc) + rank = len(out_shape) + src_map = _build_nisa_map(rctx.ctx, rank, src_acc) + dst_map = _build_nisa_map(rctx.ctx, rank, dst_acc) + nisa.tensor_scalar_arith( + **_operand_kwargs("dst", dst_acc, dst_map, out_shape), + **_operand_kwargs("src", src_acc, src_map, out_shape), + **_scalar_operand_kwargs("operand0", scalar_v), + **_empty_operand_kwargs("operand1"), + op0=arith_kind, + op1=None, + reverse_operands=reverse, + engine=nisa.Engine.Vector, + ) + op.operation.erase() + return + + # num_ins == 2 + in0, in1 = inputs[0], inputs[1] + in0_shape = _static_shape(in0.type) + in1_shape = _static_shape(in1.type) + if in0_shape is None or in1_shape is None: + return + if not (_is_sbuf(in0.type) and _is_sbuf(in1.type)): + return + if _is_constant_value(body_lhs) or _is_constant_value(body_rhs): + return + + in0_bcast = _shape_match_broadcast(in0_shape, out_shape) + in1_bcast = _shape_match_broadcast(in1_shape, out_shape) + + if in0_shape == out_shape and in1_shape == out_shape: + block_arg0 = op.regions[0].blocks[0].arguments[0] + swap = str(body_lhs) != str(block_arg0) + lhs_v = in1 if swap else in0 + rhs_v = in0 if swap else in1 + with nk_ir.InsertionPoint(op), rctx.loc: + lhs_acc = _get_base_and_offsets(rctx.ctx, lhs_v, rctx.loc) + rhs_acc = _get_base_and_offsets(rctx.ctx, rhs_v, rctx.loc) + dst_acc = _get_base_and_offsets(rctx.ctx, output, rctx.loc) + rank = len(out_shape) + lhs_map = _build_nisa_map(rctx.ctx, rank, lhs_acc) + rhs_map = _build_nisa_map(rctx.ctx, rank, rhs_acc) + dst_map = _build_nisa_map(rctx.ctx, rank, dst_acc) + nisa.tensor_tensor_arith( + **_operand_kwargs("dst", dst_acc, dst_map, out_shape), + **_operand_kwargs("lhs", lhs_acc, lhs_map, out_shape), + **_operand_kwargs("rhs", rhs_acc, rhs_map, out_shape), + op=arith_kind, + engine=nisa.Engine.Vector, + ) + op.operation.erase() + return + + if in0_bcast != in1_bcast: + tensor_v = in1 if in0_bcast else in0 + vec_v = in0 if in0_bcast else in1 + vec_shape = in0_shape if in0_bcast else in1_shape + block_arg0 = op.regions[0].blocks[0].arguments[0] + block_arg1 = op.regions[0].blocks[0].arguments[1] + vec_arg = block_arg0 if in0_bcast else block_arg1 + vec_is_lhs = str(body_lhs) == str(vec_arg) + reverse = ( + nisa.TensScalarRevOps.First if vec_is_lhs else nisa.TensScalarRevOps.None_ + ) + with nk_ir.InsertionPoint(op), rctx.loc: + src_acc = _get_base_and_offsets(rctx.ctx, tensor_v, rctx.loc) + vec_acc = _get_base_and_offsets(rctx.ctx, vec_v, rctx.loc) + dst_acc = _get_base_and_offsets(rctx.ctx, output, rctx.loc) + rank = len(out_shape) + src_map = _build_nisa_map(rctx.ctx, rank, src_acc) + vec_map = _build_nisa_map(rctx.ctx, rank, vec_acc) + dst_map = _build_nisa_map(rctx.ctx, rank, dst_acc) + # Match the deleted C++ pattern: all operands use + # tile_par_dims = rank - 1 so the broadcast operand's + # free-dim product = 1 (the broadcast dimension), which + # NISA's tensor_scalar_arith verifier requires. The vec + # operand carries its own shape, not the full out_shape. + par_dims = rank - 1 + nisa.tensor_scalar_arith( + **_operand_kwargs("dst", dst_acc, dst_map, out_shape, par_dims), + **_operand_kwargs("src", src_acc, src_map, out_shape, par_dims), + **_operand_kwargs("operand0", vec_acc, vec_map, vec_shape, par_dims), + **_empty_operand_kwargs("operand1"), + op0=arith_kind, + op1=None, + reverse_operands=reverse, + engine=nisa.Engine.Vector, + ) + op.operation.erase() + return + + +def _build_dma_copy_indirect_op( + *, + ctx: nk_ir.Context, + loc: nk_ir.Location, + dst_memloc: nk_ir.Value, + dst_indices: list[nk_ir.Value], + dst_ap: nk_ir.Attribute, + dst_tile_shape: list[int], + src_memloc: nk_ir.Value, + src_indices: list[nk_ir.Value], + src_ap: nk_ir.Attribute, + src_tile_shape: list[int], + src_index_memloc: nk_ir.Value, + src_index_ap: nk_ir.Attribute, + src_index_tile_shape: list[int], + src_indirect_max_index: int, + tile_par_dims: int = 1, +) -> None: + """Raw builder for nisa.dma_copy_indirect. + + We bypass the generated Python builder because it translates + ``dst_indirect_max_index=None`` into ``[0 : i32]``, which the + verifier then rejects since there's no matching ``dst_index`` + operand. Here we set the attribute to an empty ArrayAttr so the + verifier sees both ``dst_index`` and ``dst_indirect_max_index`` as + absent. + """ + i32_ty = nk_ir.IntegerType.get_signless(32, ctx) + bool_true = nk_ir.BoolAttr.get(True, ctx) + + operands = [dst_memloc, *dst_indices, src_memloc, *src_indices, + src_index_memloc] + # operandSegmentSizes: 13 segments in declaration order. + seg_sizes = [ + 1, len(dst_indices), 0, # dst + 1, len(src_indices), 0, # src + 1, 0, 0, # src_index + 0, 0, 0, # dst_index (absent) + 0, # dma_qos (absent) + ] + attrs = { + "dst_ap": dst_ap, + "dst_static_tile_shape": nk_ir.DenseI64ArrayAttr.get( + dst_tile_shape, ctx + ), + "dst_tile_par_dims": nk_ir.IntegerAttr.get(i32_ty, tile_par_dims), + "src_ap": src_ap, + "src_static_tile_shape": nk_ir.DenseI64ArrayAttr.get( + src_tile_shape, ctx + ), + "src_tile_par_dims": nk_ir.IntegerAttr.get(i32_ty, tile_par_dims), + "src_index_ap": src_index_ap, + "src_index_static_tile_shape": nk_ir.DenseI64ArrayAttr.get( + src_index_tile_shape, ctx + ), + "src_index_tile_par_dims": nk_ir.IntegerAttr.get(i32_ty, tile_par_dims), + "dst_index_static_tile_shape": nk_ir.DenseI64ArrayAttr.get([], ctx), + "dst_index_tile_par_dims": nk_ir.IntegerAttr.get(i32_ty, 0), + "src_indirect_max_index": nk_ir.ArrayAttr.get( + [nk_ir.IntegerAttr.get(i32_ty, src_indirect_max_index)], ctx, + ), + "dst_indirect_max_index": nk_ir.ArrayAttr.get([], ctx), + "dst_rmw_op": nk_ir.ArrayAttr.get([], ctx), + "oob_is_err": bool_true, + "unique_indices": bool_true, + "engine": nk_ir.IntegerAttr.get(i32_ty, nisa.Engine.DMA.value), + "operandSegmentSizes": nk_ir.DenseI32ArrayAttr.get(seg_sizes, ctx), + } + nk_ir.Operation.create( + "nisa.dma_copy_indirect", + results=[], + operands=operands, + attributes=attrs, + loc=loc, + ) + + +# --------------------------------------------------------------------------- +# nkipy.gather -> nisa.dma_copy_indirect +# --------------------------------------------------------------------------- +# +# Port of the deleted C++ NkipyGatherToNisaPattern. The source table stays +# in HBM and each iteration gathers one row (or one partition's worth of +# rows) into SBUF via dma_copy_indirect, then DMAs the result back to the +# output HBM tensor. +# +# Layouts: +# 2D output [N, H]: single gather of N rows (one index per partition). +# 3D output [N, I, H]: wrap in scf.for over I, gathering one row/partition +# per iteration; output[:, i, :] = source[indices[:, i]]. +# +# The indirect DMA uses three specially-shaped affine maps: +# dst : standard [d0, d1] — fill the SBUF output tile +# src : [s0, d1 + s1] — look up one row; d0 is unused, the +# row index comes from the index buffer +# index : standard [d0, d1] — read N indices, one per partition + + +def _gather_src_indirect_map(ctx: nk_ir.Context) -> nk_ir.Attribute: + # (d0, d1)[s0, s1] -> [s0, d1 + s1] + s0 = nk_ir.AffineSymbolExpr.get(0) + d1 = nk_ir.AffineDimExpr.get(1) + s1 = nk_ir.AffineSymbolExpr.get(1) + amap = nk_ir.AffineMap.get(2, 2, [s0, d1 + s1]) + return nisa.flatten_affine_map(amap, ctx) + + +def _gather_standard_2d_map(ctx: nk_ir.Context, num_symbols: int) -> nk_ir.Attribute: + # Standard 2D map with d0/d1 on each dim, plus per-dim symbol offsets. + d0 = nk_ir.AffineDimExpr.get(0) + d1 = nk_ir.AffineDimExpr.get(1) + exprs: list[nk_ir.AffineExpr] = [] + if num_symbols >= 1: + exprs.append(d0 + nk_ir.AffineSymbolExpr.get(0)) + else: + exprs.append(d0) + if num_symbols >= 2: + exprs.append(d1 + nk_ir.AffineSymbolExpr.get(1)) + else: + exprs.append(d1) + return nisa.flatten_affine_map( + nk_ir.AffineMap.get(2, num_symbols, exprs), ctx + ) + + +def _emit_gather_iteration( + rctx: _RewriteContext, + indices_sbuf: nk_ir.Value, + sbuf_output: nk_ir.Value, + idx_access: _Access, + src_access: _Access, + out_access: _Access, + zero_idx: nk_ir.Value, + idx_offset: nk_ir.Value, + out_indices: list[nk_ir.Value], + N: int, + H: int, + base_v: int, + tile_par_dims: int = 1, +) -> None: + ctx = rctx.ctx + loc = rctx.loc + + # --- 1. DMA one "column" of indices from HBM/SBUF into the indices SBUF --- + dst_map_std = _gather_standard_2d_map(ctx, num_symbols=2) + if not idx_access.indices: + idx_src_indices = [zero_idx, idx_offset] + else: + idx_src_indices = list(idx_access.indices) + # Add idx_offset to the innermost index + idx_src_indices[-1] = _emit_addi( + idx_src_indices[-1], idx_offset, ctx, loc + ) + idx_src_map = _build_nisa_map( + ctx, 2, + _Access( + base=idx_access.base, + indices=idx_src_indices, + base_type=idx_access.base_type, + dropped_dims=idx_access.dropped_dims, + ), + ) + nisa.dma_copy( + dst_memloc=indices_sbuf, + dst_indices=[zero_idx, zero_idx], + dst_ap=dst_map_std, + dst_static_tile_shape=[N, 1], + dst_tile_par_dims=tile_par_dims, + src_memloc=idx_access.base, + src_indices=idx_src_indices, + src_ap=idx_src_map, + src_static_tile_shape=[N, 1], + src_tile_par_dims=tile_par_dims, + oob_is_err=True, + engine=nisa.Engine.DMA, + ) + + # --- 2. dma_copy_indirect: gather H elements from HBM using SBUF indices --- + # dst: standard [d0, d1] 2D map (no symbols — direct to SBUF alloc base). + gather_dst_map = nisa.flatten_affine_map( + nk_ir.AffineMap.get( + 2, 0, + [nk_ir.AffineDimExpr.get(0), nk_ir.AffineDimExpr.get(1)], + ), + ctx, + ) + # src: (d0, d1)[s0, s1] -> [s0, d1 + s1] — s0 is the indirect row index + # supplied by the index buffer; d1+s1 covers the gather's free dim. + gather_src_map = _gather_src_indirect_map(ctx) + col_offset = ( + src_access.indices[1] + if len(src_access.indices) > 1 else zero_idx + ) + src_indirect_indices = [col_offset] + # index: standard [d0, d1] 2D map, no symbol offsets + index_map = nisa.flatten_affine_map( + nk_ir.AffineMap.get( + 2, 0, + [nk_ir.AffineDimExpr.get(0), nk_ir.AffineDimExpr.get(1)], + ), + ctx, + ) + + # Emit the op via builder and then repair its inherent properties. + # The Python builder sets `dst_indirect_max_index = [0 : i32]` when + # no dst_index is provided, which the verifier rejects. MLIR's + # `dma_op.attributes[...]` assignment adds a *discardable* attribute + # but leaves the op's inherent `Properties` struct untouched, so we + # rebuild the op by printing it, textually clearing the bogus attr, + # and parsing back in place. Same trick the C++ side used implicitly + # by constructing the op in one shot. + dma_op = nisa.dma_copy_indirect( + dst_memloc=sbuf_output, + dst_indices=[], + dst_ap=gather_dst_map, + dst_static_tile_shape=[N, H], + dst_tile_par_dims=tile_par_dims, + src_memloc=src_access.base, + src_indices=src_indirect_indices, + src_ap=gather_src_map, + src_static_tile_shape=[N, H], + src_tile_par_dims=tile_par_dims, + src_index_memloc=indices_sbuf, + src_index_indices=[], + src_index_ap=index_map, + src_index_static_tile_shape=[N, 1], + src_index_tile_par_dims=tile_par_dims, + dst_index_memloc=None, + dst_index_indices=[], + dst_index_ap=None, + dst_index_static_tile_shape=[], + dst_index_tile_par_dims=0, + oob_is_err=True, + src_indirect_max_index=base_v, + unique_indices=True, + engine=nisa.Engine.DMA, + ) + + # --- 3. DMA the gathered SBUF tile back out to the output HBM buffer --- + src_copy_map = _gather_standard_2d_map(ctx, num_symbols=2) + dst_copy_map = _build_nisa_map( + ctx, 2, + _Access( + base=out_access.base, + indices=out_indices, + base_type=out_access.base_type, + dropped_dims=out_access.dropped_dims, + ), + ) + nisa.dma_copy( + dst_memloc=out_access.base, + dst_indices=out_indices, + dst_ap=dst_copy_map, + dst_static_tile_shape=[N, H], + dst_tile_par_dims=tile_par_dims, + src_memloc=sbuf_output, + src_indices=[zero_idx, zero_idx], + src_ap=src_copy_map, + src_static_tile_shape=[N, H], + src_tile_par_dims=tile_par_dims, + oob_is_err=True, + engine=nisa.Engine.DMA, + ) + + +@pattern("nkipy.gather") +def _rewrite_nkipy_gather(rctx: _RewriteContext, op: nk_ir.OpView) -> None: + operands = list(op.operation.operands) + if len(operands) < 3: + return + source = operands[0] + indices = operands[1] + output = operands[2] + + src_ty = source.type + idx_ty = indices.type + out_ty = output.type + if not (isinstance(src_ty, nk_ir.MemRefType) + and isinstance(idx_ty, nk_ir.MemRefType) + and isinstance(out_ty, nk_ir.MemRefType)): + return + + src_shape = _static_shape(src_ty) + out_shape = _static_shape(out_ty) + if src_shape is None or out_shape is None: + return + if len(src_shape) != 2: + return + out_rank = len(out_shape) + if out_rank < 2 or out_rank > 3: + return + + H = src_shape[1] + N = out_shape[0] + base_v = src_shape[0] + ctx = rctx.ctx + + sbuf_attr = nk_ir.Attribute.parse("#nisa.mem") + i32_ty = nk_ir.IntegerType.get_signless(32, ctx) + idx_elt_ty = idx_ty.element_type # type: ignore[attr-defined] + out_elt_ty = out_ty.element_type # type: ignore[attr-defined] + + with nk_ir.InsertionPoint(op), rctx.loc: + src_acc = _get_base_and_offsets(ctx, source, rctx.loc) + idx_acc = _get_base_and_offsets(ctx, indices, rctx.loc) + out_acc = _get_base_and_offsets(ctx, output, rctx.loc) + + zero_idx = _emit_const_index(ctx, 0, rctx.loc) + + indices_sbuf_ty = nk_ir.MemRefType.get( + [N, 1], idx_elt_ty, memory_space=sbuf_attr, + ) + indices_sbuf = nisa.alloc(memref_type=indices_sbuf_ty, alignment=0) + sbuf_output_ty = nk_ir.MemRefType.get( + [N, H], out_elt_ty, memory_space=sbuf_attr, + ) + sbuf_output = nisa.alloc(memref_type=sbuf_output_ty, alignment=0) + + if out_rank == 3: + I_size = out_shape[1] + upper = _emit_const_index(ctx, I_size, rctx.loc) + one = _emit_const_index(ctx, 1, rctx.loc) + + # scf.for_ is a generator that yields the induction variable + # inside its body's insertion point. With no iter_args the + # single yield is the IV; the body terminates with scf.yield. + from nki.compiler._internal.dialects import scf as _scf + for iv in _scf.for_(zero_idx, upper, one, iter_args=[]): + out_indices = list(out_acc.indices) if out_acc.indices \ + else [zero_idx, zero_idx, zero_idx] + # Insert the loop IV at the second dim (I) of the output. + if len(out_indices) == 3: + out_indices[1] = _emit_addi( + out_indices[1], iv, ctx, rctx.loc + ) + _emit_gather_iteration( + rctx, indices_sbuf, sbuf_output, + idx_acc, src_acc, out_acc, + zero_idx, iv, out_indices, + N, H, base_v, + ) + _scf.yield_([]) + else: + out_indices = list(out_acc.indices) if out_acc.indices \ + else [zero_idx, zero_idx] + _emit_gather_iteration( + rctx, indices_sbuf, sbuf_output, + idx_acc, src_acc, out_acc, + zero_idx, zero_idx, out_indices, + N, H, base_v, + ) + + nisa.release(memref=indices_sbuf) + nisa.release(memref=sbuf_output) + + op.operation.erase() + + +# --------------------------------------------------------------------------- +# Post-pass: fold HBM collapse_shape/expand_shape into the nisa.alloc. +# --------------------------------------------------------------------------- + + +def _alloc_defining_op(v: nk_ir.Value): + owner = getattr(v, "owner", None) + if owner is None: + return None + op = owner.opview if hasattr(owner, "opview") else owner + if getattr(op, "name", None) != "nisa.alloc": + return None + return op + + +def _is_block_arg(v: nk_ir.Value) -> bool: + owner = getattr(v, "owner", None) + return isinstance(owner, nk_ir.Block) + + +def _try_fold_hbm_reshape_alloc(rctx: _RewriteContext, op: nk_ir.OpView) -> bool: + src = op.operation.operands[0] + src_ty = src.type + if not _is_hbm(src_ty): + return False + alloc_op = _alloc_defining_op(src) + if alloc_op is None: + return False + dst_ty = op.operation.results[0].type + if not isinstance(dst_ty, nk_ir.MemRefType): + return False + + alignment = 0 + if "alignment" in alloc_op.attributes: + alignment = nk_ir.IntegerAttr(alloc_op.attributes["alignment"]).value + + src_shape = list(getattr(src_ty, "shape", ())) + src_elt = src_ty.element_type # type: ignore[attr-defined] + + with nk_ir.InsertionPoint(alloc_op), rctx.loc: + new_alloc = nisa.alloc(memref_type=dst_ty, alignment=alignment) + + for user in list(alloc_op.result.uses): + user_op = user.owner + if user_op.name != "nisa.dma_copy": + continue + if user_op.operands[0] != alloc_op.result: + continue + user_op.operands[0] = new_alloc + existing = ( + user_op.attributes["dst_shape"] + if "dst_shape" in user_op.attributes else None + ) + if existing is None or str(existing) in ("", "array"): + user_op.attributes["dst_shape"] = nk_ir.DenseI64ArrayAttr.get(src_shape) + user_op.attributes["dst_elt_ty"] = nk_ir.TypeAttr.get(src_elt) + + op.operation.results[0].replace_all_uses_with(new_alloc) + op.operation.erase() + if not list(alloc_op.result.uses): + alloc_op.erase() + return True + + +def _try_fold_hbm_reshape_arg(rctx: _RewriteContext, op: nk_ir.OpView) -> bool: + src = op.operation.operands[0] + src_ty = src.type + if not _is_hbm(src_ty): + return False + if not _is_block_arg(src): + return False + src_shape = list(getattr(src_ty, "shape", ())) + src_elt = src_ty.element_type # type: ignore[attr-defined] + users = list(op.operation.results[0].uses) + if not users: + op.operation.erase() + return True + for user in users: + user_op = user.owner + if user_op.name != "nisa.dma_copy": + return False + if user_op.operands[1] != op.operation.results[0]: + return False + + for user in users: + user_op = user.owner + user_op.operands[1] = src + existing = ( + user_op.attributes["src_shape"] + if "src_shape" in user_op.attributes else None + ) + if existing is None or str(existing) in ("", "array"): + user_op.attributes["src_shape"] = nk_ir.DenseI64ArrayAttr.get(src_shape) + user_op.attributes["src_elt_ty"] = nk_ir.TypeAttr.get(src_elt) + + op.operation.erase() + return True + + +def _fold_hbm_reshapes(rctx: _RewriteContext) -> None: + while True: + candidates: list[nk_ir.OpView] = [] + + def visit(op_handle: nk_ir.Operation) -> nk_ir.WalkResult: + if op_handle.name in ("memref.collapse_shape", "memref.expand_shape"): + candidates.append(op_handle.opview) + return nk_ir.WalkResult.ADVANCE + + rctx.module.operation.walk(visit) + progressed = False + for op in candidates: + if _try_fold_hbm_reshape_alloc(rctx, op): + progressed = True + continue + if _try_fold_hbm_reshape_arg(rctx, op): + progressed = True + if not progressed: + return + + +# --------------------------------------------------------------------------- +# Top-level driver +# --------------------------------------------------------------------------- + + +_DEAD_VIEW_OPS = ( + "memref.subview", + "memref.collapse_shape", + "memref.expand_shape", + "memref.reinterpret_cast", +) + + +def _dce_dead_view_ops(rctx: _RewriteContext) -> None: + while True: + dead: list[nk_ir.OpView] = [] + + def visit(op_handle: nk_ir.Operation) -> nk_ir.WalkResult: + if ( + op_handle.name in _DEAD_VIEW_OPS + and not list(op_handle.results[0].uses) + ): + dead.append(op_handle.opview) + return nk_ir.WalkResult.ADVANCE + + rctx.module.operation.walk(visit) + if not dead: + return + for op in dead: + op.operation.erase() + + +def _walk_and_rewrite(rctx: _RewriteContext) -> None: + candidates: list[nk_ir.OpView] = [] + + def visit(op_handle: nk_ir.Operation) -> nk_ir.WalkResult: + name = op_handle.name + if name in _PATTERNS: + candidates.append(op_handle.opview) + return nk_ir.WalkResult.ADVANCE + + rctx.module.operation.walk(visit) + + for op in candidates: + _PATTERNS[op.operation.name](rctx, op) + + _dce_dead_view_ops(rctx) + _fold_reinterpret_casts(rctx) + _dce_dead_view_ops(rctx) + _fold_hbm_reshapes(rctx) + + +def _resolve_custom_ops(module: nk_ir.Module, ctx: nk_ir.Context) -> None: + """Python port of the deleted C++ ResolveCustomOpsPass. + + Replaces calls to custom-op declarations with the inlined NISA body + stashed in the module attribute ``nkipy.custom_op_bodies`` (a dict of + funcname → MLIR-text NISA body). Matches both conventions: + + - *Output-as-argument* (`func @f(%in, %out) { return }`): allocate + buffers for the trailing arguments and replace call results with + the allocated buffers. + - *Return-value* (kernel_builder, `func @f(%in) -> %out`): splice the + body and rewire `func.return` operands to call results. + + Must run before ``_finalize_for_nki`` which strips ``nkipy.*`` attrs. + """ + module_op = module.operation + if "nkipy.custom_op_bodies" not in module_op.attributes: + return + bodies_attr = nk_ir.DictAttr(module_op.attributes["nkipy.custom_op_bodies"]) + + # Collect custom-op declarations (body-less func.func with + # `nkipy.custom_op` attr). + decls: list[tuple[str, nk_ir.OpView]] = [] + for op in module.body.operations: + if op.operation.name != "func.func": + continue + if "nkipy.custom_op" not in op.attributes: + continue + # A declaration has an empty region body. + regions = list(op.operation.regions) + if regions and list(regions[0].blocks): + # Not a pure declaration; skip (body already present). + continue + sym = nk_ir.StringAttr(op.attributes["sym_name"]).value + decls.append((sym, op)) + + if not decls: + del module_op.attributes["nkipy.custom_op_bodies"] + return + + for func_name, decl_op in decls: + # Look up the stashed NISA body string. + if func_name not in bodies_attr: + raise RuntimeError( + f"no stashed NISA body for custom op '{func_name}'" + ) + body_text = nk_ir.StringAttr(bodies_attr[func_name]).value + body_module = nk_ir.Module.parse(body_text, ctx) + + # Find the (single) non-declaration func in the parsed body. + nisa_func = None + for bop in body_module.body.operations: + if bop.operation.name != "func.func": + continue + regions = list(bop.operation.regions) + if regions and list(regions[0].blocks): + nisa_func = bop + break + if nisa_func is None: + raise RuntimeError( + f"no function body in stashed NISA module for '{func_name}'" + ) + + func_ty = nk_ir.TypeAttr(nisa_func.attributes["function_type"]).value + num_results = len(list(func_ty.results)) # type: ignore[attr-defined] + num_args = len(list(func_ty.inputs)) # type: ignore[attr-defined] + # Output-names drives the output-as-argument split. + if "nki.output_names" in nisa_func.attributes: + num_outputs = len(list( + nk_ir.ArrayAttr(nisa_func.attributes["nki.output_names"]) + )) + else: + num_outputs = 0 + + is_return_value_style = num_results > 0 + if is_return_value_style: + num_inputs = num_args + num_outputs = num_results + else: + num_inputs = num_args - num_outputs + + # Collect all call sites for this custom op across the module. + call_sites: list[nk_ir.OpView] = [] + + def collect_calls(op: nk_ir.Operation) -> nk_ir.WalkResult: + if op.name == "func.call": + callee = nk_ir.FlatSymbolRefAttr( + op.attributes["callee"] + ).value + if callee == func_name: + call_sites.append(op.opview) + return nk_ir.WalkResult.ADVANCE + + module.operation.walk(collect_calls) + + nisa_block = list(nisa_func.operation.regions[0].blocks)[0] + nisa_args = list(nisa_block.arguments) + + for call_op in call_sites: + call_operation = call_op.operation + call_operands = list(call_operation.operands) + + with nk_ir.InsertionPoint(call_operation), call_operation.location: + + def maybe_cast(arg: nk_ir.Value, expected: nk_ir.Type) -> nk_ir.Value: + if str(arg.type) == str(expected): + return arg + v = arg + while True: + owner = getattr(v, "owner", None) + if owner is None: + break + owner_op = owner.opview if hasattr(owner, "opview") else owner + if getattr(owner_op, "name", None) != "memref.cast": + break + src = owner_op.operands[0] + if str(src.type) == str(expected): + return src + v = src + cast_op = nk_ir.Operation.create( + "memref.cast", + results=[expected], + operands=[arg], + loc=call_operation.location, + ) + return cast_op.result + + # `pairs` maps every NISA-body Value that must be + # rewritten (block args + cloned-op results) to its + # replacement in the caller. Lookups are O(len(pairs)) + # which is fine for the short bodies we inline. + pairs: list[tuple[nk_ir.Value, nk_ir.Value]] = [] + + if is_return_value_style: + for nisa_arg, call_arg in zip(nisa_args, call_operands): + pairs.append( + (nisa_arg, maybe_cast(call_arg, nisa_arg.type)) + ) + return_operands: list[nk_ir.Value] = [] + for body_op in list(nisa_block.operations): + if body_op.operation.name == "func.return": + for op_v in body_op.operation.operands: + replacement = op_v + for old, new in pairs: + if op_v == old: + replacement = new + break + return_operands.append(replacement) + continue + _clone_op_with_map(body_op, pairs, pairs, ctx) + for i, retv in enumerate(return_operands): + call_op.results[i].replace_all_uses_with(retv) + else: + # Input args come first. + for i in range(num_inputs): + pairs.append( + (nisa_args[i], + maybe_cast(call_operands[i], nisa_args[i].type)) + ) + # Allocate outputs for trailing NISA args. + out_allocs: list[nk_ir.Value] = [] + for i in range(num_outputs): + nisa_out_ty = nisa_args[num_inputs + i].type + alloc_op = nk_ir.Operation.create( + "memref.alloc", + results=[nisa_out_ty], + operands=[], + attributes={ + "operandSegmentSizes": + nk_ir.DenseI32ArrayAttr.get([0, 0], ctx), + }, + loc=call_operation.location, + ) + pairs.append( + (nisa_args[num_inputs + i], alloc_op.result) + ) + out_allocs.append(alloc_op.result) + for body_op in list(nisa_block.operations): + if body_op.operation.name == "func.return": + continue + _clone_op_with_map(body_op, pairs, pairs, ctx) + for i in range(num_outputs): + call_op.results[i].replace_all_uses_with(out_allocs[i]) + + call_operation.erase() + + # Fix enclosing function return types after inlining — NISA-body + # result types may differ from the caller's declared return type + # (e.g. non-strided vs strided memrefs). + for op in module.body.operations: + if op.operation.name != "func.func": + continue + if "nkipy.custom_op" in op.attributes: + continue + regions = list(op.operation.regions) + if not regions or not list(regions[0].blocks): + continue + block = list(regions[0].blocks)[0] + term = block.operations[len(list(block.operations)) - 1] + if term.operation.name != "func.return": + continue + ret_types = [v.type for v in term.operation.operands] + func_ty_attr = op.attributes["function_type"] + cur_ty = nk_ir.TypeAttr(func_ty_attr).value + cur_results = list(cur_ty.results) # type: ignore[attr-defined] + if len(ret_types) == len(cur_results) and all( + str(a) == str(b) for a, b in zip(ret_types, cur_results) + ): + continue + new_inputs = list(cur_ty.inputs) # type: ignore[attr-defined] + new_ty = nk_ir.FunctionType.get(new_inputs, ret_types) + op.attributes["function_type"] = nk_ir.TypeAttr.get(new_ty) + + decl_op.operation.erase() + + del module_op.attributes["nkipy.custom_op_bodies"] + + +def _clone_op_with_map( + src_op: nk_ir.OpView, + pairs: list[tuple[nk_ir.Value, nk_ir.Value]], + results_pairs: list[tuple[nk_ir.Value, nk_ir.Value]], + ctx: nk_ir.Context, +) -> None: + """Clone ``src_op`` at the current insertion point, rebinding any + operand appearing in ``pairs`` to its partner. Each entry in + ``pairs`` is (original_value, replacement_value); we do a linear + scan on the clone's operands comparing by MLIR Value equality. + Newly-produced results are appended to ``results_pairs`` so later + operations in the same body can find them. + """ + operation = src_op.operation + # Clone inserts at the currently-active InsertionPoint. + cloned = operation.clone() + + def find_replacement(v: nk_ir.Value) -> nk_ir.Value | None: + for old, new in pairs: + if v == old: + return new + return None + + def patch(op: nk_ir.Operation) -> nk_ir.WalkResult: + for i in range(len(op.operands)): + repl = find_replacement(op.operands[i]) + if repl is not None: + op.operands[i] = repl + return nk_ir.WalkResult.ADVANCE + + cloned.walk(patch) + for i, r in enumerate(operation.results): + results_pairs.append((r, cloned.results[i])) + + +def _finalize_for_nki(module: nk_ir.Module, ctx: nk_ir.Context, target: str) -> None: + def strip_nkipy_attrs(op: nk_ir.Operation) -> nk_ir.WalkResult: + to_remove = [ + named.name for named in op.attributes + if named.name.startswith("nkipy.") + ] + for name in to_remove: + del op.attributes[name] + return nk_ir.WalkResult.ADVANCE + + module.operation.walk(strip_nkipy_attrs) + + for op in module.body.operations: + if op.operation.name != "func.func": + continue + if "nki.output_names" in op.attributes: + continue + func_ty_attr = op.attributes["function_type"] + func_ty = nk_ir.TypeAttr(func_ty_attr).value + num_results = len(list(func_ty.results)) # type: ignore[attr-defined] + if num_results == 0: + continue + names = [ + f"output_{i}" if num_results > 1 else "output" + for i in range(num_results) + ] + op.attributes["nki.output_names"] = nk_ir.ArrayAttr.get( + [nk_ir.StringAttr.get(n) for n in names] + ) + + target_attr = nk_ir.Attribute.parse(f"#nisa.target<{target}>") + module.operation.attributes["nisa.target"] = target_attr + + +def linalg_to_nisa( + mlir_text: str, target: str = "trn2", print_generic: bool = True, +) -> str: + """Translate post-Phase-4 MLIR to NISA MLIR (text -> text). + + Defaults to generic form because the NISA pretty-printer omits the element + type for ``nisa.dma_copy``'s ``view(...)`` syntax, which its own parser + then rejects. Generic form roundtrips cleanly through the downstream NKI + parser. Callers that consume the output as text (STRING_CHECK/FILECHECK) + should pass ``print_generic=False``. + """ + ctx, module = _to_nki_module(mlir_text) + with ctx: + rctx = _RewriteContext(ctx, module) + _walk_and_rewrite(rctx) + _resolve_custom_ops(module, ctx) + _finalize_for_nki(module, ctx, target) + out = module.operation.get_asm( + print_generic_op_form=print_generic, assume_verified=True + ) + # Strip the bogus `dst_indirect_max_index = [0 : i32]` the Python + # builder for nisa.dma_copy_indirect injects on gather-only ops. + # The verifier requires it to be absent when dst_index is absent, + # but we can't clear the inherent property from Python — just edit + # the text on the way out. + out = re.sub( + r",\s*dst_indirect_max_index\s*=\s*\[0\s*:\s*i32\]", + "", + out, + ) + return out + + +__all__ = [ + "linalg_to_nisa", + "pattern", +] diff --git a/kernelgen/nkipy_kernelgen/transforms/nkipy_opt.py b/kernelgen/nkipy_kernelgen/transforms/nkipy_opt.py new file mode 100644 index 0000000..2276639 --- /dev/null +++ b/kernelgen/nkipy_kernelgen/transforms/nkipy_opt.py @@ -0,0 +1,419 @@ +""" +Python wrapper for nkipy-opt passes that can't run from Python. + +These passes use NISA dialect which has global constructors that conflict with +Python bindings. They must be run via the nkipy-opt C++ tool instead. +""" + +import os +import subprocess +import tempfile +from pathlib import Path + + +def _pass_to_arg(pass_name: str) -> str: + """Convert a pass spec to a CLI argument. + + Examples: + 'prepare-arithmetic' -> '--prepare-arithmetic' + 'one-shot-bufferize="opt1 opt2"' -> '--one-shot-bufferize=opt1 opt2' + 'insert-spill-reload="target=trn2"' -> '--insert-spill-reload=target=trn2' + """ + if '=' in pass_name: + name, opts = pass_name.split('=', 1) + opts = opts.strip('"').strip("'") + return f'--{name}={opts}' + return f'--{pass_name}' + + +def get_nkipy_opt_path(): + """Get the path to the nkipy-opt executable.""" + # Assumes we're in the NKIPyKernelGen package + package_dir = Path(__file__).parent.parent.parent + nkipy_opt = package_dir / "build" / "bin" / "nkipy-opt" + + if not nkipy_opt.exists(): + raise FileNotFoundError( + f"nkipy-opt not found at {nkipy_opt}. " + "Please build the project first." + ) + + return str(nkipy_opt) + + +def run_nkipy_opt_passes( + mlir_module, + passes: list[str], + print_ir_after_all: bool = False, + print_stderr: bool = False, + print_debuginfo: bool = False, + print_generic: bool = False, +) -> str: + """ + Run nkipy-opt passes on an MLIR module. + + Args: + mlir_module: MLIR module (string or Module object) + passes: List of pass names (e.g., ['cleanup-bufferization-artifacts']) + print_ir_after_all: If True, print IR after each pass (adds --mlir-print-ir-after-all) + print_stderr: If True, print stderr output (useful for debugging pass diagnostics) + print_debuginfo: If True, include source locations in output (adds --mlir-print-debuginfo) + + Returns: + Transformed MLIR module text (when print_ir_after_all=False) + Or IR dumps from all passes followed by final module (when print_ir_after_all=True) + + Raises: + RuntimeError: If nkipy-opt fails + """ + nkipy_opt = get_nkipy_opt_path() + + # Convert Module object to string if needed + mlir_text = str(mlir_module) if not isinstance(mlir_module, str) else mlir_module + + # Create temporary files for input and output + with tempfile.NamedTemporaryFile(mode='w', suffix='.mlir', delete=False) as f_in: + f_in.write(mlir_text) + input_file = f_in.name + + try: + # Build command as a list to avoid shell injection + cmd = [nkipy_opt] + if print_ir_after_all: + cmd.append('--mlir-print-ir-after-all') + if print_debuginfo: + cmd.append('--mlir-print-debuginfo') + if print_generic: + cmd.append('--mlir-print-op-generic') + cmd.extend(_pass_to_arg(p) for p in passes) + cmd.append(input_file) + + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=False, + ) + + if result.returncode != 0: + cmd_str = ' '.join(cmd) + error_msg = f"nkipy-opt failed with return code {result.returncode}\n" + error_msg += f"Command: {cmd_str}\n" + error_msg += f"Stdout:\n{result.stdout}\n" + error_msg += f"Stderr:\n{result.stderr}" + raise RuntimeError(error_msg) + + # Print stderr output if requested or when print_ir_after_all is enabled. + # MLIR outputs IR dumps and pass diagnostics to stderr. + # Always return just the clean module (stdout) so output can be parsed. + if result.stderr and (print_stderr or print_ir_after_all): + if print_ir_after_all: + print("// === IR After Each Pass ===") + else: + print("// === stderr output ===") + print(result.stderr) + if print_ir_after_all: + print("// === End IR Dumps ===\n") + else: + print("// === end stderr ===\n") + + return result.stdout + + finally: + # Clean up temporary file + if os.path.exists(input_file): + os.unlink(input_file) + + +def apply_complete_knob_pipeline( + mlir_module: str, + target: str = "trn2", + print_ir_after_all: bool = False, + dump_dir: str = None, + stop_after=None, + print_debuginfo: bool = False, + print_generic: bool = False, +) -> str: + """ + Apply the complete knob-driven compilation pipeline in a single pass. + + This avoids switching between Python bindings and nkipy-opt, running all + passes through nkipy-opt in sequence: + + Phase 0: Arithmetic Preparation + 1. remove-redundant-zero-fill: Remove linalg.fill(0) before matmul (NISA auto-zeros PSUM) + 2. prepare-arithmetic: Convert div to mul+reciprocal (NISA has no divide) + + Phase 1: Layout Inference, Partition Dim Canonicalization, and Tiling (on tensor IR) + 3. infer-layout: Infer tiling, placement, and partition_dim for unannotated ops + 4. canonicalize-partition-dim: Insert transposes to ensure partition_dim=0 everywhere + 5. assign-linalg-op-ids: Assign unique IDs to linalg ops (incl. new transposes) + 6. knob-driven-tiling: Rewrite linalg ops to tiled loops using transform dialect + 7. apply-and-strip-transforms: Apply the generated transforms, then erase + the transform module (so downstream passes — including the Python + linalg->NISA phase — see no transform-dialect ops). + 8. canonicalize-loop-step: Normalize loop steps to 1 + + Phase 2: Bufferization + 9. one-shot-bufferize: Convert tensors to memrefs + 10. canonicalize: Clean up memref operations + + Phase 3: Memory Space Annotation + Reshape Canonicalization + 11. eliminate-uninitialized-copies: Remove copies from uninitialized buffers + 12. canonicalize: Clean up dead subview chains + 13. annotate-memory-space: Apply memory space attributes + 14. canonicalize-reshape: Classify expand/collapse_shape by mem_space and partition_dim + 15. eliminate-same-memspace-copy: Remove redundant SBUF->SBUF copies + 16. canonicalize: Clean up dead allocs + + Phase 4: Memref Finalization + 17. legalize-layout: Transform SBUF tensors to physical 4D layout + 18. canonicalize: Clean up after layout legalization + 19. simplify-linalg: Decompose high-rank transposes, canonicalize trivial-broadcast generics + 20. insert-spill-reload: Insert spill/reload for SBUF overflow + 21. insert-memref-dealloc: Insert memref.dealloc at allocation scope end + 22. cse: Common subexpression elimination + 23. canonicalize: DCE for unused subviews and cleanup + + Note: nkipy.annotate ops are removed in annotate-memory-space (pass 13). + Note: The prior NISA-lowering steps (linalg-to-nisa, resolve-custom-ops, + prepare-for-nki) are currently stripped. They will be reimplemented in + Python using the public nki wheel as part of open-sourcing. + + Args: + mlir_module: MLIR module text with tensor operations and knob annotations + target: Hardware target (default "trn2") + print_ir_after_all: If True, print IR after each pass + dump_dir: If provided, save intermediate MLIR files after each pass to this directory + stop_after: Controls how many passes to run. Can be: + - None: run all passes (default) + - int: stop after pass N (1-indexed) + - str: stop after the first occurrence of the named pass. + For passes that appear multiple times (e.g. "canonicalize"), + use "name:N" to stop at the Nth occurrence (1-indexed). + print_debuginfo: If True, include source locations in output (--mlir-print-debuginfo) + print_generic: If True, print ops in generic form (--mlir-print-op-generic) + + Returns: + Fully transformed MLIR module with NISA operations + """ + passes = [ + # Phase 0: Arithmetic preparation (pre-tiling) + # Remove zero fills before matmul — NISA matmul auto-zeros PSUM, so + # linalg.fill(0) feeding into matmul outs is redundant. Must run before + # tiling/bufferization to avoid generating unnecessary nisa.memset. + 'remove-redundant-zero-fill', # 1 + # Convert linalg.div to linalg.mul + linalg.reciprocal since NISA + # tensor_tensor_arith doesn't support DIVIDE + 'prepare-arithmetic', # 2 + + # Phase 1: Layout inference and partition_dim canonicalization + # InferLayout infers tiling, placement (mem_space), and partition_dim for + # elementwise ops that lack explicit annotations, by propagating from + # annotated neighbors + 'infer-layout', # 3 + # CanonicalizePartitionDim inserts transposes to ensure partition_dim=0 + # everywhere. Must run after infer-layout (so partition_dim is propagated) + # and before assign-linalg-op-ids (so new transposes get op IDs) + 'canonicalize-partition-dim', # 4 + # AssignLinalgOpIds assigns unique nkipy.op_id to each linalg op + # (including transposes inserted above) + 'assign-linalg-op-ids', # 5 + # KnobDrivenTiling generates Transform dialect IR; the fused pass + # applies it and then erases the transform module so downstream + # (including the Python linalg->NISA phase) sees no transform-dialect + # ops in the IR. + 'knob-driven-tiling', # 6 + 'apply-and-strip-transforms', # 7 + # CanonicalizeLoopStep normalizes loop steps to 1 (e.g., for %i = 0 to 512 step 128) + # This simplifies index expressions from %i*128/128 to just %i + 'canonicalize-loop-step', # 8 + + # Phase 2: Bufferization + 'one-shot-bufferize="bufferize-function-boundaries allow-unknown-ops"', # 9 + 'canonicalize', # 10 + + # Phase 3: Memory Space Annotation + Reshape Canonicalization + # Eliminate copies from uninitialized allocations (e.g., PSUM accumulator init) + # Must run after bufferization, before annotate-memory-space + 'eliminate-uninitialized-copies', # 11 + 'canonicalize', # Clean up dead subview chains from eliminated copies # 12 + 'annotate-memory-space', # 13 + # CanonicalizeReshape: classify expand/collapse_shape by mem_space and + # partition_dim. HBM reshapes and SBUF non-pdim reshapes stay as views. + # SBUF partition dim splits get alloc+copy (NISA has no modulo). + # Returned expand_shape views of func args and direct returns of func + # args get alloc+copy (NISA needs separate output allocations). + 'canonicalize-reshape', # 14 + # Eliminate redundant SBUF->SBUF copies (when data is already in SBUF) + # This is needed after SBUF promotion of elementwise ops — if an input + # is already in SBUF (e.g., from a previous matmul), we don't need to copy it again + 'eliminate-same-memspace-copy', # 15 + 'canonicalize', # Clean up dead allocs and subviews from eliminated copies # 16 + + # Phase 4: NISA Lowering + # LegalizeLayout transforms SBUF tensors from 2D to 4D physical layout + # Runs here to inspect IR after bufferization + 'legalize-layout', # 17 + 'canonicalize', # 18 + # Simplify linalg ops before NISA lowering: decompose high-rank + # transposes to loops of 2D, collapse >2D SBUF transpose to 2D, + # canonicalize trivial-broadcast generics to named ops. + # Runs before insert-spill-reload so any SBUF temps it creates + # are accounted for in spill/reload memory budgeting. + 'simplify-linalg', # 19 + # Insert spill/reload for SBUF memory pressure. Runs after legalize-layout + # so SBUF allocs are already in physical per-partition layout and their + # total byte size equals the per-partition SBUF consumption. + f'insert-spill-reload="target={target}"', # 20 + 'insert-memref-dealloc', # Insert memref.dealloc ops at allocation scope end # 21 + 'cse', # Common subexpression elimination # 22 + 'canonicalize', # DCE for unused subviews and cleanup # 23 + + # Phase 5: NISA lowering (Python) — reimplementation of the deleted C++ + # linalg-to-nisa / resolve-custom-ops / prepare-for-nki passes using + # the `nki` wheel's Python bindings. Marked as Python-phase so the + # driver below dispatches to `linalg_to_nisa_py` instead of nkipy-opt. + 'py:linalg-to-nisa', # 24 + ] + + # Slice passes if stop_after is provided + if stop_after is not None: + if isinstance(stop_after, int): + passes = passes[:stop_after] + elif isinstance(stop_after, str): + # Support "name:N" to select the Nth occurrence (1-indexed). + # "py:" is also recognized — disambiguate by checking if + # the tail after the final ":" is an integer. + name = stop_after + nth = 1 + if ':' in stop_after: + head, tail = stop_after.rsplit(':', 1) + if tail.isdigit(): + name, nth = head, int(tail) + occurrence = 0 + found_idx = None + for i, p in enumerate(passes): + # Strip `py:` prefix for matching so users can request the + # same pass by either `py:linalg-to-nisa` or `linalg-to-nisa`. + raw = p[len('py:'):] if p.startswith('py:') else p + base_name = raw.split('=')[0].split('"')[0].strip() + req_name = name[len('py:'):] if name.startswith('py:') else name + if base_name == req_name: + occurrence += 1 + if occurrence == nth: + found_idx = i + break + if found_idx is None: + available = [ + (p[len('py:'):] if p.startswith('py:') else p) + .split('=')[0].split('"')[0].strip() + for p in passes + ] + raise ValueError( + f"Pass '{stop_after}' not found in pipeline. " + f"Available passes: {available}" + ) + passes = passes[:found_idx + 1] + else: + raise TypeError(f"stop_after must be int, str, or None, got {type(stop_after)}") + + return _run_passes_with_python_dispatch( + mlir_module, + passes, + target=target, + print_ir_after_all=print_ir_after_all, + dump_dir=dump_dir, + print_debuginfo=print_debuginfo, + print_generic=print_generic, + ) + + +def _run_passes_with_python_dispatch( + mlir_module: str, + passes: list[str], + target: str, + print_ir_after_all: bool, + dump_dir: str | None, + print_debuginfo: bool, + print_generic: bool, +) -> str: + """Run a pass list, batching consecutive nkipy-opt passes and dispatching + any `py:` entries to their Python implementation. + + Having a single driver keeps `dump_dir` numbering coherent across the + C++/Python boundary: every pass — whether it runs in nkipy-opt or in + Python — writes the same `NN_.mlir` artifact. + """ + current = mlir_module + + if dump_dir: + os.makedirs(dump_dir, exist_ok=True) + with open(os.path.join(dump_dir, "00_input.mlir"), 'w') as f: + f.write(str(current)) + + batch: list[str] = [] + batch_start_idx = 1 + + def flush_batch(next_idx: int) -> None: + nonlocal current, batch + if not batch: + return + if dump_dir: + # Run each pass separately when dumping so we save per-pass IR. + for j, p in enumerate(batch): + current = run_nkipy_opt_passes( + current, [p], print_ir_after_all, + print_debuginfo=print_debuginfo, print_generic=print_generic, + ) + simple_name = p.split('=')[0].split('"')[0].strip() + filename = f"{batch_start_idx + j:02d}_{simple_name}.mlir" + with open(os.path.join(dump_dir, filename), 'w') as f: + f.write(current) + else: + current = run_nkipy_opt_passes( + current, batch, print_ir_after_all, + print_debuginfo=print_debuginfo, print_generic=print_generic, + ) + batch = [] + + for i, pass_name in enumerate(passes, start=1): + if pass_name.startswith('py:'): + flush_batch(i) + py_name = pass_name[len('py:'):] + current = _run_python_pass( + py_name, current, target=target, print_generic=print_generic, + ) + if dump_dir: + filename = f"{i:02d}_{py_name}.mlir" + with open(os.path.join(dump_dir, filename), 'w') as f: + f.write(current) + batch_start_idx = i + 1 + else: + if not batch: + batch_start_idx = i + batch.append(pass_name) + + flush_batch(len(passes) + 1) + return current + + +def _run_python_pass( + name: str, mlir_text: str, target: str, print_generic: bool = False, +) -> str: + """Dispatch a `py:` pass to its Python implementation.""" + if name == 'linalg-to-nisa': + # Imported lazily because the NKI wheel and upstream `mlir` are only + # required for this pass; tests that stop before phase 5 do not need + # either installed. + from .linalg_to_nisa_py import linalg_to_nisa + return linalg_to_nisa(mlir_text, target=target, print_generic=print_generic) + raise ValueError(f"Unknown Python pass: {name!r}") + + +# Export the main interface +__all__ = [ + 'get_nkipy_opt_path', + 'run_nkipy_opt_passes', + 'apply_complete_knob_pipeline', +] diff --git a/kernelgen/nkipy_kernelgen/utils.py b/kernelgen/nkipy_kernelgen/utils.py new file mode 100644 index 0000000..e54e6d1 --- /dev/null +++ b/kernelgen/nkipy_kernelgen/utils.py @@ -0,0 +1,218 @@ + +# Modified from https://github.com/cornell-zhang/allo/blob/e4ababde72803aaf156db2db86820ec817285f50/allo/utils.py + +import ctypes +import numpy as np +import ml_dtypes + +from mlir.runtime import to_numpy +from mlir.dialects import func as func_d +from mlir.ir import ( + MemRefType, + RankedTensorType, + IntegerType, + IndexType, + F16Type, + F32Type, + F64Type, + BF16Type, +) + +np_supported_types = { + "bf16": ml_dtypes.bfloat16, + "f16": np.float16, + "f32": np.float32, + "f64": np.float64, + "i8": np.int8, + "i16": np.int16, + "i32": np.int32, + "i64": np.int64, + "ui1": np.bool_, + "ui8": np.uint8, + "ui16": np.uint16, + "ui32": np.uint32, + "ui64": np.uint64, +} + +ctype_map = { + # ctypes.c_float16 does not exist + # similar implementation in _mlir/runtime/np_to_memref.py/F16 + "bf16": ctypes.c_int16, + "f16": ctypes.c_int16, + "f32": ctypes.c_float, + "f64": ctypes.c_double, + "i8": ctypes.c_int8, + "i16": ctypes.c_int16, + "i32": ctypes.c_int32, + "i64": ctypes.c_int64, + "ui1": ctypes.c_bool, + "ui8": ctypes.c_uint8, + "ui16": ctypes.c_uint16, + "ui32": ctypes.c_uint32, + "ui64": ctypes.c_uint64, +} + +def np_type_to_str(dtype): + return list(np_supported_types.keys())[ + list(np_supported_types.values()).index(dtype) + ] + +def get_bitwidth_from_type(dtype): + if dtype == "index": + return 64 + if dtype.startswith("i"): + bitwidth = int(dtype[1:]) + assert bitwidth in [8, 16, 32, 64] + return bitwidth + if dtype.startswith("ui"): + bitwidth = int(dtype[2:]) + assert bitwidth in [1, 8, 16, 32, 64] + return bitwidth + if dtype.startswith("f"): + bitwidth = int(dtype[1:]) + assert bitwidth in [16, 32, 64] + return bitwidth + raise RuntimeError("Unsupported type") + +# Locate top-level func.func entry +def find_func_in_module(module, func_name): + for op in module.body.operations: + if isinstance(op, func_d.FuncOp) and op.name.value == func_name: + return op + return None + +def extract_out_np_arrays_from_out_struct(out_struct_ptr_ptr, num_output): + out_np_arrays = [] + for i in range(num_output): + out_np_arrays.append( + ranked_memref_to_numpy(getattr(out_struct_ptr_ptr[0][0], f"memref{i}")) + ) + return out_np_arrays + +def get_np_struct_type(bitwidth): + n_bytes = int(np.ceil(bitwidth / 8)) + return np.dtype( + { + "names": [f"f{i}" for i in range(n_bytes)], + # all set to unsigned byte + "formats": ["u1"] * n_bytes, + "offsets": list(range(n_bytes)), + "itemsize": n_bytes, + } + ) + +def ranked_memref_to_numpy(ranked_memref): + """Converts ranked memrefs to numpy arrays.""" + # Check rank using _length_ to avoid triggering numpy ctypes warning + rank = ranked_memref.shape._length_ + + if rank == 0: + # Special handling for rank-0 memrefs (scalars) to avoid numpy ctypes warning + # For rank-0 memrefs, directly read the scalar value + contentPtr = ctypes.cast( + ctypes.addressof(ranked_memref.aligned.contents) + + ranked_memref.offset * ctypes.sizeof(ranked_memref.aligned.contents), + type(ranked_memref.aligned), + ) + # Return as a 0-d numpy array (scalar array) + return np.array(contentPtr[0]) + + # A temporary workaround for issue + # https://discourse.llvm.org/t/setting-memref-elements-in-python-callback/72759 + contentPtr = ctypes.cast( + ctypes.addressof(ranked_memref.aligned.contents) + + ranked_memref.offset * ctypes.sizeof(ranked_memref.aligned.contents), + type(ranked_memref.aligned), + ) + np_arr = np.ctypeslib.as_array(contentPtr, shape=ranked_memref.shape) + strided_arr = np.lib.stride_tricks.as_strided( + np_arr, + np.ctypeslib.as_array(ranked_memref.shape), + np.ctypeslib.as_array(ranked_memref.strides) * np_arr.itemsize, + ) + return to_numpy(strided_arr) + + +def get_signed_type_by_hint(dtype, hint): + if hint == "u" and (dtype.startswith("i") or dtype.startswith("fixed")): + return "u" + dtype + return dtype + + +def get_dtype_and_shape_from_type(dtype): + """ + Extract dtype, shape, and whether it's a memref from an MLIR type. + + Returns: + tuple: (element_type: str, shape: tuple, is_memref: bool) + """ + if MemRefType.isinstance(dtype): + dtype = MemRefType(dtype) + shape = dtype.shape + ele_type, _, _ = get_dtype_and_shape_from_type(dtype.element_type) + return ele_type, shape, True # is_memref=True + if RankedTensorType.isinstance(dtype): + dtype = RankedTensorType(dtype) + shape = dtype.shape + ele_type, _, _ = get_dtype_and_shape_from_type(dtype.element_type) + return ele_type, shape, True # is_memref=True (will become memref after bufferization) + if IndexType.isinstance(dtype): + return "index", tuple(), False + if IntegerType.isinstance(dtype): + return str(IntegerType(dtype)), tuple(), False + if F16Type.isinstance(dtype): + return str(F16Type(dtype)), tuple(), False + if F32Type.isinstance(dtype): + return str(F32Type(dtype)), tuple(), False + if F64Type.isinstance(dtype): + return str(F64Type(dtype)), tuple(), False + if BF16Type.isinstance(dtype): + return str(BF16Type(dtype)), tuple(), False + raise RuntimeError("Unsupported type") + + +# Get input and output type frmo func.func +def get_func_inputs_outputs(func): + """ + Extract input and output types from func.func. + + Returns: + tuple: (inputs, outputs) where each element is a list of tuples + (dtype: str, shape: tuple, is_memref: bool) + """ + # Input types + inputs = [] + in_hints = ( + func.attributes["itypes"].value + if "itypes" in func.attributes + else "_" * len(func.type.inputs) + ) + for in_type, in_hint in zip(func.type.inputs, in_hints): + dtype, shape, is_memref = get_dtype_and_shape_from_type(in_type) + in_type = get_signed_type_by_hint(dtype, in_hint) + inputs.append((in_type, shape, is_memref)) + + # Output types + outputs = [] + out_hints = ( + func.attributes["otypes"].value + if "otypes" in func.attributes + else "_" * len(func.type.results) + ) + for out_type, out_hint in zip(func.type.results, out_hints): + dtype, shape, is_memref = get_dtype_and_shape_from_type(out_type) + out_type = get_signed_type_by_hint(dtype, out_hint) + outputs.append((out_type, shape, is_memref)) + return inputs, outputs + + +def create_output_struct(memref_descriptors): + fields = [ + (f"memref{i}", memref.__class__) for i, memref in enumerate(memref_descriptors) + ] + # Dynamically create and return the new class + OutputStruct = type("OutputStruct", (ctypes.Structure,), {"_fields_": fields}) + out_struct = OutputStruct() + for i, memref in enumerate(memref_descriptors): + setattr(out_struct, f"memref{i}", memref) + return out_struct \ No newline at end of file diff --git a/kernelgen/pyproject.toml b/kernelgen/pyproject.toml new file mode 100644 index 0000000..4cecede --- /dev/null +++ b/kernelgen/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools", "wheel", "cmake", "pybind11>=2.8.0", "nanobind>=2.4"] +build-backend = "setuptools.build_meta" diff --git a/kernelgen/pytest.ini b/kernelgen/pytest.ini new file mode 100644 index 0000000..f81555a --- /dev/null +++ b/kernelgen/pytest.ini @@ -0,0 +1,23 @@ +[pytest] +# Pytest configuration for NKIPyKernelGen + +# LIT FileCheck tests under tests/python/ are run via `lit`, not pytest +norecursedirs = tests/python tests/legacy + +# Test paths +testpaths = tests + +# Verbose output +addopts = -v --tb=short + +# Custom markers +markers = + llvm_sim: LLVM JIT simulation test (CPU, no device required) + bir_sim: BIR simulation test (CPU, no device required) + device: test requires Trainium hardware + trn1: test targets Trainium v1 + trn2: test targets Trainium v2 + trn3: test targets Trainium v3 + e2e: end-to-end pipeline test + passes: MLIR pass transformation test + unit: unit test for individual operations and features diff --git a/kernelgen/requirements.txt b/kernelgen/requirements.txt new file mode 100644 index 0000000..da96c74 --- /dev/null +++ b/kernelgen/requirements.txt @@ -0,0 +1,4 @@ +numpy +pytest==8.4.2 +PyYAML==6.0.3 +black==25.9.0 diff --git a/kernelgen/setup.py b/kernelgen/setup.py new file mode 100644 index 0000000..c610832 --- /dev/null +++ b/kernelgen/setup.py @@ -0,0 +1,141 @@ +import os +import sys +import subprocess +from setuptools import setup, Extension +from setuptools.command.build_ext import build_ext +from setuptools import find_packages + +class CMakeExtension(Extension): + def __init__(self, name, sourcedir=""): + Extension.__init__(self, name, sources=[]) + self.sourcedir = os.path.abspath(sourcedir) + + +class CMakeBuild(build_ext): + def run(self): + # Clean stale build artifacts before building + import shutil + stale_dirs = ["nkipy_kernelgen/_mlir"] + for dir_path in stale_dirs: + if os.path.exists(dir_path): + print(f"Removing stale artifacts: {dir_path}") + shutil.rmtree(dir_path) + + # Ensure CMake is installed + try: + subprocess.check_call(["cmake", "--version"]) + except OSError: + raise RuntimeError( + "CMake must be installed to build the following extensions: " + + ", ".join(e.name for e in self.extensions) + ) + + # Call the build process for each extension + for ext in self.extensions: + self.build_extension(ext) + + # After building, copy the _mlir bindings to nkipy package + self.copy_mlir_bindings() + + def build_extension(self, ext): + # Retrieve LLVM_BUILD_DIR from environment variable (optional) + llvm_build_dir = os.environ.get("LLVM_BUILD_DIR") + + cmake_args = [ + f"-DPython3_EXECUTABLE={sys.executable}", + f"-DPython_EXECUTABLE={sys.executable}", + ] + + # Only set MLIR_DIR if LLVM_BUILD_DIR is provided + # Otherwise, let CMake auto-detect using find_package + if llvm_build_dir: + cmake_args += [f"-DMLIR_DIR={llvm_build_dir}/lib/cmake/mlir"] + + build_temp = os.path.abspath("build") + if not os.path.exists(build_temp): + os.makedirs(build_temp) + + BUILD_WITH = os.environ.get("BUILD_WITH") + if not BUILD_WITH or BUILD_WITH == "ninja": + subprocess.run( + ["cmake", "-G Ninja", ext.sourcedir] + cmake_args, + cwd=build_temp, + check=True, + ) + if NUM_THREADS := os.environ.get("NUM_THREADS"): + subprocess.run( + ["ninja", f"-j{NUM_THREADS}"], cwd=build_temp, check=True + ) + else: + subprocess.run(["ninja"], cwd=build_temp, check=True) + elif BUILD_WITH == "make": + subprocess.run( + ["cmake", "-G Unix Makefiles", ext.sourcedir] + cmake_args, + cwd=build_temp, + check=True, + ) + if NUM_THREADS := os.environ.get("NUM_THREADS"): + subprocess.run(["make", f"-j{NUM_THREADS}"], cwd=build_temp, check=True) + else: + subprocess.run(["make", "-j"], cwd=build_temp, check=True) + else: + raise RuntimeError(f"Unsupported BUILD_WITH={BUILD_WITH}") + + def copy_mlir_bindings(self): + """Copy built _mlir bindings from build/tools/nkipy/_mlir to nkipy_kernelgen/_mlir""" + import shutil + + src_dir = os.path.join("build", "tools", "nkipy", "_mlir") + local_dest = os.path.join("nkipy_kernelgen", "_mlir") + + if os.path.exists(src_dir): + print(f"Copying _mlir bindings from {src_dir} to {local_dest}") + + # Remove existing destination if it exists + if os.path.exists(local_dest) or os.path.islink(local_dest): + if os.path.islink(local_dest) or os.path.isfile(local_dest): + os.unlink(local_dest) + else: + shutil.rmtree(local_dest) + + # Copy to local directory + shutil.copytree(src_dir, local_dest, symlinks=True) + print(f"Successfully copied _mlir bindings to {local_dest}") + else: + print(f"Warning: {src_dir} not found. MLIR bindings may not be available.") + + +def parse_requirements(filename): + """Load requirements from a pip requirements file.""" + lineiter = (line.strip() for line in open(filename)) + return [line for line in lineiter if line and not line.startswith("#")] + + +if __name__ == "__main__": + with open("README.md", encoding="utf-8") as fp: + long_description = fp.read() + + setup( + name="nkipy-kernelgen", + description="Lowering from NKIPy to NKI compiler's NISA dialect", + version="0.1", + long_description=long_description, + long_description_content_type="text/markdown", + setup_requires=["pybind11>=2.8.0", "nanobind>=2.4"], + install_requires=parse_requirements("requirements.txt"), + packages=find_packages(), + ext_modules=[CMakeExtension("mlir", sourcedir="mlir")], + cmdclass={"build_ext": CMakeBuild}, + python_requires=">=3.10", + classifiers=[ + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: System :: Hardware", + "Operating System :: OS Independent", + ], + zip_safe=False, # Changed to False to allow including _mlir bindings + package_data={ + "nkipy_kernelgen._mlir": ["**/*.so", "**/*.py"], + }, + include_package_data=True, + ) diff --git a/kernelgen/tests/README.md b/kernelgen/tests/README.md new file mode 100644 index 0000000..efde6b7 --- /dev/null +++ b/kernelgen/tests/README.md @@ -0,0 +1,139 @@ +# NKIPyKernelGen Tests + +## Test Structure + +``` +tests/ +├── conftest.py # Root conftest: sys.path setup, --test-mode, --device +├── pytest.ini # (at project root) +│ +├── test_basic_ops.py # Element-wise ops (add, sub, mul, div) — LLVM JIT +├── test_broadcast_ops.py # Broadcasting ops — LLVM JIT +├── test_elementwise_ops.py # More elementwise ops — LLVM JIT +├── test_execution_engine.py # Execution engine basics — LLVM JIT +├── test_for_loops.py # Loop constructs — LLVM JIT +├── test_import_compatibility.py # Import smoke test +├── test_matrix_ops.py # Matmul, combined ops — LLVM JIT +├── test_unary_ops.py # Unary ops (exp, sin, sqrt) — LLVM JIT +├── test_reduction_ops.py # Reductions (sum, mean, max) — LLVM JIT + BIR sim +├── test_partial_tiling.py # Partial tiling — LLVM JIT + BIR sim +├── test_full_pipeline.py # Full NISA pipeline — BIR sim +├── test_qwen3_kernels.py # Qwen3 model kernels — LLVM JIT + BIR sim +│ +├── e2e/ # End-to-end pipeline tests (BIR simulation) +│ ├── conftest.py # Auto-marks tests with 'e2e' +│ ├── test_feedforward.py +│ ├── test_matmul_add.py +│ └── test_sigmoid.py +│ +├── passes/ # MLIR pass transformation tests +│ ├── conftest.py # Auto-marks tests with 'passes', shared fixtures +│ ├── pass_utils.py # Shared pass utilities (FileCheck, compilation) +│ ├── annotate_memory_space/ # Each pass has: test_*.py, utils.py, outputs/ +│ ├── canonicalize_loop_step/ +│ ├── cleanup_bufferization_artifacts/ +│ ├── eliminate_same_memspace_copy/ +│ ├── eliminate_uninitialized_copies/ +│ ├── knob_driven_tiling/ +│ └── legalize_layout/ +│ +└── python/ # LIT FileCheck tests (run via `lit`, NOT pytest) + ├── lit.cfg.py + ├── passes/ + └── rewrites/ +``` + +## Running Tests + +All commands run from the project root (`NKIPyKernelGen/`). + +### Test modes + +The `--test-mode` flag controls which tests run: + +```bash +# CPU mode (default) — runs LLVM JIT + BIR simulation tests, no hardware needed +pytest +pytest --test-mode=cpu + +# Device mode — runs only tests marked 'device', targeting a specific Trainium generation +pytest --test-mode=device --device=trn1 +pytest --test-mode=device --device=trn2 +pytest --test-mode=device --device=trn3 + +# All mode — CPU tests + device tests (if --device is given) +pytest --test-mode=all --device=trn2 +``` + +### Filter by marker + +```bash +pytest -m bir_sim # only BIR simulation tests +pytest -m e2e # only end-to-end tests +pytest -m passes # only MLIR pass tests +pytest -m "not bir_sim" # skip BIR simulation tests +``` + +### Filter by path or name + +```bash +pytest tests/test_basic_ops.py -v # single file +pytest tests/passes/knob_driven_tiling/ -v # single pass directory +pytest tests/test_basic_ops.py::TestElementWiseOps -v # single class +pytest -k test_add_2d # name substring match +``` + +### LIT FileCheck tests + +LIT tests are excluded from pytest. Run them separately: + +```bash +lit tests/python/ -v +``` + +### Available markers + +| Marker | Meaning | +|----------|---------| +| `llvm_sim` | LLVM JIT simulation test (CPU) | +| `bir_sim` | BIR simulation test (CPU) | +| `device` | Requires Trainium hardware | +| `trn1` / `trn2` / `trn3` | Targets a specific Trainium generation | +| `e2e` | End-to-end pipeline test | +| `passes` | MLIR pass transformation test | + +## Test Design + +Each test: +1. Defines a function using NumPy operations +2. Traces the function with `@trace` decorator to generate MLIR +3. Uses `verify_against_numpy` to compare MLIR/LLVM execution with NumPy CPU execution +4. Asserts that results match within tolerance + +```python +def test_new_operation(self): + def my_func(A, B): + return np.some_operation(A, B) + + traced_func = trace(input_specs=[((4, 3), "f32"), ((4, 3), "f32")])(my_func) + + A = np.random.randn(4, 3).astype(np.float32) + B = np.random.randn(4, 3).astype(np.float32) + + matches, mlir_result, numpy_result = verify_against_numpy( + traced_func, my_func, [A, B] + ) + assert matches, "MLIR result does not match NumPy result" +``` + +## Import setup + +`tests/conftest.py` adds `tests/passes/` to `sys.path` at pytest startup. This means: +- Pass tests can do `from pass_utils import ...` +- All tests can do `from harness import ...` and `from nkipy_kernelgen import ...` (installed package) + +Each per-pass `conftest.py` adds its own directory to `sys.path` so `from utils import ...` resolves to the local `utils.py`. + +## Known issues + +Some test files import `convert_linalg_to_nisa` which may not be available depending on the build. These tests will show collection errors but are unrelated to the test infrastructure — the underlying transform is conditionally compiled. diff --git a/kernelgen/tests/__init__.py b/kernelgen/tests/__init__.py new file mode 100644 index 0000000..b522b24 --- /dev/null +++ b/kernelgen/tests/__init__.py @@ -0,0 +1,6 @@ +""" +Test suite for NKIPyKernelGen. + +This package contains tests to verify that MLIR/LLVM execution +matches NumPy CPU execution. +""" diff --git a/kernelgen/tests/conftest.py b/kernelgen/tests/conftest.py new file mode 100644 index 0000000..6452573 --- /dev/null +++ b/kernelgen/tests/conftest.py @@ -0,0 +1,59 @@ +""" +Root conftest.py for NKIPyKernelGen test suite. + +Provides: + - Centralized sys.path setup (replaces per-file path hacks) + - ``--dump-ir`` CLI flag: dumps intermediate MLIR after each compiler pass + +Test modes (LLVM, BIR_SIM, HW, STRING_CHECK, FILECHECK) are declared +per-test via the @nkipy_kernelgen_test decorator or run_kernel_test(). +Hardware tests auto-skip when no Trainium device is detected. +Use standard pytest selection to run subsets: pytest tests/passes/, pytest -k "not e2e", etc. +""" + +import sys +from pathlib import Path + +# --------------------------------------------------------------------------- +# Centralized sys.path setup +# --------------------------------------------------------------------------- +# This replaces all the scattered sys.path.insert() hacks in individual test +# files. By adding these paths once here (loaded automatically by pytest), +# utility modules become importable without path manipulation in each file. + +_tests_dir = Path(__file__).parent + +# Allow 'from harness import ...' in all tests +sys.path.insert(0, str(_tests_dir)) + +# Allow 'from pass_utils import ...' in pass tests +sys.path.insert(0, str(_tests_dir / "passes")) + + +# --------------------------------------------------------------------------- +# --dump-ir CLI option +# --------------------------------------------------------------------------- +# When passed, run_kernel_test() automatically saves intermediate MLIR after +# every compiler pass so you can inspect the IR without modifying test code. +# +# Usage: +# pytest tests/e2e/test_rope.py::test_rope --dump-ir -v -s +# +# IR files are written to tests//outputs// +# e.g. 00_input.mlir, 01_prepare_arithmetic.mlir, ... + + +def pytest_addoption(parser): + parser.addoption( + "--dump-ir", + action="store_true", + default=False, + help="Dump intermediate MLIR after each compiler pass to tests/*/outputs//", + ) + + +def pytest_configure(config): + """Store the --dump-ir flag where harness.py can read it.""" + # Using a module-level global avoids passing config through every call site. + import harness as _harness + _harness._DUMP_IR_ENABLED = config.getoption("--dump-ir", default=False) diff --git a/kernelgen/tests/debug/.gitignore b/kernelgen/tests/debug/.gitignore new file mode 100644 index 0000000..d0a0795 --- /dev/null +++ b/kernelgen/tests/debug/.gitignore @@ -0,0 +1 @@ +artifacts_*/ diff --git a/kernelgen/tests/debug/README.md b/kernelgen/tests/debug/README.md new file mode 100644 index 0000000..0138cd9 --- /dev/null +++ b/kernelgen/tests/debug/README.md @@ -0,0 +1,79 @@ +# NISA MLIR Debug (`nisa_mlir_debug`) + +Tools for debugging pre-compiled NISA-level MLIR kernels by running them through BIRSim and comparing against NumPy references. + +## Quick Start + +```bash +# From this directory: +source ./run.sh +``` + +This will: +1. Set up the NKI environment and PYTHONPATH +2. Parse the MLIR to extract function signature (shapes, dtypes) +3. Generate deterministic random inputs (seed=42) +4. Compile the MLIR to NEFF with BIRSim enabled (`target=trn2`) +5. Compare BIRSim output against a NumPy reference in `kernel.py` + +## Files + +| File | Description | +|------|-------------| +| `run.sh` | Bash entry point — sets up env, resolves paths, calls `run_sim.py` | +| `run_sim.py` | Core harness — parses MLIR, compiles to NEFF, runs BIRSim, compares output | + +## Adding a Test Case + +Each test case lives in its own subdirectory alongside the MLIR file(s) it tests. + +Required structure: +``` +my_bug_fix/ +├── kernel.py # NumPy reference: must define a function matching the MLIR func name +├── buggy.mlir # (optional) original broken MLIR +├── fixed.mlir # corrected MLIR +└── README.md # description of the bug and fix +``` + +### `kernel.py` contract +- Must contain a function whose name matches the `sym_name` in the MLIR +- Takes the same number of numpy arrays as the MLIR function's inputs +- Returns a single numpy array matching the MLIR function's output shape/dtype + +### Running a test case + +```bash +source ./run.sh my_bug_fix/fixed.mlir +``` + +Artifacts (NEFF, BIR, BIRSim outputs) are written to `artifacts_/` next to the MLIR file. These directories are git-ignored. + +## Output + +A successful run prints: + +``` +BIRSim output: shape=(256, 256), dtype=float32 + range: [-1.2345, 1.2345] + mean: 0.0012 + +BIRSim PASSED + +--- Running numpy reference from kernel.py --- + Max difference: 1.95e-02 + Mean difference: 3.40e-05 + Match: True + +SIMULATION PASSED +``` + +Failures exit non-zero with either: +- **`NCC_ISIM*` errors** — BIRSim detected an issue (e.g., uninitialized PSUM read) +- **`SIMULATION FAILED`** — BIRSim output doesn't match NumPy within tolerance (`atol=1e-2, rtol=1e-2`) + +## Existing Test Cases + +| Directory | Issue | +|-----------|-------| +| `psum_accumulate_flags_fix/` | Missing `psum_accumulate_flags` on matmul K-loops, unreleased SBUF, wrong elementwise op | diff --git a/kernelgen/tests/debug/__init__.py b/kernelgen/tests/debug/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kernelgen/tests/debug/bmm/buggy.mlir b/kernelgen/tests/debug/bmm/buggy.mlir new file mode 100644 index 0000000..287a65e --- /dev/null +++ b/kernelgen/tests/debug/bmm/buggy.mlir @@ -0,0 +1,47 @@ +module attributes {nisa.target = #nisa.target} { + func.func @bmm_kernel(%arg0: memref<2x256x256xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem>, %arg1: memref<2x256x256xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem>) -> memref<2x256x256xf32, #nisa.mem> { + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %mem = nisa.alloc alignment=64 : memref<2x256x256xf32, #nisa.mem> + scf.for %arg2 = %c0 to %c2 step %c1 { + %mem_0 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg3 = %c0 to %c2 step %c1 { + scf.for %arg4 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg4, %c128 : index + %1 = arith.muli %arg3, %c128 : index + %2 = arith.addi %0, %arg2 : index + nisa.dma_transpose(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_0[d0, %arg3 + 0, %arg4 + 0, d1], src<128| 128>=memref<2x256x256xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem> %arg0[%2 + d0, %1 + 0, d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + %mem_1 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg3 = %c0 to %c2 step %c1 { + scf.for %arg4 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg3, %c128 : index + %1 = arith.muli %arg4, %c128 : index + %2 = arith.addi %0, %arg2 : index + nisa.dma_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_1[d0, %arg3 + 0, %arg4 + 0, d1], src<128| 128>=memref<2x256x256xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem> %arg1[%2 + d0, %1 + 0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + scf.for %arg3 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg3, %c128 : index + scf.for %arg4 = %c0 to %c2 step %c1 { + %1 = arith.muli %arg4, %c128 : index + %mem_2 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + scf.for %arg5 = %c0 to %c2 step %c1 { + nisa.matmul(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_2[d0, d1], stationary<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_0[d0, %arg5 + 0, %arg3 + 0, d1], moving<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_1[d0, %arg5 + 0, %arg4 + 0, d1], row_pos=index %c0, col_pos=index %c0, is_transpose=false, perf_opt=none_, psum_zero_region=size2048) engine=tensor + } + %2 = arith.addi %0, %arg2 : index + %mem_3 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_3[d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_2[d0, d1]) engine=vector + nisa.dma_copy(dst<128| 128>=memref<2x256x256xf32, #nisa.mem> %mem[%2 + d0, %1 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_3[d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_2 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_0 : memref<128x2x2x128xf32, #nisa.mem> + nisa.release %mem_1 : memref<128x2x2x128xf32, #nisa.mem> + } + return %mem : memref<2x256x256xf32, #nisa.mem> + } +} diff --git a/kernelgen/tests/debug/bmm/fix_3d_dma_indices.mlir b/kernelgen/tests/debug/bmm/fix_3d_dma_indices.mlir new file mode 100644 index 0000000..22fb0b9 --- /dev/null +++ b/kernelgen/tests/debug/bmm/fix_3d_dma_indices.mlir @@ -0,0 +1,74 @@ +// ===== BMM 3D DMA index fix ===== +// +// Root cause: LinalgToNisa's getBaseAndOffsets + createCopyMap produce wrong +// affine maps for 3D HBM memrefs accessed through rank-reducing subviews. +// +// After KnobDrivenTiling, the batch dim is extracted via: +// tensor.extract_slice %t[%b, 0, 0] [1, 256, 256] ... : tensor<2x256x256> to tensor<256x256> +// +// After bufferization this becomes a rank-reducing memref.subview: +// memref.subview %arg0[%b, 0, 0] [1, 256, 256] ... : memref<2x256x256> to memref<256x256> +// +// Then tiling creates further 2D subviews: +// memref.subview %sv[%m_off, %n_off] [128, 128] ... : memref<256x256> to memref<128x128> +// +// When getBaseAndOffsets traces the chain: %tile -> %sv_2d -> %arg0_3d +// - From %tile subview: indices = [%m_off, %n_off] (2 entries) +// - From %sv_2d subview: source is 3D %arg0 with offsets [%b, 0, 0] +// - BUG: assert(indices.size() == offsets.size()) fails (2 != 3) +// - In release mode (no assert): indices wrongly accumulate as 2D, then +// createCopyMap maps d0->dim0 producing: %arg0[%b + d0, %m_off, d1] +// +// BUGGY pattern (from buggy.mlir): +// %2 = arith.addi %0, %arg2 // %0 = tile_offset, %arg2 = batch +// %arg0[%2 + d0, %1 + 0, d1] // d0 (128-range partition) added to batch dim (size 2)! +// +// CORRECT pattern (this file): +// %arg0[%arg2 + 0, %0 + d0, %1 + d1] // batch is pure offset, d0 in M dim, d1 in N dim +// +// The 4D SBUF memrefs (128x2x2x128) are fine - they don't use rank-reducing subviews. +// +module attributes {nisa.target = #nisa.target} { + func.func @bmm_kernel(%arg0: memref<2x256x256xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem>, %arg1: memref<2x256x256xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem>) -> memref<2x256x256xf32, #nisa.mem> { + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %mem = nisa.alloc alignment=64 : memref<2x256x256xf32, #nisa.mem> + scf.for %arg2 = %c0 to %c2 step %c1 { + %mem_0 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg3 = %c0 to %c2 step %c1 { + scf.for %arg4 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg4, %c128 : index + %1 = arith.muli %arg3, %c128 : index + nisa.dma_transpose(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_0[d0, %arg3 + 0, %arg4 + 0, d1], src<128| 128>=memref<2x256x256xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem> %arg0[%arg2 + 0, %0 + d0, %1 + d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + %mem_1 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg3 = %c0 to %c2 step %c1 { + scf.for %arg4 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg3, %c128 : index + %1 = arith.muli %arg4, %c128 : index + nisa.dma_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_1[d0, %arg3 + 0, %arg4 + 0, d1], src<128| 128>=memref<2x256x256xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem> %arg1[%arg2 + 0, %0 + d0, %1 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + scf.for %arg3 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg3, %c128 : index + scf.for %arg4 = %c0 to %c2 step %c1 { + %1 = arith.muli %arg4, %c128 : index + %mem_2 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + scf.for %arg5 = %c0 to %c2 step %c1 { + nisa.matmul(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_2[d0, d1], stationary<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_0[d0, %arg5 + 0, %arg3 + 0, d1], moving<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_1[d0, %arg5 + 0, %arg4 + 0, d1], row_pos=index %c0, col_pos=index %c0, is_transpose=false, perf_opt=none_, psum_zero_region=size2048) engine=tensor + } + %mem_3 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_3[d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_2[d0, d1]) engine=vector + nisa.dma_copy(dst<128| 128>=memref<2x256x256xf32, #nisa.mem> %mem[%arg2 + 0, %0 + d0, %1 + d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_3[d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_2 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_0 : memref<128x2x2x128xf32, #nisa.mem> + nisa.release %mem_1 : memref<128x2x2x128xf32, #nisa.mem> + } + return %mem : memref<2x256x256xf32, #nisa.mem> + } +} diff --git a/kernelgen/tests/debug/bmm/kernel.py b/kernelgen/tests/debug/bmm/kernel.py new file mode 100644 index 0000000..3cd694a --- /dev/null +++ b/kernelgen/tests/debug/bmm/kernel.py @@ -0,0 +1,14 @@ +import numpy as np +from nkipy_kernelgen import trace, knob + +batch = 2 +M, N, K = 256, 256, 256 + +@trace(input_specs=[ + ((batch, M, K), "f32"), + ((batch, K, N), "f32"), +]) +def bmm_kernel(a, b): + result = np.matmul(a, b) + knob.knob(result, mem_space="SharedHbm", tile_size=[1, 128, 128], reduction_tile=[128]) + return result diff --git a/kernelgen/tests/debug/qwen3_layer/README.md b/kernelgen/tests/debug/qwen3_layer/README.md new file mode 100644 index 0000000..1f168f1 --- /dev/null +++ b/kernelgen/tests/debug/qwen3_layer/README.md @@ -0,0 +1,240 @@ +# Qwen3 Layer Debug: NISA Lowering Bugs + +## Overview + +`buggy.mlir` is the NISA-lowered output of `kernel.py` (a full Qwen3 transformer +decoder layer). It fails BIR verification and produces incorrect results due to +three categories of bugs in the linalg-to-nisa / simplify-linalg passes. + +`fix_rope_vector_partition.mlir` contains the corrected version (simulation passes). + +``` +source ../run.sh qwen3_layer/buggy.mlir # BIR verification failed +source ../run.sh qwen3_layer/fix_rope_vector_partition.mlir # SIMULATION PASSED +``` + +--- + +## Bug 1: Q/K/V reshape — wrong column interleaving + +**Symptom**: Silent numerical corruption (no compilation error). + +**Location in MLIR**: Lines 196-234 — `(256,256) shared_hbm → (2,2,128,128) hbm/shared_hbm` + +**What happens**: The pass emits a column-by-column 128x2 → 2x128 DMA transpose, +iterating `arg13 = 0..128` and loading 2 adjacent columns per iteration. It maps +`d0 ∈ [0,2)` to the head dimension. But adjacent columns in the `(BS, hidden_size)` +layout are **not** different heads — they are adjacent `head_dim` values within the +**same** head. Head 0 occupies cols `[0, 128)`, head 1 occupies cols `[128, 256)`. + +``` +# Python: (256,256) → (2,128,2,128) → transpose(0,2,1,3) → (2,2,128,128) +# Column j maps to: head = j // head_dim, hd = j % head_dim +# Adjacent cols j, j+1 are both head 0 (for j < 127) — NOT different heads! +``` + +**Fix**: Replace the 128x2 transpose with 128x128 block copies per `(batch, head)`: +```mlir +scf.for %batch = %c0 to %c2 step %c1 { + scf.for %head = %c0 to %c2 step %c1 { + %row = arith.muli %batch, %c128 + %col = arith.muli %head, %c128 + // Load 128x128 block: rows [batch*128, (batch+1)*128), cols [head*128, (head+1)*128) + dma_copy(src=Q_proj[%row+d0, %col+d1] → sbuf_tmp[d0, d1], tile <128|128>) + dma_copy(sbuf_tmp → Q_mh[%batch, %head, d0, d1], tile <128|128>) + } +} +``` + +**Pass to fix**: `simplify-linalg` or `linalg-to-nisa` — the reshape +`(BS, hidden) → (batch, seq, heads, hd)` + transpose `(0,2,1,3)` lowering. + +--- + +## Bug 2: RoPE subtract/add on multi-partition SBUF + +**Symptom**: BIR verification failure: +``` +Invalid access of 1 partitions starting at partition 1 +Opcode: TensorTensor +``` + +**Location in MLIR**: Lines 265-273, 304-312, 348-356, 387-395 — the +`q_rot0 = q0*cos - q1*sin` (and similar) final subtract/add loops. + +**What happens**: The pass lowers `q0*cos - q1*sin` as three separate loops: +1. Loop over 4 heads: `mem_24[i] = q0_slice * cos` → 4x128x64 SBUF +2. Loop over 4 heads: `mem_25[i] = q1_slice * sin` → 4x128x64 SBUF +3. Loop over 4 heads: `mem_26[i] = mem_24[i] - mem_25[i]` ← **BUG** + +Loop 3 uses `tensor_tensor_arith` (`engine=vector`) reading from `mem_24[%arg12+d0, ...]`. +The vector engine processes all 128 SBUF partitions simultaneously and **cannot** +selectively address partition N of a multi-partition tensor. When `%arg12=1`, BIR +verification rejects the access to partition 1. + +A simpler staging-DMA fix (copy each partition to a 1-partition temp) was tried but +causes **SBUF OOM** because all three 4-partition tensors must be live simultaneously. + +**Fix**: Fuse the three loops into one, computing multiply and subtract in 1-partition +temps within a single iteration: +```mlir +scf.for %i = %c0 to %c4 step %c1 { + tmp_a = q0_slice[i] * cos // 1x128x64 sbuf — safe for vector engine + tmp_b = q1_slice[i] * sin // 1x128x64 sbuf + result = tmp_a - tmp_b // 1x128x64 sbuf — both operands are 1-partition + dma_copy(result → output[%i, ...]) +} +``` + +**Pass to fix**: `linalg-to-nisa` — when lowering element-wise binary ops on tensors +tiled with a small partition dim (e.g. `BH=4`), the pass should detect that both +operands of the subtract share the same loop structure and fuse them, avoiding +multi-partition SBUF intermediates. + +--- + +## Bug 3: Head-concat reshape — DMA transpose OOB + +**Symptom**: BIR verification failure: +``` +Access pattern out of bounds on instruction ... Pattern: [[32768,128],[1,1],[1,1],[1,2]] +``` + +**Location in MLIR**: Lines 519-534 — `(4,128,128) shared_hbm → (2,128,2,128) sbuf` +via `mem_49`, then `(2,128,2,128) → (128,2,2,128)` via second transpose. + +**What happens**: The first DMA transpose uses tile `<128|2>` on `memref<2x128x2x128xf32, sbuf>`. +This writes 128 elements into dimension 0 (size 2) — **out of bounds**. The entire +two-step transpose through `mem_49` is structurally wrong. + +**Fix**: Skip `mem_49`. Directly create `mem_51` by transposing each 128x128 head block: +```mlir +// mem_51[hd, head, batch, seq] = mem_48[batch*2+head, seq, hd] +scf.for %bh = %c0 to %c4 step %c1 { + tmp = dma_copy(mem_48[%bh, d0, d1]) // 128x128 sbuf: (seq, hd) + %batch = divui %bh, %c2 + %head = remui %bh, %c2 + dma_transpose(tmp → mem_51[d0, %head, %batch, d1], perm=[1,0]) // (hd, seq) +} +``` + +Note: `mem_51` dim1 = head, dim2 = batch (matching the downstream matmul's +`stationary[d0, %arg14, %arg12, d1]` where arg14 = reduction tile = head, +arg12 = row tile = batch). + +**Pass to fix**: `simplify-linalg` or `linalg-to-nisa` — the inverse reshape +`(BH, seq, hd) → (batch, seq, heads, hd) → (BS, hidden)` lowering. Same root +cause as Bug 1 (incorrect head/column interleaving in the transpose). + +--- + +## Root Cause Summary + +| Bug | Pass stage | Root cause | Compilation error? | +|-----|-----------|------------|-------------------| +| 1 | reshape lowering | Adjacent cols treated as different heads | No (silent corruption) | +| 2 | elementwise fusion | Vector op on multi-partition SBUF | Yes (BIR verification) | +| 3 | reshape lowering | Same as Bug 1, inverse direction | Yes (BIR verification) | + +Bugs 1 and 3 share the same root cause: the reshape `(BS, hidden) ↔ (batch, head, seq, hd)` +is lowered with a column-by-column transpose that conflates the head and head_dim dimensions. +The fix for both is to tile at the head granularity (128x128 blocks) rather than column +granularity (128x2 strips). + +--- + +## Proposed Compiler Pass Fixes + +### Fix A: `SimplifyLinalg.cpp` — `decomposeHighRankTranspose()` (Bugs 1 & 3) + +**Root cause in the pass**: `decomposeHighRankTranspose` handles `[0,2,1,3]` on +`(2,128,2,128)` by looping over identity dims {0(batch=2), 3(hd=128)} and doing a +2D `linalg.transpose [1,0]` on the swapped pair {1(seq=128), 2(heads=2)}. + +This produces 128x2 → 2x128 inner transposes, which get collapsed back through +`expand_shape → subview → collapse_shape` when `getBaseAndOffsets` in linalg-to-nisa +traces to the flat (256,256) base. The collapse_shape strides are lost — `d1*128 + hd` +becomes `hd + d1`, making adjacent columns look like different heads. + +**Proposed change** (in `decomposeHighRankTranspose`, ~line 131): + +When one of the two swapped dims is small (size << the other), move it to the outer +loop instead of transposing it. This converts the inner operation from a transpose to +a plain copy, avoiding the problematic stride through collapse_shape entirely. + +```cpp +// After identifying d0, d1 as the two swapped dims: +int64_t sizeD0 = srcShape[d0], sizeD1 = srcShape[d1]; + +// If one swapped dim is much smaller than the other, loop over the small +// dim instead of transposing. This makes the inner operation a copy +// (both remaining dims are identity), which avoids non-unit-stride +// collapse_shape tile maps that getBaseAndOffsets cannot preserve. +// +// Example: [0,2,1,3] on (2,128,2,128) +// Before: loop batch(2) * hd(128) = 256 iters, inner 128x2 transpose +// After: loop batch(2) * head(2) = 4 iters, inner 128x128 copy +constexpr int64_t kSmallDimThreshold = 16; // heads are typically 2-32 +int64_t smallSwapDim = -1; +if (sizeD0 < sizeD1 && sizeD0 <= kSmallDimThreshold) { + smallSwapDim = d0; // move d0 to outer loops +} else if (sizeD1 < sizeD0 && sizeD1 <= kSmallDimThreshold) { + smallSwapDim = d1; // move d1 to outer loops +} + +if (smallSwapDim >= 0) { + // Treat the small swapped dim as an identity dim (loop over it). + // The remaining inner dims are all identity → emit copy, not transpose. + identityDims.push_back(smallSwapDim); + // Remove from swappedDims so we don't try to transpose it + swappedDims.erase(std::remove(swappedDims.begin(), swappedDims.end(), + smallSwapDim), swappedDims.end()); + // ... continue to loop-nest creation below ... + // Inner operation becomes memref.copy (or linalg.copy) instead of + // linalg.transpose, since only 1 non-identity dim remains. +} +``` + +The key insight: for `[0,2,1,3]` on `(batch, seq, heads, hd)` where `heads=2`: +- Loop over batch, **head**: 2 x 2 = 4 iterations +- Inner: copy of `(seq=128, hd=128)` block — 128x128 DMA copy, no transpose +- Access: `src[batch*128+d0, head*128+d1]` — correct strides, no collapse_shape + +For the inverse (Bug 3, `(BH,seq,hd) → (batch,heads,seq,hd)`), the same optimization +applies: loop over the small head dim, inner is 128x128 transpose (seq,hd)→(hd,seq). + +The dst subview offset computation needs adjustment: for a swapped dim that's now +looped, the loop IV goes to `perm[dim]` position in the dst (not `dim` position), +since this dim's src→dst mapping differs from identity dims. + +### Fix B: Tiling / linalg-to-nisa — fuse RoPE elementwise chain (Bug 2) + +**Root cause**: The tiling pass creates three separate loops for `q0*cos - q1*sin`: +1. `tmp_a[i] = q0[i] * cos` → writes 4-partition SBUF +2. `tmp_b[i] = q1[i] * sin` → writes 4-partition SBUF +3. `result[i] = tmp_a[i] - tmp_b[i]` → reads 4-partition SBUF with vector engine (BUG) + +The vector engine cannot address individual partitions of a multi-partition SBUF tensor. + +**Proposed fix** (two options): + +**Option 1 — Tiling pass**: Detect element-wise chains like `sub(mul(a,b), mul(c,d))` +where all four operands share the same partition-dim tiling. Fuse into a single tiled +loop that keeps all intermediates at 1-partition granularity: + +``` +for i = 0..n_partitions: + tmp_a = a[i] * b[i] // 1-partition temp + tmp_b = c[i] * d[i] // 1-partition temp + out[i] = tmp_a - tmp_b // 1-partition vector op — safe +``` + +This is the approach used in the working fix. The key: the multiply and subtract +share the same outer loop and the intermediates never grow beyond 1 partition. + +**Option 2 — linalg-to-nisa verifier/rewriter**: Add a post-lowering check that +flags any `tensor_tensor_arith` (`engine=vector`) whose SBUF operand has a +loop-varying partition index. When detected, insert DMA staging copies to +1-partition temps. Note: this is less preferred since it increases SBUF pressure +and may cause OOM (as tested — the staging approach failed with SBUF OOM for this +kernel). diff --git a/kernelgen/tests/debug/qwen3_layer/buggy.mlir b/kernelgen/tests/debug/qwen3_layer/buggy.mlir new file mode 100644 index 0000000..5148a95 --- /dev/null +++ b/kernelgen/tests/debug/qwen3_layer/buggy.mlir @@ -0,0 +1,808 @@ +module attributes {nisa.target = #nisa.target} { + func.func @qwen3_layer(%arg0: memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem>, %arg1: memref<256x1xf32, strided<[?, ?], offset: ?>, #nisa.mem>, %arg2: memref<256x1xf32, strided<[?, ?], offset: ?>, #nisa.mem>, %arg3: memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem>, %arg4: memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem>, %arg5: memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem>, %arg6: memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem>, %arg7: memref<1x128x64xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem>, %arg8: memref<1x128x64xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem>, %arg9: memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem>, %arg10: memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem>, %arg11: memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem>) -> memref<256x256xf32, #nisa.mem> attributes {nki.output_names = ["output"]} { + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 1.000000e+00 : f32 + %cst_0 = arith.constant 0xFF800000 : f32 + %cst_1 = arith.constant 0.0883883461 : f32 + %cst_2 = arith.constant 9.99999997E-7 : f32 + %cst_3 = arith.constant 3.906250e-03 : f32 + %cst_4 = arith.constant 0.000000e+00 : f32 + %mem = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + scf.for %arg13 = %c0 to %c2 step %c1 { + %1 = arith.muli %arg13, %c128 : index + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.dma_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem> %arg0[%0 + d0, %1 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.activation(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], bias=f32 %cst_4, scale=f32 %cst, op=square) engine=scalar + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1]) engine=vector + nisa.release %mem_79 : memref<128x128xf32, #nisa.mem> + } + } + %mem_5 = nisa.alloc alignment=64 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + nisa.memset(dst<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_5[d0, %arg12 + 0, 0, d1], value=f32 %cst_4) engine=vector + } + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc : memref<128x1xf32, #nisa.mem> + nisa.tensor_reduce_arith(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem[d0, %arg12 + 0, %arg13 + 0, d1], op=add, negated=false, num_r_dim=1) engine=vector + nisa.tensor_tensor_arith(dst<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_5[d0, %arg12 + 0, 0, d1], lhs<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_5[d0, %arg12 + 0, 0, d1], rhs<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], op=add) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + } + } + nisa.release %mem : memref<128x2x2x128xf32, #nisa.mem> + %mem_6 = nisa.alloc alignment=64 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_5[d0, %arg12 + 0, 0, d1], operand0=f32 %cst_3, op0=multiply, reverse_operands=none_) engine=vector + nisa.tensor_copy(dst<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_6[d0, %arg12 + 0, 0, d1], src<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + } + nisa.release %mem_5 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + %mem_7 = nisa.alloc alignment=64 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_6[d0, %arg12 + 0, 0, d1], operand0=f32 %cst_2, op0=add, reverse_operands=none_) engine=vector + nisa.tensor_copy(dst<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_7[d0, %arg12 + 0, 0, d1], src<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + } + nisa.release %mem_6 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + %mem_8 = nisa.alloc alignment=64 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1xf32, #nisa.mem> + nisa.activation(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_7[d0, %arg12 + 0, 0, d1], bias=f32 %cst_4, scale=f32 %cst, op=sqrt) engine=scalar + nisa.tensor_copy(dst<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_8[d0, %arg12 + 0, 0, d1], src<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + } + nisa.release %mem_7 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + %mem_9 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + %mem_10 = nisa.alloc alignment=64 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1xf32, #nisa.mem> + nisa.reciprocal(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_8[d0, %arg12 + 0, 0, d1]) engine=vector + nisa.tensor_copy(dst<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_10[d0, %arg12 + 0, 0, d1], src<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + } + nisa.release %mem_8 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + scf.for %arg13 = %c0 to %c2 step %c1 { + %1 = arith.muli %arg13, %c128 : index + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.dma_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem> %arg0[%0 + d0, %1 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], operand0<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_10[d0, %arg12 + 0, 0, d1], op0=multiply, reverse_operands=none_) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_9[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1]) engine=vector + nisa.release %mem_79 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_10 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + %mem_11 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1xf32, #nisa.mem> + nisa.dma_copy(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 1>=memref<256x1xf32, strided<[?, ?], offset: ?>, #nisa.mem> %arg1[%0 + d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_9[d0, %arg12 + 0, %arg13 + 0, d1], operand0<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], op0=multiply, reverse_operands=none_) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_11[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1]) engine=vector + nisa.release %mem_79 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_9 : memref<128x2x2x128xf32, #nisa.mem> + %mem_12 = nisa.alloc alignment=64 : memref<256x256xf32, #nisa.mem> + %mem_13 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + nisa.dma_transpose(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_13[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_11[d0, %arg13 + 0, %arg12 + 0, d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + %mem_14 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + %1 = arith.muli %arg13, %c128 : index + nisa.dma_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_14[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem> %arg3[%0 + d0, %1 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + scf.for %arg12 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + scf.for %arg13 = %c0 to %c2 step %c1 { + %1 = arith.muli %arg13, %c128 : index + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + scf.for %arg14 = %c0 to %c2 step %c1 { + nisa.matmul(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], stationary<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_13[d0, %arg14 + 0, %arg12 + 0, d1], moving<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_14[d0, %arg14 + 0, %arg13 + 0, d1], row_pos=index %c0, col_pos=index %c0, is_transpose=false, perf_opt=none_, psum_zero_region=size2048) engine=tensor + } + %mem_79 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.dma_copy(dst<128| 128>=memref<256x256xf32, #nisa.mem> %mem_12[%0 + d0, %1 + d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_14 : memref<128x2x2x128xf32, #nisa.mem> + nisa.release %mem_13 : memref<128x2x2x128xf32, #nisa.mem> + %mem_15 = nisa.alloc alignment=64 : memref<256x256xf32, #nisa.mem> + %mem_16 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + nisa.dma_transpose(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_16[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_11[d0, %arg13 + 0, %arg12 + 0, d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + %mem_17 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + %1 = arith.muli %arg13, %c128 : index + nisa.dma_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_17[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem> %arg4[%0 + d0, %1 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + scf.for %arg12 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + scf.for %arg13 = %c0 to %c2 step %c1 { + %1 = arith.muli %arg13, %c128 : index + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + scf.for %arg14 = %c0 to %c2 step %c1 { + nisa.matmul(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], stationary<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_16[d0, %arg14 + 0, %arg12 + 0, d1], moving<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_17[d0, %arg14 + 0, %arg13 + 0, d1], row_pos=index %c0, col_pos=index %c0, is_transpose=false, perf_opt=none_, psum_zero_region=size2048) engine=tensor + } + %mem_79 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.dma_copy(dst<128| 128>=memref<256x256xf32, #nisa.mem> %mem_15[%0 + d0, %1 + d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_17 : memref<128x2x2x128xf32, #nisa.mem> + nisa.release %mem_16 : memref<128x2x2x128xf32, #nisa.mem> + %mem_18 = nisa.alloc alignment=64 : memref<256x256xf32, #nisa.mem> + %mem_19 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + nisa.dma_transpose(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_19[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_11[d0, %arg13 + 0, %arg12 + 0, d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + nisa.release %mem_11 : memref<128x2x2x128xf32, #nisa.mem> + %mem_20 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + %1 = arith.muli %arg13, %c128 : index + nisa.dma_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_20[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem> %arg5[%0 + d0, %1 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + scf.for %arg12 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + scf.for %arg13 = %c0 to %c2 step %c1 { + %1 = arith.muli %arg13, %c128 : index + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + scf.for %arg14 = %c0 to %c2 step %c1 { + nisa.matmul(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], stationary<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_19[d0, %arg14 + 0, %arg12 + 0, d1], moving<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_20[d0, %arg14 + 0, %arg13 + 0, d1], row_pos=index %c0, col_pos=index %c0, is_transpose=false, perf_opt=none_, psum_zero_region=size2048) engine=tensor + } + %mem_79 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.dma_copy(dst<128| 128>=memref<256x256xf32, #nisa.mem> %mem_18[%0 + d0, %1 + d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_20 : memref<128x2x2x128xf32, #nisa.mem> + nisa.release %mem_19 : memref<128x2x2x128xf32, #nisa.mem> + %mem_21 = nisa.alloc alignment=64 : memref<2x2x128x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c128 step %c1 { + %mem_78 = nisa.alloc : memref<128x2xf32, #nisa.mem> + %mem_79 = nisa.alloc : memref<2x128xf32, #nisa.mem> + %0 = arith.muli %arg12, %c128 : index + nisa.dma_copy(dst<128| 2>=memref<128x2xf32, #nisa.mem> %mem_78[d0, d1], src<128| 2>=memref<256x256xf32, #nisa.mem> %mem_12[%0 + d0, %arg13 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.dma_transpose(dst<2| 128>=memref<2x128xf32, #nisa.mem> %mem_79[d0, d1], src<128| 2>=memref<128x2xf32, #nisa.mem> %mem_78[d0, d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_78 : memref<128x2xf32, #nisa.mem> + nisa.dma_copy(dst<2| 128>=memref<2x2x128x128xf32, #nisa.mem> %mem_21[%arg12 + 0, d0, d1, %arg13 + 0], src<2| 128>=memref<2x128xf32, #nisa.mem> %mem_79[d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_79 : memref<2x128xf32, #nisa.mem> + } + } + %mem_22 = nisa.alloc alignment=64 : memref<2x2x128x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c128 step %c1 { + %mem_78 = nisa.alloc : memref<128x2xf32, #nisa.mem> + %mem_79 = nisa.alloc : memref<2x128xf32, #nisa.mem> + %0 = arith.muli %arg12, %c128 : index + nisa.dma_copy(dst<128| 2>=memref<128x2xf32, #nisa.mem> %mem_78[d0, d1], src<128| 2>=memref<256x256xf32, #nisa.mem> %mem_15[%0 + d0, %arg13 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.dma_transpose(dst<2| 128>=memref<2x128xf32, #nisa.mem> %mem_79[d0, d1], src<128| 2>=memref<128x2xf32, #nisa.mem> %mem_78[d0, d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_78 : memref<128x2xf32, #nisa.mem> + nisa.dma_copy(dst<2| 128>=memref<2x2x128x128xf32, #nisa.mem> %mem_22[%arg12 + 0, d0, d1, %arg13 + 0], src<2| 128>=memref<2x128xf32, #nisa.mem> %mem_79[d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_79 : memref<2x128xf32, #nisa.mem> + } + } + %mem_23 = nisa.alloc alignment=64 : memref<2x2x128x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c128 step %c1 { + %mem_78 = nisa.alloc : memref<128x2xf32, #nisa.mem> + %mem_79 = nisa.alloc : memref<2x128xf32, #nisa.mem> + %0 = arith.muli %arg12, %c128 : index + nisa.dma_copy(dst<128| 2>=memref<128x2xf32, #nisa.mem> %mem_78[d0, d1], src<128| 2>=memref<256x256xf32, #nisa.mem> %mem_18[%0 + d0, %arg13 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.dma_transpose(dst<2| 128>=memref<2x128xf32, #nisa.mem> %mem_79[d0, d1], src<128| 2>=memref<128x2xf32, #nisa.mem> %mem_78[d0, d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_78 : memref<128x2xf32, #nisa.mem> + nisa.dma_copy(dst<2| 128>=memref<2x2x128x128xf32, #nisa.mem> %mem_23[%arg12 + 0, d0, d1, %arg13 + 0], src<2| 128>=memref<2x128xf32, #nisa.mem> %mem_79[d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_79 : memref<2x128xf32, #nisa.mem> + } + } + %mem_24 = nisa.alloc alignment=64 : memref<4x128x64xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + %0 = arith.divui %arg12, %c2 : index + %1 = arith.remui %arg12, %c2 : index + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], src<1| 128, 64>=memref<2x2x128x128xf32, #nisa.mem> %mem_21[%0 + 0, %1 + d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_79[d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem> %arg7[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_80 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_80[d0, d1, d2], lhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], rhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_79[d0, d1, d2], op=multiply) engine=vector + nisa.release %mem_79 : memref<1x128x64xf32, #nisa.mem> + nisa.release %mem_78 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_24[%arg12 + d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_80[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_80 : memref<1x128x64xf32, #nisa.mem> + } + %mem_25 = nisa.alloc alignment=64 : memref<4x128x64xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + %0 = arith.divui %arg12, %c2 : index + %1 = arith.remui %arg12, %c2 : index + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], src<1| 128, 64>=memref<2x2x128x128xf32, #nisa.mem> %mem_21[%0 + 0, %1 + d0, d1, d2 + 64], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_79[d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem> %arg8[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_80 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_80[d0, d1, d2], lhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], rhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_79[d0, d1, d2], op=multiply) engine=vector + nisa.release %mem_79 : memref<1x128x64xf32, #nisa.mem> + nisa.release %mem_78 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_25[%arg12 + d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_80[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_80 : memref<1x128x64xf32, #nisa.mem> + } + %mem_26 = nisa.alloc alignment=64 : memref<4x128x64xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], lhs<1| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_24[%arg12 + d0, d1, d2], rhs<1| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_25[%arg12 + d0, d1, d2], op=subtract) engine=vector + nisa.dma_copy(dst<1| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_26[%arg12 + d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_78 : memref<1x128x64xf32, #nisa.mem> + } + nisa.release %mem_25 : memref<4x128x64xf32, #nisa.mem> + nisa.release %mem_24 : memref<4x128x64xf32, #nisa.mem> + %mem_27 = nisa.alloc alignment=64 : memref<4x128x64xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + %0 = arith.divui %arg12, %c2 : index + %1 = arith.remui %arg12, %c2 : index + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], src<1| 128, 64>=memref<2x2x128x128xf32, #nisa.mem> %mem_21[%0 + 0, %1 + d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_79[d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem> %arg8[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_80 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_80[d0, d1, d2], lhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], rhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_79[d0, d1, d2], op=multiply) engine=vector + nisa.release %mem_79 : memref<1x128x64xf32, #nisa.mem> + nisa.release %mem_78 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_27[%arg12 + d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_80[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_80 : memref<1x128x64xf32, #nisa.mem> + } + %mem_28 = nisa.alloc alignment=64 : memref<4x128x64xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + %0 = arith.divui %arg12, %c2 : index + %1 = arith.remui %arg12, %c2 : index + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], src<1| 128, 64>=memref<2x2x128x128xf32, #nisa.mem> %mem_21[%0 + 0, %1 + d0, d1, d2 + 64], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_79[d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem> %arg7[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_80 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_80[d0, d1, d2], lhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], rhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_79[d0, d1, d2], op=multiply) engine=vector + nisa.release %mem_79 : memref<1x128x64xf32, #nisa.mem> + nisa.release %mem_78 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_28[%arg12 + d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_80[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_80 : memref<1x128x64xf32, #nisa.mem> + } + %mem_29 = nisa.alloc alignment=64 : memref<4x128x64xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], lhs<1| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_27[%arg12 + d0, d1, d2], rhs<1| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_28[%arg12 + d0, d1, d2], op=add) engine=vector + nisa.dma_copy(dst<1| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_29[%arg12 + d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_78 : memref<1x128x64xf32, #nisa.mem> + } + nisa.release %mem_28 : memref<4x128x64xf32, #nisa.mem> + nisa.release %mem_27 : memref<4x128x64xf32, #nisa.mem> + %mem_30 = nisa.alloc alignment=64 : memref<4x128x128xf32, #nisa.mem> + nisa.dma_copy(dst<4| 128, 64>=memref<4x128x128xf32, #nisa.mem> %mem_30[d0, d1, d2], src<4| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_26[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_26 : memref<4x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<4| 128, 64>=memref<4x128x128xf32, #nisa.mem> %mem_30[d0, d1, d2 + 64], src<4| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_29[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_29 : memref<4x128x64xf32, #nisa.mem> + %mem_31 = nisa.alloc alignment=64 : memref<4x128x64xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + %0 = arith.divui %arg12, %c2 : index + %1 = arith.remui %arg12, %c2 : index + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], src<1| 128, 64>=memref<2x2x128x128xf32, #nisa.mem> %mem_22[%0 + 0, %1 + d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_79[d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem> %arg7[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_80 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_80[d0, d1, d2], lhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], rhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_79[d0, d1, d2], op=multiply) engine=vector + nisa.release %mem_79 : memref<1x128x64xf32, #nisa.mem> + nisa.release %mem_78 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_31[%arg12 + d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_80[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_80 : memref<1x128x64xf32, #nisa.mem> + } + %mem_32 = nisa.alloc alignment=64 : memref<4x128x64xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + %0 = arith.divui %arg12, %c2 : index + %1 = arith.remui %arg12, %c2 : index + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], src<1| 128, 64>=memref<2x2x128x128xf32, #nisa.mem> %mem_22[%0 + 0, %1 + d0, d1, d2 + 64], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_79[d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem> %arg8[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_80 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_80[d0, d1, d2], lhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], rhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_79[d0, d1, d2], op=multiply) engine=vector + nisa.release %mem_79 : memref<1x128x64xf32, #nisa.mem> + nisa.release %mem_78 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_32[%arg12 + d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_80[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_80 : memref<1x128x64xf32, #nisa.mem> + } + %mem_33 = nisa.alloc alignment=64 : memref<4x128x64xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], lhs<1| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_31[%arg12 + d0, d1, d2], rhs<1| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_32[%arg12 + d0, d1, d2], op=subtract) engine=vector + nisa.dma_copy(dst<1| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_33[%arg12 + d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_78 : memref<1x128x64xf32, #nisa.mem> + } + nisa.release %mem_32 : memref<4x128x64xf32, #nisa.mem> + nisa.release %mem_31 : memref<4x128x64xf32, #nisa.mem> + %mem_34 = nisa.alloc alignment=64 : memref<4x128x64xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + %0 = arith.divui %arg12, %c2 : index + %1 = arith.remui %arg12, %c2 : index + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], src<1| 128, 64>=memref<2x2x128x128xf32, #nisa.mem> %mem_22[%0 + 0, %1 + d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_79[d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem> %arg8[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_80 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_80[d0, d1, d2], lhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], rhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_79[d0, d1, d2], op=multiply) engine=vector + nisa.release %mem_79 : memref<1x128x64xf32, #nisa.mem> + nisa.release %mem_78 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_34[%arg12 + d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_80[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_80 : memref<1x128x64xf32, #nisa.mem> + } + %mem_35 = nisa.alloc alignment=64 : memref<4x128x64xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + %0 = arith.divui %arg12, %c2 : index + %1 = arith.remui %arg12, %c2 : index + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], src<1| 128, 64>=memref<2x2x128x128xf32, #nisa.mem> %mem_22[%0 + 0, %1 + d0, d1, d2 + 64], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_79[d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem> %arg7[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_80 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_80[d0, d1, d2], lhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], rhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_79[d0, d1, d2], op=multiply) engine=vector + nisa.release %mem_79 : memref<1x128x64xf32, #nisa.mem> + nisa.release %mem_78 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_35[%arg12 + d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_80[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_80 : memref<1x128x64xf32, #nisa.mem> + } + %mem_36 = nisa.alloc alignment=64 : memref<4x128x64xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], lhs<1| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_34[%arg12 + d0, d1, d2], rhs<1| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_35[%arg12 + d0, d1, d2], op=add) engine=vector + nisa.dma_copy(dst<1| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_36[%arg12 + d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_78 : memref<1x128x64xf32, #nisa.mem> + } + nisa.release %mem_35 : memref<4x128x64xf32, #nisa.mem> + nisa.release %mem_34 : memref<4x128x64xf32, #nisa.mem> + %mem_37 = nisa.alloc alignment=64 : memref<4x128x128xf32, #nisa.mem> + nisa.dma_copy(dst<4| 128, 64>=memref<4x128x128xf32, #nisa.mem> %mem_37[d0, d1, d2], src<4| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_33[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_33 : memref<4x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<4| 128, 64>=memref<4x128x128xf32, #nisa.mem> %mem_37[d0, d1, d2 + 64], src<4| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_36[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_36 : memref<4x128x64xf32, #nisa.mem> + %mem_38 = nisa.alloc alignment=64 : memref<4x128x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.dma_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<4x128x128xf32, #nisa.mem> %mem_37[%arg12 + 0, d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.dma_transpose(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + nisa.dma_copy(dst<128| 128>=memref<4x128x128xf32, #nisa.mem> %mem_38[%arg12 + 0, d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_79 : memref<128x128xf32, #nisa.mem> + } + %mem_39 = nisa.alloc alignment=64 : memref<128x1x4x1x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + %mem_79 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.dma_transpose(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], src<128| 128>=memref<4x128x128xf32, #nisa.mem> %mem_30[%arg12 + 0, d0, d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_80 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.dma_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_80[d0, d1], src<128| 128>=memref<4x128x128xf32, #nisa.mem> %mem_38[%arg12 + 0, d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_81 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.matmul(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_81[d0, d1], stationary<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], moving<128| 128>=memref<128x128xf32, #nisa.mem> %mem_80[d0, d1], row_pos=index %c0, col_pos=index %c0, is_transpose=false, perf_opt=none_, psum_zero_region=size2048) engine=tensor + nisa.release %mem_80 : memref<128x128xf32, #nisa.mem> + nisa.release %mem_79 : memref<128x128xf32, #nisa.mem> + %mem_82 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_82[d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_81[d0, d1]) engine=vector + nisa.dma_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_82[d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_81 : memref<128x128xf32, #nisa.mem> + nisa.dma_copy(dst<128| 128>=memref<128x1x4x1x128xf32, #nisa.mem> %mem_39[d0, 0, %arg12 + 0, 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + } + %mem_40 = nisa.alloc alignment=64 : memref<128x1x4x1x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1x128xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 128>=memref<128x1x128xf32, #nisa.mem> %mem_78[d0, 0, d1], src<128| 128>=memref<128x1x4x1x128xf32, #nisa.mem> %mem_39[d0, 0, %arg12 + 0, 0, d1], operand0=f32 %cst_1, op0=multiply, reverse_operands=none_) engine=vector + nisa.tensor_copy(dst<128| 128>=memref<128x1x4x1x128xf32, #nisa.mem> %mem_40[d0, 0, %arg12 + 0, 0, d1], src<128| 128>=memref<128x1x128xf32, #nisa.mem> %mem_78[d0, 0, d1]) engine=vector + nisa.release %mem_78 : memref<128x1x128xf32, #nisa.mem> + } + nisa.release %mem_39 : memref<128x1x4x1x128xf32, #nisa.mem> + %mem_41 = nisa.alloc alignment=64 : memref<128x1x4x1x1xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + nisa.memset(dst<128| 1, 1, 1, 1>=memref<128x1x4x1x1xf32, #nisa.mem> %mem_41[d0, d1, %arg12 + d2, d3, d4], value=f32 %cst_0) engine=vector + } + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc : memref<128x1xf32, #nisa.mem> + nisa.tensor_reduce_arith(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<128x1x4x1x128xf32, #nisa.mem> %mem_40[d0, 0, %arg12 + 0, 0, d1], op=max, negated=false, num_r_dim=1) engine=vector + nisa.tensor_tensor_arith(dst<128| 1>=memref<128x1x4x1x1xf32, #nisa.mem> %mem_41[d0, 0, %arg12 + 0, 0, d1], lhs<128| 1>=memref<128x1x4x1x1xf32, #nisa.mem> %mem_41[d0, 0, %arg12 + 0, 0, d1], rhs<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], op=max) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + } + %mem_42 = nisa.alloc alignment=64 : memref<128x1x4x1x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1x128xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 128>=memref<128x1x128xf32, #nisa.mem> %mem_78[d0, 0, d1], src<128| 128>=memref<128x1x4x1x128xf32, #nisa.mem> %mem_40[d0, 0, %arg12 + 0, 0, d1], operand0<128| 1>=memref<128x1x4x1x1xf32, #nisa.mem> %mem_41[d0, 0, %arg12 + 0, 0, d1], op0=subtract, reverse_operands=none_) engine=vector + nisa.tensor_copy(dst<128| 128>=memref<128x1x4x1x128xf32, #nisa.mem> %mem_42[d0, 0, %arg12 + 0, 0, d1], src<128| 128>=memref<128x1x128xf32, #nisa.mem> %mem_78[d0, 0, d1]) engine=vector + nisa.release %mem_78 : memref<128x1x128xf32, #nisa.mem> + } + nisa.release %mem_41 : memref<128x1x4x1x1xf32, #nisa.mem> + nisa.release %mem_40 : memref<128x1x4x1x128xf32, #nisa.mem> + %mem_43 = nisa.alloc alignment=64 : memref<128x1x4x1x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1x128xf32, #nisa.mem> + nisa.activation(dst<128| 128>=memref<128x1x128xf32, #nisa.mem> %mem_78[d0, 0, d1], src<128| 128>=memref<128x1x4x1x128xf32, #nisa.mem> %mem_42[d0, 0, %arg12 + 0, 0, d1], bias=f32 %cst_4, scale=f32 %cst, op=exp) engine=scalar + nisa.tensor_copy(dst<128| 128>=memref<128x1x4x1x128xf32, #nisa.mem> %mem_43[d0, 0, %arg12 + 0, 0, d1], src<128| 128>=memref<128x1x128xf32, #nisa.mem> %mem_78[d0, 0, d1]) engine=vector + nisa.release %mem_78 : memref<128x1x128xf32, #nisa.mem> + } + nisa.release %mem_42 : memref<128x1x4x1x128xf32, #nisa.mem> + %mem_44 = nisa.alloc alignment=64 : memref<128x1x4x1x1xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + nisa.memset(dst<128| 1, 1, 1, 1>=memref<128x1x4x1x1xf32, #nisa.mem> %mem_44[d0, d1, %arg12 + d2, d3, d4], value=f32 %cst_4) engine=vector + } + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc : memref<128x1xf32, #nisa.mem> + nisa.tensor_reduce_arith(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<128x1x4x1x128xf32, #nisa.mem> %mem_43[d0, 0, %arg12 + 0, 0, d1], op=add, negated=false, num_r_dim=1) engine=vector + nisa.tensor_tensor_arith(dst<128| 1>=memref<128x1x4x1x1xf32, #nisa.mem> %mem_44[d0, 0, %arg12 + 0, 0, d1], lhs<128| 1>=memref<128x1x4x1x1xf32, #nisa.mem> %mem_44[d0, 0, %arg12 + 0, 0, d1], rhs<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], op=add) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + } + %mem_45 = nisa.alloc alignment=64 : memref<128x4x128xf32, #nisa.mem> + %mem_46 = nisa.alloc alignment=64 : memref<128x1x4x1x1xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1x1xf32, #nisa.mem> + nisa.reciprocal(dst<128| 1>=memref<128x1x1xf32, #nisa.mem> %mem_78[d0, 0, d1], src<128| 1>=memref<128x1x4x1x1xf32, #nisa.mem> %mem_44[d0, 0, %arg12 + 0, 0, d1]) engine=vector + nisa.tensor_copy(dst<128| 1>=memref<128x1x4x1x1xf32, #nisa.mem> %mem_46[d0, 0, %arg12 + 0, 0, d1], src<128| 1>=memref<128x1x1xf32, #nisa.mem> %mem_78[d0, 0, d1]) engine=vector + nisa.release %mem_78 : memref<128x1x1xf32, #nisa.mem> + } + nisa.release %mem_44 : memref<128x1x4x1x1xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1x128xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 128>=memref<128x1x128xf32, #nisa.mem> %mem_78[d0, 0, d1], src<128| 128>=memref<128x1x4x1x128xf32, #nisa.mem> %mem_43[d0, 0, %arg12 + 0, 0, d1], operand0<128| 1>=memref<128x1x4x1x1xf32, #nisa.mem> %mem_46[d0, 0, %arg12 + 0, 0, d1], op0=multiply, reverse_operands=none_) engine=vector + nisa.dma_copy(dst<128| 1, 128>=memref<128x4x128xf32, #nisa.mem> %mem_45[d0, %arg12 + d1, d2], src<128| 1, 128>=memref<128x1x128xf32, #nisa.mem> %mem_78[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_78 : memref<128x1x128xf32, #nisa.mem> + } + nisa.release %mem_46 : memref<128x1x4x1x1xf32, #nisa.mem> + nisa.release %mem_43 : memref<128x1x4x1x128xf32, #nisa.mem> + %mem_47 = nisa.alloc alignment=64 : memref<4x128x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + scf.for %arg13 = %c0 to %c128 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<1x1x128xf32, #nisa.mem> + nisa.dma_copy(dst<1| 1, 128>=memref<1x1x128xf32, #nisa.mem> %mem_78[d0, d1, d2], src<1| 1, 128>=memref<128x4x128xf32, #nisa.mem> %mem_45[%arg13 + d0, %arg12 + d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<1x1x128xf32, #nisa.mem> + nisa.tensor_copy(dst<1| 128>=memref<1x1x128xf32, #nisa.mem> %mem_79[d0, 0, d1], src<1| 128>=memref<1x1x128xf32, #nisa.mem> %mem_78[d0, 0, d1]) engine=vector + nisa.release %mem_78 : memref<1x1x128xf32, #nisa.mem> + nisa.dma_copy(dst<1| 1, 128>=memref<4x128x128xf32, #nisa.mem> %mem_47[%arg12 + d0, %arg13 + d1, d2], src<1| 1, 128>=memref<1x1x128xf32, #nisa.mem> %mem_79[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_79 : memref<1x1x128xf32, #nisa.mem> + } + } + %mem_48 = nisa.alloc alignment=64 : memref<4x128x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.dma_transpose(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<4x128x128xf32, #nisa.mem> %mem_47[%arg12 + 0, d0, d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + %0 = arith.divui %arg12, %c2 : index + %1 = arith.remui %arg12, %c2 : index + nisa.dma_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], src<128| 128>=memref<2x2x128x128xf32, #nisa.mem> %mem_23[%0 + 0, %1 + 0, d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_80 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.matmul(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_80[d0, d1], stationary<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], moving<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], row_pos=index %c0, col_pos=index %c0, is_transpose=false, perf_opt=none_, psum_zero_region=size2048) engine=tensor + nisa.release %mem_79 : memref<128x128xf32, #nisa.mem> + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + %mem_81 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_81[d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_80[d0, d1]) engine=vector + nisa.dma_copy(dst<128| 128>=memref<4x128x128xf32, #nisa.mem> %mem_48[%arg12 + 0, d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_81[d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_80 : memref<128x128xf32, #nisa.mem> + } + %mem_49 = nisa.alloc alignment=64 : memref<2x128x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c128 step %c1 { + %0 = arith.muli %arg12, %c2 : index + nisa.dma_transpose(dst<128| 2>=memref<2x128x2x128xf32, #nisa.mem> %mem_49[%arg12 + d0, 0, 0, %arg13 + d1], src<2| 128>=memref<4x128x128xf32, #nisa.mem> %mem_48[%0 + d0, 0, %arg13 + d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + %mem_50 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + %mem_51 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg13, %c128 : index + %1 = arith.muli %arg12, %c128 : index + nisa.dma_transpose(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_51[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<2x128x2x128xf32, #nisa.mem> %mem_49[%0 + d0, %1 + 0, 0, d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + nisa.release %mem_49 : memref<2x128x2x128xf32, #nisa.mem> + %mem_52 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + %1 = arith.muli %arg13, %c128 : index + nisa.dma_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_52[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem> %arg6[%0 + d0, %1 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + scf.for %arg14 = %c0 to %c2 step %c1 { + nisa.matmul(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], stationary<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_51[d0, %arg14 + 0, %arg12 + 0, d1], moving<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_52[d0, %arg14 + 0, %arg13 + 0, d1], row_pos=index %c0, col_pos=index %c0, is_transpose=false, perf_opt=none_, psum_zero_region=size2048) engine=tensor + } + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_50[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_52 : memref<128x2x2x128xf32, #nisa.mem> + nisa.release %mem_51 : memref<128x2x2x128xf32, #nisa.mem> + %mem_53 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + scf.for %arg13 = %c0 to %c2 step %c1 { + %1 = arith.muli %arg13, %c128 : index + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.dma_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem> %arg0[%0 + d0, %1 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], lhs<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], rhs<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_50[d0, %arg12 + 0, %arg13 + 0, d1], op=add) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_53[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1]) engine=vector + nisa.release %mem_79 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_50 : memref<128x2x2x128xf32, #nisa.mem> + %mem_54 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.activation(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_53[d0, %arg12 + 0, %arg13 + 0, d1], bias=f32 %cst_4, scale=f32 %cst, op=square) engine=scalar + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_54[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + %mem_55 = nisa.alloc alignment=64 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + nisa.memset(dst<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_55[d0, %arg12 + 0, 0, d1], value=f32 %cst_4) engine=vector + } + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc : memref<128x1xf32, #nisa.mem> + nisa.tensor_reduce_arith(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_54[d0, %arg12 + 0, %arg13 + 0, d1], op=add, negated=false, num_r_dim=1) engine=vector + nisa.tensor_tensor_arith(dst<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_55[d0, %arg12 + 0, 0, d1], lhs<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_55[d0, %arg12 + 0, 0, d1], rhs<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], op=add) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + } + } + nisa.release %mem_54 : memref<128x2x2x128xf32, #nisa.mem> + %mem_56 = nisa.alloc alignment=64 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_55[d0, %arg12 + 0, 0, d1], operand0=f32 %cst_3, op0=multiply, reverse_operands=none_) engine=vector + nisa.tensor_copy(dst<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_56[d0, %arg12 + 0, 0, d1], src<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + } + nisa.release %mem_55 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + %mem_57 = nisa.alloc alignment=64 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_56[d0, %arg12 + 0, 0, d1], operand0=f32 %cst_2, op0=add, reverse_operands=none_) engine=vector + nisa.tensor_copy(dst<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_57[d0, %arg12 + 0, 0, d1], src<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + } + nisa.release %mem_56 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + %mem_58 = nisa.alloc alignment=64 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1xf32, #nisa.mem> + nisa.activation(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_57[d0, %arg12 + 0, 0, d1], bias=f32 %cst_4, scale=f32 %cst, op=sqrt) engine=scalar + nisa.tensor_copy(dst<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_58[d0, %arg12 + 0, 0, d1], src<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + } + nisa.release %mem_57 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + %mem_59 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + %mem_60 = nisa.alloc alignment=64 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1xf32, #nisa.mem> + nisa.reciprocal(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_58[d0, %arg12 + 0, 0, d1]) engine=vector + nisa.tensor_copy(dst<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_60[d0, %arg12 + 0, 0, d1], src<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + } + nisa.release %mem_58 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_53[d0, %arg12 + 0, %arg13 + 0, d1], operand0<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_60[d0, %arg12 + 0, 0, d1], op0=multiply, reverse_operands=none_) engine=vector + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_59[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_60 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + %mem_61 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1xf32, #nisa.mem> + nisa.dma_copy(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 1>=memref<256x1xf32, strided<[?, ?], offset: ?>, #nisa.mem> %arg2[%0 + d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_59[d0, %arg12 + 0, %arg13 + 0, d1], operand0<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], op0=multiply, reverse_operands=none_) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_61[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1]) engine=vector + nisa.release %mem_79 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_59 : memref<128x2x2x128xf32, #nisa.mem> + %mem_62 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + %mem_63 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + nisa.dma_transpose(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_63[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_61[d0, %arg13 + 0, %arg12 + 0, d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + %mem_64 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + %1 = arith.muli %arg13, %c128 : index + nisa.dma_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_64[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem> %arg9[%0 + d0, %1 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + scf.for %arg14 = %c0 to %c2 step %c1 { + nisa.matmul(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], stationary<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_63[d0, %arg14 + 0, %arg12 + 0, d1], moving<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_64[d0, %arg14 + 0, %arg13 + 0, d1], row_pos=index %c0, col_pos=index %c0, is_transpose=false, perf_opt=none_, psum_zero_region=size2048) engine=tensor + } + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_62[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_64 : memref<128x2x2x128xf32, #nisa.mem> + nisa.release %mem_63 : memref<128x2x2x128xf32, #nisa.mem> + %mem_65 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + %mem_66 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + nisa.dma_transpose(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_66[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_61[d0, %arg13 + 0, %arg12 + 0, d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + nisa.release %mem_61 : memref<128x2x2x128xf32, #nisa.mem> + %mem_67 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + %1 = arith.muli %arg13, %c128 : index + nisa.dma_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_67[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem> %arg10[%0 + d0, %1 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + scf.for %arg14 = %c0 to %c2 step %c1 { + nisa.matmul(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], stationary<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_66[d0, %arg14 + 0, %arg12 + 0, d1], moving<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_67[d0, %arg14 + 0, %arg13 + 0, d1], row_pos=index %c0, col_pos=index %c0, is_transpose=false, perf_opt=none_, psum_zero_region=size2048) engine=tensor + } + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_65[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_67 : memref<128x2x2x128xf32, #nisa.mem> + nisa.release %mem_66 : memref<128x2x2x128xf32, #nisa.mem> + %mem_68 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_62[d0, %arg12 + 0, %arg13 + 0, d1], operand0=f32 %cst_4, op0=subtract, reverse_operands=first) engine=vector + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_68[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + %mem_69 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.activation(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_68[d0, %arg12 + 0, %arg13 + 0, d1], bias=f32 %cst_4, scale=f32 %cst, op=exp) engine=scalar + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_69[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_68 : memref<128x2x2x128xf32, #nisa.mem> + %mem_70 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_69[d0, %arg12 + 0, %arg13 + 0, d1], operand0=f32 %cst, op0=add, reverse_operands=none_) engine=vector + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_70[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_69 : memref<128x2x2x128xf32, #nisa.mem> + %mem_71 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.reciprocal(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_70[d0, %arg12 + 0, %arg13 + 0, d1]) engine=vector + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_71[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_70 : memref<128x2x2x128xf32, #nisa.mem> + %mem_72 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], lhs<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_62[d0, %arg12 + 0, %arg13 + 0, d1], rhs<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_71[d0, %arg12 + 0, %arg13 + 0, d1], op=multiply) engine=vector + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_72[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_71 : memref<128x2x2x128xf32, #nisa.mem> + nisa.release %mem_62 : memref<128x2x2x128xf32, #nisa.mem> + %mem_73 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], lhs<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_72[d0, %arg12 + 0, %arg13 + 0, d1], rhs<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_65[d0, %arg12 + 0, %arg13 + 0, d1], op=multiply) engine=vector + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_73[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_72 : memref<128x2x2x128xf32, #nisa.mem> + nisa.release %mem_65 : memref<128x2x2x128xf32, #nisa.mem> + %mem_74 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + %mem_75 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + nisa.dma_transpose(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_75[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_73[d0, %arg13 + 0, %arg12 + 0, d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + nisa.release %mem_73 : memref<128x2x2x128xf32, #nisa.mem> + %mem_76 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + %1 = arith.muli %arg13, %c128 : index + nisa.dma_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_76[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem> %arg11[%0 + d0, %1 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + scf.for %arg14 = %c0 to %c2 step %c1 { + nisa.matmul(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], stationary<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_75[d0, %arg14 + 0, %arg12 + 0, d1], moving<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_76[d0, %arg14 + 0, %arg13 + 0, d1], row_pos=index %c0, col_pos=index %c0, is_transpose=false, perf_opt=none_, psum_zero_region=size2048) engine=tensor + } + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_74[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_76 : memref<128x2x2x128xf32, #nisa.mem> + nisa.release %mem_75 : memref<128x2x2x128xf32, #nisa.mem> + %mem_77 = nisa.alloc alignment=64 : memref<256x256xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + scf.for %arg13 = %c0 to %c2 step %c1 { + %1 = arith.muli %arg13, %c128 : index + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], lhs<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_53[d0, %arg12 + 0, %arg13 + 0, d1], rhs<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_74[d0, %arg12 + 0, %arg13 + 0, d1], op=add) engine=vector + nisa.dma_copy(dst<128| 128>=memref<256x256xf32, #nisa.mem> %mem_77[%0 + d0, %1 + d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_74 : memref<128x2x2x128xf32, #nisa.mem> + nisa.release %mem_53 : memref<128x2x2x128xf32, #nisa.mem> + return %mem_77 : memref<256x256xf32, #nisa.mem> + } +} \ No newline at end of file diff --git a/kernelgen/tests/debug/qwen3_layer/fix_rope_vector_partition.mlir b/kernelgen/tests/debug/qwen3_layer/fix_rope_vector_partition.mlir new file mode 100644 index 0000000..4c752d7 --- /dev/null +++ b/kernelgen/tests/debug/qwen3_layer/fix_rope_vector_partition.mlir @@ -0,0 +1,782 @@ +module attributes {nisa.target = #nisa.target} { + func.func @qwen3_layer(%arg0: memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem>, %arg1: memref<256x1xf32, strided<[?, ?], offset: ?>, #nisa.mem>, %arg2: memref<256x1xf32, strided<[?, ?], offset: ?>, #nisa.mem>, %arg3: memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem>, %arg4: memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem>, %arg5: memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem>, %arg6: memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem>, %arg7: memref<1x128x64xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem>, %arg8: memref<1x128x64xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem>, %arg9: memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem>, %arg10: memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem>, %arg11: memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem>) -> memref<256x256xf32, #nisa.mem> attributes {nki.output_names = ["output"]} { + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 1.000000e+00 : f32 + %cst_0 = arith.constant 0xFF800000 : f32 + %cst_1 = arith.constant 0.0883883461 : f32 + %cst_2 = arith.constant 9.99999997E-7 : f32 + %cst_3 = arith.constant 3.906250e-03 : f32 + %cst_4 = arith.constant 0.000000e+00 : f32 + %mem = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + scf.for %arg13 = %c0 to %c2 step %c1 { + %1 = arith.muli %arg13, %c128 : index + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.dma_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem> %arg0[%0 + d0, %1 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.activation(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], bias=f32 %cst_4, scale=f32 %cst, op=square) engine=scalar + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1]) engine=vector + nisa.release %mem_79 : memref<128x128xf32, #nisa.mem> + } + } + %mem_5 = nisa.alloc alignment=64 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + nisa.memset(dst<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_5[d0, %arg12 + 0, 0, d1], value=f32 %cst_4) engine=vector + } + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc : memref<128x1xf32, #nisa.mem> + nisa.tensor_reduce_arith(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem[d0, %arg12 + 0, %arg13 + 0, d1], op=add, negated=false, num_r_dim=1) engine=vector + nisa.tensor_tensor_arith(dst<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_5[d0, %arg12 + 0, 0, d1], lhs<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_5[d0, %arg12 + 0, 0, d1], rhs<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], op=add) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + } + } + nisa.release %mem : memref<128x2x2x128xf32, #nisa.mem> + %mem_6 = nisa.alloc alignment=64 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_5[d0, %arg12 + 0, 0, d1], operand0=f32 %cst_3, op0=multiply, reverse_operands=none_) engine=vector + nisa.tensor_copy(dst<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_6[d0, %arg12 + 0, 0, d1], src<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + } + nisa.release %mem_5 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + %mem_7 = nisa.alloc alignment=64 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_6[d0, %arg12 + 0, 0, d1], operand0=f32 %cst_2, op0=add, reverse_operands=none_) engine=vector + nisa.tensor_copy(dst<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_7[d0, %arg12 + 0, 0, d1], src<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + } + nisa.release %mem_6 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + %mem_8 = nisa.alloc alignment=64 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1xf32, #nisa.mem> + nisa.activation(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_7[d0, %arg12 + 0, 0, d1], bias=f32 %cst_4, scale=f32 %cst, op=sqrt) engine=scalar + nisa.tensor_copy(dst<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_8[d0, %arg12 + 0, 0, d1], src<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + } + nisa.release %mem_7 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + %mem_9 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + %mem_10 = nisa.alloc alignment=64 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1xf32, #nisa.mem> + nisa.reciprocal(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_8[d0, %arg12 + 0, 0, d1]) engine=vector + nisa.tensor_copy(dst<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_10[d0, %arg12 + 0, 0, d1], src<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + } + nisa.release %mem_8 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + scf.for %arg13 = %c0 to %c2 step %c1 { + %1 = arith.muli %arg13, %c128 : index + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.dma_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem> %arg0[%0 + d0, %1 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], operand0<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_10[d0, %arg12 + 0, 0, d1], op0=multiply, reverse_operands=none_) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_9[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1]) engine=vector + nisa.release %mem_79 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_10 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + %mem_11 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1xf32, #nisa.mem> + nisa.dma_copy(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 1>=memref<256x1xf32, strided<[?, ?], offset: ?>, #nisa.mem> %arg1[%0 + d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_9[d0, %arg12 + 0, %arg13 + 0, d1], operand0<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], op0=multiply, reverse_operands=none_) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_11[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1]) engine=vector + nisa.release %mem_79 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_9 : memref<128x2x2x128xf32, #nisa.mem> + %mem_12 = nisa.alloc alignment=64 : memref<256x256xf32, #nisa.mem> + %mem_13 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + nisa.dma_transpose(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_13[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_11[d0, %arg13 + 0, %arg12 + 0, d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + %mem_14 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + %1 = arith.muli %arg13, %c128 : index + nisa.dma_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_14[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem> %arg3[%0 + d0, %1 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + scf.for %arg12 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + scf.for %arg13 = %c0 to %c2 step %c1 { + %1 = arith.muli %arg13, %c128 : index + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + scf.for %arg14 = %c0 to %c2 step %c1 { + nisa.matmul(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], stationary<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_13[d0, %arg14 + 0, %arg12 + 0, d1], moving<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_14[d0, %arg14 + 0, %arg13 + 0, d1], row_pos=index %c0, col_pos=index %c0, is_transpose=false, perf_opt=none_, psum_zero_region=size2048) engine=tensor + } + %mem_79 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.dma_copy(dst<128| 128>=memref<256x256xf32, #nisa.mem> %mem_12[%0 + d0, %1 + d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_14 : memref<128x2x2x128xf32, #nisa.mem> + nisa.release %mem_13 : memref<128x2x2x128xf32, #nisa.mem> + %mem_15 = nisa.alloc alignment=64 : memref<256x256xf32, #nisa.mem> + %mem_16 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + nisa.dma_transpose(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_16[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_11[d0, %arg13 + 0, %arg12 + 0, d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + %mem_17 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + %1 = arith.muli %arg13, %c128 : index + nisa.dma_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_17[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem> %arg4[%0 + d0, %1 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + scf.for %arg12 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + scf.for %arg13 = %c0 to %c2 step %c1 { + %1 = arith.muli %arg13, %c128 : index + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + scf.for %arg14 = %c0 to %c2 step %c1 { + nisa.matmul(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], stationary<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_16[d0, %arg14 + 0, %arg12 + 0, d1], moving<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_17[d0, %arg14 + 0, %arg13 + 0, d1], row_pos=index %c0, col_pos=index %c0, is_transpose=false, perf_opt=none_, psum_zero_region=size2048) engine=tensor + } + %mem_79 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.dma_copy(dst<128| 128>=memref<256x256xf32, #nisa.mem> %mem_15[%0 + d0, %1 + d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_17 : memref<128x2x2x128xf32, #nisa.mem> + nisa.release %mem_16 : memref<128x2x2x128xf32, #nisa.mem> + %mem_18 = nisa.alloc alignment=64 : memref<256x256xf32, #nisa.mem> + %mem_19 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + nisa.dma_transpose(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_19[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_11[d0, %arg13 + 0, %arg12 + 0, d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + nisa.release %mem_11 : memref<128x2x2x128xf32, #nisa.mem> + %mem_20 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + %1 = arith.muli %arg13, %c128 : index + nisa.dma_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_20[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem> %arg5[%0 + d0, %1 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + scf.for %arg12 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + scf.for %arg13 = %c0 to %c2 step %c1 { + %1 = arith.muli %arg13, %c128 : index + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + scf.for %arg14 = %c0 to %c2 step %c1 { + nisa.matmul(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], stationary<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_19[d0, %arg14 + 0, %arg12 + 0, d1], moving<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_20[d0, %arg14 + 0, %arg13 + 0, d1], row_pos=index %c0, col_pos=index %c0, is_transpose=false, perf_opt=none_, psum_zero_region=size2048) engine=tensor + } + %mem_79 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.dma_copy(dst<128| 128>=memref<256x256xf32, #nisa.mem> %mem_18[%0 + d0, %1 + d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_20 : memref<128x2x2x128xf32, #nisa.mem> + nisa.release %mem_19 : memref<128x2x2x128xf32, #nisa.mem> + // FIX: Q reshape (256,256) → (2,2,128,128) + // Original used 128x2→2x128 transpose treating adjacent cols as different heads. + // Adjacent cols are actually different head_dim values within the SAME head. + // Head 0 occupies cols [0,128), head 1 occupies cols [128,256). + // Fix: copy 128x128 blocks per (batch, head) pair — no transpose needed. + %mem_21 = nisa.alloc alignment=64 : memref<2x2x128x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + %0 = arith.muli %arg12, %c128 : index + %1 = arith.muli %arg13, %c128 : index + nisa.dma_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<256x256xf32, #nisa.mem> %mem_12[%0 + d0, %1 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.dma_copy(dst<128| 128>=memref<2x2x128x128xf32, #nisa.mem> %mem_21[%arg12 + 0, %arg13 + 0, d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + // FIX: K reshape — same fix as Q + %mem_22 = nisa.alloc alignment=64 : memref<2x2x128x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + %0 = arith.muli %arg12, %c128 : index + %1 = arith.muli %arg13, %c128 : index + nisa.dma_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<256x256xf32, #nisa.mem> %mem_15[%0 + d0, %1 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.dma_copy(dst<128| 128>=memref<2x2x128x128xf32, #nisa.mem> %mem_22[%arg12 + 0, %arg13 + 0, d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + // FIX: V reshape — same fix as Q/K + %mem_23 = nisa.alloc alignment=64 : memref<2x2x128x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + %0 = arith.muli %arg12, %c128 : index + %1 = arith.muli %arg13, %c128 : index + nisa.dma_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<256x256xf32, #nisa.mem> %mem_18[%0 + d0, %1 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.dma_copy(dst<128| 128>=memref<2x2x128x128xf32, #nisa.mem> %mem_23[%arg12 + 0, %arg13 + 0, d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + // FIX: q_rot0 = q0 * freqs_cos - q1 * freqs_sin + // Fused into one loop so all vector ops use 1-partition temps. + // Original had 3 separate loops writing to 4-partition SBUF intermediates + // (mem_24, mem_25), then a tensor_tensor_arith reading individual partitions + // of those 4-partition tensors via engine=vector — which is illegal because + // the vector engine processes all partitions simultaneously and cannot + // selectively address partition N of a multi-partition SBUF tensor. + %mem_26 = nisa.alloc alignment=64 : memref<4x128x64xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %0 = arith.divui %arg12, %c2 : index + %1 = arith.remui %arg12, %c2 : index + // q0 * freqs_cos + %mem_78 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], src<1| 128, 64>=memref<2x2x128x128xf32, #nisa.mem> %mem_21[%0 + 0, %1 + d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_79[d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem> %arg7[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_80 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_80[d0, d1, d2], lhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], rhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_79[d0, d1, d2], op=multiply) engine=vector + nisa.release %mem_79 : memref<1x128x64xf32, #nisa.mem> + nisa.release %mem_78 : memref<1x128x64xf32, #nisa.mem> + // q1 * freqs_sin + %mem_81 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_81[d0, d1, d2], src<1| 128, 64>=memref<2x2x128x128xf32, #nisa.mem> %mem_21[%0 + 0, %1 + d0, d1, d2 + 64], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_82 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_82[d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem> %arg8[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_83 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_83[d0, d1, d2], lhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_81[d0, d1, d2], rhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_82[d0, d1, d2], op=multiply) engine=vector + nisa.release %mem_82 : memref<1x128x64xf32, #nisa.mem> + nisa.release %mem_81 : memref<1x128x64xf32, #nisa.mem> + // q_rot0 = q0*cos - q1*sin (both operands are 1-partition, safe for vector engine) + %mem_84 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_84[d0, d1, d2], lhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_80[d0, d1, d2], rhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_83[d0, d1, d2], op=subtract) engine=vector + nisa.release %mem_83 : memref<1x128x64xf32, #nisa.mem> + nisa.release %mem_80 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_26[%arg12 + d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_84[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_84 : memref<1x128x64xf32, #nisa.mem> + } + // FIX: q_rot1 = q0 * freqs_sin + q1 * freqs_cos (same fused-loop fix) + %mem_29 = nisa.alloc alignment=64 : memref<4x128x64xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %0 = arith.divui %arg12, %c2 : index + %1 = arith.remui %arg12, %c2 : index + // q0 * freqs_sin + %mem_78 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], src<1| 128, 64>=memref<2x2x128x128xf32, #nisa.mem> %mem_21[%0 + 0, %1 + d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_79[d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem> %arg8[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_80 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_80[d0, d1, d2], lhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], rhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_79[d0, d1, d2], op=multiply) engine=vector + nisa.release %mem_79 : memref<1x128x64xf32, #nisa.mem> + nisa.release %mem_78 : memref<1x128x64xf32, #nisa.mem> + // q1 * freqs_cos + %mem_81 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_81[d0, d1, d2], src<1| 128, 64>=memref<2x2x128x128xf32, #nisa.mem> %mem_21[%0 + 0, %1 + d0, d1, d2 + 64], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_82 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_82[d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem> %arg7[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_83 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_83[d0, d1, d2], lhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_81[d0, d1, d2], rhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_82[d0, d1, d2], op=multiply) engine=vector + nisa.release %mem_82 : memref<1x128x64xf32, #nisa.mem> + nisa.release %mem_81 : memref<1x128x64xf32, #nisa.mem> + // q_rot1 = q0*sin + q1*cos + %mem_84 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_84[d0, d1, d2], lhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_80[d0, d1, d2], rhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_83[d0, d1, d2], op=add) engine=vector + nisa.release %mem_83 : memref<1x128x64xf32, #nisa.mem> + nisa.release %mem_80 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_29[%arg12 + d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_84[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_84 : memref<1x128x64xf32, #nisa.mem> + } + %mem_30 = nisa.alloc alignment=64 : memref<4x128x128xf32, #nisa.mem> + nisa.dma_copy(dst<4| 128, 64>=memref<4x128x128xf32, #nisa.mem> %mem_30[d0, d1, d2], src<4| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_26[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_26 : memref<4x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<4| 128, 64>=memref<4x128x128xf32, #nisa.mem> %mem_30[d0, d1, d2 + 64], src<4| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_29[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_29 : memref<4x128x64xf32, #nisa.mem> + // FIX: k_rot0 = k0 * freqs_cos - k1 * freqs_sin (same fused-loop fix) + %mem_33 = nisa.alloc alignment=64 : memref<4x128x64xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %0 = arith.divui %arg12, %c2 : index + %1 = arith.remui %arg12, %c2 : index + // k0 * freqs_cos + %mem_78 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], src<1| 128, 64>=memref<2x2x128x128xf32, #nisa.mem> %mem_22[%0 + 0, %1 + d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_79[d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem> %arg7[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_80 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_80[d0, d1, d2], lhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], rhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_79[d0, d1, d2], op=multiply) engine=vector + nisa.release %mem_79 : memref<1x128x64xf32, #nisa.mem> + nisa.release %mem_78 : memref<1x128x64xf32, #nisa.mem> + // k1 * freqs_sin + %mem_81 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_81[d0, d1, d2], src<1| 128, 64>=memref<2x2x128x128xf32, #nisa.mem> %mem_22[%0 + 0, %1 + d0, d1, d2 + 64], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_82 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_82[d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem> %arg8[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_83 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_83[d0, d1, d2], lhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_81[d0, d1, d2], rhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_82[d0, d1, d2], op=multiply) engine=vector + nisa.release %mem_82 : memref<1x128x64xf32, #nisa.mem> + nisa.release %mem_81 : memref<1x128x64xf32, #nisa.mem> + // k_rot0 = k0*cos - k1*sin + %mem_84 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_84[d0, d1, d2], lhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_80[d0, d1, d2], rhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_83[d0, d1, d2], op=subtract) engine=vector + nisa.release %mem_83 : memref<1x128x64xf32, #nisa.mem> + nisa.release %mem_80 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_33[%arg12 + d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_84[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_84 : memref<1x128x64xf32, #nisa.mem> + } + // FIX: k_rot1 = k0 * freqs_sin + k1 * freqs_cos (same fused-loop fix) + %mem_36 = nisa.alloc alignment=64 : memref<4x128x64xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %0 = arith.divui %arg12, %c2 : index + %1 = arith.remui %arg12, %c2 : index + // k0 * freqs_sin + %mem_78 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], src<1| 128, 64>=memref<2x2x128x128xf32, #nisa.mem> %mem_22[%0 + 0, %1 + d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_79[d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem> %arg8[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_80 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_80[d0, d1, d2], lhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_78[d0, d1, d2], rhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_79[d0, d1, d2], op=multiply) engine=vector + nisa.release %mem_79 : memref<1x128x64xf32, #nisa.mem> + nisa.release %mem_78 : memref<1x128x64xf32, #nisa.mem> + // k1 * freqs_cos + %mem_81 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_81[d0, d1, d2], src<1| 128, 64>=memref<2x2x128x128xf32, #nisa.mem> %mem_22[%0 + 0, %1 + d0, d1, d2 + 64], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_82 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_82[d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, strided<[?, ?, ?], offset: ?>, #nisa.mem> %arg7[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_83 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_83[d0, d1, d2], lhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_81[d0, d1, d2], rhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_82[d0, d1, d2], op=multiply) engine=vector + nisa.release %mem_82 : memref<1x128x64xf32, #nisa.mem> + nisa.release %mem_81 : memref<1x128x64xf32, #nisa.mem> + // k_rot1 = k0*sin + k1*cos + %mem_84 = nisa.alloc alignment=64 : memref<1x128x64xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_84[d0, d1, d2], lhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_80[d0, d1, d2], rhs<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_83[d0, d1, d2], op=add) engine=vector + nisa.release %mem_83 : memref<1x128x64xf32, #nisa.mem> + nisa.release %mem_80 : memref<1x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<1| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_36[%arg12 + d0, d1, d2], src<1| 128, 64>=memref<1x128x64xf32, #nisa.mem> %mem_84[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_84 : memref<1x128x64xf32, #nisa.mem> + } + %mem_37 = nisa.alloc alignment=64 : memref<4x128x128xf32, #nisa.mem> + nisa.dma_copy(dst<4| 128, 64>=memref<4x128x128xf32, #nisa.mem> %mem_37[d0, d1, d2], src<4| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_33[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_33 : memref<4x128x64xf32, #nisa.mem> + nisa.dma_copy(dst<4| 128, 64>=memref<4x128x128xf32, #nisa.mem> %mem_37[d0, d1, d2 + 64], src<4| 128, 64>=memref<4x128x64xf32, #nisa.mem> %mem_36[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_36 : memref<4x128x64xf32, #nisa.mem> + %mem_38 = nisa.alloc alignment=64 : memref<4x128x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.dma_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<4x128x128xf32, #nisa.mem> %mem_37[%arg12 + 0, d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.dma_transpose(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + nisa.dma_copy(dst<128| 128>=memref<4x128x128xf32, #nisa.mem> %mem_38[%arg12 + 0, d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_79 : memref<128x128xf32, #nisa.mem> + } + %mem_39 = nisa.alloc alignment=64 : memref<128x1x4x1x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + %mem_79 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.dma_transpose(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], src<128| 128>=memref<4x128x128xf32, #nisa.mem> %mem_30[%arg12 + 0, d0, d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_80 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.dma_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_80[d0, d1], src<128| 128>=memref<4x128x128xf32, #nisa.mem> %mem_38[%arg12 + 0, d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_81 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.matmul(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_81[d0, d1], stationary<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], moving<128| 128>=memref<128x128xf32, #nisa.mem> %mem_80[d0, d1], row_pos=index %c0, col_pos=index %c0, is_transpose=false, perf_opt=none_, psum_zero_region=size2048) engine=tensor + nisa.release %mem_80 : memref<128x128xf32, #nisa.mem> + nisa.release %mem_79 : memref<128x128xf32, #nisa.mem> + %mem_82 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_82[d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_81[d0, d1]) engine=vector + nisa.dma_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_82[d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_81 : memref<128x128xf32, #nisa.mem> + nisa.dma_copy(dst<128| 128>=memref<128x1x4x1x128xf32, #nisa.mem> %mem_39[d0, 0, %arg12 + 0, 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + } + %mem_40 = nisa.alloc alignment=64 : memref<128x1x4x1x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1x128xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 128>=memref<128x1x128xf32, #nisa.mem> %mem_78[d0, 0, d1], src<128| 128>=memref<128x1x4x1x128xf32, #nisa.mem> %mem_39[d0, 0, %arg12 + 0, 0, d1], operand0=f32 %cst_1, op0=multiply, reverse_operands=none_) engine=vector + nisa.tensor_copy(dst<128| 128>=memref<128x1x4x1x128xf32, #nisa.mem> %mem_40[d0, 0, %arg12 + 0, 0, d1], src<128| 128>=memref<128x1x128xf32, #nisa.mem> %mem_78[d0, 0, d1]) engine=vector + nisa.release %mem_78 : memref<128x1x128xf32, #nisa.mem> + } + nisa.release %mem_39 : memref<128x1x4x1x128xf32, #nisa.mem> + %mem_41 = nisa.alloc alignment=64 : memref<128x1x4x1x1xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + nisa.memset(dst<128| 1, 1, 1, 1>=memref<128x1x4x1x1xf32, #nisa.mem> %mem_41[d0, d1, %arg12 + d2, d3, d4], value=f32 %cst_0) engine=vector + } + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc : memref<128x1xf32, #nisa.mem> + nisa.tensor_reduce_arith(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<128x1x4x1x128xf32, #nisa.mem> %mem_40[d0, 0, %arg12 + 0, 0, d1], op=max, negated=false, num_r_dim=1) engine=vector + nisa.tensor_tensor_arith(dst<128| 1>=memref<128x1x4x1x1xf32, #nisa.mem> %mem_41[d0, 0, %arg12 + 0, 0, d1], lhs<128| 1>=memref<128x1x4x1x1xf32, #nisa.mem> %mem_41[d0, 0, %arg12 + 0, 0, d1], rhs<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], op=max) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + } + %mem_42 = nisa.alloc alignment=64 : memref<128x1x4x1x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1x128xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 128>=memref<128x1x128xf32, #nisa.mem> %mem_78[d0, 0, d1], src<128| 128>=memref<128x1x4x1x128xf32, #nisa.mem> %mem_40[d0, 0, %arg12 + 0, 0, d1], operand0<128| 1>=memref<128x1x4x1x1xf32, #nisa.mem> %mem_41[d0, 0, %arg12 + 0, 0, d1], op0=subtract, reverse_operands=none_) engine=vector + nisa.tensor_copy(dst<128| 128>=memref<128x1x4x1x128xf32, #nisa.mem> %mem_42[d0, 0, %arg12 + 0, 0, d1], src<128| 128>=memref<128x1x128xf32, #nisa.mem> %mem_78[d0, 0, d1]) engine=vector + nisa.release %mem_78 : memref<128x1x128xf32, #nisa.mem> + } + nisa.release %mem_41 : memref<128x1x4x1x1xf32, #nisa.mem> + nisa.release %mem_40 : memref<128x1x4x1x128xf32, #nisa.mem> + %mem_43 = nisa.alloc alignment=64 : memref<128x1x4x1x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1x128xf32, #nisa.mem> + nisa.activation(dst<128| 128>=memref<128x1x128xf32, #nisa.mem> %mem_78[d0, 0, d1], src<128| 128>=memref<128x1x4x1x128xf32, #nisa.mem> %mem_42[d0, 0, %arg12 + 0, 0, d1], bias=f32 %cst_4, scale=f32 %cst, op=exp) engine=scalar + nisa.tensor_copy(dst<128| 128>=memref<128x1x4x1x128xf32, #nisa.mem> %mem_43[d0, 0, %arg12 + 0, 0, d1], src<128| 128>=memref<128x1x128xf32, #nisa.mem> %mem_78[d0, 0, d1]) engine=vector + nisa.release %mem_78 : memref<128x1x128xf32, #nisa.mem> + } + nisa.release %mem_42 : memref<128x1x4x1x128xf32, #nisa.mem> + %mem_44 = nisa.alloc alignment=64 : memref<128x1x4x1x1xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + nisa.memset(dst<128| 1, 1, 1, 1>=memref<128x1x4x1x1xf32, #nisa.mem> %mem_44[d0, d1, %arg12 + d2, d3, d4], value=f32 %cst_4) engine=vector + } + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc : memref<128x1xf32, #nisa.mem> + nisa.tensor_reduce_arith(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<128x1x4x1x128xf32, #nisa.mem> %mem_43[d0, 0, %arg12 + 0, 0, d1], op=add, negated=false, num_r_dim=1) engine=vector + nisa.tensor_tensor_arith(dst<128| 1>=memref<128x1x4x1x1xf32, #nisa.mem> %mem_44[d0, 0, %arg12 + 0, 0, d1], lhs<128| 1>=memref<128x1x4x1x1xf32, #nisa.mem> %mem_44[d0, 0, %arg12 + 0, 0, d1], rhs<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], op=add) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + } + %mem_45 = nisa.alloc alignment=64 : memref<128x4x128xf32, #nisa.mem> + %mem_46 = nisa.alloc alignment=64 : memref<128x1x4x1x1xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1x1xf32, #nisa.mem> + nisa.reciprocal(dst<128| 1>=memref<128x1x1xf32, #nisa.mem> %mem_78[d0, 0, d1], src<128| 1>=memref<128x1x4x1x1xf32, #nisa.mem> %mem_44[d0, 0, %arg12 + 0, 0, d1]) engine=vector + nisa.tensor_copy(dst<128| 1>=memref<128x1x4x1x1xf32, #nisa.mem> %mem_46[d0, 0, %arg12 + 0, 0, d1], src<128| 1>=memref<128x1x1xf32, #nisa.mem> %mem_78[d0, 0, d1]) engine=vector + nisa.release %mem_78 : memref<128x1x1xf32, #nisa.mem> + } + nisa.release %mem_44 : memref<128x1x4x1x1xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1x128xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 128>=memref<128x1x128xf32, #nisa.mem> %mem_78[d0, 0, d1], src<128| 128>=memref<128x1x4x1x128xf32, #nisa.mem> %mem_43[d0, 0, %arg12 + 0, 0, d1], operand0<128| 1>=memref<128x1x4x1x1xf32, #nisa.mem> %mem_46[d0, 0, %arg12 + 0, 0, d1], op0=multiply, reverse_operands=none_) engine=vector + nisa.dma_copy(dst<128| 1, 128>=memref<128x4x128xf32, #nisa.mem> %mem_45[d0, %arg12 + d1, d2], src<128| 1, 128>=memref<128x1x128xf32, #nisa.mem> %mem_78[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_78 : memref<128x1x128xf32, #nisa.mem> + } + nisa.release %mem_46 : memref<128x1x4x1x1xf32, #nisa.mem> + nisa.release %mem_43 : memref<128x1x4x1x128xf32, #nisa.mem> + %mem_47 = nisa.alloc alignment=64 : memref<4x128x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + scf.for %arg13 = %c0 to %c128 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<1x1x128xf32, #nisa.mem> + nisa.dma_copy(dst<1| 1, 128>=memref<1x1x128xf32, #nisa.mem> %mem_78[d0, d1, d2], src<1| 1, 128>=memref<128x4x128xf32, #nisa.mem> %mem_45[%arg13 + d0, %arg12 + d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<1x1x128xf32, #nisa.mem> + nisa.tensor_copy(dst<1| 128>=memref<1x1x128xf32, #nisa.mem> %mem_79[d0, 0, d1], src<1| 128>=memref<1x1x128xf32, #nisa.mem> %mem_78[d0, 0, d1]) engine=vector + nisa.release %mem_78 : memref<1x1x128xf32, #nisa.mem> + nisa.dma_copy(dst<1| 1, 128>=memref<4x128x128xf32, #nisa.mem> %mem_47[%arg12 + d0, %arg13 + d1, d2], src<1| 1, 128>=memref<1x1x128xf32, #nisa.mem> %mem_79[d0, d1, d2], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_79 : memref<1x1x128xf32, #nisa.mem> + } + } + %mem_48 = nisa.alloc alignment=64 : memref<4x128x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.dma_transpose(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<4x128x128xf32, #nisa.mem> %mem_47[%arg12 + 0, d0, d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + %0 = arith.divui %arg12, %c2 : index + %1 = arith.remui %arg12, %c2 : index + nisa.dma_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], src<128| 128>=memref<2x2x128x128xf32, #nisa.mem> %mem_23[%0 + 0, %1 + 0, d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_80 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.matmul(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_80[d0, d1], stationary<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], moving<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], row_pos=index %c0, col_pos=index %c0, is_transpose=false, perf_opt=none_, psum_zero_region=size2048) engine=tensor + nisa.release %mem_79 : memref<128x128xf32, #nisa.mem> + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + %mem_81 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_81[d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_80[d0, d1]) engine=vector + nisa.dma_copy(dst<128| 128>=memref<4x128x128xf32, #nisa.mem> %mem_48[%arg12 + 0, d0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_81[d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_80 : memref<128x128xf32, #nisa.mem> + } + // FIX: head-concat reshape: (4,128,128) → (128,2,2,128) for matmul stationary. + // Original used a broken 2-step transpose via mem_49 (2,128,2,128 sbuf) where + // the DMA transpose tile <128|2> tried to write 128 partitions to a 2-partition + // tensor — out of bounds. + // Fix: skip mem_49, directly transpose each (seq,hd) block into mem_51. + // mem_51[hd, batch, head, seq] = mem_48[batch*2+head, seq, hd] + %mem_50 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + %mem_51 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c4 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.dma_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<4x128x128xf32, #nisa.mem> %mem_48[%arg12 + 0, d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + %0 = arith.divui %arg12, %c2 : index + %1 = arith.remui %arg12, %c2 : index + nisa.dma_transpose(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_51[d0, %1 + 0, %0 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + %mem_52 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + %1 = arith.muli %arg13, %c128 : index + nisa.dma_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_52[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem> %arg6[%0 + d0, %1 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + scf.for %arg14 = %c0 to %c2 step %c1 { + nisa.matmul(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], stationary<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_51[d0, %arg14 + 0, %arg12 + 0, d1], moving<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_52[d0, %arg14 + 0, %arg13 + 0, d1], row_pos=index %c0, col_pos=index %c0, is_transpose=false, perf_opt=none_, psum_zero_region=size2048) engine=tensor + } + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_50[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_52 : memref<128x2x2x128xf32, #nisa.mem> + nisa.release %mem_51 : memref<128x2x2x128xf32, #nisa.mem> + %mem_53 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + scf.for %arg13 = %c0 to %c2 step %c1 { + %1 = arith.muli %arg13, %c128 : index + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.dma_copy(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem> %arg0[%0 + d0, %1 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], lhs<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], rhs<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_50[d0, %arg12 + 0, %arg13 + 0, d1], op=add) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_53[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1]) engine=vector + nisa.release %mem_79 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_50 : memref<128x2x2x128xf32, #nisa.mem> + %mem_54 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.activation(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_53[d0, %arg12 + 0, %arg13 + 0, d1], bias=f32 %cst_4, scale=f32 %cst, op=square) engine=scalar + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_54[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + %mem_55 = nisa.alloc alignment=64 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + nisa.memset(dst<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_55[d0, %arg12 + 0, 0, d1], value=f32 %cst_4) engine=vector + } + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc : memref<128x1xf32, #nisa.mem> + nisa.tensor_reduce_arith(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_54[d0, %arg12 + 0, %arg13 + 0, d1], op=add, negated=false, num_r_dim=1) engine=vector + nisa.tensor_tensor_arith(dst<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_55[d0, %arg12 + 0, 0, d1], lhs<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_55[d0, %arg12 + 0, 0, d1], rhs<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], op=add) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + } + } + nisa.release %mem_54 : memref<128x2x2x128xf32, #nisa.mem> + %mem_56 = nisa.alloc alignment=64 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_55[d0, %arg12 + 0, 0, d1], operand0=f32 %cst_3, op0=multiply, reverse_operands=none_) engine=vector + nisa.tensor_copy(dst<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_56[d0, %arg12 + 0, 0, d1], src<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + } + nisa.release %mem_55 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + %mem_57 = nisa.alloc alignment=64 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_56[d0, %arg12 + 0, 0, d1], operand0=f32 %cst_2, op0=add, reverse_operands=none_) engine=vector + nisa.tensor_copy(dst<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_57[d0, %arg12 + 0, 0, d1], src<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + } + nisa.release %mem_56 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + %mem_58 = nisa.alloc alignment=64 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1xf32, #nisa.mem> + nisa.activation(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_57[d0, %arg12 + 0, 0, d1], bias=f32 %cst_4, scale=f32 %cst, op=sqrt) engine=scalar + nisa.tensor_copy(dst<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_58[d0, %arg12 + 0, 0, d1], src<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + } + nisa.release %mem_57 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + %mem_59 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + %mem_60 = nisa.alloc alignment=64 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1xf32, #nisa.mem> + nisa.reciprocal(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_58[d0, %arg12 + 0, 0, d1]) engine=vector + nisa.tensor_copy(dst<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_60[d0, %arg12 + 0, 0, d1], src<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + } + nisa.release %mem_58 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_53[d0, %arg12 + 0, %arg13 + 0, d1], operand0<128| 1>=memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> %mem_60[d0, %arg12 + 0, 0, d1], op0=multiply, reverse_operands=none_) engine=vector + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_59[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_60 : memref<128x2x1x1xf32, strided<[1, 128, 128, 128]>, #nisa.mem> + %mem_61 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x1xf32, #nisa.mem> + nisa.dma_copy(dst<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], src<128| 1>=memref<256x1xf32, strided<[?, ?], offset: ?>, #nisa.mem> %arg2[%0 + d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + %mem_79 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_59[d0, %arg12 + 0, %arg13 + 0, d1], operand0<128| 1>=memref<128x1xf32, #nisa.mem> %mem_78[d0, d1], op0=multiply, reverse_operands=none_) engine=vector + nisa.release %mem_78 : memref<128x1xf32, #nisa.mem> + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_61[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_79[d0, d1]) engine=vector + nisa.release %mem_79 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_59 : memref<128x2x2x128xf32, #nisa.mem> + %mem_62 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + %mem_63 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + nisa.dma_transpose(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_63[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_61[d0, %arg13 + 0, %arg12 + 0, d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + %mem_64 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + %1 = arith.muli %arg13, %c128 : index + nisa.dma_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_64[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem> %arg9[%0 + d0, %1 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + scf.for %arg14 = %c0 to %c2 step %c1 { + nisa.matmul(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], stationary<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_63[d0, %arg14 + 0, %arg12 + 0, d1], moving<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_64[d0, %arg14 + 0, %arg13 + 0, d1], row_pos=index %c0, col_pos=index %c0, is_transpose=false, perf_opt=none_, psum_zero_region=size2048) engine=tensor + } + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_62[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_64 : memref<128x2x2x128xf32, #nisa.mem> + nisa.release %mem_63 : memref<128x2x2x128xf32, #nisa.mem> + %mem_65 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + %mem_66 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + nisa.dma_transpose(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_66[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_61[d0, %arg13 + 0, %arg12 + 0, d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + nisa.release %mem_61 : memref<128x2x2x128xf32, #nisa.mem> + %mem_67 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + %1 = arith.muli %arg13, %c128 : index + nisa.dma_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_67[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem> %arg10[%0 + d0, %1 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + scf.for %arg14 = %c0 to %c2 step %c1 { + nisa.matmul(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], stationary<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_66[d0, %arg14 + 0, %arg12 + 0, d1], moving<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_67[d0, %arg14 + 0, %arg13 + 0, d1], row_pos=index %c0, col_pos=index %c0, is_transpose=false, perf_opt=none_, psum_zero_region=size2048) engine=tensor + } + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_65[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_67 : memref<128x2x2x128xf32, #nisa.mem> + nisa.release %mem_66 : memref<128x2x2x128xf32, #nisa.mem> + %mem_68 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_62[d0, %arg12 + 0, %arg13 + 0, d1], operand0=f32 %cst_4, op0=subtract, reverse_operands=first) engine=vector + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_68[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + %mem_69 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.activation(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_68[d0, %arg12 + 0, %arg13 + 0, d1], bias=f32 %cst_4, scale=f32 %cst, op=exp) engine=scalar + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_69[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_68 : memref<128x2x2x128xf32, #nisa.mem> + %mem_70 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_scalar_arith(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_69[d0, %arg12 + 0, %arg13 + 0, d1], operand0=f32 %cst, op0=add, reverse_operands=none_) engine=vector + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_70[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_69 : memref<128x2x2x128xf32, #nisa.mem> + %mem_71 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.reciprocal(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_70[d0, %arg12 + 0, %arg13 + 0, d1]) engine=vector + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_71[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_70 : memref<128x2x2x128xf32, #nisa.mem> + %mem_72 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], lhs<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_62[d0, %arg12 + 0, %arg13 + 0, d1], rhs<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_71[d0, %arg12 + 0, %arg13 + 0, d1], op=multiply) engine=vector + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_72[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_71 : memref<128x2x2x128xf32, #nisa.mem> + nisa.release %mem_62 : memref<128x2x2x128xf32, #nisa.mem> + %mem_73 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], lhs<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_72[d0, %arg12 + 0, %arg13 + 0, d1], rhs<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_65[d0, %arg12 + 0, %arg13 + 0, d1], op=multiply) engine=vector + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_73[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_72 : memref<128x2x2x128xf32, #nisa.mem> + nisa.release %mem_65 : memref<128x2x2x128xf32, #nisa.mem> + %mem_74 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + %mem_75 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + nisa.dma_transpose(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_75[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_73[d0, %arg13 + 0, %arg12 + 0, d1], permutation=[1, 0], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + nisa.release %mem_73 : memref<128x2x2x128xf32, #nisa.mem> + %mem_76 = nisa.alloc alignment=64 : memref<128x2x2x128xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + %1 = arith.muli %arg13, %c128 : index + nisa.dma_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_76[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<256x256xf32, strided<[?, ?], offset: ?>, #nisa.mem> %arg11[%0 + d0, %1 + d1], dge_mode=no_dge, oob_is_err=true) engine=dma + } + } + scf.for %arg12 = %c0 to %c2 step %c1 { + scf.for %arg13 = %c0 to %c2 step %c1 { + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + scf.for %arg14 = %c0 to %c2 step %c1 { + nisa.matmul(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], stationary<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_75[d0, %arg14 + 0, %arg12 + 0, d1], moving<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_76[d0, %arg14 + 0, %arg13 + 0, d1], row_pos=index %c0, col_pos=index %c0, is_transpose=false, perf_opt=none_, psum_zero_region=size2048) engine=tensor + } + nisa.tensor_copy(dst<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_74[d0, %arg12 + 0, %arg13 + 0, d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1]) engine=vector + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_76 : memref<128x2x2x128xf32, #nisa.mem> + nisa.release %mem_75 : memref<128x2x2x128xf32, #nisa.mem> + %mem_77 = nisa.alloc alignment=64 : memref<256x256xf32, #nisa.mem> + scf.for %arg12 = %c0 to %c2 step %c1 { + %0 = arith.muli %arg12, %c128 : index + scf.for %arg13 = %c0 to %c2 step %c1 { + %1 = arith.muli %arg13, %c128 : index + %mem_78 = nisa.alloc alignment=64 : memref<128x128xf32, #nisa.mem> + nisa.tensor_tensor_arith(dst<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], lhs<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_53[d0, %arg12 + 0, %arg13 + 0, d1], rhs<128| 128>=memref<128x2x2x128xf32, #nisa.mem> %mem_74[d0, %arg12 + 0, %arg13 + 0, d1], op=add) engine=vector + nisa.dma_copy(dst<128| 128>=memref<256x256xf32, #nisa.mem> %mem_77[%0 + d0, %1 + d1], src<128| 128>=memref<128x128xf32, #nisa.mem> %mem_78[d0, d1], dge_mode=no_dge, oob_is_err=true) engine=dma + nisa.release %mem_78 : memref<128x128xf32, #nisa.mem> + } + } + nisa.release %mem_74 : memref<128x2x2x128xf32, #nisa.mem> + nisa.release %mem_53 : memref<128x2x2x128xf32, #nisa.mem> + return %mem_77 : memref<256x256xf32, #nisa.mem> + } +} \ No newline at end of file diff --git a/kernelgen/tests/debug/qwen3_layer/kernel.py b/kernelgen/tests/debug/qwen3_layer/kernel.py new file mode 100644 index 0000000..6eba245 --- /dev/null +++ b/kernelgen/tests/debug/qwen3_layer/kernel.py @@ -0,0 +1,305 @@ +""" +Qwen3 Transformer Decoder Layer (inlined for readability). + +All sub-kernels (RMSNorm, RoPE, softmax, SiLU) are inlined so the full +data flow is visible in one place. + +Shape convention: + - 2D projections use (BS, hidden_size) where BS = batch * seq_len. + This flattens batch and sequence into one "token" dimension so the + matmuls are plain 2D. The reshape to (batch, seq_len, n_heads, head_dim) + recovers the sequence dimension when needed for multi-head attention. + - 3D attention tensors are (BH, seq_len, X) where BH = batch * n_heads. + The partition dimension for these is dim 1 (seq_len), NOT dim 0 (BH). +""" +import numpy as np +from nkipy_kernelgen import trace, knob + +# ---------------------------------------------------------------- +# Model hyperparameters +# ---------------------------------------------------------------- +batch = 2 +seq_len = 128 +hidden_size = 256 +n_heads = 2 +head_dim = hidden_size // n_heads # 128 +intermediate_size = 256 +half_dim = head_dim // 2 # 64 +eps = 1e-6 +scale = 1.0 / np.sqrt(head_dim).item() + +# Derived (flattened dimensions) +BS = batch * seq_len # 256 (tokens = batch * seq_len) +BH = batch * n_heads # 4 (heads = batch * n_heads) + +# ---------------------------------------------------------------- +# Tile sizes +# ---------------------------------------------------------------- +matmul_tile_2d = [128, 128] +matmul_reduction_2d = [128] +attn_tile = [1, 128, 128] # (BH, seq_len, seq_len/head_dim) +attn_reduction = [128] +rope_tile = [1, 128, 64] # (BH, seq_len, half_dim) +elem_tile_2d = [128, 128] + + +@trace(input_specs=[ + # hidden_states: (BS, hidden_size) = (256, 256) + # BS = batch * seq_len, flattened for 2D matmul projections. + # Reshape to (batch, seq_len, n_heads, head_dim) recovers seq_len + # for multi-head attention. + ((BS, hidden_size), "f32"), + # RMSNorm weights — (hidden_size, 1) so broadcast is over the free dim + # ([P, 1] pattern), which maps to nisa.tensor_scalar_arith. + ((hidden_size, 1), "f32"), # ln1_weight + ((hidden_size, 1), "f32"), # ln2_weight + # Attention projection weights + ((hidden_size, hidden_size), "f32"), # w_q + ((hidden_size, hidden_size), "f32"), # w_k + ((hidden_size, hidden_size), "f32"), # w_v + ((hidden_size, hidden_size), "f32"), # w_o + # RoPE frequencies (position-dependent, broadcast over BH) + ((1, seq_len, half_dim), "f32"), # freqs_cos + ((1, seq_len, half_dim), "f32"), # freqs_sin + # FFN weights + ((hidden_size, intermediate_size), "f32"), # w_gate + ((hidden_size, intermediate_size), "f32"), # w_up + ((intermediate_size, hidden_size), "f32"), # w_down +]) +def qwen3_layer(hidden_states, + ln1_weight, ln2_weight, + w_q, w_k, w_v, w_o, + freqs_cos, freqs_sin, + w_gate, w_up, w_down): + + residual = hidden_states # (BS, hidden_size) + + # ================================================================ + # 1. Pre-attention RMSNorm + # norm(x) = x / sqrt(mean(x^2) + eps) * weight + # ================================================================ + x_fp32 = hidden_states.astype(np.float32) + w_fp32 = ln1_weight.astype(np.float32) + + sq = np.square(x_fp32) # (256, 256) + knob.knob(sq, mem_space="Sbuf", tile_size=elem_tile_2d) + + sum_sq = np.sum(sq, axis=-1, keepdims=True) # (256, 1) + knob.knob(sum_sq, mem_space="Sbuf", tile_size=[128], reduction_tile=[128]) + + mean_sq = sum_sq * np.float32(1.0 / hidden_size) # (256, 1) + knob.knob(mean_sq, mem_space="Sbuf", tile_size=[128, 1]) + + normed = x_fp32 / np.sqrt(mean_sq + eps) # (256, 256) + knob.knob(normed, mem_space="Sbuf", tile_size=elem_tile_2d) + + normed = normed * w_fp32 # (256, 256) + knob.knob(normed, mem_space="Sbuf", tile_size=elem_tile_2d) + + # ================================================================ + # 2. QKV projections (2D matmuls on flattened BS dimension) + # SharedHbm = sub-kernel boundary (results flow through reshape) + # ================================================================ + q = np.matmul(normed, w_q) # (256, 256) + knob.knob(q, mem_space="SharedHbm", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + k = np.matmul(normed, w_k) # (256, 256) + knob.knob(k, mem_space="SharedHbm", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + v = np.matmul(normed, w_v) # (256, 256) + knob.knob(v, mem_space="SharedHbm", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + # ================================================================ + # 3. Reshape to multi-head format + # (BS, hidden) -> (batch, seq_len, n_heads, head_dim) + # -> transpose to (batch, n_heads, seq_len, head_dim) + # -> flatten to (BH, seq_len, head_dim) + # ================================================================ + q = np.reshape(q, (batch, seq_len, n_heads, head_dim)) # (2, 128, 2, 128) + q = np.transpose(q, (0, 2, 1, 3)) # (2, 2, 128, 128) + q = np.reshape(q, (BH, seq_len, head_dim)) # (4, 128, 128) + + k = np.reshape(k, (batch, seq_len, n_heads, head_dim)) + k = np.transpose(k, (0, 2, 1, 3)) + k = np.reshape(k, (BH, seq_len, head_dim)) # (4, 128, 128) + + v = np.reshape(v, (batch, seq_len, n_heads, head_dim)) + v = np.transpose(v, (0, 2, 1, 3)) + v = np.reshape(v, (BH, seq_len, head_dim)) # (4, 128, 128) + # V is a sub-kernel boundary: keep in SharedHbm so the 4D transpose + # intermediate stays in HBM (avoids a 4D SBUF alloc that legalize-layout + # cannot tile). + knob.knob(v, mem_space="SharedHbm", tile_size=attn_tile) + + # ================================================================ + # 4. RoPE on Q and K (not V) + # Split head_dim in half, rotate: [x0, x1] -> [x0*cos - x1*sin, + # x0*sin + x1*cos] + # freqs_cos/sin are (1, 128, 64) — broadcast over BH dim + # ================================================================ + # --- RoPE on Q --- + q0 = q[:, :, :half_dim] # (4, 128, 64) + q1 = q[:, :, half_dim:] # (4, 128, 64) + + q0_cos = q0 * freqs_cos # (4, 128, 64) + knob.knob(q0_cos, mem_space="SharedHbm", tile_size=rope_tile) + q1_sin = q1 * freqs_sin # (4, 128, 64) + knob.knob(q1_sin, mem_space="SharedHbm", tile_size=rope_tile) + q_rot0 = q0_cos - q1_sin # (4, 128, 64) + knob.knob(q_rot0, mem_space="SharedHbm", tile_size=rope_tile) + + q0_sin = q0 * freqs_sin # (4, 128, 64) + knob.knob(q0_sin, mem_space="SharedHbm", tile_size=rope_tile) + q1_cos = q1 * freqs_cos # (4, 128, 64) + knob.knob(q1_cos, mem_space="SharedHbm", tile_size=rope_tile) + q_rot1 = q0_sin + q1_cos # (4, 128, 64) + knob.knob(q_rot1, mem_space="SharedHbm", tile_size=rope_tile) + + q = np.concatenate([q_rot0, q_rot1], axis=-1) # (4, 128, 128) + knob.knob(q, mem_space="SharedHbm", tile_size=attn_tile) + + # --- RoPE on K --- + k0 = k[:, :, :half_dim] # (4, 128, 64) + k1 = k[:, :, half_dim:] # (4, 128, 64) + + k0_cos = k0 * freqs_cos # (4, 128, 64) + knob.knob(k0_cos, mem_space="SharedHbm", tile_size=rope_tile) + k1_sin = k1 * freqs_sin # (4, 128, 64) + knob.knob(k1_sin, mem_space="SharedHbm", tile_size=rope_tile) + k_rot0 = k0_cos - k1_sin # (4, 128, 64) + knob.knob(k_rot0, mem_space="SharedHbm", tile_size=rope_tile) + + k0_sin = k0 * freqs_sin # (4, 128, 64) + knob.knob(k0_sin, mem_space="SharedHbm", tile_size=rope_tile) + k1_cos = k1 * freqs_cos # (4, 128, 64) + knob.knob(k1_cos, mem_space="SharedHbm", tile_size=rope_tile) + k_rot1 = k0_sin + k1_cos # (4, 128, 64) + knob.knob(k_rot1, mem_space="SharedHbm", tile_size=rope_tile) + + k = np.concatenate([k_rot0, k_rot1], axis=-1) # (4, 128, 128) + knob.knob(k, mem_space="SharedHbm", tile_size=attn_tile) + + # K^T for attention scores + k_t = np.transpose(k, (0, 2, 1)) # (4, 128, 128) + knob.knob(k_t, mem_space="SharedHbm", tile_size=attn_tile) + + # ================================================================ + # 5. Scaled dot-product attention + # scores = (Q @ K^T) * scale + # weights = softmax(scores) + # context = weights @ V + # ================================================================ + scores = np.matmul(q, k_t) # (4, 128, 128) + knob.knob(scores, mem_space="Sbuf", tile_size=attn_tile, reduction_tile=attn_reduction) + + scores = scores * scale # (4, 128, 128) + knob.knob(scores, mem_space="Sbuf", tile_size=attn_tile, partition_dim=1) + + # --- softmax (numerically stable) --- + scores_fp32 = scores.astype(np.float32) + + s_max = np.max(scores_fp32, axis=-1, keepdims=True) # (4, 128, 1) + knob.knob(s_max, mem_space="Sbuf", tile_size=[1, 128], + reduction_tile=[128], partition_dim=1) + + shifted = scores_fp32 - s_max # (4, 128, 128) + knob.knob(shifted, mem_space="Sbuf", tile_size=attn_tile, partition_dim=1) + + exp_s = np.exp(shifted) # (4, 128, 128) + knob.knob(exp_s, mem_space="Sbuf", tile_size=attn_tile, partition_dim=1) + + sum_exp = np.sum(exp_s, axis=-1, keepdims=True) # (4, 128, 1) + knob.knob(sum_exp, mem_space="Sbuf", tile_size=[1, 128], + reduction_tile=[128], partition_dim=1) + + attn_weights = exp_s / sum_exp # (4, 128, 128) + knob.knob(attn_weights, mem_space="SharedHbm", tile_size=attn_tile) + + # --- context = attn_weights @ V --- + context = np.matmul(attn_weights, v) # (4, 128, 128) + knob.knob(context, mem_space="SharedHbm", tile_size=attn_tile, reduction_tile=attn_reduction) + + # ================================================================ + # 6. Concat heads + output projection + # (BH, seq_len, head_dim) -> (batch, n_heads, seq_len, head_dim) + # -> transpose to (batch, seq_len, n_heads, head_dim) + # -> flatten to (BS, hidden_size) + # ================================================================ + context = np.reshape(context, (batch, n_heads, seq_len, head_dim)) + context = np.transpose(context, (0, 2, 1, 3)) + context = np.reshape(context, (BS, hidden_size)) # (256, 256) + + attn_out = np.matmul(context, w_o) # (256, 256) + knob.knob(attn_out, mem_space="Sbuf", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + # ================================================================ + # 7. First residual connection + # ================================================================ + hidden_states = residual + attn_out # (256, 256) + knob.knob(hidden_states, mem_space="Sbuf", tile_size=elem_tile_2d) + + residual = hidden_states + + # ================================================================ + # 8. Post-attention RMSNorm + # ================================================================ + x_fp32 = hidden_states.astype(np.float32) + w_fp32 = ln2_weight.astype(np.float32) + + sq = np.square(x_fp32) # (256, 256) + knob.knob(sq, mem_space="Sbuf", tile_size=elem_tile_2d) + + sum_sq = np.sum(sq, axis=-1, keepdims=True) # (256, 1) + knob.knob(sum_sq, mem_space="Sbuf", tile_size=[128], reduction_tile=[128]) + + mean_sq = sum_sq * np.float32(1.0 / hidden_size) # (256, 1) + knob.knob(mean_sq, mem_space="Sbuf", tile_size=[128, 1]) + + normed = x_fp32 / np.sqrt(mean_sq + eps) # (256, 256) + knob.knob(normed, mem_space="Sbuf", tile_size=elem_tile_2d) + + normed = normed * w_fp32 # (256, 256) + knob.knob(normed, mem_space="Sbuf", tile_size=elem_tile_2d) + + # ================================================================ + # 9. SwiGLU FFN + # gate = SiLU(normed @ w_gate) + # up = normed @ w_up + # out = (gate * up) @ w_down + # ================================================================ + gate = np.matmul(normed, w_gate) # (256, 256) + knob.knob(gate, mem_space="Sbuf", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + up = np.matmul(normed, w_up) # (256, 256) + knob.knob(up, mem_space="Sbuf", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + # --- SiLU(gate) = gate * sigmoid(gate) --- + neg_gate = -gate # (256, 256) + knob.knob(neg_gate, mem_space="Sbuf", tile_size=elem_tile_2d) + + exp_neg = np.exp(neg_gate) # (256, 256) + knob.knob(exp_neg, mem_space="Sbuf", tile_size=elem_tile_2d) + + one_plus = exp_neg + 1.0 # (256, 256) + knob.knob(one_plus, mem_space="Sbuf", tile_size=elem_tile_2d) + + sigmoid = 1.0 / one_plus # (256, 256) + knob.knob(sigmoid, mem_space="Sbuf", tile_size=elem_tile_2d) + + gate = gate * sigmoid # (256, 256) + knob.knob(gate, mem_space="Sbuf", tile_size=elem_tile_2d) + + # --- gated output --- + gated = gate * up # (256, 256) + knob.knob(gated, mem_space="Sbuf", tile_size=elem_tile_2d) + + ffn_out = np.matmul(gated, w_down) # (256, 256) + knob.knob(ffn_out, mem_space="Sbuf", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + # ================================================================ + # 10. Second residual connection + # ================================================================ + output = residual + ffn_out # (256, 256) + knob.knob(output, mem_space="SharedHbm", tile_size=elem_tile_2d) + + return output # (256, 256) diff --git a/kernelgen/tests/debug/run.sh b/kernelgen/tests/debug/run.sh new file mode 100755 index 0000000..fc1a9f6 --- /dev/null +++ b/kernelgen/tests/debug/run.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Usage: source ./run.sh +# +# Runs BIRSim on a pre-compiled NISA MLIR file and compares against +# a NumPy reference (kernel.py in the same directory as the MLIR file). +# +# Examples: +# source ./run.sh matmul_add_sbuf_oom/buggy.mlir +# source ./run.sh psum_accumulate_flags_fix/fixed.mlir + +MLIR_FILE="${1:?Usage: source ./run.sh }" +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +NKIPY_ROOT="$(dirname "$(dirname "$SCRIPT_DIR")")" + +# Source NKI environment +source "$NKIPY_ROOT/scripts/setup_nki.sh" + +# Add NKIPyKernelGen and e2e test utils to PYTHONPATH +export PYTHONPATH="$NKIPY_ROOT:$NKIPY_ROOT/tests/e2e:$PYTHONPATH" + +# Resolve to absolute path if relative +if [[ "$MLIR_FILE" != /* ]]; then + MLIR_FILE="$SCRIPT_DIR/$MLIR_FILE" +fi + +if [ ! -f "$MLIR_FILE" ]; then + echo "Error: File not found: $MLIR_FILE" + return 1 2>/dev/null || exit 1 +fi + +echo "" +echo "Running BIRSim on: $MLIR_FILE" +echo "================================================" + +python3 "$SCRIPT_DIR/run_sim.py" "$MLIR_FILE" diff --git a/kernelgen/tests/debug/run_sim.py b/kernelgen/tests/debug/run_sim.py new file mode 100755 index 0000000..4046d4b --- /dev/null +++ b/kernelgen/tests/debug/run_sim.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +""" +Run BIRSim on a pre-compiled NISA MLIR file and compare against numpy reference. + +Usage: python3 run_sim.py + +Expects a kernel.py in the same directory as containing a function +with the same name as the MLIR func. After BIRSim succeeds, runs the numpy +reference with the same inputs and compares outputs. +""" + +import sys +import os +import shutil +import importlib.util +import numpy as np + +from nki.compiler.ncc_driver import CompileOptions, compile_mlir_to_neff +from nki.compiler._internal import ir, register_all_dialects + + +MLIR_DTYPE_TO_NP = { + "f32": np.float32, + "f16": np.float16, + "f64": np.float64, + "i32": np.int32, + "i64": np.int64, +} + + +def get_memref_info(memref_type): + """Extract shape and numpy dtype from an MLIR MemRefType.""" + shape = tuple(memref_type.shape) + etype = memref_type.element_type + np_dtype = MLIR_DTYPE_TO_NP.get(str(etype), np.float32) + return shape, np_dtype + + +def main(): + if len(sys.argv) < 2: + print("Usage: python3 run_sim.py ", file=sys.stderr) + sys.exit(1) + + mlir_file = sys.argv[1] + if not os.path.exists(mlir_file): + print(f"Error: File not found: {mlir_file}", file=sys.stderr) + sys.exit(1) + + with open(mlir_file, "r") as f: + mlir_str = f.read() + + # Parse MLIR and extract function signature + ctx = ir.Context() + register_all_dialects(ctx) + + with ctx: + module = ir.Module.parse(mlir_str, ctx) + + # Find the func.func operation + func_op = None + for op in module.body.operations: + if "function_type" in op.attributes: + func_op = op + break + + if func_op is None: + print("Error: No function found in module", file=sys.stderr) + sys.exit(1) + + func_name = func_op.attributes["sym_name"].value + func_type = func_op.attributes["function_type"].value + + # Extract input types + input_specs = [] + for arg_type in func_type.inputs: + shape, dtype = get_memref_info(ir.MemRefType(arg_type)) + input_specs.append((shape, dtype)) + + # Extract output type + results = list(func_type.results) + assert len(results) == 1, f"Expected 1 result, got {len(results)}" + out_type = results[0] + out_shape, out_dtype = get_memref_info(ir.MemRefType(out_type)) + + print(f"Function: {func_name}") + for i, (shape, dtype) in enumerate(input_specs): + print(f" arg{i}: {shape} {dtype}") + print(f" returns: {out_shape} {out_dtype}") + + # Generate small random inputs ([-1, 1] range for numerical stability) + np.random.seed(42) + test_inputs = [] + for shape, dtype in input_specs: + test_inputs.append(np.random.uniform(-1, 1, size=shape).astype(dtype)) + + output_placeholder = np.zeros(out_shape, dtype=out_dtype) + + # Artifacts dir + base = os.path.splitext(os.path.basename(mlir_file))[0] + artifacts_dir = os.path.join( + os.path.dirname(os.path.abspath(mlir_file)), f"artifacts_{base}" + ) + if os.path.exists(artifacts_dir): + shutil.rmtree(artifacts_dir) + os.makedirs(artifacts_dir) + + input_names = [f"in_tensor_{i}" for i in range(len(test_inputs))] + + # Extract output name from nki.output_names attribute if present + output_name = "out_tensor_0" + if "nki.output_names" in func_op.attributes: + names_attr = ir.ArrayAttr(func_op.attributes["nki.output_names"]) + output_name = str(ir.StringAttr(names_attr[0])).strip('"') + + opts = CompileOptions( + target="trn2", + verbose=True, + output_path=os.path.join(artifacts_dir, "kernel.neff"), + neuronx_cc_args=("--lnc=1",), + artifacts_dir=artifacts_dir, + enable_simulation=True, + ) + + print(f"\nArtifacts: {artifacts_dir}") + print("=" * 60) + + compile_result = compile_mlir_to_neff( + module, + func_name, + list(test_inputs) + [output_placeholder], + input_names + [output_name], + [output_name], + opts, + ) + + print("=" * 60) + + if compile_result.neuronx_cc_error: + print(f"\nFAILED: neuronx-cc error: {compile_result.neuronx_cc_error}") + log_path = os.path.join(artifacts_dir, "log-neuron-cc.txt") + if os.path.exists(log_path): + with open(log_path, "r") as f: + lines = f.readlines() + print(f"\nErrors from log:") + for line in lines: + if "ERROR" in line or "ISIM" in line: + print(f" {line.rstrip()}") + sys.exit(1) + + if compile_result.birsim_outputs is None: + print("\nFAILED: BIRSim produced no outputs") + sys.exit(1) + + result = compile_result.birsim_outputs[0] + print(f"\nBIRSim output: shape={result.shape}, dtype={result.dtype}") + print(f" range: [{result.min():.4f}, {result.max():.4f}]") + print(f" mean: {result.mean():.4f}") + print("\nBIRSim PASSED") + + # --- NumPy reference comparison --- + mlir_dir = os.path.dirname(os.path.abspath(mlir_file)) + kernel_py = os.path.join(mlir_dir, "kernel.py") + if not os.path.exists(kernel_py): + print(f"\nNo kernel.py found in {mlir_dir}, skipping numpy comparison") + return + + print(f"\n--- Running numpy reference from kernel.py ---") + spec = importlib.util.spec_from_file_location("kernel", kernel_py) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + + ref_fn = getattr(mod, func_name, None) + if ref_fn is None: + print(f" Warning: function '{func_name}' not found in kernel.py, skipping") + return + + ref_output = ref_fn(*test_inputs) + diff = np.abs(result.astype(np.float64) - ref_output.astype(np.float64)) + max_diff = diff.max() + mean_diff = diff.mean() + match = np.allclose(result, ref_output, atol=1e-2, rtol=1e-2) + + print(f" Max difference: {max_diff:.2e}") + print(f" Mean difference: {mean_diff:.2e}") + print(f" Match: {match}") + + if match: + print("\nSIMULATION PASSED") + else: + print(f"\nSIMULATION FAILED (max_diff={max_diff:.2e})") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/kernelgen/tests/e2e/__init__.py b/kernelgen/tests/e2e/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kernelgen/tests/e2e/conftest.py b/kernelgen/tests/e2e/conftest.py new file mode 100644 index 0000000..7fd455a --- /dev/null +++ b/kernelgen/tests/e2e/conftest.py @@ -0,0 +1,10 @@ +""" +E2E test conftest.py. + +Auto-applies the 'e2e' marker to all tests in this directory. +Individual tests should add 'bir_sim' or 'device' markers as appropriate. +""" + +import pytest + +pytestmark = [pytest.mark.e2e] diff --git a/kernelgen/tests/e2e/nkipy_tests/__init__.py b/kernelgen/tests/e2e/nkipy_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kernelgen/tests/e2e/nkipy_tests/test_attention.py b/kernelgen/tests/e2e/nkipy_tests/test_attention.py new file mode 100644 index 0000000..524b407 --- /dev/null +++ b/kernelgen/tests/e2e/nkipy_tests/test_attention.py @@ -0,0 +1,141 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Ported from nkipy/tests/kernels/attention_dynamo.py + +Full attention layer (prefill) with rotary position embeddings and KV cache, +generated from torch dynamo graph. + +Operations: transpose, reshape, matmul, expand_dims, broadcast_to, multiply, + negative, concatenate, copy, dynamic indexing, softmax, divide. +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace +from harness import run_kernel_test, Mode + + +@pytest.mark.xfail(reason="np.copy, np.concatenate, dynamic index assignment, multiple unsupported ops") +def test_attention_prefill(): + seq_len = 7 + hidden = 256 + n_heads = 4 + n_kv_heads = 2 + head_dim = 64 + max_seq = 16 + + @trace(input_specs=[ + ((1, seq_len, hidden), "f32"), # input + ((hidden, hidden), "f32"), # q_weight + ((n_kv_heads * head_dim, hidden), "f32"), # k_weight + ((n_kv_heads * head_dim, hidden), "f32"), # v_weight + ((1, seq_len, head_dim), "f32"), # rotary_cos + ((1, seq_len, head_dim), "f32"), # rotary_sin + ((seq_len,), "i32"), # position_ids + ((1, n_kv_heads, max_seq, head_dim), "f32"), # k_cache + ((1, n_kv_heads, max_seq, head_dim), "f32"), # v_cache + ((1, 1, seq_len, max_seq), "f32"), # attn_mask + ((hidden, hidden), "f32"), # o_weight + ]) + def kernel(x, q_w, k_w, v_w, cos_emb, sin_emb, pos_ids, + k_cache, v_cache, attn_mask, o_w): + # Q projection + q_wt = np.transpose(q_w, [1, 0]) + x_2d = np.reshape(x, [seq_len, hidden]) + q = np.matmul(x_2d, q_wt) + q = np.reshape(q, [1, seq_len, hidden]) + q = np.reshape(q, [1, seq_len, -1, head_dim]) + q = np.transpose(q, [0, 2, 1, 3]) + + # K projection + k_wt = np.transpose(k_w, [1, 0]) + k = np.matmul(x_2d, k_wt) + k = np.reshape(k, [1, seq_len, n_kv_heads * head_dim]) + k = np.reshape(k, [1, seq_len, -1, head_dim]) + k = np.transpose(k, [0, 2, 1, 3]) + + # V projection + v_wt = np.transpose(v_w, [1, 0]) + v = np.matmul(x_2d, v_wt) + v = np.reshape(v, [1, seq_len, n_kv_heads * head_dim]) + v = np.reshape(v, [1, seq_len, -1, head_dim]) + v = np.transpose(v, [0, 2, 1, 3]) + + # Apply rotary embeddings to Q + cos_unsq = np.expand_dims(cos_emb, 1) + sin_unsq = np.expand_dims(sin_emb, 1) + q_rot = np.multiply(q, cos_unsq) + q_half1 = q[:, :, :, 0:head_dim // 2] + q_half2 = q[:, :, :, head_dim // 2:] + q_neg = np.negative(q_half2) + q_cat = np.concatenate([q_neg, q_half1], -1) + q_rot = np.add(q_rot, np.multiply(q_cat, sin_unsq)) + + # Apply rotary embeddings to K + k_rot = np.multiply(k, cos_unsq) + k_half1 = k[:, :, :, 0:head_dim // 2] + k_half2 = k[:, :, :, head_dim // 2:] + k_neg = np.negative(k_half2) + k_cat = np.concatenate([k_neg, k_half1], -1) + k_rot = np.add(k_rot, np.multiply(k_cat, sin_unsq)) + + # Update KV cache + new_k_cache = np.copy(k_cache) + new_k_cache[:, :, pos_ids] = k_rot + new_v_cache = np.copy(v_cache) + new_v_cache[:, :, pos_ids] = v + + # GQA expand + kv_slice = new_k_cache[0:, 0:] + kv_unsq = np.expand_dims(kv_slice, 2) + kv_exp = np.broadcast_to(kv_unsq[:, :, :, :, 0:], + [1, n_kv_heads, n_heads // n_kv_heads, max_seq, head_dim]) + k_full = np.reshape(np.copy(kv_exp), [1, n_heads, max_seq, head_dim]) + + v_slice = new_v_cache[0:, 0:] + v_unsq = np.expand_dims(v_slice, 2) + v_exp = np.broadcast_to(v_unsq[:, :, :, :, 0:], + [1, n_kv_heads, n_heads // n_kv_heads, max_seq, head_dim]) + v_full = np.reshape(np.copy(v_exp), [1, n_heads, max_seq, head_dim]) + + # Attention scores + k_t = np.transpose(k_full, [0, 1, 3, 2]) + q_3d = np.reshape(np.broadcast_to(q_rot, [1, n_heads, seq_len, head_dim]), + [n_heads, seq_len, head_dim]) + k_3d = np.reshape(np.broadcast_to(k_t, [1, n_heads, head_dim, max_seq]), + [n_heads, head_dim, max_seq]) + scores = np.matmul(q_3d, k_3d) + scores = np.reshape(scores, [1, n_heads, seq_len, max_seq]) + scores = np.multiply(scores, 1.0 / np.sqrt(head_dim).item()) + + # Add mask and softmax + scores = np.add(scores, attn_mask[0:, 0:, 0:]) + scores_max = np.max(scores, axis=-1, keepdims=True) + scores_exp = np.exp(np.subtract(scores, scores_max)) + attn_weights = np.divide(scores_exp, np.sum(scores_exp, axis=-1, keepdims=True)) + + # Weighted sum of values + aw_3d = np.reshape(np.broadcast_to(np.copy(attn_weights), + [1, n_heads, seq_len, max_seq]), + [n_heads, seq_len, max_seq]) + v_3d = np.reshape(np.broadcast_to(v_full, [1, n_heads, max_seq, head_dim]), + [n_heads, max_seq, head_dim]) + context = np.matmul(aw_3d, v_3d) + context = np.reshape(context, [1, n_heads, seq_len, head_dim]) + context = np.transpose(context, [0, 2, 1, 3]) + context = np.reshape(np.copy(context), [1, seq_len, -1]) + + # Output projection + o_wt = np.transpose(o_w, [1, 0]) + out_2d = np.reshape(context, [seq_len, hidden]) + out = np.matmul(out_2d, o_wt) + return np.reshape(out, [1, seq_len, hidden]) + + run_kernel_test( + kernel, + modes=Mode.BIR_SIM | Mode.HW, + rtol=1e-3, + atol=1e-3, + ) diff --git a/kernelgen/tests/e2e/nkipy_tests/test_binary_ops.py b/kernelgen/tests/e2e/nkipy_tests/test_binary_ops.py new file mode 100644 index 0000000..545d4c0 --- /dev/null +++ b/kernelgen/tests/e2e/nkipy_tests/test_binary_ops.py @@ -0,0 +1,181 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Targeted binary operation tests, inspired by nkipy/tests/unit/test_tensor_api.py. + +Tests each binary NumPy op through the full NKIPyKernelGen pipeline. +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace +from harness import run_kernel_test, Mode + +M, N = 128, 256 + + +# -- Currently supported binary ops -- + +def test_add(): + @trace(input_specs=[((M, N), "f32"), ((M, N), "f32")]) + def kernel(a, b): + return np.add(a, b) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_subtract(): + @trace(input_specs=[((M, N), "f32"), ((M, N), "f32")]) + def kernel(a, b): + return np.subtract(a, b) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_multiply(): + @trace(input_specs=[((M, N), "f32"), ((M, N), "f32")]) + def kernel(a, b): + return np.multiply(a, b) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_divide(): + @trace(input_specs=[((M, N), "f32"), ((M, N), "f32")]) + def kernel(a, b): + return np.divide(a, b) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW, rtol=1e-3, atol=1e-3) + + +def test_maximum(): + @trace(input_specs=[((M, N), "f32"), ((M, N), "f32")]) + def kernel(a, b): + return np.maximum(a, b) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_minimum(): + @trace(input_specs=[((M, N), "f32"), ((M, N), "f32")]) + def kernel(a, b): + return np.minimum(a, b) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +# -- Scalar-tensor arithmetic -- + +def test_add_scalar(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return x + 2.0 + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_scalar_subtract(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return 5.0 - x + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_multiply_scalar(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return x * 3.0 + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_scalar_divide(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return 1.0 / x + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW, rtol=1e-3, atol=1e-3) + + +# -- Binary ops that need to be added -- + +def test_power(): + @trace(input_specs=[((M, N), "f32"), ((M, N), "f32")]) + def kernel(a, b): + return np.power(a, b) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW, rtol=1e-2, atol=1e-2) + + +def test_power_scalar(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.power(x, 2) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +@pytest.mark.xfail(reason="np.floor_divide requires floor+divide decomposition") +def test_floor_divide(): + @trace(input_specs=[((M, N), "f32"), ((M, N), "f32")]) + def kernel(a, b): + return np.floor_divide(a, b) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +@pytest.mark.xfail(reason="NISA backend does not support tensor_tensor_arith(op=MOD)") +def test_mod(): + @trace(input_specs=[((M, N), "f32"), ((M, N), "f32")]) + def kernel(a, b): + return np.mod(a, b) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +# -- Comparison ops -- + +def test_greater_equal(): + @trace(input_specs=[((M, N), "f32"), ((M, N), "f32")]) + def kernel(a, b): + return np.greater_equal(a, b) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_less(): + @trace(input_specs=[((M, N), "f32"), ((M, N), "f32")]) + def kernel(a, b): + return np.less(a, b) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_equal(): + @trace(input_specs=[((M, N), "f32"), ((M, N), "f32")]) + def kernel(a, b): + return np.equal(a, b) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +# -- Broadcasting -- + +def test_broadcast_column(): + """a (M,N) + b (M,1) with broadcasting.""" + @trace(input_specs=[((M, N), "f32"), ((M, 1), "f32")]) + def kernel(a, b): + return np.add(a, b) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_broadcast_row(): + """a (M,N) + b (1,N) with broadcasting.""" + @trace(input_specs=[((M, N), "f32"), ((1, N), "f32")]) + def kernel(a, b): + return np.add(a, b) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) diff --git a/kernelgen/tests/e2e/nkipy_tests/test_composite_patterns.py b/kernelgen/tests/e2e/nkipy_tests/test_composite_patterns.py new file mode 100644 index 0000000..9c87f0e --- /dev/null +++ b/kernelgen/tests/e2e/nkipy_tests/test_composite_patterns.py @@ -0,0 +1,121 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Composite pattern tests combining multiple ops into common ML patterns, +inspired by nkipy/tests/unit/test_tensor_api.py. + +Tests softmax, layer norm, sigmoid, GELU, and other compound operations. +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace +from harness import run_kernel_test, Mode + +M, N = 128, 256 + + +def test_softmax_full(): + """Full softmax: exp(x - max(x)) / sum(exp(x - max(x))).""" + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + x_max = np.max(x, axis=-1, keepdims=True) + exp_x = np.exp(x - x_max) + sum_x = np.sum(exp_x, axis=-1, keepdims=True) + return exp_x / sum_x + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW, rtol=1e-3, atol=1e-3) + + +def test_sigmoid(): + """Sigmoid: 1 / (1 + exp(-x)).""" + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return 1.0 / (1.0 + np.exp(-x)) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW, rtol=1e-3, atol=1e-3) + + +def test_elementwise_chain(): + """Chain of elementwise ops: exp(x) * 2.0 + 1.0.""" + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.exp(x) * 2.0 + 1.0 + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_residual_add(): + """Residual connection: x + f(x) where f(x) = exp(x).""" + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return x + np.exp(x) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_rmsnorm_no_weight(): + """RMSNorm without weight: x * rsqrt(mean(x^2) + eps).""" + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + x_sq = x * x + mean_x = np.sum(x_sq, axis=-1, keepdims=True) / x_sq.shape[-1] + rsqrt = 1.0 / np.sqrt(mean_x + 1e-5) + return x * rsqrt + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW, rtol=1e-3, atol=1e-3) + + +@pytest.mark.xfail(reason="1D broadcast multiply (weight * x) fails in legalize-layout: linalg.mul requires matching operand ranks") +def test_rmsnorm(): + """RMSNorm: x * weight * rsqrt(mean(x^2) + eps).""" + @trace(input_specs=[((M, N), "f32"), ((N,), "f32")]) + def kernel(x, weight): + x_sq = x * x + mean_x = np.sum(x_sq, axis=-1, keepdims=True) / x_sq.shape[-1] + rsqrt = 1.0 / np.sqrt(mean_x + 1e-5) + return np.multiply(weight, np.multiply(x, rsqrt)) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW, rtol=1e-3, atol=1e-3) + + +def test_rmsnorm_prebroadcast(): + """RMSNorm with pre-broadcasted weight to (M, N). + + Workaround variant of test_rmsnorm: the compiler cannot yet handle + 1D row-broadcast multiply, so the weight is materialized as 2D. + """ + @trace(input_specs=[((M, N), "f32"), ((M, N), "f32")]) + def kernel(x, weight): + x_sq = x * x + mean_x = np.sum(x_sq, axis=-1, keepdims=True) / x_sq.shape[-1] + rsqrt = 1.0 / np.sqrt(mean_x + 1e-5) + return weight * (x * rsqrt) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW, rtol=1e-3, atol=1e-3) + + +def test_gelu_approx(): + """Approximate GELU: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))).""" + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + # Simplified: use sigmoid approximation instead + # GELU ≈ x * sigmoid(1.702 * x) + return x * (1.0 / (1.0 + np.exp(-1.702 * x))) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW, rtol=1e-3, atol=1e-3) + + +def test_scaled_dot_product(): + """Scaled dot-product pattern: (Q @ K^T) * scale.""" + M_dim, K_dim = 256, 256 + + @trace(input_specs=[((M_dim, K_dim), "f32"), ((M_dim, K_dim), "f32")]) + def kernel(q, k): + from nkipy_kernelgen import knob + scores = np.matmul(q, k) + knob.knob(scores, tile_size=[128, 128], reduction_tile=128) + return scores * (1.0 / np.sqrt(np.float32(K_dim))) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW, rtol=1e-3, atol=1e-3) diff --git a/kernelgen/tests/e2e/nkipy_tests/test_embedding.py b/kernelgen/tests/e2e/nkipy_tests/test_embedding.py new file mode 100644 index 0000000..69a9a1e --- /dev/null +++ b/kernelgen/tests/e2e/nkipy_tests/test_embedding.py @@ -0,0 +1,49 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Ported from nkipy/tests/kernels/embedding_dynamo.py + +Embedding lookup with boundary checking and masking. +Operations: greater_equal, less, bitwise_and, bitwise_or, logical_not, + multiply, add, subtract, take, expand_dims, where. +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace +from harness import run_kernel_test, Mode + + +def test_embedding_lookup(): + vocab_size = 256 + embed_dim = 128 + + @trace(input_specs=[ + ((1, 128), "i32"), # token indices + ((vocab_size, embed_dim), "f32"), # embedding table + ]) + def kernel(indices, embed_table): + ge = np.greater_equal(indices, 0) + lt = np.less(indices, vocab_size) + bitwise_and = np.bitwise_and(ge, lt) + ge_1 = np.greater_equal(indices, vocab_size * 2) + lt_1 = np.less(indices, vocab_size * 2) + bitwise_and_1 = np.bitwise_and(ge_1, lt_1) + mul = np.multiply(bitwise_and, 0, dtype=np.int32) + mul_1 = np.multiply(bitwise_and_1, vocab_size, dtype=np.int32) + add = np.add(mul, mul_1) + bitwise_or = np.bitwise_or(bitwise_and, bitwise_and_1) + sub = np.subtract(indices, add) + mul_2 = np.multiply(bitwise_or, sub) + bitwise_not = np.logical_not(bitwise_or) + embedding = np.take(embed_table, mul_2, axis=0) + unsqueeze = np.expand_dims(bitwise_not, -1) + scalar_tensor = np.float32(0.0) + where = np.where(unsqueeze, scalar_tensor, embedding) + return where + + run_kernel_test( + kernel, + modes=Mode.BIR_SIM | Mode.HW, + ) diff --git a/kernelgen/tests/e2e/nkipy_tests/test_indexing.py b/kernelgen/tests/e2e/nkipy_tests/test_indexing.py new file mode 100644 index 0000000..4847065 --- /dev/null +++ b/kernelgen/tests/e2e/nkipy_tests/test_indexing.py @@ -0,0 +1,28 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Ported from nkipy/tests/kernels/indexing.py + +Tensor slicing and addition: add slices [2:4,:] and [0:2,:] of input tensor. +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace +from harness import run_kernel_test, Mode + + +def test_indexed_add(): + M, N = 256, 128 + + @trace(input_specs=[((M, N), "f32")]) + def kernel(input_tensor): + a = input_tensor[128:256, :] + b = input_tensor[0:128, :] + return np.add(a, b) + + run_kernel_test( + kernel, + modes=Mode.BIR_SIM | Mode.HW, + ) diff --git a/kernelgen/tests/e2e/nkipy_tests/test_llama_decoder.py b/kernelgen/tests/e2e/nkipy_tests/test_llama_decoder.py new file mode 100644 index 0000000..31ac233 --- /dev/null +++ b/kernelgen/tests/e2e/nkipy_tests/test_llama_decoder.py @@ -0,0 +1,187 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Ported from nkipy/tests/kernels/llama_decoder_dynamo.py + +Complete Llama transformer decoder block with self-attention, +rotary position embeddings, and SwiGLU MLP. + +This is the most complex test kernel, combining: +- Embedding lookup (np.take) +- RMSNorm (power, mean, rsqrt) +- Rotary position embeddings (cos, sin, concatenate) +- Multi-head attention with GQA and KV cache +- SwiGLU MLP (sigmoid gating) +- Residual connections +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace +from harness import run_kernel_test, Mode + + +@pytest.mark.xfail(reason="np.copy, np.concatenate, np.power, np.cos, np.sin, " + "dynamic index assignment, int dtype inputs not yet supported") +def test_llama_decoder(): + seq_len = 7 + hidden = 256 + n_heads = 4 + n_kv_heads = 2 + head_dim = 64 + intermediate = 512 + max_seq = 16 + vocab_size = 1024 + + @trace(input_specs=[ + ((1, seq_len), "i32"), # token_ids + ((vocab_size, hidden), "f32"), # embed_table + ((seq_len,), "i32"), # position_indices + ((1, seq_len), "i32"), # rotary_positions + ((1, n_kv_heads, max_seq, head_dim), "f32"),# kv_cache + ((1, 1, seq_len, max_seq), "f32"), # attn_mask + ((head_dim // 2,), "f32"), # rotary_base + ((hidden,), "f32"), # ln1_weight + ((hidden, hidden), "f32"), # q_weight + ((n_kv_heads * head_dim, hidden), "f32"), # k_weight + ((n_kv_heads * head_dim, hidden), "f32"), # v_weight + ((1, n_kv_heads, max_seq, head_dim), "f32"),# v_cache + ((hidden, hidden), "f32"), # attn_out_weight + ((hidden,), "f32"), # ln2_weight + ((intermediate, hidden), "f32"), # mlp_up_weight + ((intermediate, hidden), "f32"), # mlp_gate_weight + ((hidden, intermediate), "f32"), # mlp_down_weight + ((hidden,), "f32"), # final_ln_weight + ]) + def kernel(token_ids, embed_table, pos_indices, rotary_pos, + k_cache, attn_mask, rotary_base, ln1_w, + q_w, k_w, v_w, v_cache, attn_out_w, + ln2_w, mlp_up_w, mlp_gate_w, mlp_down_w, final_ln_w): + # Embedding + embedding = np.take(embed_table, token_ids, axis=0) + + # RoPE frequency computation + unsqueeze = np.expand_dims(rotary_base, 0) + unsqueeze_1 = np.expand_dims(unsqueeze[:, 0:], 2) + expand = np.broadcast_to(unsqueeze_1, [1, head_dim // 2, 1]) + slice_2 = rotary_pos[0:] + unsqueeze_2 = np.expand_dims(slice_2, 1) + to_float = unsqueeze_2[:, :, 0:].astype(np.float32) + view = np.reshape(np.broadcast_to(expand, [1, head_dim // 2, 1]), [1, head_dim // 2, 1]) + view_1 = np.reshape(np.broadcast_to(to_float, [1, 1, seq_len]), [1, 1, seq_len]) + bmm = np.matmul(view, view_1) + permute = np.transpose(np.reshape(bmm, [1, head_dim // 2, seq_len]), [0, 2, 1]) + cat = np.concatenate([permute, permute], -1) + cos = np.multiply(np.cos(cat), 1.0) + sin = np.multiply(np.sin(cat), 1.0) + + # RMSNorm 1 + pow_1 = np.power(embedding, 2) + mean = np.divide(np.sum(pow_1, axis=(-1,), keepdims=True), pow_1.shape[-1]) + rsqrt = np.divide(1, np.sqrt(np.add(mean, 1e-05))) + normed = np.multiply(ln1_w, np.multiply(embedding, rsqrt)) + + # Q/K/V projections + normed_2d = np.reshape(normed, [seq_len, hidden]) + q = np.reshape(np.reshape(np.matmul(normed_2d, np.transpose(q_w, [1, 0])), + [1, seq_len, hidden]), [1, seq_len, -1, head_dim]) + q = np.transpose(q, [0, 2, 1, 3]) + k = np.transpose(np.reshape(np.reshape(np.matmul(normed_2d, np.transpose(k_w, [1, 0])), + [1, seq_len, n_kv_heads * head_dim]), [1, seq_len, -1, head_dim]), [0, 2, 1, 3]) + v = np.transpose(np.reshape(np.reshape(np.matmul(normed_2d, np.transpose(v_w, [1, 0])), + [1, seq_len, n_kv_heads * head_dim]), [1, seq_len, -1, head_dim]), [0, 2, 1, 3]) + + # Apply RoPE + cos_unsq = np.expand_dims(cos, 1) + sin_unsq = np.expand_dims(sin, 1) + half = head_dim // 2 + q_rot = np.add(np.multiply(q, cos_unsq), + np.multiply(np.concatenate([np.negative(q[:, :, :, half:]), + q[:, :, :, 0:half]], -1), sin_unsq)) + k_rot = np.add(np.multiply(k, cos_unsq), + np.multiply(np.concatenate([np.negative(k[:, :, :, half:]), + k[:, :, :, 0:half]], -1), sin_unsq)) + + # Update KV cache + new_k = np.copy(k_cache) + new_k[:, :, pos_indices] = k_rot + new_v = np.copy(v_cache) + new_v[:, :, pos_indices] = v + + # GQA expand keys + k_exp = np.reshape(np.copy(np.broadcast_to( + np.expand_dims(new_k[0:, 0:], 2)[:, :, :, 0:, 0:], + [1, n_kv_heads, n_heads // n_kv_heads, max_seq, head_dim])), + [1, n_heads, max_seq, head_dim]) + v_exp = np.reshape(np.copy(np.broadcast_to( + np.expand_dims(new_v[0:, 0:], 2)[:, :, :, 0:, 0:], + [1, n_kv_heads, n_heads // n_kv_heads, max_seq, head_dim])), + [1, n_heads, max_seq, head_dim]) + + # Attention + k_t = np.transpose(k_exp, [0, 1, 3, 2]) + q_3d = np.reshape(np.broadcast_to(q_rot, [1, n_heads, seq_len, head_dim]), + [n_heads, seq_len, head_dim]) + k_3d = np.reshape(np.broadcast_to(k_t, [1, n_heads, head_dim, max_seq]), + [n_heads, head_dim, max_seq]) + scores = np.multiply(np.reshape(np.matmul(q_3d, k_3d), + [1, n_heads, seq_len, max_seq]), 0.125) + scores = np.add(scores, attn_mask[0:, 0:, 0:]) + scores_max = np.max(scores, axis=-1, keepdims=True) + softmax = np.divide(np.exp(np.subtract(scores, scores_max)), + np.sum(np.exp(np.subtract(scores, scores_max)), axis=-1, keepdims=True)) + attn_out = np.matmul( + np.reshape(np.broadcast_to(np.copy(softmax), [1, n_heads, seq_len, max_seq]), + [n_heads, seq_len, max_seq]), + np.reshape(np.broadcast_to(v_exp, [1, n_heads, max_seq, head_dim]), + [n_heads, max_seq, head_dim])) + attn_out = np.reshape(np.copy(np.transpose( + np.reshape(attn_out, [1, n_heads, seq_len, head_dim]), [0, 2, 1, 3])), + [1, seq_len, -1]) + + # Output projection + out = np.reshape(np.matmul(np.reshape(attn_out, [seq_len, hidden]), + np.transpose(attn_out_w, [1, 0])), [1, seq_len, hidden]) + + # Residual + residual = np.add(embedding, out) + + # RMSNorm 2 + pow_2 = np.power(residual, 2) + mean_2 = np.divide(np.sum(pow_2, axis=(-1,), keepdims=True), pow_2.shape[-1]) + rsqrt_2 = np.divide(1, np.sqrt(np.add(mean_2, 1e-05))) + normed_2 = np.multiply(ln2_w, np.multiply(residual, rsqrt_2)) + + # SwiGLU MLP + normed_2_2d = np.reshape(normed_2, [seq_len, hidden]) + gate = np.reshape(np.matmul(normed_2_2d, np.transpose(mlp_up_w, [1, 0])), + [1, seq_len, intermediate]) + gate_sig = np.divide(1, np.add(1, np.exp(np.negative(gate)))) + gate_out = np.multiply(gate, gate_sig) + up = np.reshape(np.matmul(normed_2_2d, np.transpose(mlp_gate_w, [1, 0])), + [1, seq_len, intermediate]) + mlp_mid = np.multiply(gate_out, up) + mlp_out = np.reshape(np.matmul(np.reshape(mlp_mid, [seq_len, intermediate]), + np.transpose(mlp_down_w, [1, 0])), [1, seq_len, hidden]) + + # Residual + final norm + final_residual = np.add(residual, mlp_out) + pow_3 = np.power(final_residual, 2) + mean_3 = np.divide(np.sum(pow_3, axis=(-1,), keepdims=True), pow_3.shape[-1]) + rsqrt_3 = np.divide(1, np.sqrt(np.add(mean_3, 1e-05))) + final_normed = np.multiply(final_ln_w, np.multiply(final_residual, rsqrt_3)) + + # LM head + lm_out = np.reshape( + np.matmul(np.reshape(final_normed[:, final_normed.shape[1] - 1:, 0:], [1, hidden]), + np.transpose(embed_table, [1, 0])), + [1, 1, vocab_size]) + return lm_out + + run_kernel_test( + kernel, + modes=Mode.BIR_SIM | Mode.HW, + rtol=1e-3, + atol=1e-3, + ) diff --git a/kernelgen/tests/e2e/nkipy_tests/test_matmul_shapes.py b/kernelgen/tests/e2e/nkipy_tests/test_matmul_shapes.py new file mode 100644 index 0000000..773c26f --- /dev/null +++ b/kernelgen/tests/e2e/nkipy_tests/test_matmul_shapes.py @@ -0,0 +1,106 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Targeted matmul tests with various shapes, inspired by +nkipy/tests/unit/test_tensor_api_native_ops.py and test_tensor_api.py. + +Tests matmul with different M/N/K dimensions and batch dimensions. +""" + +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + + +def test_matmul_square(): + """Square matmul: (256,256) @ (256,256).""" + M = N = K = 256 + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32")]) + def kernel(a, b): + result = np.matmul(a, b) + knob.knob(result, tile_size=[128, 128], reduction_tile=128) + return result + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW, rtol=1e-3, atol=1e-3) + + +def test_matmul_rectangular(): + """Rectangular matmul: (128,256) @ (256,512).""" + M, K, N = 128, 256, 512 + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32")]) + def kernel(a, b): + result = np.matmul(a, b) + knob.knob(result, tile_size=[128, 128], reduction_tile=128) + return result + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW, rtol=1e-3, atol=1e-3) + + +def test_matmul_tall(): + """Tall output: (512,128) @ (128,128).""" + M, K, N = 512, 128, 128 + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32")]) + def kernel(a, b): + result = np.matmul(a, b) + knob.knob(result, tile_size=[128, 128], reduction_tile=128) + return result + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW, rtol=1e-3, atol=1e-3) + + +def test_matmul_wide(): + """Wide output: (128,128) @ (128,512).""" + M, K, N = 128, 128, 512 + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32")]) + def kernel(a, b): + result = np.matmul(a, b) + knob.knob(result, tile_size=[128, 128], reduction_tile=128) + return result + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW, rtol=1e-3, atol=1e-3) + + +def test_batch_matmul(): + """Batched matmul: (4, 128, 256) @ (4, 256, 128).""" + B, M, K, N = 4, 128, 256, 128 + + @trace(input_specs=[((B, M, K), "f32"), ((B, K, N), "f32")]) + def kernel(a, b): + return np.matmul(a, b) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW, rtol=1e-3, atol=1e-3) + + +def test_matmul_add(): + """Matmul followed by bias add: C = A @ B + bias.""" + M, K, N = 256, 256, 256 + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32"), ((M, N), "f32")]) + def kernel(a, b, bias): + result = np.matmul(a, b) + knob.knob(result, tile_size=[128, 128], reduction_tile=128) + out = np.add(result, bias) + knob.knob(out, mem_space="SharedHbm", tile_size=[128, 128]) + return out + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW, rtol=1e-3, atol=1e-3) + + +def test_matmul_chain(): + """Two matmuls chained: D = (A @ B) @ C.""" + M, K1, K2, N = 256, 256, 256, 256 + + @trace(input_specs=[((M, K1), "f32"), ((K1, K2), "f32"), ((K2, N), "f32")]) + def kernel(a, b, c): + ab = np.matmul(a, b) + knob.knob(ab, tile_size=[128, 128], reduction_tile=128, mem_space="Sbuf") + result = np.matmul(ab, c) + knob.knob(result, tile_size=[128, 128], reduction_tile=128) + return result + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW, rtol=1e-3, atol=1e-3) diff --git a/kernelgen/tests/e2e/nkipy_tests/test_mlp.py b/kernelgen/tests/e2e/nkipy_tests/test_mlp.py new file mode 100644 index 0000000..0d8eff2 --- /dev/null +++ b/kernelgen/tests/e2e/nkipy_tests/test_mlp.py @@ -0,0 +1,119 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +MLP with gated activation (SwiGLU). + +test_mlp_swiglu: compiler-friendly version (passes) +test_mlp_swiglu_original: original dynamo-traced version (xfail, see comments) +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + + +def test_mlp_swiglu(): + batch = 256 + hidden = 256 + intermediate = 256 + + matmul_tile = [128, 128] + matmul_reduction = [128] + elementwise_tile = [128, 128] + + @trace(input_specs=[ + ((batch, hidden), "f32"), # x + ((hidden, 2 * intermediate), "f32"), # gate_up_weight + ((intermediate, hidden), "f32"), # down_weight + ]) + def kernel(x, gate_up_weight, down_weight): + # Combined gate+up projection + mm_gup = np.matmul(x, gate_up_weight) + knob.knob(mm_gup, mem_space="Sbuf", tile_size=matmul_tile, reduction_tile=matmul_reduction) + + # Split into gate and up + split_axis = mm_gup.ndim - 1 + gate, up = np.split(mm_gup, 2, axis=split_axis) + + # SiLU(gate) = gate * sigmoid(gate) + neg_gate = -gate + knob.knob(neg_gate, mem_space="Sbuf", tile_size=elementwise_tile) + + exp_neg = np.exp(neg_gate) + knob.knob(exp_neg, mem_space="Sbuf", tile_size=elementwise_tile) + + one_plus_exp = exp_neg + 1.0 + knob.knob(one_plus_exp, mem_space="Sbuf", tile_size=elementwise_tile) + + sigmoid = 1.0 / one_plus_exp + knob.knob(sigmoid, mem_space="Sbuf", tile_size=elementwise_tile) + + swish_gate = gate * sigmoid + knob.knob(swish_gate, mem_space="Sbuf", tile_size=elementwise_tile) + + # Gating + gated = swish_gate * up + knob.knob(gated, mem_space="Sbuf", tile_size=elementwise_tile) + + # Down projection + output = np.matmul(gated, down_weight) + knob.knob(output, mem_space="SharedHbm", tile_size=matmul_tile, reduction_tile=matmul_reduction) + + return output + + run_kernel_test( + kernel, + modes=Mode.BIR_SIM | Mode.HW, + rtol=1e-3, + atol=1e-3, + ) + + + +@pytest.mark.xfail(reason="1D input + 3D intermediates: DMA partition mismatch and expand_shape legalization") +def test_mlp_swiglu_original(): + """Original dynamo-traced SwiGLU — 1D input, separate matmuls, 3D reshapes. + + Remaining issues after AnnotateMemorySpace resolveViewConflicts fix: + - 1D input reshape(x, [1, hidden]) creates partition dim = 1, but DMA + requires partition >= 128 for HBM->SBUF transpose + - 3D expand_shape of SBUF allocs can't be legalized consistently + (expand_shape assumes contiguous layout, legalization interleaves) + """ + hidden = 256 + intermediate = 256 + + @trace(input_specs=[ + ((intermediate, hidden), "f32"), # gate weight + ((hidden,), "f32"), # input vector + ((intermediate, hidden), "f32"), # up weight + ((hidden, intermediate), "f32"), # down weight + ]) + def kernel(gate_w, x, up_w, down_w): + gate_wt = np.transpose(gate_w, [1, 0]) + view = np.reshape(x, [1, hidden]) + mm = np.matmul(view, gate_wt) + view_1 = np.reshape(mm, [1, 1, intermediate]) + sigmoid = 1 / (1 + np.exp(-view_1)) + mul = np.multiply(view_1, sigmoid) + + up_wt = np.transpose(up_w, [1, 0]) + view_2 = np.reshape(x, [1, hidden]) + mm_1 = np.matmul(view_2, up_wt) + view_3 = np.reshape(mm_1, [1, 1, intermediate]) + mul_1 = np.multiply(mul, view_3) + + down_wt = np.transpose(down_w, [1, 0]) + view_4 = np.reshape(mul_1, [1, intermediate]) + mm_2 = np.matmul(view_4, down_wt) + view_5 = np.reshape(mm_2, [1, 1, hidden]) + return view_5 + + run_kernel_test( + kernel, + modes=Mode.BIR_SIM | Mode.HW, + rtol=1e-3, + atol=1e-3, + ) diff --git a/kernelgen/tests/e2e/nkipy_tests/test_reductions.py b/kernelgen/tests/e2e/nkipy_tests/test_reductions.py new file mode 100644 index 0000000..940ffb6 --- /dev/null +++ b/kernelgen/tests/e2e/nkipy_tests/test_reductions.py @@ -0,0 +1,106 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Targeted reduction operation tests, inspired by nkipy/tests/unit/test_tensor_api.py. + +Tests reduction ops (sum, mean, max, min) with various axis configurations. +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace +from harness import run_kernel_test, Mode + +M, N = 128, 256 + + +# -- Currently supported reductions -- + +def test_sum_axis_last(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.sum(x, axis=-1, keepdims=True) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_mean_axis_last(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.mean(x, axis=-1, keepdims=True) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW, rtol=1e-3, atol=1e-3) + + +def test_max_axis_last(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.max(x, axis=-1, keepdims=True) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +# -- Reductions used in common patterns -- + +def test_sum_subtract_pattern(): + """sum along axis then broadcast-subtract (log-softmax style).""" + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + s = np.sum(x, axis=-1, keepdims=True) + return x - s + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_max_normalize_pattern(): + """max along axis then subtract (numerical stability for softmax).""" + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + m = np.max(x, axis=-1, keepdims=True) + return x - m + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +# -- Reductions that need to be added -- + +def test_min_axis_last(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.min(x, axis=-1, keepdims=True) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_sum_axis_first(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.sum(x, axis=0, keepdims=True) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_prod_axis_last(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.prod(x, axis=-1, keepdims=True) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +@pytest.mark.xfail(reason="np.argmax not yet supported in tracer") +def test_argmax(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.argmax(x, axis=-1, keepdims=True) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_std(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.std(x, axis=-1, keepdims=True) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW, rtol=1e-3, atol=1e-3) diff --git a/kernelgen/tests/e2e/nkipy_tests/test_rope.py b/kernelgen/tests/e2e/nkipy_tests/test_rope.py new file mode 100644 index 0000000..f04d863 --- /dev/null +++ b/kernelgen/tests/e2e/nkipy_tests/test_rope.py @@ -0,0 +1,54 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Ported from nkipy/tests/kernels/rope_dynamo.py + +Rotary Position Embedding (RoPE) generation kernel from torch dynamo graph. +Operations: expand_dims, broadcast_to, reshape, matmul, transpose, concatenate, cos, sin, multiply. +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace +from harness import run_kernel_test, Mode + + +@pytest.mark.xfail(reason="np.concatenate, np.cos, np.sin, int input dtype not yet supported") +def test_rope(): + @trace(input_specs=[ + ((32,), "f32"), # rotary embedding base + ((1, 7), "i32"), # position ids + ]) + def kernel(freq_base, pos_ids): + unsqueeze = np.expand_dims(freq_base, 0) + slice_1 = unsqueeze[:, 0:] + unsqueeze_1 = np.expand_dims(slice_1, 2) + expand = np.broadcast_to(unsqueeze_1, [1, 32, 1]) + + slice_2 = pos_ids[0:] + unsqueeze_2 = np.expand_dims(slice_2, 1) + slice_3 = unsqueeze_2[:, :, 0:] + to_float = slice_3.astype(np.float32) + + expand_1 = np.broadcast_to(expand, [1, 32, 1]) + view = np.reshape(expand_1, [1, 32, 1]) + expand_2 = np.broadcast_to(to_float, [1, 1, 7]) + view_1 = np.reshape(expand_2, [1, 1, 7]) + bmm = np.matmul(view, view_1) + view_2 = np.reshape(bmm, [1, 32, 7]) + permute = np.transpose(view_2, [0, 2, 1]) + + cat = np.concatenate([permute, permute], -1) + cos = np.cos(cat) + sin = np.sin(cat) + + mul = np.multiply(cos, 1.0) + mul_1 = np.multiply(sin, 1.0) + # Original returns tuple; return first output for testing + return mul + + run_kernel_test( + kernel, + modes=Mode.BIR_SIM | Mode.HW, + ) diff --git a/kernelgen/tests/e2e/nkipy_tests/test_simple_add.py b/kernelgen/tests/e2e/nkipy_tests/test_simple_add.py new file mode 100644 index 0000000..3d5ca88 --- /dev/null +++ b/kernelgen/tests/e2e/nkipy_tests/test_simple_add.py @@ -0,0 +1,25 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Ported from nkipy/tests/kernels/simple.py + +Simple tensor addition kernel. +""" + +import numpy as np + +from nkipy_kernelgen import trace +from harness import run_kernel_test, Mode + + +def test_simple_add(): + M, N = 128, 256 + + @trace(input_specs=[((M, N), "f32"), ((M, N), "f32")]) + def kernel(a, b): + return np.add(a, b) + + run_kernel_test( + kernel, + modes=Mode.BIR_SIM | Mode.HW, + ) diff --git a/kernelgen/tests/e2e/nkipy_tests/test_softmax.py b/kernelgen/tests/e2e/nkipy_tests/test_softmax.py new file mode 100644 index 0000000..cd0a900 --- /dev/null +++ b/kernelgen/tests/e2e/nkipy_tests/test_softmax.py @@ -0,0 +1,29 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Ported from nkipy/tests/kernels/softmax.py + +Softmax kernel: exp(x - max(x)) / sum(exp(x - max(x))) +""" + +import numpy as np + +from nkipy_kernelgen import trace +from harness import run_kernel_test, Mode + + +def test_softmax(): + M, N = 128, 256 + + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True)) + sum_x = np.sum(exp_x, axis=-1, keepdims=True) + return exp_x / sum_x + + run_kernel_test( + kernel, + modes=Mode.BIR_SIM | Mode.HW, + rtol=1e-3, + atol=1e-3, + ) diff --git a/kernelgen/tests/e2e/nkipy_tests/test_tensor_manipulation.py b/kernelgen/tests/e2e/nkipy_tests/test_tensor_manipulation.py new file mode 100644 index 0000000..3c99ad6 --- /dev/null +++ b/kernelgen/tests/e2e/nkipy_tests/test_tensor_manipulation.py @@ -0,0 +1,178 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Targeted tensor manipulation tests, inspired by nkipy/tests/unit/test_tensor_api.py +and test_tensor_api_native_ops.py. + +Tests reshape, transpose, expand_dims, broadcast_to, concatenate. +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace +from nkipy_kernelgen.knob import knob +from harness import run_kernel_test, Mode + +M, N = 128, 256 + + +# -- Reshape (Category 1: contiguous dim merge/split) -- + +def test_reshape_merge_dims(): + """Merge first two dims: (2, 128, 256) -> (256, 256).""" + @trace(input_specs=[((2, M, N), "f32")]) + def kernel(x): + return np.reshape(x, (2 * M, N)) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_reshape_split_dim(): + """Split first dim: (256, 256) -> (2, 128, 256).""" + @trace(input_specs=[((2 * M, N), "f32")]) + def kernel(x): + return np.reshape(x, (2, M, N)) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_reshape_insert_unit_dim(): + """Insert unit dim: (128, 256) -> (128, 1, 256).""" + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.reshape(x, (M, 1, N)) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_reshape_remove_unit_dim(): + """Remove unit dim (squeeze): (128, 1, 256) -> (128, 256).""" + @trace(input_specs=[((M, 1, N), "f32")]) + def kernel(x): + return np.reshape(x, (M, N)) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_reshape_infer_dim_category1(): + """Reshape with -1 inferred dimension (Category 1 merge).""" + @trace(input_specs=[((2, M, N), "f32")]) + def kernel(x): + return np.reshape(x, (-1, N)) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_reshape_identity(): + """Identity reshape: same shape, should be a no-op.""" + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.reshape(x, (M, N)) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +# -- Reshape (Category 2: non-contiguous -- not supported) -- + +def test_reshape_2d(): + """Category 2 reshape (non-contiguous): (128, 256) -> (256, 128).""" + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.reshape(x, (M * N // 128, 128)) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_reshape_infer_dim(): + """Reshape with -1 inferred dimension (Category 2).""" + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.reshape(x, (-1, 128)) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +# -- Transpose -- + +def test_transpose_2d(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + # Output shape: (N, M) = (256, 128) + result = np.transpose(x, [1, 0]) + return knob(result, tile_size=[128, 128]) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_transpose_3d(): + @trace(input_specs=[((2, M, N), "f32")]) + def kernel(x): + # Output shape: (2, N, M) = (2, 256, 128) + # Batch dim tiled to 1, inner dims tiled to 128 + result = np.transpose(x, [0, 2, 1]) + return knob(result, tile_size=[1, 128, 128]) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +# -- Expand dims -- + +def test_expand_dims_last(): + @trace(input_specs=[((M,), "f32")]) + def kernel(x): + return np.expand_dims(x, axis=-1) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_expand_dims_first(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.expand_dims(x, axis=0) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +# -- Broadcast -- + +def test_broadcast_to(): + @trace(input_specs=[((1, N), "f32")]) + def kernel(x): + return np.broadcast_to(x, (M, N)) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +# -- Concatenate -- + +def test_concatenate_axis_last(): + @trace(input_specs=[((M, N), "f32"), ((M, N), "f32")]) + def kernel(a, b): + return np.concatenate([a, b], axis=-1) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_concatenate_axis_first(): + @trace(input_specs=[((M, N), "f32"), ((M, N), "f32")]) + def kernel(a, b): + return np.concatenate([a, b], axis=0) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +# -- Copy -- + +def test_copy(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.copy(x) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +# -- Test Runner -- + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/e2e/nkipy_tests/test_unary_ops.py b/kernelgen/tests/e2e/nkipy_tests/test_unary_ops.py new file mode 100644 index 0000000..232fdd6 --- /dev/null +++ b/kernelgen/tests/e2e/nkipy_tests/test_unary_ops.py @@ -0,0 +1,123 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Targeted unary operation tests, inspired by nkipy/tests/unit/test_tensor_api.py. + +Tests each unary NumPy op through the full NKIPyKernelGen pipeline. +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace +from harness import run_kernel_test, Mode + +M, N = 128, 256 +TILE = [128, 128] + + +def test_exp(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.exp(x) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_sqrt(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + # sqrt needs positive inputs; tracer uses random [0,1) by default + return np.sqrt(x) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_negative(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.negative(x) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_reciprocal(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.reciprocal(x) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW, rtol=1e-3, atol=1e-3) + + +def test_abs(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.abs(x) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_log(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.log(x) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_sin(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.sin(x) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_cos(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.cos(x) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_tanh(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.tanh(x) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +@pytest.mark.xfail(reason="No NISA activation for ceil — no hardware support") +def test_ceil(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.ceil(x) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +@pytest.mark.xfail(reason="No NISA activation for floor — no hardware support") +def test_floor(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.floor(x) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_square(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.square(x) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) + + +def test_sign(): + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.sign(x) + + run_kernel_test(kernel, modes=Mode.BIR_SIM | Mode.HW) diff --git a/kernelgen/tests/e2e/test_3d_elementwise.py b/kernelgen/tests/e2e/test_3d_elementwise.py new file mode 100644 index 0000000..0abf0f1 --- /dev/null +++ b/kernelgen/tests/e2e/test_3d_elementwise.py @@ -0,0 +1,90 @@ +""" +End-to-end tests for 3D elementwise operations. + +These tests exercise the generalized rank-R pipeline with 3D tensors: +1. Trace 3D Python ops to MLIR +2. Tile with 3D tile_size (middle dims must be 1) +3. Legalize layout: 3D -> 5D physical -> 2D collapse +4. Lower to NISA dialect +5. Simulate and validate against NumPy + +Run with: pytest tests/e2e/test_3d_elementwise.py -v +""" + +import pytest + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + + +# ============================================================================ +# Test Cases +# ============================================================================ + +def test_3d_add_chain(): + """ + Test 3D add chain: result = (a + b) + c + + Shape: (256, 2, 256) with tile [128, 1, 128] + - Intermediate (a+b) stored in SBUF (triggers 5D physical layout) + - Result stored in SharedHbm + + This exercises: + - 3D -> 5D SBUF legalization: [128, 2, 2, 2, 128] + - 3-level tiled copy loops (HBM <-> SBUF) + - Collapse to 2D for NISA compute ops + - Named op reconstruction (linalg.add with 2D operands) + """ + B, M, N = 256, 2, 256 + tile_size = [128, 1, 128] + + @trace(input_specs=[((B, M, N), "f32"), ((B, M, N), "f32"), ((B, M, N), "f32")]) + def add_chain_3d(a, b, c): + intermediate = a + b + knob.knob(intermediate, mem_space="Sbuf", tile_size=tile_size) + + result = intermediate + c + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size) + + return result + + run_kernel_test( + add_chain_3d, + check_ir_contains=[ + "nisa.alloc", "nisa.tensor_tensor_arith", "nisa.target", + ], + check_ir_not_contains=["transform.named_sequence"], + modes=Mode.BIR_SIM | Mode.STRING_CHECK | Mode.HW, + ) + + +def test_3d_add_hbm_only(): + """ + Test 3D add with HBM intermediate (no SBUF legalization needed). + + Shape: (256, 2, 256) with tile [128, 1, 128] + - Both intermediate and result go to SharedHbm + - No SBUF layout transformation — serves as a simpler 3D baseline + """ + B, M, N = 256, 2, 256 + tile_size = [128, 1, 128] + + @trace(input_specs=[((B, M, N), "f32"), ((B, M, N), "f32"), ((B, M, N), "f32")]) + def add_chain_3d_hbm(a, b, c): + intermediate = a + b + knob.knob(intermediate, mem_space="SharedHbm", tile_size=tile_size) + + result = intermediate + c + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size) + + return result + + run_kernel_test( + add_chain_3d_hbm, + check_ir_contains=["nisa.alloc", "nisa.target"], + modes=Mode.BIR_SIM | Mode.STRING_CHECK | Mode.HW, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/e2e/test_attention.py b/kernelgen/tests/e2e/test_attention.py new file mode 100644 index 0000000..2f7d739 --- /dev/null +++ b/kernelgen/tests/e2e/test_attention.py @@ -0,0 +1,274 @@ +""" +End-to-end tests for attention kernels. + +Covers the core attention building blocks used in Qwen3: +1. Softmax: exp(x - max(x)) / sum(exp(x - max(x))) +2. QKV projection: matmul + split into Q, K, V +3. Attention scores: (Q @ K^T) / sqrt(d) -> softmax + +Run with: pytest tests/e2e/test_attention.py -v +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from nkipy_kernelgen.apis import fori_loop +from harness import run_kernel_test, Mode + + +# ============================================================================ +# Softmax (standalone, for easier isolated testing) +# ============================================================================ + + +@pytest.mark.parametrize( + "M, N, tile_size", + [ + (128, 128, [128, 128]), + (128, 256, [128, 128]), + (256, 256, [128, 128]), + ], +) +def test_softmax(M, N, tile_size): + """ + Test softmax in isolation: exp(x - max(x)) / sum(exp(x - max(x))). + + Each intermediate step is annotated with a knob for tiling control. + """ + + @trace(input_specs=[((M, N), "f32")]) + def softmax_kernel(x): + x_fp32 = x.astype(np.float32) + + x_max = np.max(x_fp32, axis=-1, keepdims=True) + knob.knob(x_max, mem_space="Sbuf", tile_size=[128], reduction_tile=[128]) + + shifted = x_fp32 - x_max + knob.knob(shifted, mem_space="Sbuf", tile_size=tile_size) + + exp_x = np.exp(shifted) + knob.knob(exp_x, mem_space="Sbuf", tile_size=tile_size) + + sum_exp = np.sum(exp_x, axis=-1, keepdims=True) + knob.knob(sum_exp, mem_space="Sbuf", tile_size=[128], reduction_tile=[128]) + + result = exp_x / sum_exp + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size) + return result + + run_kernel_test( + softmax_kernel, + check_ir_contains=["nisa.activation", "op=exp"], + check_ir_not_contains=["transform.named_sequence"], + rtol=1e-3, + atol=1e-3, + modes=Mode.BIR_SIM | Mode.STRING_CHECK | Mode.HW, + ) + + +# ============================================================================ +# QKV Projection +# ============================================================================ + + +@pytest.mark.parametrize( + "M, hidden_size, matmul_tile, reduction_tile, elementwise_tile", + [ + (256, 256, [128, 128], [128], [128, 128]), + ], +) +def test_qkv_projection(M, hidden_size, matmul_tile, reduction_tile, elementwise_tile): + """ + Test QKV projection: x @ weight -> split into Q, K, V. + + Returns all three outputs (Q, K, V) using multi-output support. + The pipeline auto-inserts DMA copies for SBUF→HBM on return values. + """ + + @trace( + input_specs=[ + ((M, hidden_size), "f32"), + ((hidden_size, hidden_size * 3), "f32"), + ] + ) + def qkv_kernel(x, weight): + qkv = np.matmul(x, weight) + knob.knob( + qkv, mem_space="Sbuf", tile_size=matmul_tile, reduction_tile=reduction_tile + ) + + q, k, v = np.split(qkv, 3, axis=-1) + return q, k, v + + run_kernel_test( + qkv_kernel, + check_ir_contains=["nisa.alloc", "nisa.matmul", "nisa.target"], + check_ir_not_contains=["transform.named_sequence"], + modes=Mode.BIR_SIM | Mode.STRING_CHECK | Mode.HW, + ) + + +# ============================================================================ +# Attention Scores with fori_loop (batched) +# ============================================================================ + + +@pytest.mark.parametrize( + "batch, n_heads, seq_len, head_dim, tile_size", + [ + (2, 4, 256, 256, [128, 128]), + ], +) +def test_attention_scores_loop(batch, n_heads, seq_len, head_dim, tile_size): + """ + Test batched attention scores with fori_loop: softmax((Q @ K^T) / sqrt(d)). + + Uses fori_loop to iterate over batch * n_heads, writing each result slice + back to the output via the eliminate-same-memspace-copy HBM intermediate + elimination pattern. + """ + scale = 1.0 / np.sqrt(head_dim).item() + + @trace( + input_specs=[ + ((batch * n_heads, seq_len, head_dim), "f32"), + ((batch * n_heads, head_dim, seq_len), "f32"), + ] + ) + def attention_kernel_loop(q, k_transposed): + init_result = np.empty((batch * n_heads, seq_len, seq_len), dtype=np.float32) + + def body(i, acc): + q_i = q[i] + k_i = k_transposed[i] + + scores = np.matmul(q_i, k_i) * scale + knob.knob( + scores, mem_space="Sbuf", tile_size=tile_size, reduction_tile=[128] + ) + + scores_fp32 = scores.astype(np.float32) + + scores_max = np.max(scores_fp32, axis=-1, keepdims=True) + knob.knob( + scores_max, mem_space="Sbuf", tile_size=[128], reduction_tile=[128] + ) + + shifted = scores_fp32 - scores_max + knob.knob(shifted, mem_space="Sbuf", tile_size=tile_size) + + exp_s = np.exp(shifted) + knob.knob(exp_s, mem_space="Sbuf", tile_size=tile_size) + + sum_exp = np.sum(exp_s, axis=-1, keepdims=True) + knob.knob(sum_exp, mem_space="Sbuf", tile_size=[128], reduction_tile=[128]) + + softmax_out = exp_s / sum_exp + knob.knob(softmax_out, mem_space="SharedHbm", tile_size=tile_size) + + acc[i] = softmax_out + return acc + + results = fori_loop(0, batch * n_heads, body, init_result) + return results + + run_kernel_test( + attention_kernel_loop, + check_ir_contains=["nisa.dma_copy"], + check_ir_not_contains=["memref.reshape", "transform.named_sequence"], + rtol=1e-3, + atol=1e-3, + modes=Mode.BIR_SIM | Mode.STRING_CHECK | Mode.HW, + ) + + +# ============================================================================ +# Attention Scores with SBUF BMM output + partition_dim=1 +# ============================================================================ + + +@pytest.mark.parametrize( + "batch, n_heads, seq_len, head_dim, tile_size", + [ + (2, 4, 256, 256, [1, 128, 128]), + ], +) +def test_attention_scores_sbuf_bmm(batch, n_heads, seq_len, head_dim, tile_size): + """ + Test attention scores with BMM output in SBUF and softmax using + partition_dim=1. + + The BMM is converted to loop + matmul with MxBxN output layout in SBUF. + Softmax ops use partition_dim=1, requiring canonicalize-partition-dim + to insert boundary transposes. LegalizeLayout expands the 3D SBUF alloc + to physical layout and tileCopyAndTranspose handles the HBM→SBUF copy. + """ + scale = 1.0 / np.sqrt(head_dim).item() + + @trace( + input_specs=[ + ((batch * n_heads, seq_len, head_dim), "f32"), + ((batch * n_heads, head_dim, seq_len), "f32"), + ] + ) + def attention_kernel(q, k_transposed): + bmm_result = np.matmul(q, k_transposed) + knob.knob( + bmm_result, mem_space="Sbuf", tile_size=tile_size, reduction_tile=[128] + ) + + scores = bmm_result * scale + knob.knob(scores, mem_space="Sbuf", tile_size=tile_size, partition_dim=1) + + scores_fp32 = scores.astype(np.float32) + + scores_max = np.max(scores_fp32, axis=-1, keepdims=True) + knob.knob( + scores_max, + mem_space="Sbuf", + tile_size=[1, 128], + reduction_tile=[128], + partition_dim=1, + ) + + shifted = scores_fp32 - scores_max + knob.knob(shifted, mem_space="Sbuf", tile_size=tile_size, partition_dim=1) + + exp_s = np.exp(shifted) + knob.knob(exp_s, mem_space="Sbuf", tile_size=tile_size, partition_dim=1) + + sum_exp = np.sum(exp_s, axis=-1, keepdims=True) + knob.knob( + sum_exp, + mem_space="Sbuf", + tile_size=[1, 128], + reduction_tile=[128], + partition_dim=1, + ) + + result = exp_s / sum_exp + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size, partition_dim=1) + return result + + # Verify LLVM simulation matches NumPy through legalize-layout + run_kernel_test( + attention_kernel, + stop_after="legalize-layout", + rtol=1e-3, + atol=1e-3, + modes=Mode.LLVM, + ) + + # Verify full pipeline generates NISA dialect ops and simulates correctly + run_kernel_test( + attention_kernel, + rtol=1e-3, + atol=1e-3, + check_ir_contains=["nisa.matmul", "nisa.dma_copy"], + modes=Mode.STRING_CHECK | Mode.BIR_SIM | Mode.HW, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/e2e/test_auto_layout.py b/kernelgen/tests/e2e/test_auto_layout.py new file mode 100644 index 0000000..6fa9e51 --- /dev/null +++ b/kernelgen/tests/e2e/test_auto_layout.py @@ -0,0 +1,132 @@ +""" +End-to-end tests for auto-inferred layouts (no user annotations). + +These tests verify that the infer-layout pass can automatically determine +tile sizes, partition dims, and memory spaces, and that the resulting code +compiles and runs correctly through BIR simulation and hardware. + +Run with: pytest tests/e2e/test_auto_layout.py -v +""" + +import numpy as np +import pytest + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + + +# ============================================================================ +# Elementwise: no user annotations at all +# ============================================================================ + +def test_exp_no_annotations(): + """ + exp(x) with no user knobs. The pass should auto-infer: + partition_dim=0, tile_size=[128, 256], mem_space=SharedHbm (return val) + """ + M, N = 128, 256 + + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return np.exp(x) + + run_kernel_test( + kernel, + modes=Mode.BIR_SIM | Mode.HW, + ) + + +def test_elementwise_chain_no_annotations(): + """ + exp -> square -> add_scalar chain with no user annotations. + All ops should get auto-inferred layouts and compile/run correctly. + """ + M, N = 256, 128 + + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + y = np.exp(x) + z = np.square(y) + return z + 1.0 + + run_kernel_test( + kernel, + modes=Mode.BIR_SIM | Mode.HW, + ) + + +def test_sigmoid_no_annotations(): + """ + sigmoid(x) = 1 / (1 + exp(-x)) with no user annotations. + This exercises: negate, exp, add_scalar, reciprocal, mul_scalar. + """ + M, N = 128, 128 + + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + return 1.0 / (1.0 + np.exp(-x)) + + run_kernel_test( + kernel, + modes=Mode.BIR_SIM | Mode.HW, + rtol=1e-5, + atol=1e-5, + ) + + +# ============================================================================ +# Matmul: no user annotations +# ============================================================================ + +def test_matmul_no_annotations(): + """ + matmul(a, b) with no user knobs. The pass should auto-seed: + Result C [M,N]: partition_dim=0, tile=[128, 128], reduction_tile=[128] + Operand A [M,K]: partition_dim=1 + Operand B [K,N]: partition_dim=0 + + TODO: Once KnobDrivenTiling supports non-blocked matmul for small dims, + the auto-seeded tile can be larger (e.g., [128, 256] for N=512). + """ + M, N, K = 256, 256, 128 + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32")]) + def kernel(a, b): + return np.matmul(a, b) + + run_kernel_test( + kernel, + modes=Mode.BIR_SIM | Mode.HW, + rtol=1e-3, + atol=1e-3, + ) + + +def test_matmul_add_no_annotations(): + """ + matmul(a, b) + c with no user annotations. Tests auto-seeding of + matmul and forward propagation to the elementwise add. + + Dims must be large enough to satisfy KnobDrivenTiling's blocking + factor=2 constraint: tile * 2 <= dim for each tiled dimension. + + TODO: Once KnobDrivenTiling supports non-blocked matmul for small dims, + smaller dimensions can be used here. + """ + M, N, K = 256, 512, 128 + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32"), ((M, N), "f32")]) + def kernel(a, b, c): + return np.matmul(a, b) + c + + run_kernel_test( + kernel, + modes=Mode.BIR_SIM | Mode.HW, + rtol=1e-3, + atol=1e-3, + ) + + +if __name__ == "__main__": + import pytest + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/e2e/test_custom_op.py b/kernelgen/tests/e2e/test_custom_op.py new file mode 100644 index 0000000..142765e --- /dev/null +++ b/kernelgen/tests/e2e/test_custom_op.py @@ -0,0 +1,317 @@ +""" +End-to-end tests for custom op integration. + +Tests the full flow: tracing with CustomOp -> pipeline passes -> resolve-custom-ops. +The custom op replaces the activation function in a kernel with a +pre-compiled NISA function body (either hand-written or built via kernel_builder). + +Run with: pytest tests/e2e/test_custom_op.py -v +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from nkipy_kernelgen.custom_op import CustomOp +from harness import run_kernel_test, Mode + +import nki.compiler.kernel_builder as nb + + +def _make_silu_custom_op(M, N, tile_p=128, tile_f=128): + """Create a CustomOp with real NISA MLIR compiled via kernel_builder. + + Uses kernel_builder to compile a real SiLU activation, then extracts the + MLIR string and passes it to the direct CustomOp() constructor. This tests + the raw-constructor path with genuine NISA ops (dma_copy, activation, etc.) + rather than a hand-written stub. + """ + shape = (M, N) + + def silu_kernel(x_hbm, out_hbm): + import nki.language as nl + + n_row_tiles = M // tile_p + n_col_tiles = N // tile_f + for r in nl.affine_range(n_row_tiles): + for t in nl.affine_range(n_col_tiles): + x_sbuf = nb.ndarray((tile_p, tile_f), x_hbm.dtype, nb.sbuf) + nb.isa.dma_copy( + dst=x_sbuf, + src=x_hbm[ + r * tile_p : (r + 1) * tile_p, t * tile_f : (t + 1) * tile_f + ], + ) + + out_sbuf = nb.ndarray((tile_p, tile_f), x_hbm.dtype, nb.sbuf) + + bias = nb.ndarray((tile_p, 1), x_hbm.dtype, nb.sbuf) + nb.isa.memset(dst=bias, value=0.0) + scale = nb.ndarray((tile_p, 1), x_hbm.dtype, nb.sbuf) + nb.isa.memset(dst=scale, value=1.0) + + nb.isa.activation( + dst=out_sbuf, + src=x_sbuf, + bias=bias, + scale=scale, + op=nb.isa.activation_function.silu, + ) + + nb.isa.dma_copy( + dst=out_hbm[ + r * tile_p : (r + 1) * tile_p, t * tile_f : (t + 1) * tile_f + ], + src=out_sbuf, + ) + + # Compile via kernel_builder and extract MLIR string + module = nb.build_kernel( + silu_kernel, + input_specs={"x_hbm": nb.Tensor(shape, nb.float32, nb.shared_hbm)}, + output_specs={"out_hbm": nb.Tensor(shape, nb.float32, nb.shared_hbm)}, + ) + nisa_mlir = module.operation.get_asm(print_generic_op_form=True) + + def silu_reference(x): + return x / (1.0 + np.exp(-x)) + + return CustomOp( + nisa_mlir=nisa_mlir, + func_name=f"silu_{M}x{N}_{M}x{N}", + input_names=["x_hbm"], + output_names=["out_hbm"], + input_shapes=[shape], + output_shapes=[shape], + input_dtypes=["f32"], + output_dtypes=["f32"], + reference_fn=silu_reference, + ) + + +# ============================================================================ +# Test: custom op tracing produces correct IR structure +# ============================================================================ + + +def test_custom_op_trace_ir_structure(): + """ + Verify that tracing with a CustomOp produces the expected IR: + - func.call to the custom op + - func.func private declaration with nkipy.custom_op + - nkipy.custom_op_bodies stashed on the module + """ + custom_silu = _make_silu_custom_op(256, 256) + + @trace(input_specs=[((256, 256), "f32")]) + def kernel(x): + return custom_silu(x) + + module = kernel.to_mlir() + mlir_str = str(module) + + # Verify call site + assert "call @__custom_op__silu_256x256_256x256" in mlir_str + # Verify declaration + assert "nkipy.custom_op" in mlir_str + # Verify body stashing + assert "nkipy.custom_op_bodies" in mlir_str + # Verify the NISA body string is stashed + assert "nisa.target" in mlir_str + + +# ============================================================================ +# Test: custom op in feedforward kernel (full pipeline, STRING_CHECK) +# ============================================================================ + + +def test_matmul_custom_activation_string_check(): + """ + Simple kernel: matmul followed by a CustomOp activation. + + The pipeline should: + 1. Trace the kernel with func.call to the custom op + 2. Run all passes (tiling, bufferize, annotate, legalize, linalg-to-nisa) + - The custom op declaration passes through as a bodyless func.func + 3. resolve-custom-ops links the NISA body and rewrites call sites + 4. prepare-for-nki strips nkipy.* attrs and adds nisa.target + + We verify the final IR contains the resolved custom op. + """ + custom_silu = _make_silu_custom_op(256, 256) + + @trace( + input_specs=[ + ((256, 256), "f32"), # x + ((256, 256), "f32"), # weight + ] + ) + def matmul_activation_kernel(x, weight): + # Matrix multiply + mm_out = np.matmul(x, weight) + knob.knob( + mm_out, mem_space="SharedHbm", tile_size=[128, 128], reduction_tile=[128] + ) + + # Custom SiLU activation on result (input/output on HBM) + output = custom_silu(mm_out) + + return output + + run_kernel_test( + matmul_activation_kernel, + check_ir_contains=[ + # NISA ops from the main kernel + "nisa.matmul", + "nisa.target", + # NISA ops from the inlined SiLU custom op + "nisa.activation", + "nisa.dma_copy", + "nisa.memset", + ], + check_ir_not_contains=[ + # These should be stripped by prepare-for-nki + "nkipy.custom_op_bodies", + "nkipy.custom_op", + "transform.named_sequence", + # Function should be inlined, not linked + "__custom_op__silu_256x256_256x256", + ], + rtol=1e-3, + atol=1e-3, + modes=Mode.BIR_SIM | Mode.STRING_CHECK, + ) + + +# ============================================================================ +# Test: custom op reference_fn works for numpy execution +# ============================================================================ + + +def test_custom_op_reference_fn_numpy(): + """ + Verify that the custom op's reference_fn produces correct numpy results + when called outside of tracing (for test validation). + """ + custom_silu = _make_silu_custom_op(4, 4) + + x = np.random.randn(4, 4).astype(np.float32) + result = custom_silu(x) + expected = x / (1.0 + np.exp(-x)) + np.testing.assert_allclose(result, expected, rtol=1e-6) + + +# ============================================================================ +# Test: kernel_builder SiLU custom op (full pipeline, STRING_CHECK) +# ============================================================================ + + +def _make_silu_kernel_builder_op(M, N, tile_p=128, tile_f=128): + """Create a CustomOp using kernel_builder to compile a real SiLU activation. + + The kernel tiles internally: processes (tile_p x tile_f) chunks of the + (M x N) input, one column-tile at a time. This keeps SBUF usage to + tile_p*tile_f elements (fitting in one SBUF partition row). + """ + + def silu_kernel(x_hbm, out_hbm): + import nki.language as nl + + n_row_tiles = M // tile_p + n_col_tiles = N // tile_f + for r in nl.affine_range(n_row_tiles): + for t in nl.affine_range(n_col_tiles): + x_sbuf = nb.ndarray((tile_p, tile_f), x_hbm.dtype, nb.sbuf) + nb.isa.dma_copy( + dst=x_sbuf, + src=x_hbm[ + r * tile_p : (r + 1) * tile_p, t * tile_f : (t + 1) * tile_f + ], + ) + + out_sbuf = nb.ndarray((tile_p, tile_f), x_hbm.dtype, nb.sbuf) + + bias = nb.ndarray((tile_p, 1), x_hbm.dtype, nb.sbuf) + nb.isa.memset(dst=bias, value=0.0) + scale = nb.ndarray((tile_p, 1), x_hbm.dtype, nb.sbuf) + nb.isa.memset(dst=scale, value=1.0) + + nb.isa.activation( + dst=out_sbuf, + src=x_sbuf, + bias=bias, + scale=scale, + op=nb.isa.activation_function.silu, + ) + + nb.isa.dma_copy( + dst=out_hbm[ + r * tile_p : (r + 1) * tile_p, t * tile_f : (t + 1) * tile_f + ], + src=out_sbuf, + ) + + def silu_reference(x): + return x / (1.0 + np.exp(-x)) + + return CustomOp.from_kernel_builder( + kernel_func=silu_kernel, + input_specs={"x_hbm": nb.Tensor((M, N), nb.float32, nb.shared_hbm)}, + output_specs={"out_hbm": nb.Tensor((M, N), nb.float32, nb.shared_hbm)}, + reference_fn=silu_reference, + ) + + +def test_kernel_builder_silu(): + """ + Matmul + SiLU activation where SiLU is compiled via kernel_builder. + + This tests the from_kernel_builder() path which produces real NISA ops + (dma_copy, activation, memset) rather than a hand-written stub. + + Uses 128x128 tiles because SBUF partition dim max is 128. + Verifies both IR structure (STRING_CHECK) and numerical correctness (BIR_SIM). + """ + # Custom op processes 256x256 HBM buffer, tiling to 128x128 internally + custom_silu = _make_silu_kernel_builder_op(256, 256, tile_p=128, tile_f=128) + + @trace( + input_specs=[ + ((256, 256), "f32"), # x + ((256, 256), "f32"), # weight + ] + ) + def matmul_silu_kernel(x, weight): + mm_out = np.matmul(x, weight) + knob.knob( + mm_out, mem_space="SharedHbm", tile_size=[128, 128], reduction_tile=[128] + ) + + output = custom_silu(mm_out) + return output + + run_kernel_test( + matmul_silu_kernel, + check_ir_contains=[ + # Custom op body is inlined — check for NISA ops from both + # the main kernel (matmul) and the inlined SiLU activation + "nisa.activation", + "nisa.dma_copy", + "nisa.matmul", + "nisa.target", + "nisa.memset", # from SiLU bias/scale initialization + ], + check_ir_not_contains=[ + "nkipy.custom_op_bodies", + "nkipy.custom_op", + # Function should be inlined, not linked as separate func + "__custom_op__silu_kernel", + ], + rtol=1e-3, + atol=1e-3, + modes=Mode.BIR_SIM | Mode.STRING_CHECK, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/e2e/test_feedforward.py b/kernelgen/tests/e2e/test_feedforward.py new file mode 100644 index 0000000..f6c4061 --- /dev/null +++ b/kernelgen/tests/e2e/test_feedforward.py @@ -0,0 +1,141 @@ +""" +End-to-end tests for feedforward network kernel without bias. + +The feedforward kernel implements: +1. Gate+Up projection: x @ gate_up_weight -> split into gate and up +2. SwiGLU activation: SiLU(gate) * up +3. Down projection: result @ down_weight + +This test runs the full pipeline: +1. Trace Python code to MLIR +2. Run passes through nkipy-opt (assign-linalg-op-ids, knob-driven-tiling, etc.) +3. Convert to NISA dialect (linalg-to-nisa, prepare-for-nki) +4. Simulate with neuron-cc + +Run with: pytest tests/e2e/test_feedforward.py -v +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + + +# ============================================================================ +# Test Cases +# ============================================================================ + +@pytest.mark.parametrize("batch_size, hidden_size, intermediate_size, matmul_tile, matmul_reduction_tile, elementwise_tile", [ + (256, 256, 256, [128, 128], [128], [128, 128]), +]) +def test_feedforward_sbuf(batch_size, hidden_size, intermediate_size, + matmul_tile, matmul_reduction_tile, elementwise_tile): + """ + Test feedforward network kernel with SBUF intermediates. + SiLU broken down into individual ops, each annotated with a knob. + """ + @trace(input_specs=[ + ((batch_size, hidden_size), "f32"), + ((hidden_size, 2 * intermediate_size), "f32"), + ((intermediate_size, hidden_size), "f32"), + ]) + def feedforward_kernel(x, gate_up_weight, down_weight): + """Feedforward network: Gate+Up projection -> SwiGLU -> Down projection""" + # Gate and Up projection + mm_gup = np.matmul(x, gate_up_weight) + knob.knob(mm_gup, mem_space="Sbuf", tile_size=matmul_tile, reduction_tile=matmul_reduction_tile) + + # Split into gate and up components + split_axis = mm_gup.ndim - 1 + gate, up = np.split(mm_gup, 2, axis=split_axis) + + # Apply SiLU activation to gate: sigmoid(gate) * gate + # Break down sigmoid into individual ops so each can be tiled + neg_gate = -gate + knob.knob(neg_gate, mem_space="Sbuf", tile_size=elementwise_tile) + + exp_neg_gate = np.exp(neg_gate) + knob.knob(exp_neg_gate, mem_space="Sbuf", tile_size=elementwise_tile) + + one_plus_exp = exp_neg_gate + 1.0 + knob.knob(one_plus_exp, mem_space="Sbuf", tile_size=elementwise_tile) + + sigmoid_gate = 1.0 / one_plus_exp + knob.knob(sigmoid_gate, mem_space="Sbuf", tile_size=elementwise_tile) + + swish_gate = gate * sigmoid_gate + knob.knob(swish_gate, mem_space="Sbuf", tile_size=elementwise_tile) + + # Element-wise multiplication (gating) + gated = swish_gate * up + knob.knob(gated, mem_space="Sbuf", tile_size=elementwise_tile) + + # Down projection + output = np.matmul(gated, down_weight) + knob.knob(output, mem_space="SharedHbm", tile_size=matmul_tile, reduction_tile=matmul_reduction_tile) + + return output + + run_kernel_test( + feedforward_kernel, + check_ir_contains=[ + "nisa.alloc", "nisa.matmul", "nisa.target", + ], + check_ir_not_contains=["transform.named_sequence"], + rtol=1e-3, # Relaxed due to accumulated errors across many ops + atol=1e-3, + modes=Mode.BIR_SIM | Mode.STRING_CHECK | Mode.HW, + ) + + +@pytest.mark.parametrize("batch_size, hidden_size, intermediate_size, matmul_tile, matmul_reduction_tile, elementwise_tile", [ + (256, 256, 256, [128, 128], [128], [128, 128]), +]) +def test_feedforward_sbuf_compact_silu(batch_size, hidden_size, intermediate_size, + matmul_tile, matmul_reduction_tile, elementwise_tile): + """ + Test feedforward network kernel with compact SiLU expression. + + Same computation as test_feedforward_sbuf but with SiLU written + as a single expression instead of broken-down ops with per-op knobs. + """ + @trace(input_specs=[ + ((batch_size, hidden_size), "f32"), + ((hidden_size, 2 * intermediate_size), "f32"), + ((intermediate_size, hidden_size), "f32"), + ]) + def feedforward_kernel(x, gate_up_weight, down_weight): + """Feedforward network: Gate+Up projection -> SwiGLU -> Down projection""" + # Gate and Up projection + mm_gup = np.matmul(x, gate_up_weight) + knob.knob(mm_gup, mem_space="Sbuf", tile_size=matmul_tile, reduction_tile=matmul_reduction_tile) + + # Split into gate and up components + split_axis = mm_gup.ndim - 1 + gate, up = np.split(mm_gup, 2, axis=split_axis) + + # SwiGLU: SiLU(gate) * up = (gate / (1 + exp(-gate))) * up + gated = gate / (1.0 + np.exp(-gate)) * up + knob.knob(gated, mem_space="Sbuf", tile_size=elementwise_tile) + + # Down projection + output = np.matmul(gated, down_weight) + knob.knob(output, mem_space="SharedHbm", tile_size=matmul_tile, reduction_tile=matmul_reduction_tile) + + return output + + run_kernel_test( + feedforward_kernel, + check_ir_contains=[ + "nisa.alloc", "nisa.matmul", "nisa.target", + ], + check_ir_not_contains=["transform.named_sequence"], + rtol=1e-3, + atol=1e-3, + modes=Mode.BIR_SIM | Mode.STRING_CHECK | Mode.HW, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/e2e/test_head_deconcat.py b/kernelgen/tests/e2e/test_head_deconcat.py new file mode 100644 index 0000000..5401662 --- /dev/null +++ b/kernelgen/tests/e2e/test_head_deconcat.py @@ -0,0 +1,88 @@ +""" +End-to-end tests for head de-concatenation (reshape + transpose + reshape). + +This is the pattern used in multi-head attention to merge head outputs back +into the hidden dimension: + (BH, seq, hdim) -> (B, N, seq, hdim) -> (B, seq, N, hdim) -> (BS, hidden) + +The 4D transpose [0,2,1,3] creates a 4D SBUF alloc where dim 0 = batch (small), +not the partition dim (128). Without the SharedHbm workaround, legalize-layout +cannot tile this alloc and getBaseAndOffsets maps d0 to the batch dim, causing +OOB access in BIR simulation. + +The workaround annotates the 4D transpose output as SharedHbm so the transpose +stays in HBM and the SBUF alloc is never created. + +Run with: pytest tests/e2e/test_head_deconcat.py -v +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + + +def test_head_deconcat(request): + """ + Minimal reproducer for the head-deconcat SBUF dim 0 bug. + + Without the SharedHbm knob on the transpose output, the compiler creates + memref<2x128x2x128xf32, sbuf> with dim 0 = batch = 2. NISA lowering + assumes dim 0 = partition (128), causing OOB. + + The SharedHbm workaround keeps the transpose in HBM, sidestepping the issue. + """ + batch = 2 + n_heads = 2 + seq_len = 128 + head_dim = 128 + BH = batch * n_heads + BS = batch * seq_len + hidden = n_heads * head_dim + + @trace(input_specs=[ + ((BH, seq_len, head_dim), "f32"), + ((hidden, hidden), "f32"), + ]) + def head_deconcat_kernel(x, w): + # Reshape to expose batch and head dims + x = np.reshape(x, (batch, n_heads, seq_len, head_dim)) + + # Transpose to (batch, seq, heads, hdim) — the problematic op + x = np.transpose(x, (0, 2, 1, 3)) + + # Collapse back to 2D + x = np.reshape(x, (BS, hidden)) + # Workaround: annotate the 2D result as SharedHbm so the + # 4D transpose intermediate stays in HBM (not promoted to SBUF). + # Without this, the 4D SBUF alloc has dim 0 = batch (not partition), + # which legalize-layout cannot handle. + knob.knob(x, mem_space="SharedHbm", tile_size=[128, 128]) + + # Downstream matmul + result = np.matmul(x, w) + knob.knob(result, mem_space="SharedHbm", tile_size=[128, 128], + reduction_tile=[128]) + return result + + # Verify LLVM simulation through legalize-layout + run_kernel_test( + head_deconcat_kernel, + stop_after="legalize-layout", + modes=Mode.LLVM, + request=request, + ) + + # Full pipeline: NISA generation + BIR simulation + run_kernel_test( + head_deconcat_kernel, + check_ir_contains=["nisa.matmul", "nisa.alloc"], + check_ir_not_contains=["transform.named_sequence"], + modes=Mode.BIR_SIM | Mode.STRING_CHECK | Mode.HW, + request=request, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/e2e/test_matmul_add.py b/kernelgen/tests/e2e/test_matmul_add.py new file mode 100644 index 0000000..a2434f8 --- /dev/null +++ b/kernelgen/tests/e2e/test_matmul_add.py @@ -0,0 +1,119 @@ +""" +End-to-end tests for matmul + add using the complete C++ pass pipeline. + +This test runs the full pipeline: +1. Trace Python code to MLIR +2. Run passes through nkipy-opt (assign-linalg-op-ids, knob-driven-tiling, etc.) +3. Convert to NISA dialect (linalg-to-nisa, prepare-for-nki) +4. Simulate with neuron-cc + +Run with: pytest tests/e2e/test_matmul_add.py -v +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + + +# ============================================================================ +# Parameterized shapes and tile sizes +# ============================================================================ + +# (M, N, K, matmul_tile, matmul_reduction_tile, add_tile) +# SBUF budget per partition (trn1): 180,224 bytes. +# Peak SBUF during matmul = result_buf + lhs_buf + rhs_buf, where: +# result_buf = (M/tileM) * (N/tileN) * tileN * 4 (full MxN in SBUF) +# lhs_buf = (K/tileK) * 2 * tileK * 4 (full K for one BLOCK_M) +# rhs_buf = (K/tileK) * 2 * tileK * 4 (full K for one BLOCK_N) +# With tile=128: result = M*N/32, each operand = K*8. Total <= 180,224. +MATMUL_ADD_CONFIGS = [ + # Small matmul: blocking degenerates to 1 (tile == dim for M and N) + (128, 128, 128, [128, 128], [128], [128, 128]), + # Standard cases with block size 2 + (256, 256, 256, [128, 128], [128], [128, 128]), + (1024, 1024, 1024, [128, 128], [128], [128, 128]), + (2048, 2048, 2048, [128, 128], [128], [128, 128]), + (4096, 1024, 1024, [128, 128], [128], [128, 128]), + (1024, 4096, 1024, [128, 128], [128], [128, 128]), + (2048, 1024, 2048, [128, 128], [128], [128, 128]), + (1024, 2048, 2048, [128, 128], [128], [128, 128]), + (2048, 2048, 1024, [128, 128], [128], [128, 128]), +] + + +def _config_id(val): + M, N, K, mt, rt, at = val + return f"{M}x{N}x{K}_mt{'x'.join(map(str, mt))}_rt{rt[0]}_at{'x'.join(map(str, at))}" + + +# ============================================================================ +# Test Cases +# ============================================================================ + +@pytest.mark.parametrize( + "M, N, K, matmul_tile, matmul_reduction_tile, add_tile", + MATMUL_ADD_CONFIGS, + ids=[_config_id(c) for c in MATMUL_ADD_CONFIGS], +) +def test_matmul_sbuf_add_hbm(M, N, K, matmul_tile, matmul_reduction_tile, add_tile): + """ + Test matmul + add with SBUF intermediate. + + Pattern: result = matmul(A, B) + bias + - matmul output: SBUF intermediate + - final result: HBM output + """ + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32"), ((M, N), "f32")]) + def matmul_add_kernel(a, b, bias): + c = np.matmul(a, b) + knob.knob(c, mem_space="Sbuf", tile_size=matmul_tile, reduction_tile=matmul_reduction_tile) + + # Add outputs to SharedHbm (returned from kernel) + result = c + bias + knob.knob(result, mem_space="SharedHbm", tile_size=add_tile) + + return result + + run_kernel_test( + matmul_add_kernel, + check_ir_contains=[ + "nisa.alloc", "nisa.matmul", "nisa.tensor_tensor_arith", + "nisa.dma_transpose", "nisa.dma_copy", "nisa.target", + ], + check_ir_not_contains=["transform.named_sequence"], + modes=Mode.BIR_SIM | Mode.STRING_CHECK | Mode.HW, + ) + + +@pytest.mark.parametrize( + "M, N, K, matmul_tile, matmul_reduction_tile, add_tile", + MATMUL_ADD_CONFIGS, + ids=[_config_id(c) for c in MATMUL_ADD_CONFIGS], +) +def test_matmul_hbm_add_hbm(M, N, K, matmul_tile, matmul_reduction_tile, add_tile): + """ + Test matmul + add with HBM intermediate (no SBUF buffer reuse). + """ + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32"), ((M, N), "f32")]) + def matmul_add_kernel_hbm(a, b, bias): + c = np.matmul(a, b) + knob.knob(c, mem_space="SharedHbm", tile_size=matmul_tile, reduction_tile=matmul_reduction_tile) + result = c + bias + knob.knob(result, mem_space="SharedHbm", tile_size=add_tile) + return result + + run_kernel_test( + matmul_add_kernel_hbm, + check_ir_contains=["nisa.alloc", "nisa.matmul", "nisa.target"], + modes=Mode.BIR_SIM | Mode.STRING_CHECK | Mode.HW, + ) + + +# ============================================================================ +# Test Runner +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/e2e/test_multi_output.py b/kernelgen/tests/e2e/test_multi_output.py new file mode 100644 index 0000000..28e63bc --- /dev/null +++ b/kernelgen/tests/e2e/test_multi_output.py @@ -0,0 +1,114 @@ +""" +End-to-end tests for multi-output kernels. + +Verifies that kernels returning tuples of tensors compile and produce +correct results through LLVM JIT, BIR simulation, and hardware execution. +""" + +from nkipy_kernelgen.trace import trace +from harness import run_kernel_test, Mode + + +# ============================================================================ +# Multi-output kernel definitions +# ============================================================================ + + +@trace(input_specs=[((256, 256), "f32"), ((256, 256), "f32")]) +def add_and_sub(a, b): + """Return both sum and difference.""" + return a + b, a - b + + +@trace(input_specs=[((256, 256), "f32"), ((256, 256), "f32")]) +def add_and_mul(a, b): + """Return both sum and product.""" + s = a + b + p = a * b + return s, p + + +@trace(input_specs=[((256, 256), "f32"), ((256, 256), "f32")]) +def three_outputs(a, b): + """Return three outputs: sum, difference, product.""" + return a + b, a - b, a * b + + +# ============================================================================ +# Tests: Tracing +# ============================================================================ + + +def test_multi_output_traces(): + """Verify that multi-output kernels trace to MLIR with correct func signature.""" + module = add_and_sub.to_mlir() + mlir_str = str(module) + # Function should have two result types + assert "-> (tensor<256x256xf32>, tensor<256x256xf32>)" in mlir_str + + +def test_three_output_traces(): + """Verify three-output kernel traces correctly.""" + module = three_outputs.to_mlir() + mlir_str = str(module) + assert ( + "-> (tensor<256x256xf32>, tensor<256x256xf32>, tensor<256x256xf32>)" in mlir_str + ) + + +# ============================================================================ +# Tests: LLVM JIT verification +# ============================================================================ + + +def test_add_and_sub_llvm(request): + """Two-output kernel: sum and difference, verified via LLVM JIT.""" + run_kernel_test( + add_and_sub, + stop_after="trace", + modes=Mode.LLVM, + request=request, + ) + + +def test_add_and_mul_llvm(request): + """Two-output kernel: sum and product, verified via LLVM JIT.""" + run_kernel_test( + add_and_mul, + stop_after="trace", + modes=Mode.LLVM, + request=request, + ) + + +def test_three_outputs_llvm(request): + """Three-output kernel verified via LLVM JIT.""" + run_kernel_test( + three_outputs, + stop_after="trace", + modes=Mode.LLVM, + request=request, + ) + + +# ============================================================================ +# Tests: Full pipeline (BIR_SIM) +# ============================================================================ + + +def test_add_and_sub_bir_sim(request): + """Two-output kernel through full pipeline + BIR simulation.""" + run_kernel_test( + add_and_sub, + modes=Mode.BIR_SIM, + request=request, + ) + + +def test_add_and_mul_bir_sim(request): + """Two-output kernel through full pipeline + BIR simulation.""" + run_kernel_test( + add_and_mul, + modes=Mode.BIR_SIM, + request=request, + ) diff --git a/kernelgen/tests/e2e/test_partition_dim.py b/kernelgen/tests/e2e/test_partition_dim.py new file mode 100644 index 0000000..58164c9 --- /dev/null +++ b/kernelgen/tests/e2e/test_partition_dim.py @@ -0,0 +1,94 @@ +""" +End-to-end tests for non-zero partition_dim. + +These tests verify that the canonicalize-partition-dim pass correctly inserts +transposes so that kernels with partition_dim != 0 produce numerically correct +results through the full compilation pipeline. + +Run with: pytest tests/e2e/test_partition_dim.py -v +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + + +# ============================================================================ +# Test: 2D elementwise with partition_dim=1 +# ============================================================================ + +def test_exp_partition_dim_1(): + """ + exp(x) with partition_dim=1 through full pipeline to BIR simulation. + Verifies transposes are correctly inserted and the result matches NumPy. + + Shape (64, 128) with partition_dim=1: dim 1 (size 128) is the partition dim. + After canonicalization: tensor becomes (128, 64) with partition_dim=0. + """ + M, N = 64, 128 + tile_size = [M, N] + + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + result = np.exp(x) + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size, + partition_dim=1) + return result + + run_kernel_test( + kernel, + + check_ir_contains=["nisa.activation", "op=exp"], + modes=Mode.BIR_SIM | Mode.STRING_CHECK | Mode.HW, + ) + + +# ============================================================================ +# Test: 2D elementwise chain with partition_dim=1 +# ============================================================================ + +def test_sigmoid_partition_dim_1(): + """ + Sigmoid with partition_dim=1: sigmoid(x) = 1 / (1 + exp(-x)) + + The entire elementwise chain should be rewritten with permuted shapes, + and the result should match NumPy through BIR simulation. + + Shape (64, 128) with partition_dim=1: dim 1 (size 128) is the partition dim. + After canonicalization: tensors become (128, 64) with partition_dim=0. + """ + M, N = 64, 128 + tile_size = [M, N] + + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + neg_x = -x + knob.knob(neg_x, mem_space="Sbuf", tile_size=tile_size, + partition_dim=1) + + exp_neg_x = np.exp(neg_x) + knob.knob(exp_neg_x, mem_space="Sbuf", tile_size=tile_size, + partition_dim=1) + + one_plus_exp = 1.0 + exp_neg_x + knob.knob(one_plus_exp, mem_space="Sbuf", tile_size=tile_size, + partition_dim=1) + + result = 1.0 / one_plus_exp + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size, + partition_dim=1) + + return result + + run_kernel_test( + kernel, + + check_ir_contains=["nisa.activation", "op=exp"], + modes=Mode.BIR_SIM | Mode.STRING_CHECK | Mode.HW, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/e2e/test_qwen3_layer.py b/kernelgen/tests/e2e/test_qwen3_layer.py new file mode 100644 index 0000000..11e0de7 --- /dev/null +++ b/kernelgen/tests/e2e/test_qwen3_layer.py @@ -0,0 +1,309 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Qwen3 Transformer Decoder Layer. + +Ported from compiler_explorer/examples/qwen3_layer.py + +Architecture: + 1. Pre-attention RMSNorm + 2. QKV projection (hidden -> q, k, v per head) + 3. Reshape + transpose to multi-head format + 4. RoPE on Q and K + 5. Scaled dot-product attention + 6. Concat heads + output projection + 7. Residual connection + 8. Post-attention RMSNorm + 9. SwiGLU feedforward (gate_up projection, SiLU, down projection) + 10. Residual connection + +Sub-kernel boundaries (values that flow through reshape/transpose or between +independent compute stages) are annotated as SharedHbm so the compiler can +freely reshape/transpose them without partition_dim constraints. +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + +# ---------------------------------------------------------------- +# Model hyperparameters +# ---------------------------------------------------------------- +batch = 2 +seq_len = 128 +hidden_size = 256 +n_heads = 2 +head_dim = hidden_size // n_heads # 128 +intermediate_size = 256 +half_dim = head_dim // 2 # 64 +eps = 1e-6 +scale = 1.0 / np.sqrt(head_dim).item() + +# Derived +BS = batch * seq_len # 256 +BH = batch * n_heads # 4 + +# ---------------------------------------------------------------- +# Tile sizes +# ---------------------------------------------------------------- +matmul_tile_2d = [128, 128] +matmul_reduction_2d = [128] +attn_tile = [1, 128, 128] +attn_reduction = [128] +rope_tile = [1, 128, 64] # (BH, seq_len, half_dim) +elem_tile_2d = [128, 128] + + +# ---- helpers ---- + + +def rmsnorm(x, weight): + x_fp32 = x.astype(np.float32) + w_fp32 = weight.astype(np.float32) + + sq = np.square(x_fp32) + knob.knob(sq, mem_space="Sbuf", tile_size=elem_tile_2d) + + sum_sq = np.sum(sq, axis=-1, keepdims=True) + knob.knob(sum_sq, mem_space="Sbuf", tile_size=[128], reduction_tile=[128]) + + mean_sq = sum_sq * np.float32(1.0 / hidden_size) + knob.knob(mean_sq, mem_space="Sbuf", tile_size=[128, 1]) + + normed = x_fp32 / np.sqrt(mean_sq + eps) + knob.knob(normed, mem_space="Sbuf", tile_size=elem_tile_2d) + + result = normed * w_fp32 + knob.knob(result, mem_space="Sbuf", tile_size=elem_tile_2d) + + return result + + +def softmax_3d(x): + x_fp32 = x.astype(np.float32) + + # Workaround: reduction accumulators (x_max, sum_exp) use SharedHbm + # to avoid 5D SBUF allocs from legalize-layout. The 3D shape + # (BH, 128, 1) creates a 5D physical layout where the collapse_shape + # back to 2D has a tile/base mismatch in linalg-to-nisa. + x_max = np.max(x_fp32, axis=-1, keepdims=True) + knob.knob(x_max, mem_space="SharedHbm", tile_size=[1, 128], + reduction_tile=[128], partition_dim=1) + + shifted = x_fp32 - x_max + knob.knob(shifted, mem_space="Sbuf", tile_size=attn_tile, partition_dim=1) + + exp_s = np.exp(shifted) + knob.knob(exp_s, mem_space="Sbuf", tile_size=attn_tile, partition_dim=1) + + sum_exp = np.sum(exp_s, axis=-1, keepdims=True) + knob.knob(sum_exp, mem_space="SharedHbm", tile_size=[1, 128], + reduction_tile=[128], partition_dim=1) + + # Softmax result is a sub-kernel boundary (feeds into context matmul). + result = exp_s / sum_exp + knob.knob(result, mem_space="SharedHbm", tile_size=attn_tile) + + return result + + +def silu(x): + neg_x = -x + knob.knob(neg_x, mem_space="Sbuf", tile_size=elem_tile_2d) + + exp_neg = np.exp(neg_x) + knob.knob(exp_neg, mem_space="Sbuf", tile_size=elem_tile_2d) + + one_plus = exp_neg + 1.0 + knob.knob(one_plus, mem_space="Sbuf", tile_size=elem_tile_2d) + + sigmoid = 1.0 / one_plus + knob.knob(sigmoid, mem_space="Sbuf", tile_size=elem_tile_2d) + + result = x * sigmoid + knob.knob(result, mem_space="Sbuf", tile_size=elem_tile_2d) + + return result + + +def apply_rope(x, freqs_cos, freqs_sin): + x0 = x[:, :, :half_dim] + x1 = x[:, :, half_dim:] + + # Workaround: split compound expressions so each intermediate gets a + # SharedHbm knob. Without this, the multiply intermediates default to + # SBUF with shape (BH, seq, half_dim) = (4, 128, 64) where dim 0 = 4 + # (not partition). getBaseAndOffsets maps d0 to dim 0, causing OOB. + x0_cos = x0 * freqs_cos + knob.knob(x0_cos, mem_space="SharedHbm", tile_size=rope_tile) + x1_sin = x1 * freqs_sin + knob.knob(x1_sin, mem_space="SharedHbm", tile_size=rope_tile) + out_0 = x0_cos - x1_sin + knob.knob(out_0, mem_space="SharedHbm", tile_size=rope_tile) + + x0_sin = x0 * freqs_sin + knob.knob(x0_sin, mem_space="SharedHbm", tile_size=rope_tile) + x1_cos = x1 * freqs_cos + knob.knob(x1_cos, mem_space="SharedHbm", tile_size=rope_tile) + out_1 = x0_sin + x1_cos + knob.knob(out_1, mem_space="SharedHbm", tile_size=rope_tile) + + # RoPE result is a sub-kernel boundary (feeds into matmul/transpose). + result = np.concatenate([out_0, out_1], axis=-1) + knob.knob(result, mem_space="SharedHbm", tile_size=attn_tile) + + return result + + +# ---- test ---- + + +def test_qwen3_layer(request): + @trace(input_specs=[ + ((BS, hidden_size), "f32"), # hidden_states + ((hidden_size, 1), "f32"), # ln1_weight + ((hidden_size, 1), "f32"), # ln2_weight + ((hidden_size, hidden_size), "f32"), # w_q + ((hidden_size, hidden_size), "f32"), # w_k + ((hidden_size, hidden_size), "f32"), # w_v + ((hidden_size, hidden_size), "f32"), # w_o + ((1, seq_len, half_dim), "f32"), # freqs_cos + ((1, seq_len, half_dim), "f32"), # freqs_sin + ((hidden_size, intermediate_size), "f32"), # w_gate + ((hidden_size, intermediate_size), "f32"), # w_up + ((intermediate_size, hidden_size), "f32"), # w_down + ]) + def kernel(hidden_states, + ln1_weight, ln2_weight, + w_q, w_k, w_v, w_o, + freqs_cos, freqs_sin, + w_gate, w_up, w_down): + residual = hidden_states + + # 1. Pre-attention RMSNorm + normed = rmsnorm(hidden_states, ln1_weight) + + # 2. QKV projections — SharedHbm boundary (results go through reshape) + q = np.matmul(normed, w_q) + knob.knob(q, mem_space="SharedHbm", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + k = np.matmul(normed, w_k) + knob.knob(k, mem_space="SharedHbm", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + v = np.matmul(normed, w_v) + knob.knob(v, mem_space="SharedHbm", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + # 3. Reshape to multi-head + q = np.reshape(q, (batch, seq_len, n_heads, head_dim)) + q = np.transpose(q, (0, 2, 1, 3)) + q = np.reshape(q, (BH, seq_len, head_dim)) + + k = np.reshape(k, (batch, seq_len, n_heads, head_dim)) + k = np.transpose(k, (0, 2, 1, 3)) + k = np.reshape(k, (BH, seq_len, head_dim)) + + v = np.reshape(v, (batch, seq_len, n_heads, head_dim)) + v = np.transpose(v, (0, 2, 1, 3)) + v = np.reshape(v, (BH, seq_len, head_dim)) + # V is a sub-kernel boundary (feeds into context matmul via DMA). + # Annotate as SharedHbm so the 4D reshape intermediate stays in HBM + # rather than being promoted to SBUF (which would create a 4D SBUF + # alloc that legalize-layout cannot tile). + knob.knob(v, mem_space="SharedHbm", tile_size=attn_tile) + + # 4. RoPE on Q and K + q = apply_rope(q, freqs_cos, freqs_sin) + k = apply_rope(k, freqs_cos, freqs_sin) + + # K^T — SharedHbm boundary (transpose feeds into matmul) + k_t = np.transpose(k, (0, 2, 1)) + knob.knob(k_t, mem_space="SharedHbm", tile_size=attn_tile) + + # 5. Scaled dot-product attention + scores = np.matmul(q, k_t) + knob.knob(scores, mem_space="Sbuf", tile_size=attn_tile, reduction_tile=attn_reduction) + + scores = scores * scale + knob.knob(scores, mem_space="Sbuf", tile_size=attn_tile, partition_dim=1) + + attn_weights = softmax_3d(scores) + + # Context — SharedHbm boundary (result goes through reshape) + context = np.matmul(attn_weights, v) + knob.knob(context, mem_space="SharedHbm", tile_size=attn_tile, reduction_tile=attn_reduction) + + # 6. Concat heads + output projection + context = np.reshape(context, (batch, n_heads, seq_len, head_dim)) + context = np.transpose(context, (0, 2, 1, 3)) + context = np.reshape(context, (BS, hidden_size)) + # Workaround: annotate the 2D result as SharedHbm so the 4D transpose + # intermediate stays in HBM (not promoted to SBUF). Without this, + # the 4D SBUF alloc has dim 0 = batch (not partition) and + # legalize-layout cannot tile it, causing OOB in NISA lowering. + knob.knob(context, mem_space="SharedHbm", tile_size=matmul_tile_2d) + + attn_out = np.matmul(context, w_o) + knob.knob(attn_out, mem_space="Sbuf", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + # 7. Residual connection + hidden_states = residual + attn_out + knob.knob(hidden_states, mem_space="Sbuf", tile_size=elem_tile_2d) + + residual = hidden_states + + # 8. Post-attention RMSNorm + normed = rmsnorm(hidden_states, ln2_weight) + + # 9. SwiGLU FFN + gate = np.matmul(normed, w_gate) + knob.knob(gate, mem_space="Sbuf", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + up = np.matmul(normed, w_up) + knob.knob(up, mem_space="Sbuf", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + gate = silu(gate) + + gated = gate * up + knob.knob(gated, mem_space="Sbuf", tile_size=elem_tile_2d) + + ffn_out = np.matmul(gated, w_down) + knob.knob(ffn_out, mem_space="Sbuf", tile_size=matmul_tile_2d, reduction_tile=matmul_reduction_2d) + + # 10. Residual connection + output = residual + ffn_out + knob.knob(output, mem_space="SharedHbm", tile_size=elem_tile_2d) + + return output + + # First verify compilation succeeds + run_kernel_test( + kernel, + modes=Mode.STRING_CHECK, + check_ir_contains=["nisa.matmul", "nisa.alloc"], + request=request, + ) + + # Verify numerical correctness via LLVM JIT (stop after insert-memref-dealloc, + # before linalg-to-nisa which LLVM JIT cannot execute) + run_kernel_test( + kernel, + stop_after="insert-memref-dealloc", + modes=Mode.LLVM, + rtol=1e-3, + atol=1e-3, + request=request, + ) + + run_kernel_test( + kernel, + modes=Mode.BIR_SIM | Mode.HW, + rtol=1e-3, + atol=1e-3, + request=request, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/e2e/test_reduce.py b/kernelgen/tests/e2e/test_reduce.py new file mode 100644 index 0000000..84fcd91 --- /dev/null +++ b/kernelgen/tests/e2e/test_reduce.py @@ -0,0 +1,144 @@ +""" +End-to-end tests for reduction kernels (sum, mean). + +Exercises the full reduce pipeline: +1. Element-wise square +2. Reduction over last axis with keepdims=True +3. Tiled accumulation (tensor_reduce_arith + tensor_tensor_arith) + +Run with: pytest tests/e2e/test_reduce.py -v +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + + +M, N = 256, 256 +TILE_SIZE = [128, 128] + + +# ============================================================================ +# Trace-level tests (LLVM JIT) — verify linalg lowering +# ============================================================================ + + +@pytest.mark.parametrize("reduce_fn", ["sum", "mean"]) +def test_reduce_square_trace(reduce_fn): + """ + Test np.sum / np.mean of squared input at trace level. + + Verifies tracing produces a single linalg.generic + instead of linalg.reduce + tensor.reshape. + """ + reduce_op = getattr(np, reduce_fn) + + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + sq = np.square(x.astype(np.float32)) + knob.knob(sq, mem_space="Sbuf", tile_size=TILE_SIZE) + + result = reduce_op(sq, axis=-1, keepdims=True) + knob.knob( + result, + mem_space="SharedHbm", + tile_size=[128, 1], + reduction_tile=[128], + ) + return result + + run_kernel_test( + kernel, + stop_after="trace", + check_ir_contains=["linalg.generic"], + check_ir_not_contains=["linalg.reduce", "tensor.reshape"], + modes=Mode.LLVM | Mode.STRING_CHECK, + rtol=1e-3, + atol=1e-3, + ) + + +# ============================================================================ +# BIR simulation tests — verify full pipeline correctness +# ============================================================================ + + +def test_reduce_sum_sim(): + """ + BIR simulation: np.sum of squared input. + + Verifies the tiled reduction accumulation pattern: + tensor_reduce_arith(dst=temp, src=tile) -- partial reduce + tensor_tensor_arith(accum += temp) -- accumulate + """ + + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + sq = np.square(x.astype(np.float32)) + knob.knob(sq, mem_space="Sbuf", tile_size=TILE_SIZE) + + result = np.sum(sq, axis=-1, keepdims=True) + knob.knob( + result, + mem_space="SharedHbm", + tile_size=[128], + reduction_tile=[128], + ) + return result + + run_kernel_test( + kernel, + check_ir_contains=["nisa.tensor_reduce_arith", "nisa.tensor_tensor_arith"], + modes=Mode.BIR_SIM | Mode.STRING_CHECK | Mode.HW, + rtol=1e-3, + atol=1e-3, + ) + + +def test_reduce_mean_sim(): + """ + BIR simulation: np.mean of squared input. + + Mean is expressed as sum * (1/N) with separate knobs on the + intermediate sum and the final multiply, since they use different + memory spaces and tile configurations. + """ + + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + sq = np.square(x.astype(np.float32)) + knob.knob(sq, mem_space="Sbuf", tile_size=TILE_SIZE) + + sm = np.sum(sq, axis=-1, keepdims=True) + knob.knob( + sm, + mem_space="SharedHbm", + tile_size=[128], + reduction_tile=[128], + ) + + result = sm * np.float32(1.0 / N) + knob.knob( + result, + mem_space="SharedHbm", + tile_size=[128, 1], + ) + return result + + run_kernel_test( + kernel, + check_ir_contains=[ + "nisa.tensor_reduce_arith", + "nisa.tensor_tensor_arith", + "nisa.tensor_scalar_arith", + ], + modes=Mode.BIR_SIM | Mode.STRING_CHECK | Mode.HW, + rtol=1e-3, + atol=1e-3, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/e2e/test_rmsnorm.py b/kernelgen/tests/e2e/test_rmsnorm.py new file mode 100644 index 0000000..411ed21 --- /dev/null +++ b/kernelgen/tests/e2e/test_rmsnorm.py @@ -0,0 +1,87 @@ +""" +End-to-end tests for RMSNorm kernel. + +RMSNorm: output = (x / sqrt(mean(x^2) + eps)) * weight + +This exercises: +1. Element-wise square (multiply) +2. Sum reduction over last axis +3. Scalar multiply for mean (sum * 1/N) +4. Addition with scalar epsilon +5. Square root + division for normalization +6. Element-wise multiply with weight + +Run with: pytest tests/e2e/test_rmsnorm.py -v +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + + +# ============================================================================ +# Test Cases +# ============================================================================ + +@pytest.mark.parametrize("M, N, tile_size", [ + (256, 256, [128, 128]), +]) +def test_rmsnorm(M, N, tile_size): + """ + Test RMSNorm: x / sqrt(mean(x^2) + eps) * weight. + + Broken down into individual ops with per-op knobs, matching + compiler_explorer/examples/rmsnorm.py. + """ + eps = 1e-6 + + @trace(input_specs=[((M, N), "f32"), ((N, 1), "f32")]) + def rmsnorm_kernel(x, weight): + x_fp32 = x.astype(np.float32) + w_fp32 = weight.astype(np.float32) + + sq = np.square(x_fp32) + knob.knob(sq, mem_space="Sbuf", tile_size=tile_size) + + sum_sq = np.sum(sq, axis=-1, keepdims=True) + knob.knob( + sum_sq, + mem_space="Sbuf", + tile_size=[128], + reduction_tile=[128], + ) + + mean_sq = sum_sq * np.float32(1.0 / N) + knob.knob( + mean_sq, + mem_space="Sbuf", + tile_size=[128, 1], + ) + + normed = x_fp32 / np.sqrt(mean_sq + eps) + knob.knob(normed, mem_space="Sbuf", tile_size=tile_size) + + result = normed * w_fp32 + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size) + return result + + run_kernel_test( + rmsnorm_kernel, + stop_after="legalize-layout", + modes=Mode.LLVM, + rtol=1e-3, + atol=1e-3, + ) + + run_kernel_test( + rmsnorm_kernel, + modes=Mode.BIR_SIM | Mode.HW, + rtol=1e-3, + atol=1e-3, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/e2e/test_rope.py b/kernelgen/tests/e2e/test_rope.py new file mode 100644 index 0000000..59d5f42 --- /dev/null +++ b/kernelgen/tests/e2e/test_rope.py @@ -0,0 +1,238 @@ +""" +End-to-end tests for Rotary Position Embedding (RoPE) kernel. + +RoPE applies rotary embeddings to query and key tensors: + x_out[..., :half] = x[..., :half] * cos - x[..., half:] * sin + x_out[..., half:] = x[..., :half] * sin + x[..., half:] * cos + +This exercises: +1. Tensor slicing (split along last axis) +2. Element-wise multiply with broadcast cos/sin (via expand_dims) +3. Subtraction and addition +4. Concatenation along last axis +5. Phase 0 fold-reshape-into-alloc (2D cos/sin reshaped to 3D) +6. Trivial-broadcast generic canonicalization to named ops + +Run with: pytest tests/e2e/test_rope.py -v +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + + +# ============================================================================ +# Test Cases +# ============================================================================ + +def test_rope(): + """ + Test RoPE with 3D tensors: x(bs, n_heads, head_dim), cos/sin(bs, half_h). + + cos/sin are expanded to (bs, 1, half_h) via np.expand_dims, creating a + broadcast multiply pattern. After tiling with tile_size=[128, 1, 64], + the broadcast becomes trivial (size-1 on both sides). + + This exercises: + - Phase 0: 2D alloc + copy + reshape -> 3D SBUF alloc + - Trivial-broadcast linalg.generic -> named op canonicalization + - 3D SBUF legalization (5D physical layout) + - Concatenation via insert_slice + """ + batch = 2 + seq_len = 128 + n_heads = 4 + head_dim = 128 + half_h = head_dim // 2 + bs = batch * seq_len + tile_size = [128, 1, 64] + + @trace(input_specs=[ + ((bs, n_heads, head_dim), "f32"), + ((bs, half_h), "f32"), + ((bs, half_h), "f32"), + ]) + def rope_kernel(x, freqs_cos, freqs_sin): + # Broadcast cos/sin to (bs, 1, half_h) + # No knobs on cos/sin: they are views (expand_dims) of HBM inputs. + # Tiling promotes them to SBUF automatically as inputs to SBUF compute. + cos = np.expand_dims(freqs_cos, axis=1) + sin = np.expand_dims(freqs_sin, axis=1) + + # Split input into two halves along head_dim + x0 = x[:, :, :half_h] + x1 = x[:, :, half_h:] + + # Apply rotation + out_0 = x0 * cos - x1 * sin + knob.knob(out_0, mem_space="Sbuf", tile_size=tile_size) + + out_1 = x0 * sin + x1 * cos + knob.knob(out_1, mem_space="Sbuf", tile_size=tile_size) + + # Concatenate back along head_dim axis + result = np.concatenate([out_0, out_1], axis=-1) + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size) + return result + + run_kernel_test( + rope_kernel, + stop_after="legalize-layout", + modes=Mode.LLVM, + ) + + run_kernel_test( + rope_kernel, + check_ir_contains=[ + "nisa.alloc", "nisa.tensor_tensor_arith", "nisa.target", + ], + check_ir_not_contains=["transform.named_sequence"], + modes=Mode.BIR_SIM | Mode.STRING_CHECK | Mode.HW, + ) + + +def test_rope_3d_multi_partition(): + """ + Test RoPE with 3D tensors where dim 0 spans multiple partitions: + x(BH, seq_len, head_dim), cos/sin(1, seq_len, half_dim). + + BH=4 with tile [1, 128, 64] means 4 partitions in SBUF. + The subtract/add ops produce 4-partition intermediates. + The vector engine cannot selectively address individual partitions, + so the compiler must bounce through 1-partition DMA temps. + + This exercises: + - Multi-partition SBUF element-wise with engine=vector + - DMA materialization for multi-partition vector operands + - Concatenation of multi-partition results + """ + BH = 4 + seq_len = 128 + head_dim = 128 + half_dim = head_dim // 2 + rope_tile = [1, 128, 64] + attn_tile = [1, 128, 128] + + @trace(input_specs=[ + ((BH, seq_len, half_dim), "f32"), # q0 + ((BH, seq_len, half_dim), "f32"), # q1 + ((1, seq_len, half_dim), "f32"), # freqs_cos (broadcast over BH) + ((1, seq_len, half_dim), "f32"), # freqs_sin (broadcast over BH) + ]) + def rope_3d_kernel(q0, q1, freqs_cos, freqs_sin): + # Split compound expressions so each intermediate gets an HBM knob. + # Without this, the multiply intermediates default to SBUF, and the + # vector engine illegally indexes into specific partitions. + t0 = q0 * freqs_cos + knob.knob(t0, mem_space="SharedHbm", tile_size=rope_tile) + t1 = q1 * freqs_sin + knob.knob(t1, mem_space="SharedHbm", tile_size=rope_tile) + q_rot0 = t0 - t1 + knob.knob(q_rot0, mem_space="SharedHbm", tile_size=rope_tile) + + t2 = q0 * freqs_sin + knob.knob(t2, mem_space="SharedHbm", tile_size=rope_tile) + t3 = q1 * freqs_cos + knob.knob(t3, mem_space="SharedHbm", tile_size=rope_tile) + q_rot1 = t2 + t3 + knob.knob(q_rot1, mem_space="SharedHbm", tile_size=rope_tile) + + result = np.concatenate([q_rot0, q_rot1], axis=-1) + knob.knob(result, mem_space="SharedHbm", tile_size=attn_tile) + return result + + run_kernel_test( + rope_3d_kernel, + stop_after="legalize-layout", + modes=Mode.LLVM, + ) + + run_kernel_test( + rope_3d_kernel, + check_ir_contains=[ + "nisa.alloc", "nisa.tensor_tensor_arith", "nisa.target", + ], + check_ir_not_contains=["transform.named_sequence"], + modes=Mode.BIR_SIM | Mode.STRING_CHECK | Mode.HW, + ) + + +# ============================================================================ +# 3D RoPE: compound expression in SBUF (no HBM workaround) +# +# Same kernel body, parametrized by partition_dim. The shape is permuted so +# that the partition dim (seq_len=128) sits at position `pdim`. +# +# pdim=0: (seq=128, BH=4, half=64) — partition at dim 0, works today. +# pdim=1: (BH=4, seq=128, half=64) — partition at dim 1, requires +# infer-layout to propagate partition_dim=1 to compound expression +# intermediates. See docs/qwen3_sbuf_partition_dim_workarounds.md +# ============================================================================ + +_ROPE_COMPOUND_PARAMS = [ + pytest.param(0, id="pdim0"), + pytest.param(1, id="pdim1"), +] + + +@pytest.mark.parametrize("pdim", _ROPE_COMPOUND_PARAMS) +def test_rope_3d_compound(pdim): + """ + 3D RoPE compound expression with SBUF output, no intermediate knobs. + + The kernel body is identical for both partition_dim values; only the + input shapes and tile layout change. + """ + seq_len = 128 + BH = 4 + half_dim = 64 + + # Build shapes: partition dim (seq_len) at position `pdim`, + # batch dim (BH) at the other position. Last dim is always half_dim. + if pdim == 0: + q_shape = (seq_len, BH, half_dim) + cos_shape = (seq_len, 1, half_dim) + tile = [128, 1, 64] + concat_tile = [128, 1, 128] + else: + q_shape = (BH, seq_len, half_dim) + cos_shape = (1, seq_len, half_dim) + tile = [1, 128, 64] + concat_tile = [1, 128, 128] + + @trace(input_specs=[ + (q_shape, "f32"), # q0 + (q_shape, "f32"), # q1 + (cos_shape, "f32"), # freqs_cos (broadcast over BH) + (cos_shape, "f32"), # freqs_sin + ]) + def kernel(q0, q1, freqs_cos, freqs_sin): + q_rot0 = q0 * freqs_cos - q1 * freqs_sin + knob.knob(q_rot0, mem_space="Sbuf", tile_size=tile, + partition_dim=pdim) + + q_rot1 = q0 * freqs_sin + q1 * freqs_cos + knob.knob(q_rot1, mem_space="Sbuf", tile_size=tile, + partition_dim=pdim) + + result = np.concatenate([q_rot0, q_rot1], axis=-1) + knob.knob(result, mem_space="SharedHbm", tile_size=concat_tile) + return result + + run_kernel_test( + kernel, + stop_after="legalize-layout", + modes=Mode.LLVM, + ) + + run_kernel_test( + kernel, + check_ir_contains=["nisa.alloc", "nisa.tensor_tensor_arith"], + modes=Mode.BIR_SIM | Mode.STRING_CHECK | Mode.HW, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/e2e/test_sigmoid.py b/kernelgen/tests/e2e/test_sigmoid.py new file mode 100644 index 0000000..2f7b803 --- /dev/null +++ b/kernelgen/tests/e2e/test_sigmoid.py @@ -0,0 +1,172 @@ +""" +End-to-end tests for sigmoid activation and tensor-scalar arithmetic. + +Sigmoid: sigmoid(x) = 1.0 / (1.0 + exp(-x)) + +This exercises the full Path 2 implementation: +1. Division converted to multiply + reciprocal (prepare-arithmetic pass) +2. Tensor-scalar operations for the 1.0 constants (nisa.tensor_scalar_arith) +3. Negation for -x +4. Exponential via nisa.activation(op=exp) +5. Reciprocal via nisa.reciprocal + +Run with: pytest tests/e2e/test_sigmoid.py -v +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + + +# ============================================================================ +# Test Cases +# ============================================================================ + +def test_exp_activation(): + """ + Test exp(x) lowering to nisa.activation. + + This verifies: + - linalg.exp is converted to nisa.activation with op=exp + - Scalar bias=0.0 and scale=1.0 are used + """ + M, N = 128, 256 + tile_size = [128, 128] + + @trace(input_specs=[((M, N), "f32")]) + def exp_kernel(x): + result = np.exp(x) + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size) + return result + + run_kernel_test( + exp_kernel, + + check_ir_contains=["nisa.activation", "op=exp"], + modes=Mode.BIR_SIM | Mode.STRING_CHECK | Mode.HW, + ) + + +def test_tensor_add_scalar(): + """ + Test tensor + scalar arithmetic. + + This verifies: + - Constant 2.0 is broadcast via linalg.fill with CONSTANT memspace + - linalg.add detects CONSTANT operand and emits nisa.tensor_scalar_arith + """ + M, N = 128, 256 + tile_size = [128, 128] + scalar_value = 2.0 + + @trace(input_specs=[((M, N), "f32")]) + def add_scalar_kernel(x): + result = x + scalar_value + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size) + return result + + run_kernel_test( + add_scalar_kernel, + check_ir_contains=[ + "nisa.alloc", "nisa.target", + "nisa.tensor_scalar_arith", "op0=add", + ], + modes=Mode.BIR_SIM | Mode.STRING_CHECK | Mode.HW, + ) + + +def test_sigmoid(): + """ + Test sigmoid activation: sigmoid(x) = 1.0 / (1.0 + exp(-x)) + + This is the full test that exercises: + 1. Negation: -x (via linalg.negf or multiply by -1) + 2. Exponential: exp(-x) via nisa.activation(op=exp) + 3. Addition with scalar: 1.0 + exp(-x) via nisa.tensor_scalar_arith(op=add) + 4. Division: 1.0 / (result) converted to reciprocal by prepare-arithmetic + 5. Multiply by 1.0 (or direct reciprocal output) + """ + M, N = 128, 256 + tile_size = [128, 128] + + @trace(input_specs=[((M, N), "f32")]) + def sigmoid_kernel(x): + # Sigmoid: 1 / (1 + exp(-x)) + neg_x = -x + knob.knob(neg_x, mem_space="Sbuf", tile_size=tile_size) + + exp_neg_x = np.exp(neg_x) + knob.knob(exp_neg_x, mem_space="Sbuf", tile_size=tile_size) + + one_plus_exp = 1.0 + exp_neg_x + knob.knob(one_plus_exp, mem_space="Sbuf", tile_size=tile_size) + + # Division gets converted to reciprocal by prepare-arithmetic + result = 1.0 / one_plus_exp + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size) + + return result + + run_kernel_test( + sigmoid_kernel, + + check_ir_contains=["nisa.activation", "op=exp"], + modes=Mode.BIR_SIM | Mode.STRING_CHECK | Mode.HW, + ) + + +def test_scalar_minus_tensor(): + """ + Test scalar - tensor arithmetic (reverse operands). + + This verifies: + - When lhs is CONSTANT (scalar), reverse_operands=first is set + - nisa.tensor_scalar_arith correctly computes scalar - tensor + """ + M, N = 128, 256 + tile_size = [128, 128] + scalar_value = 5.0 + + @trace(input_specs=[((M, N), "f32")]) + def sub_scalar_kernel(x): + # scalar - tensor requires reverse_operands + result = scalar_value - x + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size) + return result + + run_kernel_test( + sub_scalar_kernel, + check_ir_contains=[ + "nisa.tensor_scalar_arith", "op0=subtract", "reverse_operands=first", + ], + modes=Mode.BIR_SIM | Mode.STRING_CHECK | Mode.HW, + ) + + +def test_division_to_reciprocal(): + """ + Test division conversion to multiply + reciprocal. + + x / 2.0 is converted by prepare-arithmetic to: + x * reciprocal(broadcast(2.0)) + """ + M, N = 128, 256 + tile_size = [128, 128] + divisor = 2.0 + + @trace(input_specs=[((M, N), "f32")]) + def div_kernel(x): + result = x / divisor + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size) + return result + + run_kernel_test( + div_kernel, + modes=Mode.BIR_SIM | Mode.HW, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/harness.py b/kernelgen/tests/harness.py new file mode 100644 index 0000000..efaeec5 --- /dev/null +++ b/kernelgen/tests/harness.py @@ -0,0 +1,1020 @@ +""" +Unified declarative test harness for NKIPyKernelGen. + +Provides: + - Mode enum for declaring test verification modes + - run_kernel_test() for programmatic test invocation + - @nkipy_kernelgen_test decorator for declarative test definitions + - Auto input generation from traced function's input_specs + - Per-mode default tolerances + - Parameter validation with clear error messages + +Usage: + from harness import nkipy_kernelgen_test, run_kernel_test, Mode + + # Decorator form (non-parametrized tests): + @nkipy_kernelgen_test( + input_specs=[((256, 256), "f32"), ((256, 256), "f32")], + stop_after="apply-and-strip-transforms", + check_patterns="CHECK: scf.for\\nCHECK: linalg.matmul", + modes=Mode.LLVM | Mode.FILECHECK, + ) + def test_matmul_tiling(a, b): + result = np.matmul(a, b) + knob.knob(result, tile_size=[128, 128, 128]) + return result + + # Function form (parametrized tests): + @pytest.mark.parametrize("shape,tile_size", [...]) + def test_add_shapes(shape, tile_size, request): + @trace(input_specs=[(shape, "f32"), (shape, "f32")]) + def kernel(a, b): + ... + run_kernel_test(kernel, stop_after="apply-and-strip-transforms", + modes=Mode.LLVM | Mode.FILECHECK, request=request) +""" + +import enum +import os +import tempfile +from typing import List, Optional, Tuple + +import pytest + +import numpy as np + +# Set by conftest.py pytest_configure when --dump-ir is passed. +_DUMP_IR_ENABLED = False + + +# ============================================================================ +# Mode Enum +# ============================================================================ + + +class Mode(enum.Flag): + """Test execution modes (combinable via |). + + LLVM and BIR_SIM/HW are mutually exclusive (different pipeline stages). + STRING_CHECK and FILECHECK can combine with any execution mode. + """ + + LLVM = enum.auto() # LLVM JIT execution, compare to NumPy (requires stop_after) + BIR_SIM = ( + enum.auto() + ) # BIR simulation via neuron-cc (full pipeline, stop_after=None) + HW = enum.auto() # Trainium hardware execution (full pipeline, stop_after=None) + STRING_CHECK = ( + enum.auto() + ) # check_ir_contains / check_ir_not_contains on compiled IR + FILECHECK = enum.auto() # check_patterns via external FileCheck tool on compiled IR + + +# ============================================================================ +# Constants +# ============================================================================ + +# Default tolerances per execution mode +DEFAULT_TOLERANCES = { + Mode.LLVM: {"rtol": 1e-5, "atol": 1e-8}, + Mode.BIR_SIM: {"rtol": 1e-4, "atol": 1e-4}, + Mode.HW: {"rtol": 1e-3, "atol": 1e-3}, +} + +# dtype string -> numpy dtype mapping +DTYPE_MAP = { + "f32": np.float32, + "f64": np.float64, + "f16": np.float16, + "bf16": np.dtype("bfloat16") if hasattr(np, "bfloat16") else "bfloat16", + "i32": np.int32, + "i64": np.int64, +} + +# Sentinel for "trace only, no passes" +TRACE = "trace" + + +def _default_target() -> str: + """Auto-detect hardware target, defaulting to trn2 on trn1 machines. + + trn1 has restrictive limitations (e.g. scalar bias in activations) + and is no longer the preferred target. On trn1 instances we warn and + fall back to trn2, which BIR simulation supports via cross-target + compilation. trn2+ targets are returned as-is. + """ + from nki.compiler.target import resolve_target + + detected = resolve_target() + if detected == "trn1": + import warnings + warnings.warn( + f"Detected {detected} instance — defaulting to trn2 target " + "for compilation (trn1 has known NISA limitations)", + stacklevel=2, + ) + return "trn2" + return detected + + +# ============================================================================ +# Parameter Validation +# ============================================================================ + + +def _validate_params( + modes: Mode, + stop_after, + check_patterns: Optional[str], + check_ir_contains: Optional[List[str]], + check_ir_not_contains: Optional[List[str]], +) -> None: + """Validate parameter combinations and raise ValueError for misconfigurations.""" + + has_llvm = Mode.LLVM in modes + has_bir_sim = Mode.BIR_SIM in modes + has_hw = Mode.HW in modes + has_filecheck = Mode.FILECHECK in modes + has_string_check = Mode.STRING_CHECK in modes + + # LLVM requires stop_after (LLVM JIT cannot execute NISA IR) + if has_llvm and stop_after is None: + raise ValueError( + "Mode.LLVM requires stop_after (LLVM JIT cannot execute NISA IR " + "from full pipeline). Use stop_after='trace' for raw traced MLIR " + "or stop_after='' for intermediate IR." + ) + + # BIR_SIM / HW require full pipeline (no stop_after) + if has_bir_sim and stop_after is not None: + raise ValueError( + "Mode.BIR_SIM requires full pipeline (stop_after must not be specified)." + ) + if has_hw and stop_after is not None: + raise ValueError( + "Mode.HW requires full pipeline (stop_after must not be specified)." + ) + + # LLVM is mutually exclusive with BIR_SIM / HW + if has_llvm and (has_bir_sim or has_hw): + raise ValueError( + "Mode.LLVM cannot be combined with Mode.BIR_SIM/HW " + "(different pipeline stages). Use separate tests." + ) + + # FILECHECK requires check_patterns + if has_filecheck and not check_patterns: + raise ValueError("Mode.FILECHECK requires check_patterns string.") + + # STRING_CHECK requires check_ir_contains or check_ir_not_contains + if has_string_check and not check_ir_contains and not check_ir_not_contains: + raise ValueError( + "Mode.STRING_CHECK requires check_ir_contains or check_ir_not_contains." + ) + + +# ============================================================================ +# Input Generation +# ============================================================================ + + +def generate_inputs(input_specs, seed: int = 42) -> List[np.ndarray]: + """Generate random test inputs from input_specs. + + Args: + input_specs: List of (shape, dtype_str) tuples, e.g. [((256, 256), "f32")] + seed: Random seed for reproducibility + + Returns: + List of numpy arrays matching the specs + """ + np.random.seed(seed) + inputs = [] + for shape, dtype_str in input_specs: + if dtype_str not in DTYPE_MAP: + raise ValueError( + f"Unsupported dtype: {dtype_str}. Supported: {list(DTYPE_MAP.keys())}" + ) + np_dtype = DTYPE_MAP[dtype_str] + if np_dtype in (np.float32, np.float64, np.float16): + arr = np.random.rand(*shape).astype(np_dtype) + else: + arr = np.random.randint(0, 100, size=shape).astype(np_dtype) + inputs.append(arr) + return inputs + + +def compute_reference(traced_func, inputs: List[np.ndarray]) -> List[np.ndarray]: + """Compute reference output by calling the original (unwrapped) function. + + Args: + traced_func: A traced function decorated with @trace + inputs: List of numpy input arrays + + Returns: + List of NumPy reference outputs (single-element list for single-output kernels) + """ + original_func = traced_func.__wrapped__ + result = original_func(*inputs) + if isinstance(result, tuple): + return list(result) + return [result] + + +# ============================================================================ +# Hardware Detection +# ============================================================================ + + +def is_hw_available() -> bool: + """Check if Trainium hardware is available.""" + try: + from nki.runtime import SpikeModel # noqa: F401 + + # Check if a neuron device is actually present + import subprocess + + result = subprocess.run( + ["neuron-ls"], capture_output=True, text=True, timeout=5 + ) + return result.returncode == 0 and "NEURON" in result.stdout + except Exception: + return False + + +def _can_run_hw(target: Optional[str]) -> bool: + """Check if Mode.HW can run for the given target. + + Mode.HW requires trn2+. Even if a Neuron device is detected, trn1 + cannot execute HW mode. The detected device must also match the + requested target (e.g. a trn2 test cannot run on trn1 hardware). + """ + if not is_hw_available(): + return False + from nki.compiler.target import resolve_target + detected = resolve_target() + # trn1 does not support Mode.HW + if detected == "trn1": + return False + # If the test requests a specific target, the device must match + if target is not None and detected != target: + return False + return True + + +# ============================================================================ +# Compilation Pipeline +# ============================================================================ + + +def _compile_pipeline( + traced_func, + stop_after, + target: Optional[str] = None, + dump_dir=None, + print_generic: bool = False, +) -> str: + """Compile a traced function through the pipeline. + + Args: + traced_func: A traced function with .to_mlir() method + stop_after: "trace" for trace-only, pass name for intermediate, None for full pipeline + target: Hardware target (trn1, trn2, trn3) + dump_dir: Optional directory to save intermediate MLIR + + Returns: + Compiled MLIR/NISA IR as string + """ + if stop_after == TRACE: + # Trace only -- no pass pipeline + from pass_utils import trace_to_mlir_with_preprocessing + + mlir_str = trace_to_mlir_with_preprocessing(traced_func) + if dump_dir: + os.makedirs(dump_dir, exist_ok=True) + with open(os.path.join(dump_dir, "00_traced.mlir"), "w") as f: + f.write(mlir_str) + return mlir_str + + elif stop_after is not None: + # Intermediate pass -- use compile_knob_pipeline with stop_after + from pass_utils import compile_knob_pipeline + + return compile_knob_pipeline( + traced_func, stop_after=stop_after, dump_dir=dump_dir + ) + + else: + # Full pipeline to NISA + if target is None: + target = _default_target() + from pass_utils import trace_to_mlir_with_preprocessing + from nkipy_kernelgen.transforms.nkipy_opt import apply_complete_knob_pipeline + + mlir_str = trace_to_mlir_with_preprocessing(traced_func) + return apply_complete_knob_pipeline( + mlir_str, target=target, dump_dir=dump_dir, print_generic=print_generic + ) + + +# ============================================================================ +# Per-Mode Verification +# ============================================================================ + + +def _run_string_check(compiled_ir: str, check_ir_contains, check_ir_not_contains): + """Run simple string containment checks on the compiled IR.""" + if check_ir_contains: + for pattern in check_ir_contains: + assert pattern in compiled_ir, ( + f"Expected pattern not found in IR: '{pattern}'\n" + f"IR (first 2000 chars):\n{compiled_ir[:2000]}" + ) + if check_ir_not_contains: + for pattern in check_ir_not_contains: + assert pattern not in compiled_ir, ( + f"Unexpected pattern found in IR: '{pattern}'\n" + f"IR (first 2000 chars):\n{compiled_ir[:2000]}" + ) + + +def _run_filecheck(compiled_ir: str, check_patterns: str): + """Run FileCheck verification on the compiled IR.""" + from pass_utils import run_filecheck + + run_filecheck(compiled_ir, check_patterns) + + +def _run_llvm_verification(compiled_ir: str, traced_func, rtol: float, atol: float): + """Run LLVM JIT verification against NumPy reference.""" + from pass_utils import verify_tiled_mlir_with_numpy + + verify_tiled_mlir_with_numpy(compiled_ir, traced_func, rtol=rtol, atol=atol) + + +def simulate_mlir( + mlir_str: str, + func_name: str, + test_inputs: List[np.ndarray], + expected_output: np.ndarray, + rtol: float = 1e-4, + atol: float = 1e-4, + verbose: bool = False, + keep_artifacts: bool = False, + artifacts_dir: Optional[str] = None, +) -> Tuple[bool, float, Optional[str]]: + """ + Run simulation on an MLIR string. + + Parses the MLIR, compiles to NEFF with simulation enabled, and compares + the simulation output against expected_output. + + Args: + mlir_str: MLIR module as a string + func_name: Name of the function to simulate + test_inputs: List of input tensors (in function argument order) + expected_output: Expected output for validation + rtol: Relative tolerance for comparison + atol: Absolute tolerance for comparison + verbose: Print detailed output + keep_artifacts: Keep debug artifacts on success + artifacts_dir: Custom path for artifacts (created if doesn't exist). + If None, uses a temp directory. + + Returns: + Tuple of (success, max_diff, artifacts_dir or None) + """ + import tempfile + from nki.compiler.ncc_driver import CompileOptions, compile_mlir_to_neff + from nki.compiler._internal import ir, register_all_dialects + + # Setup artifacts directory + if artifacts_dir: + os.makedirs(artifacts_dir, exist_ok=True) + debug_dir = artifacts_dir + else: + debug_dir = tempfile.mkdtemp(prefix="e2e_sim_") + + if verbose: + print("MLIR for simulation:") + print(mlir_str) + + opts = CompileOptions( + target="trn2", + verbose=False, + output_path=os.path.join(debug_dir, "kernel.neff"), + neuronx_cc_args=("--lnc=1",), + artifacts_dir=debug_dir, + enable_simulation=True, + kernel_json_filename="kernel.json", + ) + + # Parse and simulate + ctx = ir.Context() + register_all_dialects(ctx) + + try: + with ctx: + mlir = ir.Module.parse(mlir_str, ctx) + + # Extract nki.output_names from the target function so that + # output_arg_names match what the backend pipeline uses for BIRSim. + output_arg_names = None + for op in mlir.body.operations: + if hasattr(op, "name") and hasattr(op, "attributes"): + op_name = None + try: + op_name = op.attributes["sym_name"] + except (KeyError, IndexError): + pass + if op_name and str(op_name).strip('"') == func_name: + try: + names_attr = op.attributes["nki.output_names"] + output_arg_names = [ + str(names_attr[i]).strip('"') + for i in range(len(names_attr)) + ] + except (KeyError, IndexError): + pass + if output_arg_names is None: + output_arg_names = ["out_tensor"] + + input_names = [f"in_tensor_{i}" for i in range(len(test_inputs))] + + # Extract output names from nki.output_names attribute set by + # prepare-for-nki pass; fall back to "output_0" if not present. + output_names = ["output_0"] + for op in mlir.body.operations: + if ( + "function_type" in op.attributes + and "nki.output_names" in op.attributes + ): + names_attr = ir.ArrayAttr(op.attributes["nki.output_names"]) + output_names = [ + str(ir.StringAttr(names_attr[i])).strip('"') + for i in range(len(names_attr)) + ] + break + + # Normalize expected_output to list + if isinstance(expected_output, (list, tuple)): + expected_outputs = list(expected_output) + else: + expected_outputs = [expected_output] + + output_placeholders = [np.zeros_like(eo) for eo in expected_outputs] + all_arrays = list(test_inputs) + output_placeholders + argument_names = input_names + output_names + output_arg_names = output_names + + compile_result = compile_mlir_to_neff( + mlir, + func_name, + all_arrays, + argument_names, + output_arg_names, + opts, + ) + + if compile_result.birsim_outputs is None: + err = getattr(compile_result, 'neuronx_cc_error', None) + raise RuntimeError( + f"BIRSim produced no outputs. " + f"neuronx-cc error: {err}, artifacts: {debug_dir}" + ) + results = [ + compile_result.birsim_outputs[i] for i in range(len(expected_outputs)) + ] + except Exception as e: + import traceback + + print(f"Parsing/simulation failed: {e}") + print(f"Artifacts: {debug_dir}") + if verbose: + traceback.print_exc() + return False, float("inf"), debug_dir + + # Validate all outputs + max_diff = 0.0 + matches = True + for i, (result, expected) in enumerate(zip(results, expected_outputs)): + diff = np.max(np.abs(result - expected)) + max_diff = max(max_diff, diff) + if not np.allclose(result, expected, rtol=rtol, atol=atol): + matches = False + + if verbose: + print(f"Max difference: {max_diff:.2e}") + print(f"Match: {matches}") + + # Cleanup if success and not keeping artifacts (only for temp dirs) + artifacts_path = debug_dir + if matches and not keep_artifacts and not artifacts_dir: + import shutil + + try: + shutil.rmtree(debug_dir) + artifacts_path = None + except: + pass + + return matches, max_diff, artifacts_path + + +def _compile_nisa_to_neff( + compiled_ir: str, + traced_func, + inputs: List[np.ndarray], + reference_output: np.ndarray, + target: Optional[str], + dump_dir: Optional[str], + enable_simulation: bool = True, +): + """Compile NISA IR to NEFF (and optionally run BIR simulation). + + This is the shared compilation step for both BIR_SIM and HW modes. + When enable_simulation=True, the result includes birsim_outputs. + The NEFF is always produced and can be used for HW execution. + + Returns: + Tuple of (compile_result, input_names, debug_dir) + """ + import tempfile + from nki.compiler.ncc_driver import CompileOptions, compile_mlir_to_neff + from nki.compiler._internal import ir, register_all_dialects + + if target is None: + target = _default_target() + + func_name = traced_func.__wrapped__.__name__ + debug_dir = dump_dir or tempfile.mkdtemp(prefix="e2e_compile_") + os.makedirs(debug_dir, exist_ok=True) + + opts = CompileOptions( + target=target, + verbose=False, + output_path=os.path.join(debug_dir, "kernel.neff"), + neuronx_cc_args=("--lnc=1",), + artifacts_dir=debug_dir, + enable_simulation=enable_simulation, + kernel_json_filename="kernel.json", + ) + + ctx = ir.Context() + register_all_dialects(ctx) + with ctx: + mlir = ir.Module.parse(compiled_ir, ctx) + + input_names = [f"in_tensor_{i}" for i in range(len(inputs))] + + output_names = ["output_0"] + for op in mlir.body.operations: + if "function_type" in op.attributes and "nki.output_names" in op.attributes: + names_attr = ir.ArrayAttr(op.attributes["nki.output_names"]) + output_names = [ + str(ir.StringAttr(names_attr[i])).strip('"') + for i in range(len(names_attr)) + ] + break + + output_placeholders = [np.zeros_like(ro) for ro in reference_output] + all_arrays = list(inputs) + output_placeholders + argument_names = input_names + output_names + output_arg_names = output_names + + compile_result = compile_mlir_to_neff( + mlir, + func_name, + all_arrays, + argument_names, + output_arg_names, + opts, + ) + + return compile_result, input_names, debug_dir + + +def _run_bir_sim( + compiled_ir: str, + traced_func, + inputs: List[np.ndarray], + reference_output: List[np.ndarray], + rtol: float, + atol: float, + dump_dir: Optional[str], + compile_result=None, + target: Optional[str] = None, +): + """Run BIR simulation and compare to reference. + + If compile_result is provided, uses the pre-compiled birsim_outputs. + Otherwise compiles from scratch. + """ + if compile_result is None: + compile_result, _, _ = _compile_nisa_to_neff( + compiled_ir, + traced_func, + inputs, + reference_output, + target=target, + dump_dir=dump_dir, + enable_simulation=True, + ) + + if compile_result.birsim_outputs is None: + err = getattr(compile_result, 'neuronx_cc_error', None) or "unknown" + artifacts = getattr(compile_result, 'artifacts_dir', None) or "N/A" + raise AssertionError( + f"BIRSim produced no outputs (birsim_outputs is None).\n" + f"neuronx-cc error: {err}\n" + f"Artifacts dir: {artifacts}" + ) + + for i, expected in enumerate(reference_output): + result = compile_result.birsim_outputs[i] + max_diff = np.max(np.abs(result - expected)) + success = np.allclose(result, expected, rtol=rtol, atol=atol) + + print(f"\nOutput {i} max difference: {max_diff:.2e}") + print(f"Output {i} match: {success}") + + assert success, ( + f"BIR simulation failed on output {i} with max_diff={max_diff:.2e} " + f"(rtol={rtol}, atol={atol})" + ) + + +def _run_hw_execution( + compiled_ir: str, + traced_func, + inputs: List[np.ndarray], + reference_output: List[np.ndarray], + rtol: float, + atol: float, + target: Optional[str], + dump_dir: Optional[str], + compile_result=None, + input_names=None, + compile_debug_dir=None, +): + """Run hardware execution and compare to reference. + + If compile_result is provided, reuses the NEFF from a previous compilation. + Otherwise compiles from scratch. + """ + import pytest + + if not is_hw_available(): + pytest.skip("No Trainium device detected -- skipping Mode.HW") + + from nki.runtime import SpikeModel, SpikeTensor + + if compile_result is None: + compile_result, input_names, compile_debug_dir = _compile_nisa_to_neff( + compiled_ir, + traced_func, + inputs, + reference_output, + target=target, + dump_dir=dump_dir, + enable_simulation=False, + ) + if input_names is None: + input_names = [f"in_tensor_{i}" for i in range(len(inputs))] + + neff_path = compile_result.neff_path + model = SpikeModel.load_from_neff(neff_path) + + neff_input_names = list(model.input_tensors_info.keys()) + compile_input_map = dict(zip(input_names, inputs)) + spike_inputs = { + name: SpikeTensor.from_numpy(compile_input_map[name], name=name) + for name in neff_input_names + } + + spike_outputs = model(inputs=spike_inputs, outputs=None) + artifacts = compile_debug_dir or dump_dir or "unknown" + + # Look up outputs by name rather than relying on dict iteration order, + # which may not match the expected order. + neff_output_names = sorted(model.output_tensors_info.keys()) + + for i, expected in enumerate(reference_output): + name = neff_output_names[i] if i < len(neff_output_names) else None + result_tensor = spike_outputs.get(name) if name else list(spike_outputs.values())[i] + raw = result_tensor.numpy() + # SpikeTensor.numpy() may return void dtype (V4). Interpret using + # the expected element size to determine the correct float type. + if raw.dtype.kind == 'V': + elem_size = raw.dtype.itemsize + float_dtype = {2: np.float16, 4: np.float32, 8: np.float64}.get(elem_size, np.float32) + result = raw.view(float_dtype) + else: + result = raw + # Cast expected to match the HW output dtype (e.g., bool -> f32 for + # comparison ops, since NISA always produces float results). + if result.dtype != expected.dtype: + expected = expected.astype(result.dtype) + + max_diff = np.max(np.abs(result - expected)) + success = np.allclose(result, expected, rtol=rtol, atol=atol) + assert success, ( + f"HW execution failed on output {i} with max_diff={max_diff:.2e} " + f"(rtol={rtol}, atol={atol}). Artifacts: {artifacts}" + ) + + +# ============================================================================ +# Dump-IR Helpers +# ============================================================================ + + +def _print_dump_dir_listing(dump_dir: str) -> None: + """Print the list of IR files saved in dump_dir.""" + if not dump_dir or not os.path.isdir(dump_dir): + return + files = sorted(f for f in os.listdir(dump_dir) if f.endswith((".mlir", ".txt"))) + if not files: + print(f"[dump-ir] No IR files in {dump_dir}") + return + print(f"[dump-ir] {len(files)} IR files saved to: {dump_dir}") + for f in files: + size = os.path.getsize(os.path.join(dump_dir, f)) + print(f" {f} ({size:,} bytes)") + + +# ============================================================================ +# Main Entry Point +# ============================================================================ + + +def run_kernel_test( + traced_func, + *, + stop_after=None, + target: Optional[str] = None, + check_patterns: Optional[str] = None, + check_ir_contains: Optional[List[str]] = None, + check_ir_not_contains: Optional[List[str]] = None, + rtol: Optional[float] = None, + atol: Optional[float] = None, + seed: int = 42, + inputs: Optional[List[np.ndarray]] = None, + reference_output: Optional[np.ndarray] = None, + modes: Mode = Mode.LLVM, + dump_dir: Optional[str] = None, + request=None, +): + """Unified test runner for NKIPyKernelGen kernels. + + Compiles the traced function through the pipeline and runs all requested + verification modes. + + Args: + traced_func: A @trace-decorated function + stop_after: Pipeline stop point: + - "trace": trace to MLIR only, no passes + - "": stop after a specific pass + - None: run all passes (full pipeline to NISA) + target: Hardware target (trn1, trn2, trn3) + check_patterns: FileCheck patterns string (required for Mode.FILECHECK) + check_ir_contains: List of strings that must appear in compiled IR + check_ir_not_contains: List of strings that must NOT appear in compiled IR + rtol: Relative tolerance override (None = use per-mode defaults) + atol: Absolute tolerance override (None = use per-mode defaults) + seed: Random seed for input generation + inputs: Override auto-generated inputs + reference_output: Override auto-computed reference output + modes: Verification modes to run (combined with |) + dump_dir: Directory for compilation artifacts. Auto-created from + request.node.name if request is provided and dump_dir is None. + request: pytest request fixture (for auto-naming artifacts) + + Raises: + ValueError: If parameter validation fails + AssertionError: If any verification mode fails + """ + # 1. Validate parameters + _validate_params( + modes, stop_after, check_patterns, check_ir_contains, check_ir_not_contains + ) + + # 1b. Strip Mode.HW if hardware cannot run it (trn1, no device, or + # device/target mismatch). Other modes still execute normally. + if Mode.HW in modes and not _can_run_hw(target): + modes = modes & ~Mode.HW + if not modes: + pytest.skip("Test requires Mode.HW but no compatible device detected") + + # 2. Resolve artifact directory + # --dump-ir: always create a dump directory so intermediate IR is saved. + # Without --dump-ir: only create when request is provided (backward compat). + if dump_dir is None and request is not None: + dump_dir = os.path.join( + os.path.dirname(os.path.abspath(request.fspath)), + "outputs", + request.node.name, + ) + elif dump_dir is None and _DUMP_IR_ENABLED: + # --dump-ir without request fixture: use a temp directory + dump_dir = tempfile.mkdtemp(prefix="nkipy_dump_ir_") + + if dump_dir and _DUMP_IR_ENABLED: + print(f"\n[dump-ir] IR will be saved to: {dump_dir}") + + # 3. Generate inputs if not provided + if inputs is None: + inputs = generate_inputs(traced_func.input_specs, seed=seed) + + # 4. Compute reference output if not provided, and normalize to list + if reference_output is None: + reference_output = compute_reference(traced_func, inputs) + elif not isinstance(reference_output, list): + if isinstance(reference_output, tuple): + reference_output = list(reference_output) + else: + reference_output = [reference_output] + + # 5. Compile through pipeline + # On failure: if no dump_dir was set, re-run pass-by-pass into a temp + # directory so the user gets intermediate IR for debugging. + try: + compiled_ir = _compile_pipeline( + traced_func, stop_after, target=target, dump_dir=dump_dir + ) + except Exception as exc: + if dump_dir: + # dump_dir was already set — IR files are already there + _print_dump_dir_listing(dump_dir) + raise + # No dump_dir: re-run pass-by-pass to capture intermediate IR + fallback_dir = tempfile.mkdtemp(prefix="nkipy_fail_dump_") + print(f"\n[dump-ir] Compilation failed — dumping intermediate IR to: {fallback_dir}") + try: + _compile_pipeline( + traced_func, stop_after, target=target, dump_dir=fallback_dir + ) + except Exception: + pass # expected to fail again; we just want the IR files + _print_dump_dir_listing(fallback_dir) + raise exc + + # 6. Run each verification mode + if Mode.STRING_CHECK in modes: + _run_string_check(compiled_ir, check_ir_contains, check_ir_not_contains) + + if Mode.FILECHECK in modes: + _run_filecheck(compiled_ir, check_patterns) + + if Mode.LLVM in modes: + tol = _resolve_tolerances(Mode.LLVM, rtol, atol) + _run_llvm_verification(compiled_ir, traced_func, **tol) + + # BIR_SIM and HW both need NISA→NEFF compilation. Compile once if both + # are requested, then share the result. + need_bir = Mode.BIR_SIM in modes + need_hw = Mode.HW in modes + + if need_bir or need_hw: + try: + compile_result, input_names, compile_dir = _compile_nisa_to_neff( + compiled_ir, + traced_func, + inputs, + reference_output, + target=target, + dump_dir=dump_dir, + enable_simulation=need_bir, + ) + except Exception: + # Non-round-trippable custom assembly (e.g. view(...) syntax) + # can't be parsed. Recompile with generic MLIR form. + compiled_ir = _compile_pipeline( + traced_func, + stop_after, + target=target, + dump_dir=dump_dir, + print_generic=True, + ) + compile_result, input_names, compile_dir = _compile_nisa_to_neff( + compiled_ir, + traced_func, + inputs, + reference_output, + target=target, + dump_dir=dump_dir, + enable_simulation=need_bir, + ) + + if need_bir: + tol = _resolve_tolerances(Mode.BIR_SIM, rtol, atol) + _run_bir_sim( + compiled_ir, + traced_func, + inputs, + reference_output, + dump_dir=dump_dir, + compile_result=compile_result, + target=target, + **tol, + ) + + if need_hw: + tol = _resolve_tolerances(Mode.HW, rtol, atol) + _run_hw_execution( + compiled_ir, + traced_func, + inputs, + reference_output, + target=target, + dump_dir=dump_dir, + compile_result=compile_result, + input_names=input_names, + compile_debug_dir=compile_dir, + **tol, + ) + + # Print dump directory listing on success when --dump-ir is active + if dump_dir and _DUMP_IR_ENABLED: + _print_dump_dir_listing(dump_dir) + + +def _resolve_tolerances(mode: Mode, rtol: Optional[float], atol: Optional[float]): + """Get tolerances: use overrides if provided, else per-mode defaults.""" + defaults = DEFAULT_TOLERANCES.get(mode, {"rtol": 1e-5, "atol": 1e-6}) + return { + "rtol": rtol if rtol is not None else defaults["rtol"], + "atol": atol if atol is not None else defaults["atol"], + } + + +# ============================================================================ +# Decorator +# ============================================================================ + + +def nkipy_kernelgen_test( + input_specs, + *, + modes: Mode = Mode.LLVM, + stop_after=None, + **kwargs, +): + """Decorator that turns a kernel function into a pytest test. + + The decorated function is the kernel body. It will be traced with @trace + using the given input_specs, then compiled and verified according to the + specified modes. + + Args: + input_specs: Input specifications for @trace, e.g. [((256, 256), "f32")] + modes: Verification modes (combined with |) + stop_after: Pipeline stop point ("trace", "", or None) + **kwargs: Additional arguments passed to run_kernel_test() + + Returns: + A pytest-compatible test function + + Example: + @nkipy_kernelgen_test( + input_specs=[((256, 256), "f32"), ((256, 256), "f32")], + stop_after="apply-and-strip-transforms", + check_patterns="CHECK: scf.for", + modes=Mode.LLVM | Mode.FILECHECK, + ) + def test_matmul_tiling(a, b): + result = np.matmul(a, b) + knob.knob(result, tile_size=[128, 128, 128]) + return result + """ + # Validate early (at decoration time) so misconfigured tests fail on import + check_patterns = kwargs.get("check_patterns") + check_ir_contains = kwargs.get("check_ir_contains") + check_ir_not_contains = kwargs.get("check_ir_not_contains") + _validate_params( + modes, stop_after, check_patterns, check_ir_contains, check_ir_not_contains + ) + + def decorator(func): + def test_wrapper(): + from nkipy_kernelgen import trace as nkipy_trace + + traced = nkipy_trace(input_specs=input_specs)(func) + run_kernel_test( + traced, + modes=modes, + stop_after=stop_after, + **kwargs, + ) + + # Preserve name/qualname for pytest discovery (do NOT use functools.wraps + # here -- it copies __wrapped__/__signature__ from func, which makes pytest + # think test_wrapper has func's kernel parameters and try to resolve them + # as fixtures) + test_wrapper.__name__ = func.__name__ + test_wrapper.__qualname__ = func.__qualname__ + test_wrapper.__doc__ = func.__doc__ + test_wrapper.__module__ = func.__module__ + return test_wrapper + + return decorator diff --git a/kernelgen/tests/passes/__init__.py b/kernelgen/tests/passes/__init__.py new file mode 100644 index 0000000..6c70bef --- /dev/null +++ b/kernelgen/tests/passes/__init__.py @@ -0,0 +1 @@ +"""Pass-specific tests for nkipy transforms.""" diff --git a/kernelgen/tests/passes/annotate_memory_space/__init__.py b/kernelgen/tests/passes/annotate_memory_space/__init__.py new file mode 100644 index 0000000..e81e03e --- /dev/null +++ b/kernelgen/tests/passes/annotate_memory_space/__init__.py @@ -0,0 +1 @@ +"""Tests for the annotate-memory-space pass.""" diff --git a/kernelgen/tests/passes/annotate_memory_space/test_basic.py b/kernelgen/tests/passes/annotate_memory_space/test_basic.py new file mode 100644 index 0000000..21a782a --- /dev/null +++ b/kernelgen/tests/passes/annotate_memory_space/test_basic.py @@ -0,0 +1,204 @@ +""" +Tests for annotate-memory-space pass. + +The annotate-memory-space pass: +1. Annotates function inputs/outputs with SharedHbm (#nisa.mem) +2. Applies memory space attributes from nkipy.annotate to internal memrefs +3. Propagates memory spaces through subview, collapse_shape, expand_shape, reshape ops +4. Removes nkipy.annotate ops after processing + +Run with: python -m pytest tests/passes/annotate_memory_space/test_basic.py -v +Or directly: python tests/passes/annotate_memory_space/test_basic.py +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + +# ============================================================================ +# Test Cases +# ============================================================================ + +def test_matmul_sbuf_add_hbm(): + """ + Test matmul-add chain with: + - matmul output -> SBUF (intermediate) + - add output -> SharedHbm (returned result) + + This tests that the pass correctly: + 1. Annotates function arguments with #nisa.mem + 2. Applies 3 : i32 to matmul intermediate buffer + 3. Propagates memory space to subviews + 4. Removes all nkipy.annotate ops + """ + M, N, K = 256, 256, 256 + matmul_tile = [128, 128] # TILE_M, TILE_N + matmul_reduction_tile = [128] # TILE_K + add_tile = [128, 128] # TILE_M, TILE_N + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32"), ((M, N), "f32")]) + def matmul_add_kernel(a, b, bias): + # Matmul outputs to SBUF for reuse in the add + c = np.matmul(a, b) + knob.knob(c, mem_space="Sbuf", tile_size=matmul_tile, reduction_tile=matmul_reduction_tile) + + # Add outputs to SharedHbm (returned from kernel) + result = c + bias + knob.knob(result, mem_space="SharedHbm", tile_size=add_tile) + + return result + + # FileCheck patterns to verify: + # 1. All 3 function arguments get 4 : i32 + # 2. Matmul intermediate gets 3 : i32 from nkipy.annotate + # 3. Transpose temp stays without memory space annotation + # 4. LHS/RHS promote buffers already have 3 : i32 + # 5. PSUM accumulator already has 2 : i32 + # 6. Add's inputs/output promoted to SBUF (by elementwise SBUF promotion) + # 7. Add final output gets 4 : i32 from nkipy.annotate + # 8. All nkipy.annotate ops are removed + # 9. Memory space propagates to subviews + # + # With SBUF promotion for elementwise ops, the add tiling loop contains: + # - 3 SBUF allocs for add's promoted inputs and output + # - Copy from matmul SBUF output to promoted input 1 (SBUF->SBUF) + # - Copy from bias SharedHbm to promoted input 2 (SharedHbm->SBUF) + # - Copy from promoted output to SharedHbm result (SBUF->SharedHbm) + check_patterns = ''' +CHECK: func.func @matmul_add_kernel +CHECK-SAME: memref<256x256xf32, strided<[?, ?], offset: ?>, 4 : i32> +CHECK-SAME: memref<256x256xf32, strided<[?, ?], offset: ?>, 4 : i32> +CHECK-SAME: memref<256x256xf32, strided<[?, ?], offset: ?>, 4 : i32> +CHECK-SAME: -> memref<256x256xf32, 4 : i32> +CHECK: memref.alloc(){{.*}}: memref<256x256xf32, 3 : i32> +CHECK: memref.alloc(){{.*}}: memref<256x256xf32, 3 : i32> +CHECK: linalg.transpose{{.*}}outs({{.*}}memref<256x256xf32, 3 : i32>) +CHECK: memref.alloc(){{.*}}: memref<256x256xf32, 3 : i32> +CHECK: memref.copy{{.*}}3 : i32>{{.*}}to{{.*}}3 : i32> +CHECK: memref.alloc(){{.*}}: memref<256x256xf32, 3 : i32> +CHECK: memref.copy{{.*}}to memref<256x256xf32, 3 : i32> +CHECK: scf.for +CHECK: memref.subview{{.*}}3 : i32> +CHECK: memref.subview{{.*}}3 : i32> +CHECK: scf.for +CHECK: memref.subview{{.*}}3 : i32> +CHECK: memref.alloc(){{.*}}: memref<128x128xf32, 2 : i32> +CHECK: scf.for +CHECK: memref.subview{{.*}}3 : i32> +CHECK: memref.subview{{.*}}3 : i32> +CHECK: linalg.matmul +CHECK: memref.subview{{.*}}3 : i32> +CHECK: memref.copy{{.*}}2 : i32>{{.*}}to{{.*}}3 : i32> +CHECK: memref.subview{{.*}}3 : i32> +CHECK: memref.copy{{.*}}3 : i32>{{.*}}to{{.*}}3 : i32> +CHECK: memref.alloc(){{.*}}: memref<256x256xf32, 4 : i32> +CHECK: scf.for +CHECK: scf.for +CHECK: memref.subview{{.*}}3 : i32> +CHECK: memref.alloc(){{.*}}: memref<128x128xf32, 3 : i32> +CHECK: memref.copy{{.*}}3 : i32>{{.*}}to{{.*}}3 : i32> +CHECK: memref.subview{{.*}}4 : i32> +CHECK: memref.alloc(){{.*}}: memref<128x128xf32, 3 : i32> +CHECK: memref.copy{{.*}}4 : i32>{{.*}}to{{.*}}3 : i32> +CHECK: memref.alloc(){{.*}}: memref<128x128xf32, 3 : i32> +CHECK: linalg.add{{.*}}3 : i32>{{.*}}3 : i32>{{.*}}3 : i32> +CHECK: memref.subview{{.*}}4 : i32> +CHECK: memref.copy{{.*}}3 : i32>{{.*}}to{{.*}}4 : i32> +CHECK-NOT: nkipy.annotate +CHECK: return{{.*}}memref<256x256xf32, 4 : i32> +''' + run_kernel_test( + matmul_add_kernel, + stop_after='annotate-memory-space', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +def test_rope_reshape_propagation(): + """ + Test RoPE kernel where expand_dims produces expand_shape view ops. + + expand_dims creates a tensor.expand_shape (a view) of the HBM input. + Since views can't cross memory spaces, the expand_shape stays on HBM. + Tiling automatically promotes inputs of SBUF compute ops to SBUF + by inserting HBM->SBUF copies. + + This tests that the pass correctly: + 1. expand_shape views stay in HBM (same memory space as source) + 2. Tiling creates SBUF allocs for compute inputs + 3. Intermediate elementwise results annotated as SBUF get sbuf + 4. Final concatenated output annotated as SharedHbm gets shared_hbm + 5. All nkipy.annotate ops are removed + """ + bs = 256 + n_heads = 4 + head_dim = 128 + half_h = head_dim // 2 + tile_size = [128, 1, 64] + + @trace(input_specs=[ + ((bs, n_heads, head_dim), "f32"), + ((bs, half_h), "f32"), + ((bs, half_h), "f32"), + ]) + def rope_kernel(x, freqs_cos, freqs_sin): + # No knobs on cos/sin: they are views (expand_dims) of HBM inputs. + # Tiling promotes them to SBUF automatically as inputs to SBUF compute. + cos = np.expand_dims(freqs_cos, axis=1) + sin = np.expand_dims(freqs_sin, axis=1) + + x0 = x[:, :, :half_h] + x1 = x[:, :, half_h:] + + out_0 = x0 * cos - x1 * sin + knob.knob(out_0, mem_space="Sbuf", tile_size=tile_size) + + out_1 = x0 * sin + x1 * cos + knob.knob(out_1, mem_space="Sbuf", tile_size=tile_size) + + result = np.concatenate([out_0, out_1], axis=-1) + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size) + return result + + # Key checks: + # 1. Function args get 4 : i32 + # 2. expand_dims stays on HBM side (expand_shape is a view, same mem space) + # 3. Tiling creates SBUF allocs for compute inputs (copies from HBM) + # 4. Intermediate SBUF allocs are annotated + # 5. Final output is 4 : i32 + # 6. No nkipy.annotate ops remain + check_patterns = ''' +CHECK: func.func @rope_kernel +CHECK-SAME: memref<256x4x128xf32, strided<[?, ?, ?], offset: ?>, 4 : i32> +CHECK-SAME: memref<256x64xf32, strided<[?, ?], offset: ?>, 4 : i32> +CHECK-SAME: memref<256x64xf32, strided<[?, ?], offset: ?>, 4 : i32> +CHECK-SAME: -> memref<256x4x128xf32, 4 : i32> +CHECK: memref.expand_shape{{.*}}4 : i32> +CHECK: memref.expand_shape{{.*}}4 : i32> +CHECK: memref.alloc(){{.*}}: memref<256x4x64xf32, 3 : i32> +CHECK: memref.alloc(){{.*}}: memref<256x4x64xf32, 3 : i32> +CHECK: memref.alloc(){{.*}}: memref<256x4x64xf32, 3 : i32> +CHECK: memref.alloc(){{.*}}: memref<256x4x64xf32, 3 : i32> +CHECK: memref.alloc(){{.*}}: memref<256x4x64xf32, 3 : i32> +CHECK: memref.alloc(){{.*}}: memref<256x4x64xf32, 3 : i32> +CHECK: memref.alloc(){{.*}}: memref<256x4x128xf32, 4 : i32> +CHECK-NOT: nkipy.annotate +CHECK: return{{.*}}memref<256x4x128xf32, 4 : i32> +''' + run_kernel_test( + rope_kernel, + stop_after='annotate-memory-space', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +# ============================================================================ +# Test Runner +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/passes/canonicalize_loop_step/__init__.py b/kernelgen/tests/passes/canonicalize_loop_step/__init__.py new file mode 100644 index 0000000..c66e5f9 --- /dev/null +++ b/kernelgen/tests/passes/canonicalize_loop_step/__init__.py @@ -0,0 +1 @@ +"""Tests for canonicalize-loop-step pass.""" diff --git a/kernelgen/tests/passes/canonicalize_loop_step/test_elementwise.py b/kernelgen/tests/passes/canonicalize_loop_step/test_elementwise.py new file mode 100644 index 0000000..fe25a10 --- /dev/null +++ b/kernelgen/tests/passes/canonicalize_loop_step/test_elementwise.py @@ -0,0 +1,154 @@ +""" +Tests for canonicalize-loop-step pass with elementwise operations. + +Run with: python -m pytest tests/passes/canonicalize_loop_step/test_elementwise.py -v +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + +# ============================================================================ +# Test Configurations +# ============================================================================ + +# Elementwise add test configurations: (shape, tile_size) +ADD_CONFIGS = [ + pytest.param((256, 256), [128, 128], id="256x256_tile128"), + pytest.param((512, 512), [128, 128], id="512x512_tile128"), + pytest.param((256, 256), [64, 64], id="256x256_tile64"), + pytest.param((512, 512), [64, 64], id="512x512_tile64_8iters"), + pytest.param((256, 256), [256, 256], id="256x256_single_iter"), +] + +# 3D elementwise configurations +ADD_3D_CONFIGS = [ + pytest.param((128, 256, 64), [64, 128, 32], id="3d_128x256x64"), + pytest.param((64, 128, 256), [32, 64, 128], id="3d_64x128x256"), +] + + +# ============================================================================ +# 2D Elementwise Tests +# ============================================================================ + +@pytest.mark.parametrize("shape,tile_size", ADD_CONFIGS) +def test_add_loop_canonicalization(shape, tile_size): + """ + Test loop step canonicalization on 2D elementwise add. + + After canonicalization: + - Loop bounds change from (0 to X step Y) to (0 to X/Y step 1) + - Original offset is recovered via arith.muli: offset = loop_idx * original_step + """ + M, N = shape + tile_m, tile_n = tile_size + + @trace(input_specs=[(shape, "f32"), (shape, "f32")]) + def add_kernel(a, b): + result = a + b + knob.knob(result, tile_size=tile_size) + return result + + # Strict checks: + # 1. Loop bounds: step is always 1 + # 2. arith.muli to recover original offset with tile size constant + # 3. Loop upper bound matches num_iters (but constant naming may vary with suffix) + check_patterns = f""" + CHECK: func.func + CHECK: scf.for {{{{.*}}}} step %c1{{{{.*}}}} + CHECK: arith.muli {{{{.*}}}}, %c{tile_m}{{{{.*}}}} : index + CHECK: scf.for {{{{.*}}}} step %c1{{{{.*}}}} + CHECK: arith.muli {{{{.*}}}}, %c{tile_n}{{{{.*}}}} : index + CHECK: linalg.add + """ + run_kernel_test( + add_kernel, + stop_after='canonicalize-loop-step', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +@pytest.mark.parametrize("shape,tile_size", ADD_CONFIGS) +def test_mul_loop_canonicalization(shape, tile_size): + """ + Test loop step canonicalization on 2D elementwise mul. + + After canonicalization: + - Loop bounds change from (0 to X step Y) to (0 to X/Y step 1) + - Original offset is recovered via arith.muli: offset = loop_idx * original_step + """ + M, N = shape + tile_m, tile_n = tile_size + + @trace(input_specs=[(shape, "f32"), (shape, "f32")]) + def mul_kernel(a, b): + result = a * b + knob.knob(result, tile_size=tile_size) + return result + + check_patterns = f""" + CHECK: func.func + CHECK: scf.for {{{{.*}}}} step %c1{{{{.*}}}} + CHECK: arith.muli {{{{.*}}}}, %c{tile_m}{{{{.*}}}} : index + CHECK: scf.for {{{{.*}}}} step %c1{{{{.*}}}} + CHECK: arith.muli {{{{.*}}}}, %c{tile_n}{{{{.*}}}} : index + CHECK: linalg.mul + """ + run_kernel_test( + mul_kernel, + stop_after='canonicalize-loop-step', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +# ============================================================================ +# 3D Elementwise Tests +# ============================================================================ + +@pytest.mark.parametrize("shape,tile_size", ADD_3D_CONFIGS) +def test_add_3d_loop_canonicalization(shape, tile_size): + """ + Test loop step canonicalization on 3D elementwise add. + + For 3D tensors, we expect 3 nested scf.for loops with step 1, + each with arith.muli to recover original offsets. + """ + D0, D1, D2 = shape + t0, t1, t2 = tile_size + + @trace(input_specs=[(shape, "f32"), (shape, "f32")]) + def add_kernel(a, b): + result = a + b + knob.knob(result, tile_size=tile_size) + return result + + # Strict checks: 3 nested loops with arith.muli for each offset recovery + check_patterns = f""" + CHECK: func.func + CHECK: scf.for {{{{.*}}}} step %c1{{{{.*}}}} + CHECK: arith.muli {{{{.*}}}}, %c{t0}{{{{.*}}}} : index + CHECK: scf.for {{{{.*}}}} step %c1{{{{.*}}}} + CHECK: arith.muli {{{{.*}}}}, %c{t1}{{{{.*}}}} : index + CHECK: scf.for {{{{.*}}}} step %c1{{{{.*}}}} + CHECK: arith.muli {{{{.*}}}}, %c{t2}{{{{.*}}}} : index + CHECK: linalg.add + """ + run_kernel_test( + add_kernel, + stop_after='canonicalize-loop-step', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +# ============================================================================ +# Test Runner +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/passes/canonicalize_loop_step/test_matmul.py b/kernelgen/tests/passes/canonicalize_loop_step/test_matmul.py new file mode 100644 index 0000000..72bd553 --- /dev/null +++ b/kernelgen/tests/passes/canonicalize_loop_step/test_matmul.py @@ -0,0 +1,91 @@ +""" +Tests for canonicalize-loop-step pass with matmul operations. + +The canonicalize-loop-step pass normalizes loop steps to 1: +- Before: for %i = 0 to 256 step 128 (2 iterations) +- After: for %i = 0 to 2 step 1 (2 iterations) + +The pass requires: +- Constant loop bounds and step +- Upper bound must be evenly divisible by step +- Lower bound must be 0 + +Run with: python -m pytest tests/passes/canonicalize_loop_step/test_matmul.py -v +Or directly: python tests/passes/canonicalize_loop_step/test_matmul.py +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + +# ============================================================================ +# Test Configurations +# ============================================================================ + +# Matmul test configurations: (M, N, K, tile_size, reduction_tile) +# Block size is dynamically chosen: 2 when dim >= tile*2, else 1 +MATMUL_CONFIGS = [ + pytest.param(256, 256, 256, [128, 128], [128], id="256x256x256_tile128"), + pytest.param(512, 512, 512, [128, 128], [128], id="512x512x512_tile128"), + pytest.param(256, 256, 256, [64, 64], [64], id="256x256x256_tile64"), + pytest.param(512, 512, 256, [128, 128], [64], id="512x512x256_tile_128_128_64"), +] + +# ============================================================================ +# Matmul Tests +# ============================================================================ + +@pytest.mark.parametrize("M,N,K,tile_size,reduction_tile", MATMUL_CONFIGS) +def test_matmul_loop_canonicalization(M, N, K, tile_size, reduction_tile, request): + """ + Test that loop steps are normalized to 1 after canonicalization. + + After canonicalize-loop-step: + - Loop bounds change from (0 to X step Y) to (0 to X/Y step 1) + - Original offsets are recovered via arith.muli: offset = loop_idx * original_step + + Matmul has 5 nested loops after tiling: + - 2 outer block loops (M, N dimensions) + - 1 K reduction loop + - 2 inner tile loops (M, N tile dimensions) + """ + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32")]) + def matmul_kernel(a, b): + result = np.matmul(a, b) + knob.knob(result, tile_size=tile_size, reduction_tile=reduction_tile) + return result + + # Strict checks for matmul: + # 1. All 5 loops should have step 1 + # 2. arith.muli should appear to recover offsets for each loop + # Note: The exact order and tile sizes vary, so we check for key patterns + check_patterns = f""" + CHECK: func.func + CHECK: scf.for {{{{.*}}}} step %c1{{{{.*}}}} + CHECK: arith.muli {{{{.*}}}} : index + CHECK: scf.for {{{{.*}}}} step %c1{{{{.*}}}} + CHECK: arith.muli {{{{.*}}}} : index + CHECK: scf.for {{{{.*}}}} step %c1{{{{.*}}}} + CHECK: arith.muli {{{{.*}}}} : index + CHECK: scf.for {{{{.*}}}} step %c1{{{{.*}}}} + CHECK: arith.muli {{{{.*}}}} : index + CHECK: scf.for {{{{.*}}}} step %c1{{{{.*}}}} + CHECK: arith.muli {{{{.*}}}} : index + CHECK: linalg.matmul + """ + run_kernel_test( + matmul_kernel, + stop_after='canonicalize-loop-step', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +# ============================================================================ +# Test Runner +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/passes/canonicalize_loop_step/test_multi_op.py b/kernelgen/tests/passes/canonicalize_loop_step/test_multi_op.py new file mode 100644 index 0000000..6580e16 --- /dev/null +++ b/kernelgen/tests/passes/canonicalize_loop_step/test_multi_op.py @@ -0,0 +1,177 @@ +""" +Tests for canonicalize-loop-step pass with multiple operations. + +Run with: python -m pytest tests/passes/canonicalize_loop_step/test_multi_op.py -v +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + +# ============================================================================ +# Multi-Op Tests: Matmul + Elementwise +# ============================================================================ + +def test_matmul_add_chain(): + """ + Test matmul followed by elementwise add (common pattern: C = A @ B + bias). + + Both operations should have canonicalized loop steps, and each loop + should have arith.muli to recover original offsets. + """ + M, N, K = 256, 256, 256 + matmul_tile = [128, 128] + matmul_reduction_tile = [128] + add_tile = [128, 128] + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32"), ((M, N), "f32")]) + def matmul_add_kernel(a, b, bias): + c = np.matmul(a, b) + knob.knob(c, tile_size=matmul_tile, reduction_tile=matmul_reduction_tile) + + result = c + bias + knob.knob(result, tile_size=add_tile) + + return result + + # Strict checks: + # - Matmul has 5 nested loops (block-M, block-N, tile-M, tile-N, tile-K) + # - Add has 2 nested loops (tile-M, tile-N) + # All loops should have step 1 and arith.muli for offset recovery + check_patterns = """ + CHECK: func.func + CHECK: scf.for {{.*}} step %c1{{.*}} + CHECK: arith.muli {{.*}} : index + CHECK: scf.for {{.*}} step %c1{{.*}} + CHECK: arith.muli {{.*}} : index + CHECK: scf.for {{.*}} step %c1{{.*}} + CHECK: arith.muli {{.*}} : index + CHECK: scf.for {{.*}} step %c1{{.*}} + CHECK: arith.muli {{.*}} : index + CHECK: scf.for {{.*}} step %c1{{.*}} + CHECK: arith.muli {{.*}} : index + CHECK: linalg.matmul + CHECK: scf.for {{.*}} step %c1{{.*}} + CHECK: arith.muli {{.*}} : index + CHECK: scf.for {{.*}} step %c1{{.*}} + CHECK: arith.muli {{.*}} : index + CHECK: linalg.add + CHECK: return + """ + run_kernel_test( + matmul_add_kernel, + stop_after='canonicalize-loop-step', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +def test_matmul_add_different_tile_sizes(): + """ + Test matmul + add with different tile sizes for each operation. + + Different tile sizes mean different constants in the arith.muli operations. + """ + M, N, K = 512, 512, 256 + matmul_tile = [128, 128] + matmul_reduction_tile = [64] + add_tile = [64, 64] + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32"), ((M, N), "f32")]) + def matmul_add_kernel(a, b, bias): + c = np.matmul(a, b) + knob.knob(c, tile_size=matmul_tile, reduction_tile=matmul_reduction_tile) + + result = c + bias + knob.knob(result, tile_size=add_tile) + + return result + + # Strict checks: + # - Matmul has 5 nested loops (block-M, block-N, tile-M, tile-N, tile-K) + # - Add has 2 nested loops (tile-M, tile-N) + # All loops should have step 1 and arith.muli for offset recovery + check_patterns = """ + CHECK: func.func + CHECK: scf.for {{.*}} step %c1{{.*}} + CHECK: arith.muli {{.*}} : index + CHECK: scf.for {{.*}} step %c1{{.*}} + CHECK: arith.muli {{.*}} : index + CHECK: scf.for {{.*}} step %c1{{.*}} + CHECK: arith.muli {{.*}} : index + CHECK: scf.for {{.*}} step %c1{{.*}} + CHECK: arith.muli {{.*}} : index + CHECK: scf.for {{.*}} step %c1{{.*}} + CHECK: arith.muli {{.*}} : index + CHECK: linalg.matmul + CHECK: scf.for {{.*}} step %c1{{.*}} + CHECK: arith.muli {{.*}} : index + CHECK: scf.for {{.*}} step %c1{{.*}} + CHECK: arith.muli {{.*}} : index + CHECK: linalg.add + """ + run_kernel_test( + matmul_add_kernel, + stop_after='canonicalize-loop-step', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +# ============================================================================ +# Multi-Op Tests: Multiple Elementwise +# ============================================================================ + +def test_add_add_chain(): + """ + Test two elementwise adds in sequence: C = (A + B) + D. + + Both add operations should have canonicalized loops with arith.muli + for offset recovery. + """ + shape = (256, 256) + tile_size = [128, 128] + tile_m, tile_n = tile_size + + @trace(input_specs=[(shape, "f32"), (shape, "f32"), (shape, "f32")]) + def add_add_kernel(a, b, d): + c = a + b + knob.knob(c, tile_size=tile_size) + + result = c + d + knob.knob(result, tile_size=tile_size) + + return result + + # Strict checks: Two add operations, each with 2 nested loops + # Each loop should have step 1 and arith.muli for offset recovery + # Note: Using flexible pattern to handle various constant naming conventions + check_patterns = f""" + CHECK: func.func + CHECK: scf.for {{{{.*}}}} step %c1{{{{.*}}}} + CHECK: arith.muli {{{{.*}}}}, %c{tile_m}{{{{.*}}}} : index + CHECK: scf.for {{{{.*}}}} step %c1{{{{.*}}}} + CHECK: arith.muli {{{{.*}}}}, %c{tile_n}{{{{.*}}}} : index + CHECK: linalg.add + CHECK: scf.for {{{{.*}}}} step %c1{{{{.*}}}} + CHECK: arith.muli {{{{.*}}}} : index + CHECK: scf.for {{{{.*}}}} step %c1{{{{.*}}}} + CHECK: arith.muli {{{{.*}}}} : index + CHECK: linalg.add + """ + run_kernel_test( + add_add_kernel, + stop_after='canonicalize-loop-step', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +# ============================================================================ +# Test Runner +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/passes/canonicalize_partition_dim/__init__.py b/kernelgen/tests/passes/canonicalize_partition_dim/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kernelgen/tests/passes/canonicalize_partition_dim/test_basic.py b/kernelgen/tests/passes/canonicalize_partition_dim/test_basic.py new file mode 100644 index 0000000..0b6e6bb --- /dev/null +++ b/kernelgen/tests/passes/canonicalize_partition_dim/test_basic.py @@ -0,0 +1,325 @@ +""" +Tests for the canonicalize-partition-dim pass. + +This pass inserts linalg.transpose operations to ensure partition_dim=0 +everywhere. It operates on connected components of elementwise ops that +share a non-zero partition_dim, inserting transposes at component boundaries +and rewriting elementwise ops with permuted shapes. + +Run with: python -m pytest tests/passes/canonicalize_partition_dim/ -v +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + + +# ============================================================================ +# Test: No-op when partition_dim=0 (default) +# ============================================================================ + +def test_partition_dim_zero_is_noop(): + """ + When all annotations have partition_dim=0 (or no partition_dim), + the pass should be a no-op: no transposes inserted. + """ + shape = (256, 256) + + @trace(input_specs=[(shape, "f32")]) + def kernel(x): + y = np.exp(x) + knob.knob(y, mem_space="Sbuf", tile_size=[128, 128], partition_dim=0) + return y + + run_kernel_test( + kernel, + stop_after='canonicalize-partition-dim', + check_ir_not_contains=["linalg.transpose"], + modes=Mode.STRING_CHECK, + ) + + +# ============================================================================ +# Test: Single op with partition_dim=1 (2D) +# ============================================================================ + +def test_single_op_partition_dim_1(): + """ + A single elementwise op annotated with partition_dim=1 on a 2D tensor. + + Input: y = exp(x) : tensor<256x128xf32>, partition_dim=1 + After: transpose inputs [1,0] -> exp on tensor<128x256xf32> -> transpose output [1,0] + annotation updated to partition_dim=0, tile_size permuted + """ + M, N = 256, 128 + + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + y = np.exp(x) + knob.knob(y, mem_space="Sbuf", tile_size=[M, N], partition_dim=1) + return y + + # After the pass, we expect: + # - linalg.transpose ops inserted at boundaries + # - partition_dim updated to 0 + # - tile_size permuted from [256, 128] to [128, 256] + run_kernel_test( + kernel, + stop_after='canonicalize-partition-dim', + check_ir_contains=[ + "linalg.transpose", + "permutation = [1, 0]", + "linalg.exp", + "tensor<128x256xf32>", + "partition_dim = 0", + "tile_size = array", + ], + modes=Mode.STRING_CHECK, + ) + + +# ============================================================================ +# Test: Elementwise chain with partition_dim=1 +# ============================================================================ + +def test_elementwise_chain_partition_dim_1(): + """ + A chain of elementwise ops where the final op is annotated with + partition_dim=1. After infer-layout propagates partition_dim backward, + the entire chain should be rewritten. + + Input: y = exp(x), z = y + 1.0, knob on z with partition_dim=1 + After: transpose x -> exp -> add -> transpose back + All ops in the chain have permuted shapes. + """ + M, N = 256, 128 + + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + y = np.exp(x) + z = y + 1.0 + knob.knob(z, mem_space="Sbuf", tile_size=[M, N], partition_dim=1) + return z + + # After the pass: + # - One input transpose at the boundary + # - Elementwise ops rewritten with permuted shapes + # - One output transpose at the boundary + # - partition_dim=0 in all annotations + run_kernel_test( + kernel, + stop_after='canonicalize-partition-dim', + check_ir_contains=[ + "linalg.transpose", + "permutation = [1, 0]", + "linalg.exp", + "linalg.generic", + "tensor<128x256xf32>", + "partition_dim = 0", + "tile_size = array", + ], + modes=Mode.STRING_CHECK, + ) + + +# ============================================================================ +# Test: Numerical correctness after tiling (partition_dim=1) +# ============================================================================ + +def test_partition_dim_1_tiling_executes(): + """ + Verify that after canonicalize-partition-dim, the pipeline can tile + and execute the kernel correctly via LLVM JIT. + + This is the key correctness check: the transposes preserve semantics. + """ + M, N = 256, 128 + + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + y = np.exp(x) + z = y + 1.0 + knob.knob(z, mem_space="Sbuf", tile_size=[M, N], partition_dim=1) + return z + + run_kernel_test( + kernel, + stop_after='apply-and-strip-transforms', + modes=Mode.LLVM, + rtol=1e-5, + atol=1e-5, + ) + + +# ============================================================================ +# Test: 3D tensor with partition_dim=2 +# ============================================================================ + +def test_3d_partition_dim_2(): + """ + A 3D tensor annotated with partition_dim=2. + + Input: y = x + 1.0 : tensor<4x256x128xf32>, partition_dim=2 + After: permutation [2, 0, 1] moves dim 2 to position 0 + tensor becomes <128x4x256xf32> + tile_size permuted accordingly + """ + B, M, N = 4, 256, 128 + + @trace(input_specs=[((B, M, N), "f32")]) + def kernel(x): + y = x + 1.0 + knob.knob(y, mem_space="Sbuf", tile_size=[1, M, N], partition_dim=2) + return y + + run_kernel_test( + kernel, + stop_after='canonicalize-partition-dim', + check_ir_contains=[ + "linalg.transpose", + "permutation = [2, 0, 1]", + "linalg.generic", + "tensor<128x4x256xf32>", + "partition_dim = 0", + "tile_size = array", + "permutation = [1, 2, 0]", + ], + modes=Mode.STRING_CHECK, + ) + + +# ============================================================================ +# Test: 3D numerical correctness +# ============================================================================ + +def test_3d_partition_dim_2_executes(): + """ + Verify 3D tensor with partition_dim=2 executes correctly after tiling. + """ + B, M, N = 4, 256, 128 + + @trace(input_specs=[((B, M, N), "f32")]) + def kernel(x): + y = x + 1.0 + knob.knob(y, mem_space="Sbuf", tile_size=[1, M, N], partition_dim=2) + return y + + run_kernel_test( + kernel, + stop_after='apply-and-strip-transforms', + modes=Mode.LLVM, + rtol=1e-5, + atol=1e-5, + ) + + +# ============================================================================ +# Test: 3D broadcast generic with partition_dim=1 +# ============================================================================ + +def test_3d_broadcast_generic_partition_dim_1(): + """ + Verify that canonicalize-partition-dim correctly permutes indexing maps + of broadcast linalg.generic ops. + + Input: a(4,128,64) * b(1,128,64) with partition_dim=1 + After perm=[1,0,2]: + - a becomes (128,4,64), b becomes (128,1,64) + - The broadcast generic's indexing map for b must change from + (d0,d1,d2)->(0,d1,d2) to (d0,d1,d2)->(d0,0,d2) + + Without the fix: shapes are permuted but indexing maps are not, + causing a verifier error ('inferred shape dimension #1 to be 4, + but found 1'). + """ + BH, S, D = 4, 128, 64 + tile_size = [1, 128, 64] + + @trace(input_specs=[ + ((BH, S, D), "f32"), + ((1, S, D), "f32"), + ]) + def kernel(a, b): + result = a * b + knob.knob(result, mem_space="Sbuf", tile_size=tile_size, + partition_dim=1) + return result + + # After the pass, shapes should be permuted and the broadcast generic + # should verify correctly with updated indexing maps. + run_kernel_test( + kernel, + stop_after='canonicalize-partition-dim', + check_ir_contains=[ + "linalg.transpose", + "permutation = [1, 0, 2]", + "tensor<128x4x64xf32>", + "tensor<128x1x64xf32>", + "partition_dim = 0", + ], + modes=Mode.STRING_CHECK, + ) + + +def test_3d_broadcast_generic_executes(): + """ + Verify that a 3D broadcast multiply with partition_dim=1 produces + correct numerical results through tiling after canonicalization. + """ + BH, S, D = 4, 128, 64 + tile_size = [1, 128, 64] + + @trace(input_specs=[ + ((BH, S, D), "f32"), + ((1, S, D), "f32"), + ]) + def kernel(a, b): + result = a * b + knob.knob(result, mem_space="Sbuf", tile_size=tile_size, + partition_dim=1) + return result + + run_kernel_test( + kernel, + stop_after='apply-and-strip-transforms', + modes=Mode.LLVM, + rtol=1e-5, + atol=1e-5, + ) + + +# ============================================================================ +# Test: Matmul with partition_dim errors +# ============================================================================ + +def test_matmul_partition_dim_errors(): + """ + partition_dim on a matmul/bmm should error out. + Users must split annotations: no partition_dim on the matmul itself. + """ + M, K, N = 128, 64, 128 + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32")]) + def kernel(a, b): + c = np.matmul(a, b) + knob.knob(c, mem_space="Sbuf", tile_size=[M, N], + reduction_tile=[K], partition_dim=1) + return c + + with pytest.raises(RuntimeError, match="matmul"): + run_kernel_test( + kernel, + stop_after='canonicalize-partition-dim', + check_ir_contains=["should_not_get_here"], + modes=Mode.STRING_CHECK, + ) + + +# ============================================================================ +# Test Runner +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/passes/canonicalize_partition_dim/test_reduce.py b/kernelgen/tests/passes/canonicalize_partition_dim/test_reduce.py new file mode 100644 index 0000000..5da550d --- /dev/null +++ b/kernelgen/tests/passes/canonicalize_partition_dim/test_reduce.py @@ -0,0 +1,119 @@ +""" +Tests for canonicalize-partition-dim pass with reduction ops. + +Verifies that reductions (np.max, np.sum) with non-zero partition_dim +get their shapes correctly permuted and produce correct results. + +Run with: python -m pytest tests/passes/canonicalize_partition_dim/test_reduce.py -v +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + + +# ============================================================================ +# Test: 3D reduction (max) with partition_dim=1 +# ============================================================================ + +def test_3d_reduction_max_partition_dim_1(): + """ + 3D np.max with keepdims=True and partition_dim=1. + + Input: tensor<8x128x64xf32>, reduce axis=-1, keepdims=True + Output: tensor<8x128x1xf32> with partition_dim=1 + + After canonicalization (perm=[1,0,2]): + - Input becomes tensor<128x8x64xf32> + - Reduction output becomes tensor<128x8x1xf32> + """ + B, M, N = 8, 128, 64 + + @trace(input_specs=[((B, M, N), "f32")]) + def kernel(x): + sq = x * x + knob.knob(sq, mem_space="Sbuf", tile_size=[1, M, N], partition_dim=1) + + sm = np.max(sq, axis=-1, keepdims=True) + knob.knob(sm, mem_space="SharedHbm", tile_size=[1, M], + reduction_tile=[N], partition_dim=1) + return sm + + # String check: verify transposes and permuted shapes + run_kernel_test( + kernel, + stop_after='canonicalize-partition-dim', + check_ir_contains=[ + "linalg.transpose", + "permutation = [1, 0, 2]", + "tensor<128x8x1xf32>", + "partition_dim = 0", + ], + modes=Mode.STRING_CHECK, + ) + + # Numerical correctness via LLVM JIT + run_kernel_test( + kernel, + stop_after='apply-and-strip-transforms', + modes=Mode.LLVM, + rtol=1e-5, + atol=1e-5, + ) + + +# ============================================================================ +# Test: 3D reduction (sum) with partition_dim=1 +# ============================================================================ + +def test_3d_reduction_sum_partition_dim_1(): + """ + 3D np.sum with keepdims=True and partition_dim=1. + + Verifies that reduction with linalg.fill(tensor.empty) init gets + its shapes correctly permuted, and produces correct results. + """ + B, M, N = 8, 128, 64 + + @trace(input_specs=[((B, M, N), "f32")]) + def kernel(x): + y = x + 1.0 + knob.knob(y, mem_space="Sbuf", tile_size=[1, M, N], partition_dim=1) + + sm = np.sum(y, axis=-1, keepdims=True) + knob.knob(sm, mem_space="SharedHbm", tile_size=[1, M], + reduction_tile=[N], partition_dim=1) + return sm + + # String check: verify transposes and permuted shapes + run_kernel_test( + kernel, + stop_after='canonicalize-partition-dim', + check_ir_contains=[ + "linalg.transpose", + "permutation = [1, 0, 2]", + "tensor<128x8x64xf32>", + "tensor<128x8x1xf32>", + "partition_dim = 0", + ], + modes=Mode.STRING_CHECK, + ) + + # Numerical correctness via LLVM JIT + run_kernel_test( + kernel, + stop_after='apply-and-strip-transforms', + modes=Mode.LLVM, + rtol=1e-5, + atol=1e-5, + ) + + +# ============================================================================ +# Test Runner +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/passes/cleanup_bufferization_artifacts/__init__.py b/kernelgen/tests/passes/cleanup_bufferization_artifacts/__init__.py new file mode 100644 index 0000000..21a9ace --- /dev/null +++ b/kernelgen/tests/passes/cleanup_bufferization_artifacts/__init__.py @@ -0,0 +1 @@ +"""Tests for the cleanup-bufferization-artifacts pass.""" diff --git a/kernelgen/tests/passes/cleanup_bufferization_artifacts/test_basic.py b/kernelgen/tests/passes/cleanup_bufferization_artifacts/test_basic.py new file mode 100644 index 0000000..30a72af --- /dev/null +++ b/kernelgen/tests/passes/cleanup_bufferization_artifacts/test_basic.py @@ -0,0 +1,78 @@ +""" +Tests that AnnotateOp bufferizes correctly via BufferizableOpInterface. + +After one-shot-bufferize, nkipy.annotate ops should target memrefs (not tensors), +with no leftover bufferization.to_tensor wrappers. + +Run with: python -m pytest tests/passes/cleanup_bufferization_artifacts/test_basic.py -v +Or directly: python tests/passes/cleanup_bufferization_artifacts/test_basic.py +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + +# ============================================================================ +# Test Cases +# ============================================================================ + +def test_matmul_sbuf_add_hbm(): + """ + Test matmul-add chain with: + - matmul output -> SBUF (intermediate) + - add output -> SharedHbm (returned result) + + After bufferization, nkipy.annotate ops should operate on memrefs directly + (BufferizableOpInterface on AnnotateOp handles the tensor→memref conversion). + """ + M, N, K = 256, 256, 256 + matmul_tile = [128, 128] # TILE_M, TILE_N + matmul_reduction_tile = [128] # TILE_K + add_tile = [128, 128] # TILE_M, TILE_N + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32"), ((M, N), "f32")]) + def matmul_add_kernel(a, b, bias): + # Matmul outputs to SBUF for reuse in the add + c = np.matmul(a, b) + knob.knob(c, mem_space="Sbuf", tile_size=matmul_tile, reduction_tile=matmul_reduction_tile) + + # Add outputs to SharedHbm (returned from kernel) + result = c + bias + knob.knob(result, mem_space="SharedHbm", tile_size=add_tile) + + return result + + # After bufferization + canonicalize (stop_after=9): + # - nkipy.annotate ops should target memrefs, not tensors + # - No bufferization.to_tensor wrappers should remain for annotate + # - Sbuf (mem_space=3) and SharedHbm (mem_space=4) annotations preserved + check_patterns = ''' +CHECK: func.func +CHECK-NOT: bufferization.to_tensor +CHECK: nkipy.annotate +CHECK-SAME: mem_space = 3 +CHECK-SAME: reduction_tile = array +CHECK-SAME: tile_size = array +CHECK-NOT: bufferization.to_tensor +CHECK: nkipy.annotate +CHECK-SAME: mem_space = 4 +CHECK-SAME: tile_size = array +CHECK-NOT: bufferization.to_tensor +CHECK: return +''' + run_kernel_test( + matmul_add_kernel, + stop_after='one-shot-bufferize', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +# ============================================================================ +# Test Runner +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/passes/conftest.py b/kernelgen/tests/passes/conftest.py new file mode 100644 index 0000000..f5f0aac --- /dev/null +++ b/kernelgen/tests/passes/conftest.py @@ -0,0 +1,61 @@ +""" +Passes test conftest.py. + +Provides: + - Auto-marks all pass tests with 'passes' marker + - Shared pass utilities exposed as fixtures +""" + +import pytest + +from pass_utils import ( + run_passes, + trace_to_mlir_with_preprocessing, + compile_through_passes, + compile_knob_pipeline, + save_mlir_to_file, + get_test_output_dir, + run_filecheck, + assert_ir_unchanged, + get_filecheck_path, + verify_tiled_mlir_with_numpy, +) + +# Auto-mark all tests in this directory tree +pytestmark = pytest.mark.passes + + +@pytest.fixture +def filecheck(): + """Fixture providing the run_filecheck function.""" + return run_filecheck + + +@pytest.fixture +def pass_runner(): + """Fixture providing the run_passes function.""" + return run_passes + + +@pytest.fixture +def pass_compiler(): + """Fixture providing the compile_through_passes function.""" + return compile_through_passes + + +@pytest.fixture +def knob_pipeline(): + """Fixture providing the compile_knob_pipeline function.""" + return compile_knob_pipeline + + +@pytest.fixture +def mlir_preprocessor(): + """Fixture providing trace_to_mlir_with_preprocessing.""" + return trace_to_mlir_with_preprocessing + + +@pytest.fixture +def mlir_verifier(): + """Fixture providing the verify_tiled_mlir_with_numpy function.""" + return verify_tiled_mlir_with_numpy diff --git a/kernelgen/tests/passes/eliminate_same_memspace_copy/__init__.py b/kernelgen/tests/passes/eliminate_same_memspace_copy/__init__.py new file mode 100644 index 0000000..180188b --- /dev/null +++ b/kernelgen/tests/passes/eliminate_same_memspace_copy/__init__.py @@ -0,0 +1 @@ +"""Tests for eliminate-same-memspace-copy pass.""" diff --git a/kernelgen/tests/passes/eliminate_same_memspace_copy/test_basic.py b/kernelgen/tests/passes/eliminate_same_memspace_copy/test_basic.py new file mode 100644 index 0000000..7cd08bb --- /dev/null +++ b/kernelgen/tests/passes/eliminate_same_memspace_copy/test_basic.py @@ -0,0 +1,129 @@ +""" +Tests for eliminate-same-memspace-copy pass. + +The eliminate-same-memspace-copy pass: +1. Removes memref.copy ops where source and target are subviews of the same base + with the same offsets (i.e., they're copying to themselves) +2. This commonly occurs when tiling generates copy-back operations for promoted + buffers that happen to be the same as the original memory region + +Run with: python -m pytest tests/passes/eliminate_same_memspace_copy/test_basic.py -v +Or directly: python tests/passes/eliminate_same_memspace_copy/test_basic.py +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + +# ============================================================================ +# Test Cases +# ============================================================================ + +def test_matmul_sbuf_add_hbm(): + """ + Test matmul-add chain with: + - matmul output -> SBUF (intermediate) + - add output -> SharedHbm (returned result) + + This tests that the pass correctly: + 1. Removes redundant SBUF->SBUF copies where source and target are the same region + + Before the pass (in annotate-memory-space output), we have: + %subview_4 = memref.subview %alloc[%0, 0] [128, 256] ... + ... (inner loops) ... + %subview_5 = memref.subview %alloc[%0, 0] [128, 256] ... // SAME region! + memref.copy %subview_4, %subview_5 // REDUNDANT - copying to itself + + After the pass: + The redundant memref.copy is removed since both subviews access + the same memory region of %alloc. + """ + M, N, K = 256, 256, 256 + matmul_tile = [128, 128] # TILE_M, TILE_N + matmul_reduction_tile = [128] # TILE_K + add_tile = [128, 128] # TILE_M, TILE_N + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32"), ((M, N), "f32")]) + def matmul_add_kernel(a, b, bias): + # Matmul outputs to SBUF for reuse in the add + c = np.matmul(a, b) + knob.knob(c, mem_space="Sbuf", tile_size=matmul_tile, reduction_tile=matmul_reduction_tile) + + # Add outputs to SharedHbm (returned from kernel) + result = c + bias + knob.knob(result, mem_space="SharedHbm", tile_size=add_tile) + + return result + + # FileCheck patterns to verify: + # + # After eliminate-same-memspace-copy, we should see: + # 1. All the valid copies preserved (HBM->SBUF, SBUF->HBM, PSUM->SBUF, etc.) + # 2. The redundant SBUF->SBUF self-copy at the end of matmul M-loop should be GONE + # + # The matmul loop structure: + # scf.for %arg3 (M loop) + # %subview = subview LHS promoted + # %subview_4 = subview %alloc (matmul output tile) + # scf.for %arg4 (N loop) + # ... matmul inner loops ... + # memref.copy %psum, %subview_8 (writeback from PSUM to SBUF subview) + # } end N loop + # // REMOVED: memref.copy %subview_4, %subview_5 (was redundant) + # } end M loop + # + # After the pass runs, the redundant copy should not appear. + # We verify by checking the structure and that no SBUF->SBUF copy exists + # right after the inner N loop closes. + + check_patterns = ''' +CHECK: func.func @matmul_add_kernel +CHECK-SAME: 4 : i32 +CHECK: memref.alloc(){{.*}}: memref<256x256xf32, 3 : i32> +CHECK: memref.alloc(){{.*}}: memref<256x256xf32, 3 : i32> +CHECK: linalg.transpose{{.*}}outs({{.*}}memref<256x256xf32, 3 : i32>) +CHECK: memref.alloc(){{.*}}: memref<256x256xf32, 3 : i32> +CHECK: memref.copy{{.*}}to memref<256x256xf32, 3 : i32> +CHECK-NOT: memref.copy{{.*}}3 : i32>{{.*}}to{{.*}}3 : i32> +CHECK: scf.for +CHECK: memref.subview{{.*}}3 : i32> +CHECK: memref.subview{{.*}}3 : i32> +CHECK: scf.for +CHECK: memref.subview{{.*}}3 : i32> +CHECK: memref.alloc(){{.*}}: memref<128x128xf32, 2 : i32> +CHECK: scf.for +CHECK: linalg.matmul +CHECK: memref.subview{{.*}}3 : i32> +CHECK: memref.copy{{.*}}2 : i32>{{.*}}to{{.*}}3 : i32> +CHECK: } +CHECK: } +CHECK-NOT: memref.copy{{.*}}to memref<128x256xf32{{.*}}3 : i32> +CHECK: memref.alloc(){{.*}}: memref<256x256xf32, 4 : i32> +CHECK: scf.for +CHECK: scf.for +CHECK: memref.subview{{.*}}3 : i32> +CHECK: memref.subview{{.*}}4 : i32> +CHECK: memref.alloc(){{.*}}: memref<128x128xf32, 3 : i32> +CHECK: memref.copy{{.*}}4 : i32>{{.*}}to{{.*}}3 : i32> +CHECK: memref.alloc(){{.*}}: memref<128x128xf32, 3 : i32> +CHECK: linalg.add{{.*}}memref<128x128xf32, strided{{.*}}3 : i32>{{.*}}memref<128x128xf32, 3 : i32>{{.*}}memref<128x128xf32, 3 : i32> +CHECK: memref.subview{{.*}}4 : i32> +CHECK: memref.copy{{.*}}3 : i32>{{.*}}to{{.*}}4 : i32> +CHECK: return{{.*}}memref<256x256xf32, 4 : i32> +''' + run_kernel_test( + matmul_add_kernel, + stop_after='eliminate-same-memspace-copy', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +# ============================================================================ +# Test Runner +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/passes/eliminate_uninitialized_copies/__init__.py b/kernelgen/tests/passes/eliminate_uninitialized_copies/__init__.py new file mode 100644 index 0000000..7f0ef08 --- /dev/null +++ b/kernelgen/tests/passes/eliminate_uninitialized_copies/__init__.py @@ -0,0 +1 @@ +"""Tests for the eliminate-uninitialized-copies pass.""" diff --git a/kernelgen/tests/passes/eliminate_uninitialized_copies/test_basic.py b/kernelgen/tests/passes/eliminate_uninitialized_copies/test_basic.py new file mode 100644 index 0000000..33d9890 --- /dev/null +++ b/kernelgen/tests/passes/eliminate_uninitialized_copies/test_basic.py @@ -0,0 +1,131 @@ +""" +Tests for eliminate-uninitialized-copies pass. + +The eliminate-uninitialized-copies pass: +1. Finds memref.copy ops where source is a fresh alloc with no prior writes +2. Eliminates such copies since they copy undefined values +3. Common pattern: copy from SBUF subview to PSUM for accumulator init + +Run with: python -m pytest tests/passes/eliminate_uninitialized_copies/test_basic.py -v +Or directly: python tests/passes/eliminate_uninitialized_copies/test_basic.py +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + +# ============================================================================ +# Test Cases +# ============================================================================ + +def test_matmul_sbuf_add_hbm(): + """ + Test matmul-add chain with: + - matmul output -> SBUF (intermediate) + - add output -> SharedHbm (returned result) + + This tests that the pass correctly: + 1. Eliminates the copy from uninitialized SBUF subview to PSUM accumulator + Pattern before: + %alloc = memref.alloc() : memref<256x256xf32> // matmul output (uninit) + %subview_7 = memref.subview %subview_4... // subview chain to %alloc + %alloc_8 = memref.alloc() : memref<128x128xf32, #psum> // PSUM accumulator + memref.copy %subview_7, %alloc_8 // THIS GETS ELIMINATED + linalg.matmul ... outs(%alloc_8) // matmul writes to PSUM + + Pattern after: + %alloc = memref.alloc() : memref<256x256xf32> + %alloc_8 = memref.alloc() : memref<128x128xf32, #psum> + linalg.matmul ... outs(%alloc_8) // No copy before matmul! + + 2. Preserves necessary copies (those with initialized source data) + """ + M, N, K = 256, 256, 256 + matmul_tile = [128, 128] # TILE_M, TILE_N + matmul_reduction_tile = [128] # TILE_K + add_tile = [128, 128] # TILE_M, TILE_N + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32"), ((M, N), "f32")]) + def matmul_add_kernel(a, b, bias): + # Matmul outputs to SBUF for reuse in the add + c = np.matmul(a, b) + knob.knob(c, mem_space="Sbuf", tile_size=matmul_tile, reduction_tile=matmul_reduction_tile) + + # Add outputs to SharedHbm (returned from kernel) + result = c + bias + knob.knob(result, mem_space="SharedHbm", tile_size=add_tile) + + return result + + # FileCheck patterns to verify: + # + # NOTE: At this point in the pipeline, annotate-memory-space hasn't run yet. + # - Matmul output alloc is memref<256x256xf32> (no mem space - nkipy.annotate exists but not applied) + # - Add output alloc is memref<256x256xf32> (no mem space) + # - BUT promoted buffers DO have memory spaces (from tiling transform) + # + # COPIES THAT SHOULD BE ELIMINATED (uninitialized source): + # 1. Copy from uninitialized SBUF subview to PSUM accumulator before matmul + # + # COPIES THAT MUST BE PRESERVED (initialized source): + # 2. Copy from transpose temp to SBUF (LHS promote) + # 3. Copy from arg1 to SBUF (RHS promote) + # 4. Copy from PSUM to matmul output subview (matmul result writeback) + # 5. Copy from matmul output subview to add's promoted SBUF input (INITIALIZED!) + # 6. Copy from bias arg subview to add's promoted SBUF input + # 7. Copy from add's SBUF output to add output subview + check_patterns = ''' +CHECK: func.func @matmul_add_kernel +CHECK: memref.alloc(){{.*}}: memref<256x256xf32> +CHECK: memref.alloc(){{.*}}: memref<256x256xf32, 3 : i32> +CHECK: linalg.transpose{{.*}}outs({{.*}}memref<256x256xf32, 3 : i32>) +CHECK: memref.alloc(){{.*}}: memref<256x256xf32, 3 : i32> +CHECK: memref.copy{{.*}}3 : i32>{{.*}}to{{.*}}3 : i32> +CHECK: memref.alloc(){{.*}}: memref<256x256xf32, 3 : i32> +CHECK: memref.copy{{.*}}to memref<256x256xf32, 3 : i32> +CHECK: scf.for +CHECK: memref.subview{{.*}}3 : i32> +CHECK: memref.subview +CHECK: scf.for +CHECK: memref.subview{{.*}}3 : i32> +CHECK: memref.alloc(){{.*}}: memref<128x128xf32, 2 : i32> +CHECK-NOT: memref.copy{{.*}}to memref<128x128xf32, 2 : i32> +CHECK: scf.for +CHECK: memref.subview{{.*}}3 : i32> +CHECK: memref.subview{{.*}}3 : i32> +CHECK: linalg.matmul +CHECK: memref.subview +CHECK: memref.copy{{.*}}2 : i32>{{.*}}to{{.*}}memref<128x128xf32 +CHECK: nkipy.annotate +CHECK: memref.alloc(){{.*}}: memref<256x256xf32> +CHECK: scf.for +CHECK: scf.for +CHECK: memref.subview +CHECK: memref.alloc(){{.*}}: memref<128x128xf32, 3 : i32> +CHECK: memref.copy{{.*}}to{{.*}}memref<128x128xf32, 3 : i32> +CHECK: memref.subview +CHECK: memref.alloc(){{.*}}: memref<128x128xf32, 3 : i32> +CHECK: memref.copy{{.*}}to{{.*}}3 : i32> +CHECK: memref.alloc(){{.*}}: memref<128x128xf32, 3 : i32> +CHECK: linalg.add +CHECK: memref.subview +CHECK: memref.copy{{.*}}3 : i32>{{.*}}to{{.*}}memref<128x128xf32 +CHECK: nkipy.annotate +CHECK: return +''' + run_kernel_test( + matmul_add_kernel, + stop_after='eliminate-uninitialized-copies', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +# ============================================================================ +# Test Runner +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/passes/infer_layout/__init__.py b/kernelgen/tests/passes/infer_layout/__init__.py new file mode 100644 index 0000000..a8612c4 --- /dev/null +++ b/kernelgen/tests/passes/infer_layout/__init__.py @@ -0,0 +1 @@ +"""Tests for the infer-layout pass.""" diff --git a/kernelgen/tests/passes/infer_layout/test_infer_layout_broadcast.py b/kernelgen/tests/passes/infer_layout/test_infer_layout_broadcast.py new file mode 100644 index 0000000..809a90a --- /dev/null +++ b/kernelgen/tests/passes/infer_layout/test_infer_layout_broadcast.py @@ -0,0 +1,97 @@ +""" +Tests for InferLayout backward propagation through broadcast operations. + +When a broadcast division like `x (M,N) / sqrt(y (M,1))` has a knob on +the result, InferLayout should propagate backward to the (M,1)-shaped +producers with tile_size clamped to their dimensions. + +This is the pattern from RMSNorm: normed = x / sqrt(mean_sq + eps) + +Run with: pytest tests/passes/infer_layout/test_infer_layout_broadcast.py -v +""" + +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + + +def test_broadcast_div_propagates_clamped_tile(): + """ + Verify InferLayout propagates through broadcast with clamped tile_size. + + Pattern (from RMSNorm): + intermediate = reduced + eps # (256,1), no knob + normed = x / sqrt(intermediate) # (256,256), knob tile=[128,128] + + Expected after infer-layout: + intermediate gets tile_size=[128, 1] (clamped from [128, 128]) + sqrt(intermediate) gets tile_size=[128, 1] (clamped) + """ + M, N = 256, 256 + + @trace(input_specs=[((M, N), "f32"), ((M, 1), "f32")]) + def kernel(x, reduced): + intermediate = reduced + np.float32(1e-6) + normed = x / np.sqrt(intermediate) + knob.knob(normed, mem_space="Sbuf", tile_size=[128, 128]) + return normed + + # After infer-layout, the (256,1) ops should get clamped tile_size=[128, 1] + run_kernel_test( + kernel, + stop_after='infer-layout', + check_ir_contains=["tile_size = array"], + modes=Mode.STRING_CHECK, + ) + + +def test_broadcast_div_full_rmsnorm_pattern(): + """ + Full RMSNorm broadcast pattern: reduction output (M,1) flows into + a broadcast division with (M,N). + + Chain: + sq = square(x) # (256,256), knob tile=[128,128] + sum_sq = sum(sq, axis=-1) # (256,1), knob tile=[128], red=[128] + mean_sq = sum_sq * (1/N) # (256,1), knob tile=[128,1] + normed = x / sqrt(mean_sq + eps) # (256,256), knob tile=[128,128] + + The intermediate (mean_sq + eps) and sqrt(...) have shape (256,1) + but NO knob. InferLayout should infer tile_size=[128, 1] for them. + """ + M, N = 256, 256 + tile_size = [128, 128] + eps = 1e-6 + + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + x_fp32 = x.astype(np.float32) + + sq = np.square(x_fp32) + knob.knob(sq, mem_space="Sbuf", tile_size=tile_size) + + sum_sq = np.sum(sq, axis=-1, keepdims=True) + knob.knob(sum_sq, mem_space="Sbuf", tile_size=[128], reduction_tile=[128]) + + mean_sq = sum_sq * np.float32(1.0 / N) + knob.knob(mean_sq, mem_space="Sbuf", tile_size=[128, 1]) + + normed = x_fp32 / np.sqrt(mean_sq + eps) + knob.knob(normed, mem_space="Sbuf", tile_size=tile_size) + + return normed + + # The (mean_sq + eps) intermediate should get inferred tile_size=[128, 1] + # via backward propagation from normed's [128, 128] clamped to (256,1) shape + run_kernel_test( + kernel, + stop_after='infer-layout', + check_ir_contains=["tile_size = array"], + modes=Mode.STRING_CHECK, + ) + + +if __name__ == "__main__": + import pytest + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/passes/infer_layout/test_infer_layout_elementwise.py b/kernelgen/tests/passes/infer_layout/test_infer_layout_elementwise.py new file mode 100644 index 0000000..fc61e95 --- /dev/null +++ b/kernelgen/tests/passes/infer_layout/test_infer_layout_elementwise.py @@ -0,0 +1,420 @@ +""" +Tests for the infer-layout pass. + +The infer-layout pass infers layout information (tiling and placement) for +elementwise operations that lack explicit annotations. It propagates tile_size +and mem_space from annotated elementwise ops to adjacent unannotated ones +along the SSA use-def chain. + +Run with: python -m pytest tests/passes/infer_layout/test_infer_layout.py -v +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + + +# ============================================================================ +# Test: SiLU chain propagation (the motivating case) +# ============================================================================ + +def test_silu_chain_annotations(): + """ + Verify that infer-layout propagates tiling and placement to all intermediate + elementwise ops in a compact SiLU expression. + + Input: gated = gate / (1.0 + exp(-gate)) * up + knob only on final 'gated' (linalg.mul) + + After prepare-arithmetic, the chain is: + linalg.generic (negate) -> no layout + linalg.exp -> no layout + linalg.generic (add 1.0) -> no layout + linalg.reciprocal -> no layout + linalg.mul (x * recip) -> no layout + linalg.mul (swish * up) -> HAS layout (tile_size=[128, 128], mem_space=Sbuf) + + After infer-layout, ALL elementwise ops should have nkipy.annotate with + tile_size = [128, 128]. + """ + shape = (256, 256) + tile_size = [128, 128] + + @trace(input_specs=[(shape, "f32"), (shape, "f32")]) + def silu_kernel(gate, up): + gated = gate / (1.0 + np.exp(-gate)) * up + knob.knob(gated, mem_space="Sbuf", tile_size=tile_size) + return gated + + # After infer-layout, we expect nkipy.annotate ops with tile_size on every + # elementwise op in the chain. Each op should be followed by one. + check_patterns = """ + CHECK: linalg.generic + CHECK: nkipy.annotate{{.*}}tile_size + CHECK: linalg.exp + CHECK: nkipy.annotate{{.*}}tile_size + CHECK: linalg.generic + CHECK: nkipy.annotate{{.*}}tile_size + CHECK: linalg.reciprocal + CHECK: nkipy.annotate{{.*}}tile_size + CHECK: linalg.mul + CHECK: nkipy.annotate{{.*}}tile_size + CHECK: linalg.mul + CHECK: nkipy.annotate{{.*}}tile_size + """ + run_kernel_test( + silu_kernel, + stop_after='infer-layout', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +def test_silu_chain_tiling_executes(): + """ + Verify that after infer-layout, knob-driven-tiling can tile ALL ops in the + SiLU chain and the result is numerically correct via LLVM JIT. + + This is the end-to-end correctness check: inferred layout -> tiling -> JIT. + """ + shape = (256, 256) + tile_size = [128, 128] + + @trace(input_specs=[(shape, "f32"), (shape, "f32")]) + def silu_kernel(gate, up): + gated = gate / (1.0 + np.exp(-gate)) * up + knob.knob(gated, mem_space="Sbuf", tile_size=tile_size) + return gated + + run_kernel_test( + silu_kernel, + stop_after='apply-and-strip-transforms', + modes=Mode.LLVM, + rtol=1e-5, + atol=1e-5, + ) + + +# ============================================================================ +# Test: Simple chain (A -> B -> C, only C annotated) +# ============================================================================ + +def test_simple_chain(): + """ + Test backward propagation through a simple 2-op chain. + + Input: y = exp(x), z = y + 1.0, knob on z + After infer-layout: exp should also get a layout annotation. + """ + shape = (256, 256) + tile_size = [128, 128] + + @trace(input_specs=[(shape, "f32")]) + def chain_kernel(x): + y = np.exp(x) + z = y + 1.0 + knob.knob(z, mem_space="Sbuf", tile_size=tile_size) + return z + + # After infer-layout, both linalg.exp and linalg.generic(add) should + # have nkipy.annotate with tile_size + check_patterns = """ + CHECK: linalg.exp + CHECK: nkipy.annotate{{.*}}tile_size + CHECK: linalg.generic + CHECK: nkipy.annotate{{.*}}tile_size + """ + run_kernel_test( + chain_kernel, + stop_after='infer-layout', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +def test_simple_chain_tiling_executes(): + """ + Verify that a simple chain with inferred layout tiles and executes correctly. + """ + shape = (256, 256) + tile_size = [128, 128] + + @trace(input_specs=[(shape, "f32")]) + def chain_kernel(x): + y = np.exp(x) + z = y + 1.0 + knob.knob(z, mem_space="Sbuf", tile_size=tile_size) + return z + + run_kernel_test( + chain_kernel, + stop_after='apply-and-strip-transforms', + modes=Mode.LLVM, + ) + + +# ============================================================================ +# Test: Already-annotated ops should not be overridden +# ============================================================================ + +def test_existing_annotations_preserved(): + """ + If a producer op already has a layout annotation, infer-layout should NOT + override it. Each op retains its original annotation. + + Input: y = exp(x) with knob tile=[128, 128] + z = y + 1.0 with knob tile=[128, 128] + Both ops already annotated -> no new annotations should be created. + """ + shape = (256, 256) + tile_size = [128, 128] + + @trace(input_specs=[(shape, "f32")]) + def kernel(x): + y = np.exp(x) + knob.knob(y, mem_space="Sbuf", tile_size=tile_size) + z = y + 1.0 + knob.knob(z, mem_space="Sbuf", tile_size=tile_size) + return z + + # Both ops already have annotations. After infer-layout, the pass should + # report 0 inferred annotations. We verify the IR still has the same + # structure: exp then annotate, then generic then annotate. + check_patterns = """ + CHECK: linalg.exp + CHECK: nkipy.annotate{{.*}}tile_size + CHECK: linalg.generic + CHECK: nkipy.annotate{{.*}}tile_size + """ + run_kernel_test( + kernel, + stop_after='infer-layout', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +# ============================================================================ +# Test: Non-elementwise boundary (matmul -> elementwise chain) +# ============================================================================ + +def test_stops_at_matmul_boundary(): + """ + Verify that layout propagation does NOT cross non-elementwise ops. + + Input: mm = matmul(a, b) with its own matmul knob + y = exp(mm) -> no layout + z = y + 1.0 -> knob on z + + The infer-layout pass should propagate from z to y (both elementwise), + but should NOT create a new annotation on mm (it's a matmul, not elementwise). + """ + shape = (256, 256) + matmul_tile = [128, 128] + ew_tile = [128, 128] + + @trace(input_specs=[(shape, "f32"), (shape, "f32")]) + def kernel(a, b): + mm = np.matmul(a, b) + knob.knob(mm, mem_space="Sbuf", tile_size=matmul_tile, reduction_tile=[128]) + y = np.exp(mm) + z = y + 1.0 + knob.knob(z, mem_space="Sbuf", tile_size=ew_tile) + return z + + # After infer-layout: + # - matmul should still have its own annotate with tile_size [128,128,128] + # - exp should get an inferred annotate with tile_size [128,128] + # - generic(add) should have its original annotate with tile_size [128,128] + check_patterns = """ + CHECK: linalg.matmul + CHECK: nkipy.annotate{{.*}}tile_size + CHECK: linalg.exp + CHECK: nkipy.annotate{{.*}}tile_size + CHECK: linalg.generic + CHECK: nkipy.annotate{{.*}}tile_size + """ + run_kernel_test( + kernel, + stop_after='infer-layout', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +def test_matmul_plus_elementwise_tiling_executes(): + """ + Verify that matmul followed by elementwise chain with inferred layout + tiles and executes correctly. + """ + shape = (256, 256) + matmul_tile = [128, 128] + ew_tile = [128, 128] + + @trace(input_specs=[(shape, "f32"), (shape, "f32")]) + def kernel(a, b): + mm = np.matmul(a, b) + knob.knob(mm, mem_space="Sbuf", tile_size=matmul_tile, reduction_tile=[128]) + y = np.exp(mm) + z = y + 1.0 + knob.knob(z, mem_space="Sbuf", tile_size=ew_tile) + return z + + run_kernel_test( + kernel, + stop_after='apply-and-strip-transforms', + modes=Mode.LLVM, + rtol=1e-3, + atol=1e-3, + ) + + +# ============================================================================ +# Test: No annotations (pass should be a no-op) +# ============================================================================ + +def test_no_annotations_generates_defaults(): + """ + When no ops have layout annotations, infer-layout should generate + default annotations: partition_dim=0, tile_size=[min(dim0,128), dim_last], + mem_space=SBUF for intermediates. + """ + shape = (256, 256) + + @trace(input_specs=[(shape, "f32")]) + def kernel(x): + return np.exp(x) + + # Default tile: partition_dim=0, tile_size=[min(256,128), 256] = [128, 256]. + run_kernel_test( + kernel, + stop_after='infer-layout', + check_ir_contains=[ + "tile_size = array", + "partition_dim = 0", + ], + modes=Mode.STRING_CHECK, + ) + + +# ============================================================================ +# Test: 3D partition_dim BFS propagation +# ============================================================================ + +def test_3d_partition_dim_inferred_from_tile(): + """ + Verify that both ops retain partition_dim=1 from their explicit knob + annotations. Both exp and mul are annotated with partition_dim=1, + so infer-layout should preserve both. + """ + B, S, D = 4, 128, 64 + tile_size = [1, 128, 64] + + @trace(input_specs=[((B, S, D), "f32"), ((B, S, D), "f32")]) + def kernel_3d_pdim(a, b): + intermediate = np.exp(a) + knob.knob(intermediate, mem_space="Sbuf", tile_size=tile_size, + partition_dim=1) + result = intermediate * b + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size, + partition_dim=1) + return result + + # After infer-layout, ALL ops should have partition_dim = 1 since + # tile dim 1 = 128 = MAX_PARTITION_DIM. + check_patterns = """ + CHECK: linalg.exp + CHECK: nkipy.annotate{{.*}}partition_dim = 1 + CHECK: linalg.mul + CHECK: nkipy.annotate{{.*}}partition_dim = 1 + """ + run_kernel_test( + kernel_3d_pdim, + stop_after='infer-layout', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +def test_3d_partition_dim_propagation_unannotated(): + """ + Verify that BFS backward propagation copies partition_dim=1 from an + annotated consumer to an unannotated producer. + + Chain: x → exp(x) → exp(x) + bias → result + Only result has a knob with partition_dim=1. BFS should propagate + partition_dim=1 to the intermediate exp op. + """ + B, S, D = 4, 128, 64 + tile_size = [1, 128, 64] + + @trace(input_specs=[((B, S, D), "f32"), ((B, S, D), "f32")]) + def kernel_3d_chain(x, bias): + y = np.exp(x) + # Only annotate the final result — y gets tile_size via BFS + z = y + bias + knob.knob(z, mem_space="Sbuf", tile_size=tile_size, + partition_dim=1) + return z + + # Both ops should have partition_dim = 1 after infer-layout + check_patterns = """ + CHECK: linalg.exp + CHECK: nkipy.annotate{{.*}}partition_dim = 1 + CHECK: linalg.add + CHECK: nkipy.annotate{{.*}}partition_dim = 1 + """ + run_kernel_test( + kernel_3d_chain, + stop_after='infer-layout', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +def test_3d_partition_dim_enables_canonicalize(): + """ + Verify that BFS-propagated partition_dim=1 enables + canonicalize-partition-dim to insert transposes. + + Chain: x → exp(x) → exp(x) + bias → result + Only result has explicit partition_dim=1. After BFS propagation + and canonicalize-partition-dim, transposes should be inserted to + move partition_dim=1 to position 0. + """ + B, S, D = 4, 128, 64 + tile_size = [1, 128, 64] + + @trace(input_specs=[((B, S, D), "f32"), ((B, S, D), "f32")]) + def kernel_3d_canon(x, bias): + y = np.exp(x) + z = y + bias + knob.knob(z, mem_space="SharedHbm", tile_size=tile_size, + partition_dim=1) + return z + + # After canonicalize-partition-dim, transposes should be inserted + # to move partition_dim=1 to dim 0. This proves that infer-layout + # correctly filled partition_dim on the unannotated exp op. + check_patterns = """ + CHECK: func.func @kernel_3d_canon + CHECK: linalg.transpose + CHECK: linalg.exp + CHECK: linalg.add + """ + run_kernel_test( + kernel_3d_canon, + stop_after='canonicalize-partition-dim', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +# ============================================================================ +# Test Runner +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/passes/infer_layout/test_infer_layout_matmul.py b/kernelgen/tests/passes/infer_layout/test_infer_layout_matmul.py new file mode 100644 index 0000000..eb5c2ad --- /dev/null +++ b/kernelgen/tests/passes/infer_layout/test_infer_layout_matmul.py @@ -0,0 +1,474 @@ +""" +Tests for the infer-layout pass: matmul seeding and propagation. + +Verifies that infer-layout correctly seeds matmul operands with the +hardware-specific layout rules: + - Result C [M,N]: partition_dim=0 (M), tile=[min(M,128), min(N,512)] + - Operand A [M,K]: partition_dim=1 (K), tile=[min(M,128), min(K,128)] + - Operand B [K,N]: partition_dim=0 (K), tile=[min(K,128), min(N,512)] + +Also tests forward propagation from matmul operands to downstream +elementwise consumers. + +Run with: pytest tests/passes/infer_layout/test_infer_layout_matmul.py -v +""" + +import numpy as np +import pytest + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + + +# ============================================================================ +# Test: Matmul seeding produces correct operand layouts +# ============================================================================ + +def test_matmul_seed_result_layout(): + """ + A matmul with a user knob on its result should keep that annotation. + Operands that are produced by linalg ops should get matmul-specific + layouts via backward propagation. + + matmul(exp(a), b) with knob on result: + - result C: user annotation preserved + - operand A (exp result): partition_dim=1, tile=[128, 128] + """ + M, N, K = 256, 256, 256 + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32")]) + def kernel(a, b): + a_exp = np.exp(a) + mm = np.matmul(a_exp, b) + knob.knob(mm, mem_space="Sbuf", tile_size=[128, 128], reduction_tile=[128]) + return mm + + # exp(a) should get matmul operand A layout: partition_dim=1 + run_kernel_test( + kernel, + stop_after='infer-layout', + check_ir_contains=[ + "partition_dim = 1", # operand A gets partition_dim=1 + "tile_size = array", + ], + modes=Mode.STRING_CHECK, + ) + + +def test_matmul_no_annotation_auto_seeds(): + """ + A matmul with NO user annotation should get auto-seeded defaults. + + matmul(a, b) -> result [M, N]: + tile_size=[min(M,128), min(N,512)], partition_dim=0, reduction_tile=[min(K,128)] + KnobDrivenTiling dynamically adjusts blocking based on dim/tile ratio. + """ + M, N, K = 256, 512, 128 + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32")]) + def kernel(a, b): + return np.matmul(a, b) + + # Result C: tile=[min(256,128), min(512,512)] = [128, 512] + # reduction_tile=[min(128,128)] = [128] + run_kernel_test( + kernel, + stop_after='infer-layout', + check_ir_contains=[ + "tile_size = array", + "reduction_tile = array", + "partition_dim = 0", + ], + modes=Mode.STRING_CHECK, + ) + + +def test_matmul_auto_seed_large_dims(): + """ + Verify matmul auto-seeding respects hardware limits for large dims. + + matmul [1024, 256] x [256, 2048] -> [1024, 2048] + Result: tile=[min(1024/2,128), min(2048/2,512)] = [128, 512] + reduction_tile=[min(256,128)] = [128] + + TODO: Once KnobDrivenTiling supports non-blocked matmul for small dims, + the BF=2 divisor can be removed. + """ + M, N, K = 1024, 2048, 256 + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32")]) + def kernel(a, b): + return np.matmul(a, b) + + run_kernel_test( + kernel, + stop_after='infer-layout', + check_ir_contains=[ + "tile_size = array", + "reduction_tile = array", + ], + modes=Mode.STRING_CHECK, + ) + + +# ============================================================================ +# Test: Matmul + elementwise chain propagation +# ============================================================================ + +def test_matmul_forward_propagates_to_elementwise(): + """ + After matmul seeding, forward propagation should reach downstream + elementwise ops. + + matmul(a, b) -> exp -> result + Only matmul result is seeded. exp should get layout via forward + propagation from the matmul result. + """ + M, N, K = 256, 256, 256 + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32")]) + def kernel(a, b): + mm = np.matmul(a, b) + knob.knob(mm, mem_space="Sbuf", tile_size=[128, 128], reduction_tile=[128]) + return np.exp(mm) + + # exp should get an annotation via forward propagation + check_patterns = """ + CHECK: linalg.matmul + CHECK: nkipy.annotate{{.*}}tile_size + CHECK: linalg.exp + CHECK: nkipy.annotate{{.*}}tile_size + """ + run_kernel_test( + kernel, + stop_after='infer-layout', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +def test_matmul_elementwise_chain_executes(): + """ + Verify that matmul -> elementwise chain with auto-seeded layouts + tiles and executes correctly. + """ + M, N, K = 256, 256, 256 + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32")]) + def kernel(a, b): + mm = np.matmul(a, b) + knob.knob(mm, mem_space="Sbuf", tile_size=[128, 128], reduction_tile=[128]) + y = np.exp(mm) + knob.knob(y, mem_space="SharedHbm", tile_size=[128, 128]) + return y + + run_kernel_test( + kernel, + stop_after='apply-and-strip-transforms', + modes=Mode.LLVM, + rtol=1e-3, + atol=1e-3, + ) + + +# ============================================================================ +# Test: Matmul operand A backward propagation through elementwise chain +# ============================================================================ + +def test_matmul_operand_backward_chain(): + """ + Matmul operand A is produced by an elementwise chain: + x -> exp -> square -> matmul(square, b) + + The matmul seeding should set operand A (square) to partition_dim=1. + Backward propagation from square should reach exp with partition_dim=1. + """ + M, N, K = 256, 256, 256 + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32")]) + def kernel(x, b): + y = np.exp(x) + z = np.square(y) + mm = np.matmul(z, b) + knob.knob(mm, mem_space="Sbuf", tile_size=[128, 128], reduction_tile=[128]) + return mm + + # Both exp and square should get partition_dim=1 (matmul operand A layout) + check_patterns = """ + CHECK: linalg.exp + CHECK: nkipy.annotate{{.*}}partition_dim = 1 + CHECK: linalg.square + CHECK: nkipy.annotate{{.*}}partition_dim = 1 + CHECK: linalg.matmul + """ + run_kernel_test( + kernel, + stop_after='infer-layout', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +# ============================================================================ +# Test: Compatible tile sizes across producer-consumer boundary +# ============================================================================ + +def test_compatible_tile_sizes_no_conflict(): + """ + Two ops with different but compatible tile sizes (one divides the other) + should NOT trigger a conflict error. + + add1 tile=[128,128], add2 tile=[64,64]: 64 divides 128 -> compatible. + """ + shape = (256, 256) + + @trace(input_specs=[(shape, "f32"), (shape, "f32"), (shape, "f32")]) + def kernel(a, b, c): + x = a + b + knob.knob(x, mem_space="Sbuf", tile_size=[128, 128]) + y = x + c + knob.knob(y, mem_space="SharedHbm", tile_size=[64, 64]) + return y + + # Should compile without conflict + run_kernel_test( + kernel, + stop_after='infer-layout', + check_ir_contains=[ + "tile_size = array", + "tile_size = array", + ], + modes=Mode.STRING_CHECK, + ) + + +def test_compatible_tile_sizes_executes(): + """ + Verify that compatible but different tile sizes tile and execute correctly. + """ + shape = (256, 256) + + @trace(input_specs=[(shape, "f32"), (shape, "f32"), (shape, "f32")]) + def kernel(a, b, c): + x = a + b + knob.knob(x, mem_space="Sbuf", tile_size=[128, 128]) + y = x + c + knob.knob(y, mem_space="SharedHbm", tile_size=[64, 64]) + return y + + run_kernel_test( + kernel, + stop_after='apply-and-strip-transforms', + modes=Mode.LLVM, + ) + + +# ============================================================================ +# Test: Forward propagation through elementwise (no matmul) +# ============================================================================ + +def test_forward_propagation_elementwise(): + """ + When the first op in a chain has a knob, forward propagation should + reach downstream unannotated ops. + + exp(x) [knob] -> square -> result + square should get layout from forward propagation. + """ + shape = (256, 256) + + @trace(input_specs=[(shape, "f32")]) + def kernel(x): + y = np.exp(x) + knob.knob(y, mem_space="Sbuf", tile_size=[128, 128]) + z = np.square(y) + return z + + check_patterns = """ + CHECK: linalg.exp + CHECK: nkipy.annotate{{.*}}tile_size + CHECK: linalg.square + CHECK: nkipy.annotate{{.*}}tile_size + """ + run_kernel_test( + kernel, + stop_after='infer-layout', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +# ============================================================================ +# Test: Elementwise fallback defaults +# ============================================================================ + +def test_fallback_3d_defaults(): + """ + For a 3D tensor with no annotations, fallback should produce: + partition_dim=0, tile_size=[min(dim0,128), 1, dim_last] + + Shape [128, 4, 256]: + tile = [min(128,128), 1, 256] = [128, 1, 256] + """ + shape = (128, 4, 256) + + @trace(input_specs=[(shape, "f32")]) + def kernel(x): + return np.exp(x) + + run_kernel_test( + kernel, + stop_after='infer-layout', + check_ir_contains=[ + "tile_size = array", + "partition_dim = 0", + ], + modes=Mode.STRING_CHECK, + ) + + +def test_fallback_small_partition_dim(): + """ + When dim 0 < 128, tile_size[0] should be the actual dim size. + + Shape [64, 512]: + tile = [min(64,128), 512] = [64, 512] + """ + shape = (64, 512) + + @trace(input_specs=[(shape, "f32")]) + def kernel(x): + return np.exp(x) + + run_kernel_test( + kernel, + stop_after='infer-layout', + check_ir_contains=[ + "tile_size = array", + "partition_dim = 0", + ], + modes=Mode.STRING_CHECK, + ) + + +def test_fallback_chain_no_annotations(): + """ + A chain of ops with no annotations should all get fallback defaults + and tile/execute correctly. + """ + shape = (256, 256) + + @trace(input_specs=[(shape, "f32")]) + def kernel(x): + y = np.exp(x) + z = np.square(y) + return z + + # All ops should get annotations + check_patterns = """ + CHECK: linalg.exp + CHECK: nkipy.annotate{{.*}}tile_size + CHECK: linalg.square + CHECK: nkipy.annotate{{.*}}tile_size + """ + run_kernel_test( + kernel, + stop_after='infer-layout', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +def test_fallback_chain_executes(): + """ + Verify that a chain with only fallback defaults tiles and runs correctly. + """ + shape = (256, 256) + + @trace(input_specs=[(shape, "f32")]) + def kernel(x): + y = np.exp(x) + z = np.square(y) + return z + + run_kernel_test( + kernel, + stop_after='apply-and-strip-transforms', + modes=Mode.LLVM, + ) + + +# ============================================================================ +# Test: Mixed user annotation + fallback +# ============================================================================ + +def test_partial_annotation_fills_gaps(): + """ + When only some ops are annotated, the pass should: + 1. Propagate from user annotations + 2. Fill remaining gaps with fallback defaults + + exp(x) [knob] -> square [no knob] -> add(square, y) [no knob] + exp has user annotation; square gets it via forward propagation; + add gets it via forward propagation from square. + """ + shape = (256, 256) + + @trace(input_specs=[(shape, "f32"), (shape, "f32")]) + def kernel(x, y): + a = np.exp(x) + knob.knob(a, mem_space="Sbuf", tile_size=[128, 128]) + b = np.square(a) + c = b + y + return c + + # All three ops should have annotations + check_patterns = """ + CHECK: linalg.exp + CHECK: nkipy.annotate{{.*}}tile_size + CHECK: linalg.square + CHECK: nkipy.annotate{{.*}}tile_size + CHECK: linalg.add + CHECK: nkipy.annotate{{.*}}tile_size + """ + run_kernel_test( + kernel, + stop_after='infer-layout', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +# ============================================================================ +# Test: Return value gets SharedHbm +# ============================================================================ + +def test_return_value_gets_shared_hbm(): + """ + Values that flow to func.return should get mem_space=SharedHbm. + The fallback default for non-return values is SBUF (mem_space=3), + but return values need SharedHbm (mem_space=4). + """ + shape = (256, 256) + + @trace(input_specs=[(shape, "f32")]) + def kernel(x): + return np.exp(x) + + # mem_space = 4 is SharedHbm + run_kernel_test( + kernel, + stop_after='infer-layout', + check_ir_contains=[ + "mem_space = 4", + ], + modes=Mode.STRING_CHECK, + ) + + +# ============================================================================ +# Test Runner +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/passes/infer_layout/test_infer_layout_reduce.py b/kernelgen/tests/passes/infer_layout/test_infer_layout_reduce.py new file mode 100644 index 0000000..c47e00f --- /dev/null +++ b/kernelgen/tests/passes/infer_layout/test_infer_layout_reduce.py @@ -0,0 +1,135 @@ +""" +Tests for the infer-layout pass with reduction operations. + +Verifies that infer-layout propagates layout annotations backward through +reduction generics (linalg.generic with reduction iterator types). + +np.mean decomposes into sum (reduction generic) + divide (elementwise generic). +A knob on the final mean result should propagate backward to the sum. + +Run with: pytest tests/passes/infer_layout/test_infer_layout_reduce.py -v +""" + +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + + +M, N = 256, 256 +TILE_SIZE = [128, 128] + + +def test_mean_propagates_to_sum(): + """ + np.mean = sum (linalg.generic reduction) + divide (linalg.generic elementwise). + A knob on the mean result should propagate to the intermediate sum. + + Chain: linalg.square -> linalg.generic(sum) -> linalg.generic(div) -> result + ^ knob + + After infer-layout, all three ops should have nkipy.annotate with tile_size. + The sum's inferred tile_size should be [128] (not [128, 1]) — for reduction + ops, tile_size only covers the non-reduced dimensions. If [128, 1] is + propagated instead, knob-driven-tiling will error out. + """ + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + sq = np.square(x.astype(np.float32)) + knob.knob(sq, mem_space="Sbuf", tile_size=TILE_SIZE) + + result = np.mean(sq, axis=-1, keepdims=True) + knob.knob( + result, + mem_space="SharedHbm", + tile_size=[128, 1], + reduction_tile=[128], + ) + return result + + # Verify annotation propagation ordering: + # square (user knob) -> sum (inferred, tile=[128]) -> div (user knob, tile=[128,1]) + # The sum is a reduction generic — its tile_size must be [128] (partition dim only), + # NOT [128, 1], otherwise knob-driven-tiling will fail. + check_patterns = """ + CHECK: linalg.square + CHECK: nkipy.annotate{{.*}}tile_size + CHECK: linalg.generic + CHECK: nkipy.annotate{{.*}}tile_size = array + CHECK: linalg.generic + CHECK: nkipy.annotate{{.*}}tile_size = array + """ + run_kernel_test( + kernel, + stop_after='infer-layout', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + # Verify numerical correctness after infer-layout + run_kernel_test( + kernel, + stop_after='infer-layout', + modes=Mode.LLVM, + rtol=1e-3, + atol=1e-3, + ) + + # Verify knob-driven-tiling succeeds with the inferred sum tile_size. + # If InferLayout propagates [128, 1] instead of [128], this will error out. + run_kernel_test( + kernel, + stop_after='apply-and-strip-transforms', + modes=Mode.LLVM, + rtol=1e-3, + atol=1e-3, + ) + + +def test_rmsnorm_reduction_knob_propagation(): + """ + RMSNorm pattern: square -> sum -> scale -> add_eps -> sqrt -> broadcast div. + + The sum reduction has an explicit knob with reduction_tile. InferLayout + should NOT propagate a knob without reduction_tile to the sum op, which + would cause knob-driven-tiling to fail with: + "Invalid tile configuration: reduction op requires reduction_tile, got none" + """ + tile_size = [128, 128] + eps = 1e-6 + + @trace(input_specs=[((M, N), "f32")]) + def kernel(x): + sq = np.square(x) + knob.knob(sq, mem_space="Sbuf", tile_size=tile_size) + mean_sq = np.sum(sq, axis=1, keepdims=True) / 256.0 + rms = np.sqrt(mean_sq + eps) + result = np.divide(x, rms) + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size) + return result + + # Verify the reduction op gets tile_size=[128] with reduction_tile, + # not an incorrectly propagated [128, 128] without reduction_tile + run_kernel_test( + kernel, + stop_after='infer-layout', + check_ir_contains=[ + "tile_size = array", + "reduction_tile = array", + ], + modes=Mode.STRING_CHECK, + ) + + # Verify knob-driven-tiling succeeds (would fail if reduction_tile is missing) + run_kernel_test( + kernel, + stop_after='apply-and-strip-transforms', + modes=Mode.LLVM, + rtol=1e-3, + atol=1e-3, + ) + + +if __name__ == "__main__": + import pytest + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/passes/insert_spill_reload/test_basic_spill.py b/kernelgen/tests/passes/insert_spill_reload/test_basic_spill.py new file mode 100644 index 0000000..3b6ce63 --- /dev/null +++ b/kernelgen/tests/passes/insert_spill_reload/test_basic_spill.py @@ -0,0 +1,350 @@ +""" +Tests for the insert-spill-reload pass. + +Two test categories: + 1. IR-level tests: Run the pass on hand-crafted MLIR to verify spill/reload + insertion logic (victims selected, HBM slots created, copies inserted). + 2. Kernel correctness test: Trace a real kernel through the full pipeline + (including insert-spill-reload) and verify numerical output via BIR sim. +""" + +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode +from pass_utils import run_passes, run_filecheck + + +# ============================================================================ +# IR-level pass tests +# ============================================================================ + + +def test_sbuf_overflow_spill(): + """ + SBUF overflow triggers spill insertion. + + Uses post-legalize-layout physical shapes [partTile, nB0, nB1, freeTile]. + Three 128×1×1×2048 SBUF allocations → per-partition = 1×1×2048×4 = 8192 bytes + each, 24576 bytes peak. Capacity set to 16384 → spill required. + + Expected: HBM spill slot and memref.copy spill/reload ops. + """ + input_ir = """ +module { + func.func @sbuf_overflow(%arg0: memref<128x1x1x2048xf32, 4 : i32>) -> memref<128x1x1x2048xf32, 4 : i32> { + %a = memref.alloc() : memref<128x1x1x2048xf32, 3 : i32> + %b = memref.alloc() : memref<128x1x1x2048xf32, 3 : i32> + %c = memref.alloc() : memref<128x1x1x2048xf32, 3 : i32> + linalg.copy ins(%arg0 : memref<128x1x1x2048xf32, 4 : i32>) + outs(%a : memref<128x1x1x2048xf32, 3 : i32>) + linalg.exp ins(%a : memref<128x1x1x2048xf32, 3 : i32>) + outs(%b : memref<128x1x1x2048xf32, 3 : i32>) + linalg.mul ins(%b, %b : memref<128x1x1x2048xf32, 3 : i32>, memref<128x1x1x2048xf32, 3 : i32>) + outs(%c : memref<128x1x1x2048xf32, 3 : i32>) + linalg.copy ins(%c : memref<128x1x1x2048xf32, 3 : i32>) + outs(%arg0 : memref<128x1x1x2048xf32, 4 : i32>) + memref.dealloc %a : memref<128x1x1x2048xf32, 3 : i32> + memref.dealloc %b : memref<128x1x1x2048xf32, 3 : i32> + memref.dealloc %c : memref<128x1x1x2048xf32, 3 : i32> + return %arg0 : memref<128x1x1x2048xf32, 4 : i32> + } +} +""" + # Per-partition: 1*1*2048*4 = 8192 bytes each, 3 live = 24576 bytes + # Capacity 16384 < 24576 → spill triggered + output_ir = run_passes(input_ir, ["insert-spill-reload=sbuf-capacity=16384"]) + + check_patterns = """ + CHECK: func.func @sbuf_overflow + CHECK: memref.alloc() : memref<128x1x1x2048xf32, 3 : i32> + CHECK: memref.alloc() : memref<128x1x1x2048xf32, 1 : i32> + CHECK: linalg.exp + CHECK: memref.copy {{.*}} 3 : i32> to memref<128x1x1x2048xf32, 1 : i32> + CHECK: memref.copy {{.*}} 1 : i32> to memref<128x1x1x2048xf32, 3 : i32> + CHECK: return + """ + run_filecheck(output_ir, check_patterns) + + +def test_no_spill_below_capacity(): + """ + Single 128×1×1×512 allocation: per-partition = 1×1×512×4 = 2048 bytes, + well below the 16384-byte capacity. + + Expected: no HBM spill slot created. + """ + input_ir = """ +module { + func.func @no_spill(%arg0: memref<128x1x1x512xf32, 4 : i32>) -> memref<128x1x1x512xf32, 4 : i32> { + %a = memref.alloc() : memref<128x1x1x512xf32, 3 : i32> + linalg.copy ins(%arg0 : memref<128x1x1x512xf32, 4 : i32>) + outs(%a : memref<128x1x1x512xf32, 3 : i32>) + linalg.exp ins(%a : memref<128x1x1x512xf32, 3 : i32>) + outs(%a : memref<128x1x1x512xf32, 3 : i32>) + linalg.copy ins(%a : memref<128x1x1x512xf32, 3 : i32>) + outs(%arg0 : memref<128x1x1x512xf32, 4 : i32>) + memref.dealloc %a : memref<128x1x1x512xf32, 3 : i32> + return %arg0 : memref<128x1x1x512xf32, 4 : i32> + } +} +""" + # Per-partition: 1*1*512*4 = 2048 bytes; capacity 16384 > 2048 → no spill + output_ir = run_passes(input_ir, ["insert-spill-reload=sbuf-capacity=16384"]) + + check_patterns = """ + CHECK: func.func @no_spill + CHECK: memref.alloc() : memref<128x1x1x512xf32, 3 : i32> + CHECK-NOT: 1 : i32 + CHECK: linalg.exp + CHECK: return + """ + run_filecheck(output_ir, check_patterns) + + +def test_non_overlapping_lifetimes(): + """ + Two sequential 128×1×1×2048 allocations whose lifetimes don't overlap. + + Per-partition: 8192 bytes each. Total (16384) exceeds capacity, + but peak live usage (8192) does not. + Expected: no HBM spill slot created. + """ + input_ir = """ +module { + func.func @sequential(%arg0: memref<128x1x1x2048xf32, 4 : i32>) -> memref<128x1x1x2048xf32, 4 : i32> { + %a = memref.alloc() : memref<128x1x1x2048xf32, 3 : i32> + linalg.copy ins(%arg0 : memref<128x1x1x2048xf32, 4 : i32>) + outs(%a : memref<128x1x1x2048xf32, 3 : i32>) + linalg.exp ins(%a : memref<128x1x1x2048xf32, 3 : i32>) + outs(%a : memref<128x1x1x2048xf32, 3 : i32>) + linalg.copy ins(%a : memref<128x1x1x2048xf32, 3 : i32>) + outs(%arg0 : memref<128x1x1x2048xf32, 4 : i32>) + memref.dealloc %a : memref<128x1x1x2048xf32, 3 : i32> + %b = memref.alloc() : memref<128x1x1x2048xf32, 3 : i32> + linalg.copy ins(%arg0 : memref<128x1x1x2048xf32, 4 : i32>) + outs(%b : memref<128x1x1x2048xf32, 3 : i32>) + linalg.sqrt ins(%b : memref<128x1x1x2048xf32, 3 : i32>) + outs(%b : memref<128x1x1x2048xf32, 3 : i32>) + linalg.copy ins(%b : memref<128x1x1x2048xf32, 3 : i32>) + outs(%arg0 : memref<128x1x1x2048xf32, 4 : i32>) + memref.dealloc %b : memref<128x1x1x2048xf32, 3 : i32> + return %arg0 : memref<128x1x1x2048xf32, 4 : i32> + } +} +""" + # Per-partition: 8192 bytes each, non-overlapping → peak = 8192 + # Capacity 16384 > 8192 → no spill + output_ir = run_passes(input_ir, ["insert-spill-reload=sbuf-capacity=16384"]) + + check_patterns = """ + CHECK: func.func @sequential + CHECK: memref.alloc() : memref<128x1x1x2048xf32, 3 : i32> + CHECK-NOT: 1 : i32 + CHECK: linalg.exp + CHECK-NOT: 1 : i32 + CHECK: linalg.sqrt + CHECK: return + """ + run_filecheck(output_ir, check_patterns) + + +def test_spill_with_loop_use(): + """ + A spilled value is used inside a loop body after the spill point. + + %a and %b are both 128×1×1×2048 → per-partition = 8192 bytes each; + together (16384) they exceed the 12288-byte capacity. %a is spilled at + the peak pressure point (second linalg.copy). Both %a and %b are used + inside the subsequent scf.for loop, so the reload must be inserted before + the loop — not omitted because the use is in a nested region. + + Verifies the nested-region fix: the pass walks up to the entry block to + find the ancestor of each user, so scf.for is correctly identified as the + first use after the spill. + """ + input_ir = """ +module { + func.func @spill_with_loop_use(%arg0: memref<128x1x1x2048xf32, 4 : i32>) -> memref<128x1x1x2048xf32, 4 : i32> { + %a = memref.alloc() : memref<128x1x1x2048xf32, 3 : i32> + %b = memref.alloc() : memref<128x1x1x2048xf32, 3 : i32> + linalg.copy ins(%arg0 : memref<128x1x1x2048xf32, 4 : i32>) + outs(%a : memref<128x1x1x2048xf32, 3 : i32>) + linalg.copy ins(%arg0 : memref<128x1x1x2048xf32, 4 : i32>) + outs(%b : memref<128x1x1x2048xf32, 3 : i32>) + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + scf.for %i = %c0 to %c4 step %c1 { + linalg.add ins(%a, %b : memref<128x1x1x2048xf32, 3 : i32>, memref<128x1x1x2048xf32, 3 : i32>) + outs(%b : memref<128x1x1x2048xf32, 3 : i32>) + } + linalg.copy ins(%b : memref<128x1x1x2048xf32, 3 : i32>) + outs(%arg0 : memref<128x1x1x2048xf32, 4 : i32>) + memref.dealloc %a : memref<128x1x1x2048xf32, 3 : i32> + memref.dealloc %b : memref<128x1x1x2048xf32, 3 : i32> + return %arg0 : memref<128x1x1x2048xf32, 4 : i32> + } +} +""" + # Per-partition: 8192 bytes each, 2 live = 16384 bytes + # Capacity 12288 < 16384 → spill triggered + output_ir = run_passes(input_ir, ["insert-spill-reload=sbuf-capacity=12288"]) + + check_patterns = """ + CHECK: func.func @spill_with_loop_use + CHECK: memref.alloc() : memref<128x1x1x2048xf32, 1 : i32> + CHECK: memref.copy {{.*}} 3 : i32> to memref<128x1x1x2048xf32, 1 : i32> + CHECK: memref.copy {{.*}} 1 : i32> to memref<128x1x1x2048xf32, 3 : i32> + CHECK-NEXT: scf.for + """ + run_filecheck(output_ir, check_patterns) + + +def test_multiple_pressure_peaks(): + """ + Two independent high-pressure windows, each requiring a separate spill. + + Window 1: %a + %b simultaneously live (2 × 8192 = 16384 > 12288) around linalg.exp. + Window 2: %c + %d simultaneously live (16384 > 12288) around linalg.sqrt. + The two windows are separated by deallocations, so they have non-overlapping + lifetimes. The single-peak algorithm would only fix one window; the + multi-peak fix ensures both receive a spill. + + Verifies: two HBM spill slots are created and two SBUF→HBM spill copies + are inserted — one for each window. + """ + input_ir = """ +module { + func.func @two_peaks(%arg0: memref<128x1x1x2048xf32, 4 : i32>) -> memref<128x1x1x2048xf32, 4 : i32> { + %a = memref.alloc() : memref<128x1x1x2048xf32, 3 : i32> + %b = memref.alloc() : memref<128x1x1x2048xf32, 3 : i32> + linalg.copy ins(%arg0 : memref<128x1x1x2048xf32, 4 : i32>) + outs(%a : memref<128x1x1x2048xf32, 3 : i32>) + linalg.exp ins(%a : memref<128x1x1x2048xf32, 3 : i32>) + outs(%b : memref<128x1x1x2048xf32, 3 : i32>) + linalg.copy ins(%b : memref<128x1x1x2048xf32, 3 : i32>) + outs(%arg0 : memref<128x1x1x2048xf32, 4 : i32>) + memref.dealloc %a : memref<128x1x1x2048xf32, 3 : i32> + memref.dealloc %b : memref<128x1x1x2048xf32, 3 : i32> + %c = memref.alloc() : memref<128x1x1x2048xf32, 3 : i32> + %d = memref.alloc() : memref<128x1x1x2048xf32, 3 : i32> + linalg.copy ins(%arg0 : memref<128x1x1x2048xf32, 4 : i32>) + outs(%c : memref<128x1x1x2048xf32, 3 : i32>) + linalg.sqrt ins(%c : memref<128x1x1x2048xf32, 3 : i32>) + outs(%d : memref<128x1x1x2048xf32, 3 : i32>) + linalg.copy ins(%d : memref<128x1x1x2048xf32, 3 : i32>) + outs(%arg0 : memref<128x1x1x2048xf32, 4 : i32>) + memref.dealloc %c : memref<128x1x1x2048xf32, 3 : i32> + memref.dealloc %d : memref<128x1x1x2048xf32, 3 : i32> + return %arg0 : memref<128x1x1x2048xf32, 4 : i32> + } +} +""" + # Per-partition: 8192 bytes each, 2 live per window = 16384 bytes + # Capacity 12288 < 16384 → spill triggered in each window + output_ir = run_passes(input_ir, ["insert-spill-reload=sbuf-capacity=12288"]) + + check_patterns = """ + CHECK: func.func @two_peaks + CHECK: memref.alloc() : memref<128x1x1x2048xf32, 1 : i32> + CHECK: linalg.exp + CHECK: memref.copy {{.*}} 3 : i32> to memref<128x1x1x2048xf32, 1 : i32> + CHECK: memref.alloc() : memref<128x1x1x2048xf32, 1 : i32> + CHECK: linalg.sqrt + CHECK: memref.copy {{.*}} 3 : i32> to memref<128x1x1x2048xf32, 1 : i32> + CHECK: return + """ + run_filecheck(output_ir, check_patterns) + + +def test_rmsnorm_no_spill(): + """ + Verify insert-spill-reload is a no-op for a 256×256 RMSNorm kernel. + + RMSNorm: output = (x / sqrt(mean(x^2) + eps)) * weight + + After tiling and bufferization, the full sq buffer is allocated as a single + 256×256×f32 SBUF alloc (subviews alias into it). Per-partition size is + 256 × 4 = 1024 bytes — well within the trn2 per-partition SBUF capacity + (~208 KB). No spilling should occur. + """ + M, N = 256, 256 + tile_size = [128, 128] + eps = 1e-6 + + @trace(input_specs=[((M, N), "f32"), ((N, 1), "f32")]) + def rmsnorm_kernel(x, weight): + x_fp32 = x.astype(np.float32) + w_fp32 = weight.astype(np.float32) + + sq = np.square(x_fp32) + knob.knob(sq, mem_space="Sbuf", tile_size=tile_size) + + sum_sq = np.sum(sq, axis=-1, keepdims=True) + knob.knob(sum_sq, mem_space="Sbuf", tile_size=[128], reduction_tile=[128]) + + mean_sq = sum_sq * np.float32(1.0 / N) + knob.knob(mean_sq, mem_space="Sbuf", tile_size=[128, 1]) + + normed = x_fp32 / np.sqrt(mean_sq + eps) + knob.knob(normed, mem_space="Sbuf", tile_size=tile_size) + + result = normed * w_fp32 + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size) + return result + + check_patterns = """ + CHECK: func.func @rmsnorm_kernel + CHECK: memref.alloc() {{.*}} : memref<128x{{.*}}xf32, 3 : i32> + CHECK-NOT: 1 : i32 + CHECK: return + """ + run_kernel_test( + rmsnorm_kernel, + stop_after="insert-spill-reload", + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +# ============================================================================ +# Kernel correctness test (full pipeline → BIR simulation) +# ============================================================================ + + +def test_exp_kernel_no_spurious_spill(): + """ + Verify insert-spill-reload is a no-op for a kernel that fits in SBUF. + + A 128×128 exp kernel uses 2 SBUF allocs of 128×128×f32. Per-partition + size = 128 × 4 = 512 bytes each, well within the trn2 per-partition SBUF + capacity (~208 KB). The pass should leave the IR unchanged: no HBM spill + slots inserted. + + Stops immediately after insert-spill-reload so the check runs on the pass + output directly, before later passes could transform or eliminate any + spill-related ops. + """ + M, N = 128, 128 + tile_size = [128, 128] + + @trace(input_specs=[((M, N), "f32")]) + def exp_kernel(x): + result = np.exp(x) + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size) + return result + + check_patterns = """ + CHECK: func.func @exp_kernel + CHECK: memref.alloc() {{.*}} 3 : i32 + CHECK-NOT: 1 : i32 + CHECK: linalg.exp + CHECK: return + """ + run_kernel_test( + exp_kernel, + stop_after="insert-spill-reload", + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) diff --git a/kernelgen/tests/passes/knob_driven_tiling/__init__.py b/kernelgen/tests/passes/knob_driven_tiling/__init__.py new file mode 100644 index 0000000..99a19d4 --- /dev/null +++ b/kernelgen/tests/passes/knob_driven_tiling/__init__.py @@ -0,0 +1 @@ +"""Tests for the knob-driven-tiling pass.""" diff --git a/kernelgen/tests/passes/knob_driven_tiling/test_elementwise.py b/kernelgen/tests/passes/knob_driven_tiling/test_elementwise.py new file mode 100644 index 0000000..41be1cd --- /dev/null +++ b/kernelgen/tests/passes/knob_driven_tiling/test_elementwise.py @@ -0,0 +1,369 @@ +""" +Tests for knob-driven-tiling pass with elementwise operations. + +These tests verify simple N-dimensional tiling for elementwise ops like add, sub, etc. +Run with: python -m pytest tests/passes/knob_driven_tiling/test_elementwise.py -v +Or directly: python tests/passes/knob_driven_tiling/test_elementwise.py +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + +# ============================================================================ +# Test Configurations +# ============================================================================ + +# (shape, tile_size, test_id) +ADD_TEST_CONFIGS = [ + pytest.param((256, 256), [128, 128], id="add_256x256_tile128"), + pytest.param((512, 512), [128, 128], id="add_512x512_tile128"), + pytest.param((256, 256), [64, 64], id="add_256x256_tile64"), + pytest.param((128, 256, 64), [64, 128, 32], id="add_3d_128x256x64"), +] + + +# ============================================================================ +# Test Functions +# ============================================================================ + +@pytest.mark.parametrize("shape,tile_size", ADD_TEST_CONFIGS) +def test_add_tiling(shape, tile_size): + """ + Test knob-driven-tiling on linalg.add. + + Pattern: result = A + B + + Elementwise ops use simple single-level tiling (no blocking). + For 2D: for i in [0, dim0, tile0): for j in [0, dim1, tile1): add_tile + """ + @trace(input_specs=[(shape, "f32"), (shape, "f32")]) + def add_kernel(a, b): + result = a + b + knob.knob(result, tile_size=tile_size) + return result + + # Build FileCheck patterns based on dimensionality + # Elementwise ops have one scf.for loop per dimension + # With SBUF promotion, we expect: + # 1. scf.for loops for tiling + # 2. tensor.extract_slice for each input/output + # 3. bufferization.alloc_tensor for SBUF promotion + # 4. linalg.add on SBUF tensors + # 5. tensor.insert_slice to write back + # + # FileCheck regex uses {{.*}} - in f-strings, {{{{.*}}}} produces {{.*}} + check_patterns = "CHECK: func.func\n" + for i, (dim, tile) in enumerate(zip(shape, tile_size)): + check_patterns += f" CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{dim}{{{{.*}}}} step %c{tile}\n" + + # Tiled tensor shapes - IR uses [dim1, dim2] with comma separator + tile_shape_comma = ", ".join(str(t) for t in tile_size) + tile_shape_x = "x".join(str(t) for t in tile_size) + + # With SBUF promotion, we expect alloc_tensor ops with memory_space + check_patterns += f" CHECK: tensor.extract_slice {{{{.*}}}} [{tile_shape_comma}]\n" + check_patterns += f" CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32\n" + check_patterns += f" CHECK: tensor.extract_slice {{{{.*}}}} [{tile_shape_comma}]\n" + check_patterns += f" CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32\n" + check_patterns += f" CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32\n" + check_patterns += f" CHECK: linalg.add {{{{.*}}}} tensor<{tile_shape_x}xf32\n" + check_patterns += f" CHECK: tensor.insert_slice\n" + check_patterns += f" CHECK: scf.yield\n" + + run_kernel_test( + add_kernel, + stop_after='apply-and-strip-transforms', + check_patterns=check_patterns, + modes=Mode.LLVM | Mode.FILECHECK, + ) + + +def test_sub_2d(): + """ + Test knob-driven-tiling on linalg.sub. + """ + shape = (256, 256) + tile_size = [128, 128] + + @trace(input_specs=[(shape, "f32"), (shape, "f32")]) + def sub_kernel(a, b): + result = a - b + knob.knob(result, tile_size=tile_size) + return result + + run_kernel_test( + sub_kernel, + stop_after='apply-and-strip-transforms', + modes=Mode.LLVM, + ) + + +def test_mul_2d(): + """ + Test knob-driven-tiling on linalg.mul (element-wise multiplication). + """ + shape = (256, 256) + tile_size = [128, 128] + + @trace(input_specs=[(shape, "f32"), (shape, "f32")]) + def mul_kernel(a, b): + result = a * b + knob.knob(result, tile_size=tile_size) + return result + + run_kernel_test( + mul_kernel, + stop_after='apply-and-strip-transforms', + modes=Mode.LLVM, + ) + + +def test_add_simple(): + """ + Simple test for 256x256 add to verify basic elementwise functionality. + + For 256x256 with tile_size=[128, 128]: + for i in [0, 256, 128): for j in [0, 256, 128): add([128, 128]) + """ + shape = (256, 256) + tile_size = [128, 128] + dim0, dim1 = shape + tile0, tile1 = tile_size + + @trace(input_specs=[(shape, "f32"), (shape, "f32")]) + def add_kernel(a, b): + result = a + b + knob.knob(result, tile_size=tile_size) + return result + + # FileCheck patterns for 2D elementwise tiling with SBUF promotion + # After promotion, we expect: + # 1. extract_slice for input tiles + # 2. alloc_tensor with 3 : i32 for SBUF copies + # 3. linalg.add on promoted tensors + # 4. insert_slice to write back + check_patterns = f""" + CHECK: func.func + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{dim0}{{{{.*}}}} step %c{tile0} + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{dim1}{{{{.*}}}} step %c{tile1} + CHECK: tensor.extract_slice {{{{.*}}}} [{tile0}, {tile1}] + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32 + CHECK: tensor.extract_slice {{{{.*}}}} [{tile0}, {tile1}] + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32 + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32 + CHECK: linalg.add {{{{.*}}}} tensor<{tile0}x{tile1}xf32 + CHECK: tensor.insert_slice + CHECK: scf.yield + """ + run_kernel_test( + add_kernel, + stop_after='apply-and-strip-transforms', + check_patterns=check_patterns, + modes=Mode.LLVM | Mode.FILECHECK, + ) + + +# ============================================================================ +# Scalar Constant Tests +# ============================================================================ +# When one operand is a scalar constant, the tracer generates a linalg.generic +# with the arith.constant embedded in the region body. This has 1 DPS input +# (the tensor) and 1 DPS init (the output), so tiling promotes 1 input + 1 +# output to SBUF (2 alloc_tensors, not 3 like binary tensor-tensor ops). + + +def test_tensor_add_scalar(): + """ + Test tiling of tensor + scalar constant. + + Pattern: result = x + 2.0 + Generated IR: linalg.generic with arith.addf and embedded arith.constant. + """ + shape = (256, 256) + tile_size = [128, 128] + dim0, dim1 = shape + tile0, tile1 = tile_size + + @trace(input_specs=[(shape, "f32")]) + def kernel(x): + result = x + 2.0 + knob.knob(result, tile_size=tile_size) + return result + + # linalg.generic with 1 input: 1 extract_slice + 2 alloc_tensors (1 input + 1 output) + check_patterns = f""" + CHECK: func.func + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{dim0}{{{{.*}}}} step %c{tile0} + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{dim1}{{{{.*}}}} step %c{tile1} + CHECK: tensor.extract_slice {{{{.*}}}} [{tile0}, {tile1}] + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32 + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32 + CHECK: linalg.generic + CHECK: tensor.insert_slice + CHECK: scf.yield + """ + run_kernel_test( + kernel, + stop_after='apply-and-strip-transforms', + check_patterns=check_patterns, + modes=Mode.LLVM | Mode.FILECHECK, + ) + + +def test_scalar_minus_tensor(): + """ + Test tiling of scalar - tensor (non-commutative, scalar on LHS). + + Pattern: result = 5.0 - x + Generated IR: linalg.generic with arith.subf, scalar_is_lhs=True. + """ + shape = (256, 256) + tile_size = [128, 128] + dim0, dim1 = shape + tile0, tile1 = tile_size + + @trace(input_specs=[(shape, "f32")]) + def kernel(x): + result = 5.0 - x + knob.knob(result, tile_size=tile_size) + return result + + # 1 DPS input -> 1 extract_slice, 2 alloc_tensors (1 input + 1 output) + check_patterns = f""" + CHECK: func.func + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{dim0}{{{{.*}}}} step %c{tile0} + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{dim1}{{{{.*}}}} step %c{tile1} + CHECK: tensor.extract_slice {{{{.*}}}} [{tile0}, {tile1}] + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32 + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32 + CHECK: linalg.generic + CHECK: tensor.insert_slice + CHECK: scf.yield + """ + run_kernel_test( + kernel, + stop_after='apply-and-strip-transforms', + check_patterns=check_patterns, + modes=Mode.LLVM | Mode.FILECHECK, + ) + + +def test_tensor_mul_scalar(): + """ + Test tiling of tensor * scalar constant. + + Pattern: result = x * 3.0 + """ + shape = (256, 256) + tile_size = [128, 128] + dim0, dim1 = shape + tile0, tile1 = tile_size + + @trace(input_specs=[(shape, "f32")]) + def kernel(x): + result = x * 3.0 + knob.knob(result, tile_size=tile_size) + return result + + check_patterns = f""" + CHECK: func.func + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{dim0}{{{{.*}}}} step %c{tile0} + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{dim1}{{{{.*}}}} step %c{tile1} + CHECK: tensor.extract_slice {{{{.*}}}} [{tile0}, {tile1}] + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32 + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32 + CHECK: linalg.generic + CHECK: tensor.insert_slice + CHECK: scf.yield + """ + run_kernel_test( + kernel, + stop_after='apply-and-strip-transforms', + check_patterns=check_patterns, + modes=Mode.LLVM | Mode.FILECHECK, + ) + + +def test_tensor_div_scalar(): + """ + Test tiling of tensor / scalar constant. + + Pattern: result = x / 2.0 + """ + shape = (256, 256) + tile_size = [128, 128] + dim0, dim1 = shape + tile0, tile1 = tile_size + + @trace(input_specs=[(shape, "f32")]) + def kernel(x): + result = x / 2.0 + knob.knob(result, tile_size=tile_size) + return result + + check_patterns = f""" + CHECK: func.func + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{dim0}{{{{.*}}}} step %c{tile0} + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{dim1}{{{{.*}}}} step %c{tile1} + CHECK: tensor.extract_slice {{{{.*}}}} [{tile0}, {tile1}] + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32 + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32 + CHECK: linalg.generic + CHECK: tensor.insert_slice + CHECK: scf.yield + """ + run_kernel_test( + kernel, + stop_after='apply-and-strip-transforms', + check_patterns=check_patterns, + modes=Mode.LLVM | Mode.FILECHECK, + ) + + +def test_scalar_div_tensor(): + """ + Test tiling of scalar / tensor (non-commutative, scalar on LHS). + + Pattern: result = 1.0 / x (reciprocal-like pattern used in sigmoid) + prepare-arithmetic converts this to linalg.reciprocal. + """ + shape = (256, 256) + tile_size = [128, 128] + dim0, dim1 = shape + tile0, tile1 = tile_size + + @trace(input_specs=[(shape, "f32")]) + def kernel(x): + result = 1.0 / x + knob.knob(result, tile_size=tile_size) + return result + + # prepare-arithmetic converts 1.0/x into linalg.reciprocal + check_patterns = f""" + CHECK: func.func + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{dim0}{{{{.*}}}} step %c{tile0} + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{dim1}{{{{.*}}}} step %c{tile1} + CHECK: tensor.extract_slice {{{{.*}}}} [{tile0}, {tile1}] + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32 + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32 + CHECK: linalg.reciprocal + CHECK: tensor.insert_slice + CHECK: scf.yield + """ + run_kernel_test( + kernel, + stop_after='apply-and-strip-transforms', + check_patterns=check_patterns, + modes=Mode.LLVM | Mode.FILECHECK, + ) + + +# ============================================================================ +# Test Runner +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/passes/knob_driven_tiling/test_matmul.py b/kernelgen/tests/passes/knob_driven_tiling/test_matmul.py new file mode 100644 index 0000000..ff0030f --- /dev/null +++ b/kernelgen/tests/passes/knob_driven_tiling/test_matmul.py @@ -0,0 +1,272 @@ +""" +Tests for knob-driven-tiling pass with matmul operations. + +These tests generate MLIR output files for manual inspection. +Run with: python -m pytest tests/passes/knob_driven_tiling/test_matmul.py -v +Or directly: python tests/passes/knob_driven_tiling/test_matmul.py +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from pass_utils import compile_knob_pipeline +from harness import run_kernel_test, Mode + +# ============================================================================ +# Test Configurations +# ============================================================================ + +# (M, N, K, tile_size [M, N], reduction_tile [K]) +TEST_CONFIGS = [ + # Square matrices + pytest.param(256, 256, 256, [128, 128], [128], id="256x256x256_tile128"), + pytest.param(512, 512, 512, [128, 128], [128], id="512x512x512_tile128"), + + # Different tile sizes + pytest.param(256, 256, 256, [64, 64], [64], id="256x256x256_tile64"), + pytest.param(256, 256, 256, [128, 64], [128], id="256x256x256_tile_mnk_128_64_128"), + + # Block size 1 cases: tile == dim, so blocking degenerates (no data reuse) + pytest.param(128, 128, 128, [128, 128], [128], id="128x128x128_no_blocking"), + pytest.param(256, 128, 256, [128, 128], [128], id="256x128x256_no_blocking_N"), + pytest.param(128, 256, 256, [128, 128], [128], id="128x256x256_no_blocking_M"), +] + + +# ============================================================================ +# Test Functions +# ============================================================================ + +@pytest.mark.parametrize("M,N,K,tile_size,reduction_tile", TEST_CONFIGS) +def test_matmul_tiling(M, N, K, tile_size, reduction_tile, request): + """ + Test knob-driven-tiling on simple matmul with 6-level blocking. + + KnobDrivenTiling generates (TILES_IN_BLOCK = 2): + for block_m in [0, M, BLOCK_M): // BLOCK_M = TILE_M * 2 + for block_n in [0, N, BLOCK_N): // BLOCK_N = TILE_N * 2 + for tile_m in [0, BLOCK_M, TILE_M): + for tile_n in [0, BLOCK_N, TILE_N): + for k in [0, K, TILE_K): + matmul_transpose_a + + Args: + M, N, K: Matrix dimensions (A is MxK, B is KxN) + tile_size: Tile size for matmul output [tileM, tileN] + reduction_tile: Tile size for reduction dim [tileK] + """ + # Extract tile sizes + tile_m, tile_n = tile_size + tile_k = reduction_tile[0] + + # Dynamic blocking: use 2 if dim is large enough, else 1 + blocks_m = 2 if M >= tile_m * 2 else 1 + blocks_n = 2 if N >= tile_n * 2 else 1 + block_m = tile_m * blocks_m + block_n = tile_n * blocks_n + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32")]) + def matmul_kernel(a, b): + result = np.matmul(a, b) + knob.knob(result, tile_size=tile_size, reduction_tile=reduction_tile) + return result + + # FileCheck verification for 6-level blocking + # LHS block size: [K, block_m] (transposed) + # RHS block size: [K, block_n] + lhs_block_k = K + lhs_block_m = block_m + rhs_block_k = K + rhs_block_n = block_n + + # LHS is transposed: original [M,K] -> transposed [K,M] + # FileCheck regex uses {{.*}} - in f-strings we need {{{{.*}}}} to escape the braces + # Constants may have suffixes like %c0_3, so we use %c0{{.*}} to match %c0, %c0_1, %c0_3, etc. + check_patterns = f""" + CHECK: func.func + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{M}{{{{.*}}}} step %c{block_m} + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32{{{{.*}}}} : tensor<{lhs_block_k}x{block_m}xf32> + CHECK: linalg.transpose {{{{.*}}}} outs({{{{.*}}}} : tensor<{lhs_block_k}x{block_m}xf32>) permutation = [1, 0] + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32{{{{.*}}}} : tensor<{lhs_block_k}x{lhs_block_m}xf32> + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{N}{{{{.*}}}} step %c{block_n} + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32{{{{.*}}}} : tensor<{rhs_block_k}x{rhs_block_n}xf32> + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{block_m}{{{{.*}}}} step %c{tile_m} + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{block_n}{{{{.*}}}} step %c{tile_n} + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 2 : i32{{{{.*}}}} : tensor<{tile_m}x{tile_n}xf32> + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{K}{{{{.*}}}} step %c{tile_k} + CHECK: tensor.extract_slice %{{{{.*}}}} [{tile_k}, {tile_m}] + CHECK: tensor.extract_slice %{{{{.*}}}} [{tile_k}, {tile_n}] + CHECK: linalg.matmul_transpose_a ins(%{{{{.*}}}}, %{{{{.*}}}} : tensor<{tile_k}x{tile_m}xf32>, tensor<{tile_k}x{tile_n}xf32>) + CHECK-SAME: outs(%{{{{.*}}}} : tensor<{tile_m}x{tile_n}xf32>) + CHECK: tensor.insert_slice + CHECK: scf.yield + """ + run_kernel_test( + matmul_kernel, + stop_after='apply-and-strip-transforms', + check_patterns=check_patterns, + modes=Mode.LLVM | Mode.FILECHECK, + ) + + +def test_matmul_simple_256(): + """ + Simple test for 256x256 matmul to verify 6-level blocking. + + KnobDrivenTiling generates 6-level blocking for matmul: + for block_m in [0, M, BLOCK_M): // BLOCK_M = TILE_M * 2 + for block_n in [0, N, BLOCK_N): // BLOCK_N = TILE_N * 2 + for tile_m in [0, BLOCK_M, TILE_M): + for tile_n in [0, BLOCK_N, TILE_N): + for k in [0, K, TILE_K): + matmul (linalg.matmul_transpose_a) + + For 256x256x256 with tile_size=[128, 128], reduction_tile=[128]: + BLOCK_M = 256, BLOCK_N = 256 + Steps: 256, 256, 128, 128, 128 + """ + M, N, K = 256, 256, 256 + tile_m, tile_n, tile_k = 128, 128, 128 + tile_size = [tile_m, tile_n] + reduction_tile = [tile_k] + + # Block size = tile * 2 (dim is large enough for blocking) + block_m = tile_m * 2 # 256 + block_n = tile_n * 2 # 256 + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32")]) + def matmul_kernel(a, b): + result = np.matmul(a, b) + knob.knob(result, tile_size=tile_size, reduction_tile=reduction_tile) + return result + + # FileCheck verification for 6-level blocking with input promotion + # LHS block size: [K, block_m] (transposed) + # RHS block size: [K, block_n] + lhs_block_k = K + lhs_block_m = block_m + rhs_block_k = K + rhs_block_n = block_n + + check_patterns = f""" + CHECK: func.func + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{M}{{{{.*}}}} step %c{block_m} + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32{{{{.*}}}} : tensor<{lhs_block_k}x{block_m}xf32> + CHECK: linalg.transpose {{{{.*}}}} outs({{{{.*}}}} : tensor<{lhs_block_k}x{block_m}xf32>) permutation = [1, 0] + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32{{{{.*}}}} : tensor<{lhs_block_k}x{lhs_block_m}xf32> + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{N}{{{{.*}}}} step %c{block_n} + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32{{{{.*}}}} : tensor<{rhs_block_k}x{rhs_block_n}xf32> + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{block_m}{{{{.*}}}} step %c{tile_m} + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{block_n}{{{{.*}}}} step %c{tile_n} + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 2 : i32{{{{.*}}}} : tensor<{tile_m}x{tile_n}xf32> + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{K}{{{{.*}}}} step %c{tile_k} + CHECK: tensor.extract_slice %{{{{.*}}}} [{tile_k}, {tile_m}] + CHECK: tensor.extract_slice %{{{{.*}}}} [{tile_k}, {tile_n}] + CHECK: linalg.matmul_transpose_a ins(%{{{{.*}}}}, %{{{{.*}}}} : tensor<{tile_k}x{tile_m}xf32>, tensor<{tile_k}x{tile_n}xf32>) + CHECK-SAME: outs(%{{{{.*}}}} : tensor<{tile_m}x{tile_n}xf32>) + CHECK: tensor.insert_slice + CHECK: scf.yield + """ + run_kernel_test( + matmul_kernel, + stop_after='apply-and-strip-transforms', + check_patterns=check_patterns, + modes=Mode.LLVM | Mode.FILECHECK, + ) + + +# ============================================================================ +# Batch Matmul Tests +# ============================================================================ + +BATCH_TEST_CONFIGS = [ + pytest.param(8, 256, 256, 256, [1, 128, 128], [128], id="b8_256x256x256_tile128"), + pytest.param(4, 512, 512, 512, [1, 128, 128], [128], id="b4_512x512x512_tile128"), +] + + +@pytest.mark.parametrize("B,M,N,K,tile_size,reduction_tile", BATCH_TEST_CONFIGS) +def test_batch_matmul_tiling(B, M, N, K, tile_size, reduction_tile, request): + """ + Test knob-driven-tiling on batch_matmul with 6-level blocking. + + Same blocking structure as matmul, with a leading batch dim tiled at 1. + """ + tile_m = tile_size[-2] + tile_n = tile_size[-1] + tile_k = reduction_tile[0] + + block_m = tile_m * 2 + block_n = tile_n * 2 + + @trace(input_specs=[((B, M, K), "f32"), ((B, K, N), "f32")]) + def batch_matmul_kernel(a, b): + result = np.matmul(a, b) + knob.knob(result, tile_size=tile_size, reduction_tile=reduction_tile) + return result + + lhs_block_k = K + lhs_block_m = block_m + rhs_block_k = K + rhs_block_n = block_n + + # Batch matmul: same 6-level blocking but with batch dim + # After tiling batch dim at 1, inner ops become 2D slices + check_patterns = f""" + CHECK: func.func + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{B}{{{{.*}}}} step %c1 + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{M}{{{{.*}}}} step %c{block_m} + CHECK: linalg.transpose + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{N}{{{{.*}}}} step %c{block_n} + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{block_m}{{{{.*}}}} step %c{tile_m} + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{block_n}{{{{.*}}}} step %c{tile_n} + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 2 : i32{{{{.*}}}} + CHECK: scf.for %{{{{.*}}}} = %c0{{{{.*}}}} to %c{K}{{{{.*}}}} step %c{tile_k} + CHECK: linalg.matmul_transpose_a + """ + run_kernel_test( + batch_matmul_kernel, + stop_after='apply-and-strip-transforms', + check_patterns=check_patterns, + modes=Mode.LLVM | Mode.FILECHECK, + ) + + +# ============================================================================ +# Error Handling Tests +# ============================================================================ + + +def test_matmul_k_tile_too_large(): + """ + Test that K tile size larger than K dimension produces an error. + """ + M, N, K = 256, 256, 64 # K is small + tile_size = [64, 64] # output tile + reduction_tile = [128] # K tile (128) is larger than K dimension (64) + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32")]) + def matmul_kernel(a, b): + result = np.matmul(a, b) + knob.knob(result, tile_size=tile_size, reduction_tile=reduction_tile) + return result + + # This should raise an exception + with pytest.raises(Exception) as excinfo: + compile_knob_pipeline(matmul_kernel, stop_after='apply-and-strip-transforms') + + error_msg = str(excinfo.value) + + # Check for the exact error message from KnobDrivenTiling.cpp + expected_error = "matmul K tile (128) is larger than K dimension (64)" + assert expected_error in error_msg, \ + f"Expected error message:\n {expected_error}\nGot:\n {error_msg}" + + +# ============================================================================ +# Test Runner +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/passes/knob_driven_tiling/test_multi_op.py b/kernelgen/tests/passes/knob_driven_tiling/test_multi_op.py new file mode 100644 index 0000000..c89e954 --- /dev/null +++ b/kernelgen/tests/passes/knob_driven_tiling/test_multi_op.py @@ -0,0 +1,218 @@ +""" +Tests for knob-driven-tiling with multiple operations. + +These tests verify tiling of kernels with multiple linalg ops, +including same-type ops with different tile sizes using nkipy.op_id. +Run with: python -m pytest tests/passes/knob_driven_tiling/test_multi_op.py -v +Or directly: python tests/passes/knob_driven_tiling/test_multi_op.py +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + +# ============================================================================ +# Test Functions +# ============================================================================ + +def test_two_matmuls_different_tiles(): + """ + Test with two matmul operations with different tile sizes. + + Pattern: + C = matmul(A, B) # tile_size=[128, 128], reduction_tile=[128] + D = matmul(C, E) # tile_size=[64, 64], reduction_tile=[64] + + This tests per-instance matching via nkipy.op_id - each matmul + should get its own transform sequence with its specific tile size. + """ + M, N, K = 256, 256, 256 + # First matmul: tile 128x128, K=128, block 256x256 + tile1_m, tile1_n, tile1_k = 128, 128, 128 + # Second matmul: tile 64x64, K=64, block 128x128 + tile2_m, tile2_n, tile2_k = 64, 64, 64 + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32"), ((N, N), "f32")]) + def two_matmul_kernel(a, b, e): + c = np.matmul(a, b) + knob.knob(c, tile_size=[128, 128], reduction_tile=[128]) + + d = np.matmul(c, e) + knob.knob(d, tile_size=[64, 64], reduction_tile=[64]) + + return d + + # FileCheck: verify both matmuls get correct tile sizes + # First matmul: 128x128x128 tiles, with linalg.transpose for LHS + # Second matmul: 64x64x64 tiles + # Transpose output is promoted to SBUF (bufferization.alloc_tensor before linalg.transpose) + check_patterns = f""" + CHECK: func.func + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32 + CHECK: linalg.transpose {{{{.*}}}} permutation = [1, 0] + CHECK: linalg.matmul_transpose_a {{{{.*}}}}tensor<{tile1_k}x{tile1_m}xf32>, tensor<{tile1_k}x{tile1_n}xf32> + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32 + CHECK: linalg.transpose {{{{.*}}}} permutation = [1, 0] + CHECK: linalg.matmul_transpose_a {{{{.*}}}}tensor<{tile2_k}x{tile2_m}xf32>, tensor<{tile2_k}x{tile2_n}xf32> + """ + run_kernel_test( + two_matmul_kernel, + stop_after='apply-and-strip-transforms', + check_patterns=check_patterns, + modes=Mode.LLVM | Mode.FILECHECK, + ) + + +def test_two_adds_different_tiles(): + """ + Test with two add operations with different tile sizes. + + Pattern: + C = A + B # tile_size=[128, 128] + D = C + E # tile_size=[64, 64] + + Both adds will have all operands promoted to SBUF. + """ + shape = (256, 256) + tile1 = [128, 128] + tile2 = [64, 64] + + @trace(input_specs=[(shape, "f32"), (shape, "f32"), (shape, "f32")]) + def two_add_kernel(a, b, e): + c = a + b + knob.knob(c, tile_size=[128, 128]) + + d = c + e + knob.knob(d, tile_size=[64, 64]) + + return d + + # FileCheck: verify both adds get tiled with different tile sizes + # With SBUF promotion, each add will have alloc_tensor with memory_space + check_patterns = f""" + CHECK: func.func + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32 + CHECK: linalg.add {{{{.*}}}}tensor<{tile1[0]}x{tile1[1]}xf32> + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32 + CHECK: linalg.add {{{{.*}}}}tensor<{tile2[0]}x{tile2[1]}xf32> + """ + run_kernel_test( + two_add_kernel, + stop_after='apply-and-strip-transforms', + check_patterns=check_patterns, + modes=Mode.LLVM | Mode.FILECHECK, + ) + + +def test_three_matmuls_same_and_different_tiles(): + """ + Test with three matmuls - two with same tile size, one different. + + Pattern: + C = matmul(A, B) # tile_size=[128, 128], reduction_tile=[128] + D = matmul(C, E) # tile_size=[64, 64], reduction_tile=[64] + F = matmul(D, G) # tile_size=[128, 128], reduction_tile=[128] + """ + M, N, K = 256, 256, 256 + + @trace(input_specs=[ + ((M, K), "f32"), ((K, N), "f32"), ((N, N), "f32"), ((N, N), "f32") + ]) + def three_matmul_kernel(a, b, e, g): + c = np.matmul(a, b) + knob.knob(c, tile_size=[128, 128], reduction_tile=[128]) + + d = np.matmul(c, e) + knob.knob(d, tile_size=[64, 64], reduction_tile=[64]) + + f = np.matmul(d, g) + knob.knob(f, tile_size=[128, 128], reduction_tile=[128]) + + return f + + run_kernel_test( + three_matmul_kernel, + stop_after='apply-and-strip-transforms', + modes=Mode.LLVM, + ) + + +def test_mixed_matmul_and_add(): + """ + Test with mixed operations - matmul followed by add with different tiles. + + Pattern: + C = matmul(A, B) # tile_size=[128, 128], reduction_tile=[128] + D = C + E # tile_size=[64, 64] + """ + M, N, K = 256, 256, 256 + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32"), ((M, N), "f32")]) + def mixed_kernel(a, b, e): + c = np.matmul(a, b) + knob.knob(c, tile_size=[128, 128], reduction_tile=[128]) + + d = c + e + knob.knob(d, tile_size=[64, 64]) + + return d + + run_kernel_test( + mixed_kernel, + stop_after='apply-and-strip-transforms', + modes=Mode.LLVM, + ) + + +def test_matmul_add_chain(): + """ + Test typical matmul + bias add pattern with different tile sizes. + + Pattern: + C = matmul(A, B) # tile_size=[128, 128], reduction_tile=[128] + D = C + bias # tile_size=[128, 128] (same spatial dims) + + Matmul: LHS/RHS promoted to SBUF, output to PSUM + Add: all operands promoted to SBUF + """ + M, N, K = 256, 256, 256 + tile_m, tile_n, tile_k = 128, 128, 128 + add_tile = [128, 128] + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32"), ((M, N), "f32")]) + def matmul_add_kernel(a, b, bias): + c = np.matmul(a, b) + knob.knob(c, tile_size=[128, 128], reduction_tile=[128]) + + # Add bias - using same spatial tile sizes as matmul output + d = c + bias + knob.knob(d, tile_size=[128, 128]) + + return d + + # FileCheck: matmul tiled then add tiled + # Transpose output promoted to SBUF, then add operands promoted to SBUF + check_patterns = f""" + CHECK: func.func + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32 + CHECK: linalg.transpose {{{{.*}}}} permutation = [1, 0] + CHECK: linalg.matmul_transpose_a {{{{.*}}}}tensor<{tile_k}x{tile_m}xf32>, tensor<{tile_k}x{tile_n}xf32> + CHECK: bufferization.alloc_tensor() {{{{.*}}}}memory_space = 3 : i32 + CHECK: linalg.add {{{{.*}}}}tensor<{add_tile[0]}x{add_tile[1]}xf32> + """ + run_kernel_test( + matmul_add_kernel, + stop_after='apply-and-strip-transforms', + check_patterns=check_patterns, + modes=Mode.LLVM | Mode.FILECHECK, + ) + + +# ============================================================================ +# Test Runner +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/passes/legalize_layout/__init__.py b/kernelgen/tests/passes/legalize_layout/__init__.py new file mode 100644 index 0000000..2d2fceb --- /dev/null +++ b/kernelgen/tests/passes/legalize_layout/__init__.py @@ -0,0 +1 @@ +# Tests for legalize-layout pass diff --git a/kernelgen/tests/passes/legalize_layout/test_basic.py b/kernelgen/tests/passes/legalize_layout/test_basic.py new file mode 100644 index 0000000..d797a3e --- /dev/null +++ b/kernelgen/tests/passes/legalize_layout/test_basic.py @@ -0,0 +1,188 @@ +""" +Tests for legalize-layout pass. + +The legalize-layout pass: +1. Transforms SBUF memrefs to (R+2)-D physical layout + - 2D: [tileM, numTilesM, numTilesN, tileN] (4D physical) + - 3D: [tileP, numB0, numB1, numB2, tileF] (5D physical) +2. Updates memref.subview ops to use physical indexing +3. Tiles HBM↔SBUF memref.copy and linalg.transpose into looped tile transfers +4. Collapses all (R+2)-D and R-D memrefs to 2D for NISA consumption + +Run with: python -m pytest tests/passes/legalize_layout/test_basic.py -v +Or directly: python tests/passes/legalize_layout/test_basic.py +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + +# ============================================================================ +# Test Cases +# ============================================================================ + +def test_matmul_sbuf_add_hbm(): + """ + Test matmul-add chain with: + - matmul output -> SBUF (intermediate) + - add output -> SharedHbm (returned result) + + This tests that the pass correctly: + 1. Transforms 2D SBUF allocs to 4D: memref<256x256xf32, #sbuf> -> memref<128x2x2x128xf32, #sbuf> + 2. Tiles HBM→SBUF transpose into nested loops + 3. Transforms subview ops to use 4D indexing + + Before the pass (annotate-memory-space output): + %alloc = memref.alloc() : memref<256x256xf32, 3 : i32> + linalg.transpose ins(%hbm_input) outs(%alloc) + scf.for ... { + %subview = memref.subview %alloc[%off_m, 0][128, 256]... + // ... linalg ops using 2D tiles ... + } + + After the pass (legalize-layout output): + %alloc_4d = memref.alloc() : memref<128x2x2x128xf32, 3 : i32> + scf.for %blk_m = 0 to 2 { + scf.for %blk_n = 0 to 2 { + %input_tile = memref.subview %hbm_input[...] + %output_tile = memref.subview %alloc_4d[0, %blk_m, %blk_n, 0][128, 1, 1, 128] + linalg.transpose ins(%input_tile) outs(%output_tile) + } + } + scf.for ... { + %subview_4d = memref.subview %alloc_4d[0, %blk, 0, 0][128, 1, 2, 128]... + // ... linalg ops using 4D tiles with rank reduction ... + } + """ + M, N, K = 256, 256, 256 + matmul_tile = [128, 128] # TILE_M, TILE_N + matmul_reduction_tile = [128] # TILE_K + add_tile = [128, 128] # TILE_M, TILE_N + + @trace(input_specs=[((M, K), "f32"), ((K, N), "f32"), ((M, N), "f32")]) + def matmul_add_kernel(a, b, bias): + # Matmul outputs to SBUF for reuse in the add + c = np.matmul(a, b) + knob.knob(c, mem_space="Sbuf", tile_size=matmul_tile, reduction_tile=matmul_reduction_tile) + + # Add outputs to SharedHbm (returned from kernel) + result = c + bias + knob.knob(result, mem_space="SharedHbm", tile_size=add_tile) + + return result + + # FileCheck patterns to verify: + # + # After legalize-layout, we should see: + # 1. 4D SBUF allocations: memref<128x2x2x128xf32, 3 : i32> + # 2. Tiled transpose loops (HBM→SBUF) + # 3. 4D subview operations with rank reduction + + check_patterns = ''' +CHECK: func.func @matmul_add_kernel +CHECK-SAME: 4 : i32 +CHECK: memref.alloc(){{.*}}: memref<128x2x2x128xf32, 3 : i32> +CHECK: scf.for +CHECK: scf.for +CHECK: linalg.transpose{{.*}}outs({{.*}}memref<128x{{.*}}128xf32{{.*}}3 : i32>) +CHECK: memref.alloc(){{.*}}: memref<128x2x2x128xf32, 3 : i32> +CHECK: scf.for +CHECK: scf.for +CHECK: memref.copy{{.*}}to{{.*}}memref<128x{{.*}}128xf32{{.*}}3 : i32> +CHECK: scf.for +CHECK: memref.subview{{.*}}3 : i32> +CHECK: memref.subview{{.*}}3 : i32> +CHECK: scf.for +CHECK: memref.subview{{.*}}3 : i32> +CHECK: memref.alloc(){{.*}}: memref<128x128xf32, 2 : i32> +CHECK: linalg.matmul +CHECK: memref.copy{{.*}}2 : i32>{{.*}}to{{.*}}3 : i32> +CHECK: } +CHECK: } +CHECK: memref.alloc(){{.*}}: memref<256x256xf32, 4 : i32> +CHECK: scf.for +CHECK: scf.for +CHECK: memref.subview{{.*}}3 : i32> +CHECK: memref.subview{{.*}}4 : i32> +CHECK: memref.alloc(){{.*}}: memref<128x128xf32, 3 : i32> +CHECK: memref.copy{{.*}}4 : i32>{{.*}}to{{.*}}3 : i32> +CHECK: memref.alloc(){{.*}}: memref<128x128xf32, 3 : i32> +CHECK: linalg.add +CHECK: memref.copy{{.*}}3 : i32>{{.*}}to{{.*}}4 : i32> +CHECK: return{{.*}}memref<256x256xf32, 4 : i32> +''' + run_kernel_test( + matmul_add_kernel, + stop_after='legalize-layout', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +def test_3d_add_chain_sbuf(): + """ + Test 3D add chain with SBUF intermediate: + - Shape: (256, 2, 256) — 3D tensor where dim0 > 128 triggers legalization + - tile_size: [128, 1, 128] — middle dim must have tile=1 (design constraint) + - intermediate (a+b) goes to SBUF, result (intermediate+c) goes to SharedHbm + + This tests that the pass correctly handles rank-3 tensors: + 1. Transforms 3D SBUF alloc to 5D physical layout: + memref<256x2x256xf32, #sbuf> -> memref<128x2x2x2x128xf32, #sbuf> + Physical shape: [tileP=128, numB0=2, numB1=2, numB2=2, tileF=128] + 2. Tiles HBM↔SBUF memref.copy into 3-level nested loops + 3. Collapses (5D physical) and (3D logical) memrefs to 2D for compute + 4. Reconstructs linalg ops with 2D iteration domain + """ + B, M, N = 256, 2, 256 + tile_size = [128, 1, 128] + + @trace(input_specs=[((B, M, N), "f32"), ((B, M, N), "f32"), ((B, M, N), "f32")]) + def add_chain_3d(a, b, c): + # First add: intermediate stored in SBUF + intermediate = a + b + knob.knob(intermediate, mem_space="Sbuf", tile_size=tile_size) + + # Second add: result goes to SharedHbm + result = intermediate + c + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size) + + return result + + # FileCheck patterns to verify 3D -> 5D legalization: + # + # 1. 5D SBUF allocation: memref<128x2x2x2x128xf32, 3 : i32> + # 2. 3-level tiled copy loops (HBM→SBUF) + # 3. memref.collapse_shape to 2D for compute ops + # 4. 2D linalg.add (after reconstruction from 3D) + check_patterns = ''' +CHECK: func.func @add_chain_3d +CHECK: memref.alloc(){{.*}}: memref<128x2x2x2x128xf32, 3 : i32> +CHECK: scf.for +CHECK: scf.for +CHECK: scf.for +CHECK: memref.collapse_shape +CHECK: memref.alloc(){{.*}}: memref<256x2x256xf32, 4 : i32> +CHECK: scf.for +CHECK: scf.for +CHECK: scf.for +CHECK: memref.subview{{.*}}3 : i32> +CHECK: linalg.add +CHECK: return{{.*}}memref<256x2x256xf32, 4 : i32> +''' + run_kernel_test( + add_chain_3d, + stop_after='legalize-layout', + check_patterns=check_patterns, + modes=Mode.FILECHECK, + ) + + +# ============================================================================ +# Test Runner +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/passes/legalize_layout/test_fold_reshape_copy.py b/kernelgen/tests/passes/legalize_layout/test_fold_reshape_copy.py new file mode 100644 index 0000000..0aef0c5 --- /dev/null +++ b/kernelgen/tests/passes/legalize_layout/test_fold_reshape_copy.py @@ -0,0 +1,431 @@ +""" +Tests for legalize-layout handling of Phase 0 (foldReshapeIntoAlloc) patterns. + +When an SBUF alloc is followed by a copy then a reshape, Phase 0 folds the +reshape into the alloc, creating a collapse_shape for the copy. Phases 1-3 +must follow through the collapse_shape to: + - discover tile sizes (Phase 1 -- traceToLinalgOperands) + - resolve the legalized SBUF alloc (Phase 3 -- tileMemrefCopy) + - handle the HBM/SBUF rank mismatch when the reshape inserted dims (Phase 3) + +These tests use crafted MLIR input (the output of annotate-memory-space) and +run only the legalize-layout pass to verify the transformation in isolation. + +Run with: python -m pytest tests/passes/legalize_layout/test_fold_reshape_copy.py -v +""" + +import pytest + +from nkipy_kernelgen.transforms.nkipy_opt import run_nkipy_opt_passes +from passes.pass_utils import run_filecheck + + +# ============================================================================ +# Helpers +# ============================================================================ + +def run_legalize_layout(mlir_input: str) -> str: + """Run only the legalize-layout pass on the given MLIR.""" + return run_nkipy_opt_passes(mlir_input, ['legalize-layout']) + + +# ============================================================================ +# Test: 2D SBUF alloc + copy + reshape to 3D +# ============================================================================ + +# This is the pattern produced by the upstream pipeline for: +# cos_sbuf = alloc(256x64, sbuf); copy(hbm -> cos_sbuf); reshape -> 256x1x64 +# The reshape is then subviewed [128,1,64] inside a tiled loop. + +MLIR_2D_COPY_RESHAPE_3D = ''' +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, 0, d2)> +module { + memref.global "private" constant @__constant_3xindex : memref<3xindex> = dense<[256, 1, 64]> {alignment = 64 : i64} + func.func @test_reshape_copy( + %arg0: memref<256x1x64xf32, strided<[?, ?, ?], offset: ?>, 4 : i32>, + %arg1: memref<256x64xf32, strided<[?, ?], offset: ?>, 4 : i32> + ) -> memref<256x1x64xf32, 4 : i32> { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0 = arith.constant 0 : index + %0 = memref.get_global @__constant_3xindex : memref<3xindex> + + // Pattern under test: 2D SBUF alloc + copy from HBM + reshape to 3D + %alloc = memref.alloc() {alignment = 64 : i64} : memref<256x64xf32, 3 : i32> + memref.copy %arg1, %alloc : memref<256x64xf32, strided<[?, ?], offset: ?>, 4 : i32> to memref<256x64xf32, 3 : i32> + %reshape = memref.reshape %alloc(%0) : (memref<256x64xf32, 3 : i32>, memref<3xindex>) -> memref<256x1x64xf32, 3 : i32> + + // Output buffer + %alloc_out = memref.alloc() {alignment = 64 : i64} : memref<256x1x64xf32, 3 : i32> + + // Tiled loop using [128, 1, 64] tiles of the reshaped buffer + // Linalg ops use 3D operands directly (Phase 4 handles collapse to 2D) + scf.for %arg3 = %c0 to %c2 step %c1 { + %off = arith.muli %arg3, %c128 : index + + // Load a tile from HBM input + %sv_hbm = memref.subview %arg0[%off, 0, 0] [128, 1, 64] [1, 1, 1] + : memref<256x1x64xf32, strided<[?, ?, ?], offset: ?>, 4 : i32> + to memref<128x1x64xf32, strided<[?, ?, ?], offset: ?>, 4 : i32> + %tile_in = memref.alloc() {alignment = 64 : i64} : memref<128x1x64xf32, 3 : i32> + memref.copy %sv_hbm, %tile_in + : memref<128x1x64xf32, strided<[?, ?, ?], offset: ?>, 4 : i32> + to memref<128x1x64xf32, 3 : i32> + + // Subview into the reshaped cos/sin buffer (the Phase 0 subject) + %sv_reshape = memref.subview %reshape[%off, 0, 0] [128, 1, 64] [1, 1, 1] + : memref<256x1x64xf32, 3 : i32> + to memref<128x1x64xf32, strided<[64, 64, 1], offset: ?>, 3 : i32> + + // Compute: elementwise multiply using 3D operands + %tile_out = memref.alloc() {alignment = 64 : i64} : memref<128x1x64xf32, 3 : i32> + linalg.generic {indexing_maps = [#map, #map1, #map], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%tile_in, %sv_reshape + : memref<128x1x64xf32, 3 : i32>, + memref<128x1x64xf32, strided<[64, 64, 1], offset: ?>, 3 : i32>) + outs(%tile_out : memref<128x1x64xf32, 3 : i32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %mul = arith.mulf %in, %in_1 : f32 + linalg.yield %mul : f32 + } + + // Store result tile + %sv_out = memref.subview %alloc_out[%off, 0, 0] [128, 1, 64] [1, 1, 1] + : memref<256x1x64xf32, 3 : i32> + to memref<128x1x64xf32, strided<[64, 64, 1], offset: ?>, 3 : i32> + memref.copy %tile_out, %sv_out + : memref<128x1x64xf32, 3 : i32> + to memref<128x1x64xf32, strided<[64, 64, 1], offset: ?>, 3 : i32> + } + + // Copy result to HBM + %alloc_hbm = memref.alloc() {alignment = 64 : i64} : memref<256x1x64xf32, 4 : i32> + memref.copy %alloc_out, %alloc_hbm + : memref<256x1x64xf32, 3 : i32> + to memref<256x1x64xf32, 4 : i32> + return %alloc_hbm : memref<256x1x64xf32, 4 : i32> + } +} +''' + + +def test_2d_copy_reshape_3d_legalized(): + """ + The 256x64 SBUF alloc + copy + reshape(256x1x64) pattern must be legalized. + + Phase 0 folds the reshape into the alloc: + alloc(256x1x64, sbuf) + collapse_shape(256x64) + copy(hbm, collapse) + + Then Phases 1-3 must: + 1. Trace through collapse_shape to find tile sizes [128, 1, 64] + 2. Legalize alloc to 128x2x1x1x64 physical layout + 3. Tile the HBM->SBUF copy with a loop, handling HBM rank (2) < SBUF rank (3) + + Checks: + - The 256x1x64 SBUF alloc is replaced by a legalized 5D alloc + - The HBM->SBUF copy is tiled (appears inside a scf.for) + - No un-legalized 256x1x64 SBUF alloc remains + """ + result = run_legalize_layout(MLIR_2D_COPY_RESHAPE_3D) + + check_patterns = ''' +CHECK: func.func @test_reshape_copy +CHECK-NOT: memref.alloc(){{.*}}: memref<256x1x64xf32, 3 : i32> +CHECK-NOT: memref.alloc(){{.*}}: memref<256x64xf32, 3 : i32> +CHECK: memref.alloc(){{.*}}: memref<128x2x1x1x64xf32, 3 : i32> +CHECK: scf.for +CHECK: memref.copy{{.*}}4 : i32>{{.*}}to{{.*}}3 : i32> +CHECK: linalg.generic +''' + run_filecheck(result, check_patterns) + + +# ============================================================================ +# Test: 2D SBUF alloc (no reshape) — baseline regression +# ============================================================================ + +MLIR_2D_SBUF_BASELINE = ''' +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @test_2d_baseline( + %arg0: memref<256x128xf32, strided<[?, ?], offset: ?>, 4 : i32>, + %arg1: memref<256x128xf32, strided<[?, ?], offset: ?>, 4 : i32> + ) -> memref<256x128xf32, 4 : i32> { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0 = arith.constant 0 : index + + %alloc_out = memref.alloc() {alignment = 64 : i64} : memref<256x128xf32, 3 : i32> + + scf.for %iv = %c0 to %c2 step %c1 { + %off = arith.muli %iv, %c128 : index + + %sv_a = memref.subview %arg0[%off, 0] [128, 128] [1, 1] + : memref<256x128xf32, strided<[?, ?], offset: ?>, 4 : i32> + to memref<128x128xf32, strided<[?, ?], offset: ?>, 4 : i32> + %tile_a = memref.alloc() {alignment = 64 : i64} : memref<128x128xf32, 3 : i32> + memref.copy %sv_a, %tile_a + : memref<128x128xf32, strided<[?, ?], offset: ?>, 4 : i32> + to memref<128x128xf32, 3 : i32> + + %sv_b = memref.subview %arg1[%off, 0] [128, 128] [1, 1] + : memref<256x128xf32, strided<[?, ?], offset: ?>, 4 : i32> + to memref<128x128xf32, strided<[?, ?], offset: ?>, 4 : i32> + %tile_b = memref.alloc() {alignment = 64 : i64} : memref<128x128xf32, 3 : i32> + memref.copy %sv_b, %tile_b + : memref<128x128xf32, strided<[?, ?], offset: ?>, 4 : i32> + to memref<128x128xf32, 3 : i32> + + %tile_out = memref.alloc() {alignment = 64 : i64} : memref<128x128xf32, 3 : i32> + linalg.add ins(%tile_a, %tile_b + : memref<128x128xf32, 3 : i32>, + memref<128x128xf32, 3 : i32>) + outs(%tile_out : memref<128x128xf32, 3 : i32>) + + %sv_out = memref.subview %alloc_out[%off, 0] [128, 128] [1, 1] + : memref<256x128xf32, 3 : i32> + to memref<128x128xf32, strided<[128, 1], offset: ?>, 3 : i32> + memref.copy %tile_out, %sv_out + : memref<128x128xf32, 3 : i32> + to memref<128x128xf32, strided<[128, 1], offset: ?>, 3 : i32> + } + + %alloc_hbm = memref.alloc() {alignment = 64 : i64} : memref<256x128xf32, 4 : i32> + memref.copy %alloc_out, %alloc_hbm + : memref<256x128xf32, 3 : i32> + to memref<256x128xf32, 4 : i32> + return %alloc_hbm : memref<256x128xf32, 4 : i32> + } +} +''' + + +def test_2d_sbuf_baseline(): + """ + Baseline: 2D SBUF alloc (256x128) without reshape legalizes normally. + + The 256x128 alloc with tile [128, 128] should become 128x2x1x128 (4D). + No Phase 0 pattern is involved — this is the standard path. + """ + result = run_legalize_layout(MLIR_2D_SBUF_BASELINE) + + check_patterns = ''' +CHECK: func.func @test_2d_baseline +CHECK: memref.alloc(){{.*}}: memref<128x2x1x128xf32, 3 : i32> +CHECK: scf.for +CHECK: linalg.add +CHECK: return{{.*}}4 : i32 +''' + run_filecheck(result, check_patterns) + + +# ============================================================================ +# Test: 3D SBUF alloc with full-buffer copy from 3D HBM (no reshape, no rank mismatch) +# ============================================================================ + +MLIR_3D_SBUF_COPY = ''' +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +module { + func.func @test_3d_copy( + %arg0: memref<256x2x64xf32, strided<[?, ?, ?], offset: ?>, 4 : i32>, + %arg1: memref<256x2x64xf32, strided<[?, ?, ?], offset: ?>, 4 : i32> + ) -> memref<256x2x64xf32, 4 : i32> { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0 = arith.constant 0 : index + + // Full SBUF alloc loaded from HBM (no reshape -- ranks match) + %alloc = memref.alloc() {alignment = 64 : i64} : memref<256x2x64xf32, 3 : i32> + memref.copy %arg1, %alloc + : memref<256x2x64xf32, strided<[?, ?, ?], offset: ?>, 4 : i32> + to memref<256x2x64xf32, 3 : i32> + + %alloc_out = memref.alloc() {alignment = 64 : i64} : memref<256x2x64xf32, 3 : i32> + + // Linalg ops use 3D operands directly (Phase 4 handles collapse to 2D) + scf.for %i = %c0 to %c2 step %c1 { + %off = arith.muli %i, %c128 : index + scf.for %j = %c0 to %c2 step %c1 { + %sv_a = memref.subview %arg0[%off, %j, 0] [128, 1, 64] [1, 1, 1] + : memref<256x2x64xf32, strided<[?, ?, ?], offset: ?>, 4 : i32> + to memref<128x1x64xf32, strided<[?, ?, ?], offset: ?>, 4 : i32> + %tile_a = memref.alloc() {alignment = 64 : i64} : memref<128x1x64xf32, 3 : i32> + memref.copy %sv_a, %tile_a + : memref<128x1x64xf32, strided<[?, ?, ?], offset: ?>, 4 : i32> + to memref<128x1x64xf32, 3 : i32> + + %sv_b = memref.subview %alloc[%off, %j, 0] [128, 1, 64] [1, 1, 1] + : memref<256x2x64xf32, 3 : i32> + to memref<128x1x64xf32, strided<[128, 64, 1], offset: ?>, 3 : i32> + + %tile_out = memref.alloc() {alignment = 64 : i64} : memref<128x1x64xf32, 3 : i32> + linalg.generic {indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%tile_a, %sv_b + : memref<128x1x64xf32, 3 : i32>, + memref<128x1x64xf32, strided<[128, 64, 1], offset: ?>, 3 : i32>) + outs(%tile_out : memref<128x1x64xf32, 3 : i32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %add = arith.addf %in, %in_1 : f32 + linalg.yield %add : f32 + } + + %sv_out = memref.subview %alloc_out[%off, %j, 0] [128, 1, 64] [1, 1, 1] + : memref<256x2x64xf32, 3 : i32> + to memref<128x1x64xf32, strided<[128, 64, 1], offset: ?>, 3 : i32> + memref.copy %tile_out, %sv_out + : memref<128x1x64xf32, 3 : i32> + to memref<128x1x64xf32, strided<[128, 64, 1], offset: ?>, 3 : i32> + } + } + + %hbm = memref.alloc() {alignment = 64 : i64} : memref<256x2x64xf32, 4 : i32> + memref.copy %alloc_out, %hbm + : memref<256x2x64xf32, 3 : i32> + to memref<256x2x64xf32, 4 : i32> + return %hbm : memref<256x2x64xf32, 4 : i32> + } +} +''' + + +def test_3d_sbuf_full_copy(): + """ + 3D SBUF alloc (256x2x64) with full-buffer copy from 3D HBM. + + No reshape involved — HBM and SBUF have the same rank (3). + The alloc should be legalized to 5D: 128x2x2x1x64. + The full-buffer HBM->SBUF copy should be tiled into a 3-level loop. + """ + result = run_legalize_layout(MLIR_3D_SBUF_COPY) + + check_patterns = ''' +CHECK: func.func @test_3d_copy +CHECK: memref.alloc(){{.*}}: memref<128x2x2x1x64xf32, 3 : i32> +CHECK: scf.for +CHECK: scf.for +CHECK: scf.for +CHECK: memref.copy{{.*}}4 : i32>{{.*}}to{{.*}}3 : i32> +CHECK: linalg.generic +CHECK: return{{.*}}4 : i32 +''' + run_filecheck(result, check_patterns) + + +# ============================================================================ +# Test: expandTileShape with multi-non-unit collapse group (Fix 2) +# ============================================================================ + +# Pattern: 3D SBUF alloc (128, 2, 128) with collapse_shape [[0],[1,2]] → 2D (128, 256). +# A linalg op uses 2D tiles [128, 128]. expandTileShape must expand +# tile=128 for group [1,2] with srcShape=[2, 128]. +# +# Before fix: expanded=[2, 64] → middle tile=2 ≠ 1 → REJECTED by legalize-layout. +# After fix: expanded=[1, 128] → middle tile=1 → legalized to 5D physical. + +MLIR_MULTI_NON_UNIT_COLLAPSE = ''' +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @test_expand_tile_multi_non_unit( + %arg0: memref<128x2x128xf32, strided<[?, ?, ?], offset: ?>, 4 : i32> + ) -> memref<128x256xf32, 4 : i32> { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0 = arith.constant 0 : index + + // 3D SBUF alloc — partition at dim 0, batch at dim 1, free at dim 2 + %alloc_3d = memref.alloc() {alignment = 64 : i64} : memref<128x2x128xf32, 3 : i32> + + // Manually tiled HBM→SBUF copy + scf.for %j = %c0 to %c2 step %c1 { + %sv_hbm = memref.subview %arg0[0, %j, 0] [128, 1, 128] [1, 1, 1] + : memref<128x2x128xf32, strided<[?, ?, ?], offset: ?>, 4 : i32> + to memref<128x1x128xf32, strided<[?, ?, ?], offset: ?>, 4 : i32> + %sv_sbuf = memref.subview %alloc_3d[0, %j, 0] [128, 1, 128] [1, 1, 1] + : memref<128x2x128xf32, 3 : i32> + to memref<128x1x128xf32, strided<[256, 128, 1], offset: ?>, 3 : i32> + memref.copy %sv_hbm, %sv_sbuf + : memref<128x1x128xf32, strided<[?, ?, ?], offset: ?>, 4 : i32> + to memref<128x1x128xf32, strided<[256, 128, 1], offset: ?>, 3 : i32> + } + + // collapse_shape [[0],[1,2]] → 2D (128, 256) + // group [1,2] has srcShape=[2, 128] — multi-non-unit! + %collapsed = memref.collapse_shape %alloc_3d [[0], [1, 2]] + : memref<128x2x128xf32, 3 : i32> into memref<128x256xf32, 3 : i32> + + // Output in HBM (avoids legalization issues on the output side) + %alloc_out = memref.alloc() {alignment = 64 : i64} : memref<128x256xf32, 4 : i32> + + // Tiled loop using 2D [128, 128] tiles of the collapsed view + scf.for %j = %c0 to %c2 step %c1 { + %off = arith.muli %j, %c128 : index + + %sv_in = memref.subview %collapsed[0, %off] [128, 128] [1, 1] + : memref<128x256xf32, 3 : i32> + to memref<128x128xf32, strided<[256, 1], offset: ?>, 3 : i32> + + %sv_out = memref.subview %alloc_out[0, %off] [128, 128] [1, 1] + : memref<128x256xf32, 4 : i32> + to memref<128x128xf32, strided<[256, 1], offset: ?>, 4 : i32> + + %tile_out = memref.alloc() {alignment = 64 : i64} : memref<128x128xf32, 3 : i32> + linalg.generic {indexing_maps = [#map, #map], + iterator_types = ["parallel", "parallel"]} + ins(%sv_in + : memref<128x128xf32, strided<[256, 1], offset: ?>, 3 : i32>) + outs(%tile_out : memref<128x128xf32, 3 : i32>) { + ^bb0(%in: f32, %out: f32): + %exp = math.exp %in : f32 + linalg.yield %exp : f32 + } + + memref.copy %tile_out, %sv_out + : memref<128x128xf32, 3 : i32> + to memref<128x128xf32, strided<[256, 1], offset: ?>, 4 : i32> + } + + return %alloc_out : memref<128x256xf32, 4 : i32> + } +} +''' + + +def test_expand_tile_multi_non_unit_collapse(): + """ + expandTileShape must handle collapse groups with multiple non-unit dims. + + Alloc: memref<128x2x128, sbuf> collapsed to 2D via [[0],[1,2]]. + Linalg ops use 2D tiles [128, 128]. + + expandTileShape must expand tile=128 for group [1,2] (srcShape=[2,128]): + Correct: [1, 128] — middle tile=1, legalize-layout accepts. + Old bug: [2, 64] — middle tile=2, legalize-layout rejects. + + After legalization the 3D alloc becomes 5D physical: + tile=[128, 1, 128], numBlocks=[1, 2, 1] → [128, 1, 2, 1, 128] + """ + result = run_legalize_layout(MLIR_MULTI_NON_UNIT_COLLAPSE) + + check_patterns = ''' +CHECK: func.func @test_expand_tile_multi_non_unit +CHECK: memref.alloc(){{.*}}: memref<128x1x2x1x128xf32, 3 : i32> +CHECK: scf.for +CHECK: linalg.generic +CHECK: return{{.*}}4 : i32 +''' + run_filecheck(result, check_patterns) + + +# ============================================================================ +# Test Runner +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/passes/linalg_to_nisa/__init__.py b/kernelgen/tests/passes/linalg_to_nisa/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kernelgen/tests/passes/linalg_to_nisa/test_basic.py b/kernelgen/tests/passes/linalg_to_nisa/test_basic.py new file mode 100644 index 0000000..7fe68c5 --- /dev/null +++ b/kernelgen/tests/passes/linalg_to_nisa/test_basic.py @@ -0,0 +1,105 @@ +""" +Tests for linalg.sqrt -> nisa.activation(op=sqrt) lowering. + +The linalg-to-nisa pass should convert linalg.sqrt into nisa.activation +with op=sqrt, running on the SCALAR engine. + +Run with: python -m pytest tests/passes/linalg_to_nisa/test_sqrt.py -v +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + + +def test_sqrt_basic(): + """ + Basic sqrt: linalg.sqrt should be lowered to nisa.activation(op=sqrt). + """ + shape = (128, 256) + tile_size = [64, 128] + + @trace(input_specs=[(shape, "f32")]) + def kernel(a): + result = np.sqrt(a) + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size) + return result + + np.random.seed(42) + A = np.abs(np.random.randn(*shape)).astype(np.float32) + 0.01 + + # After linalg-to-nisa: sqrt should become nisa.activation + check_patterns = """ + CHECK: func.func + CHECK: nisa.activation + CHECK-SAME: op=sqrt + CHECK-NOT: linalg.sqrt + CHECK: return + """ + run_kernel_test( + kernel, + stop_after='linalg-to-nisa', + check_patterns=check_patterns, + inputs=[A], + modes=Mode.FILECHECK, + ) + + # BIR simulation: verify numerical correctness through full pipeline + run_kernel_test( + kernel, + + check_ir_contains=["nisa.activation", "op=sqrt"], + inputs=[A], + modes=Mode.BIR_SIM | Mode.STRING_CHECK, + ) + + +def test_sqrt_256x256(): + """ + Sqrt on a 256x256 tensor with 128x128 tiles. + """ + shape = (256, 256) + tile_size = [128, 128] + + @trace(input_specs=[(shape, "f32")]) + def kernel(a): + result = np.sqrt(a) + knob.knob(result, mem_space="SharedHbm", tile_size=tile_size) + return result + + np.random.seed(42) + A = np.abs(np.random.randn(*shape)).astype(np.float32) + 0.01 + + check_patterns = """ + CHECK: func.func + CHECK: nisa.activation + CHECK-SAME: op=sqrt + CHECK-NOT: linalg.sqrt + CHECK: return + """ + run_kernel_test( + kernel, + stop_after='linalg-to-nisa', + check_patterns=check_patterns, + inputs=[A], + modes=Mode.FILECHECK, + ) + + # BIR simulation: verify numerical correctness through full pipeline + run_kernel_test( + kernel, + + check_ir_contains=["nisa.activation", "op=sqrt"], + inputs=[A], + modes=Mode.BIR_SIM | Mode.STRING_CHECK, + ) + + +# ============================================================================ +# Test Runner +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/passes/linalg_to_nisa/test_multi_non_unit_collapse.py b/kernelgen/tests/passes/linalg_to_nisa/test_multi_non_unit_collapse.py new file mode 100644 index 0000000..01ec2d1 --- /dev/null +++ b/kernelgen/tests/passes/linalg_to_nisa/test_multi_non_unit_collapse.py @@ -0,0 +1,205 @@ +""" +Tests for linalg-to-nisa handling of collapse_shape with multi-non-unit groups. + +After legalize-layout, a 4D SBUF tensor (e.g. (128, 2, 2, 128) from a head- +deconcat pattern) gets collapsed to 2D via collapse_shape [[0,1],[2,3]]. Both +groups have multiple non-unit dims ([128,2] and [2,128]). + +getBaseAndOffsets must trace through this collapse to reach the SBUF alloc base, +because NCC requires NISA ops to reference alloc results. For each multi-non- +unit group, the largest dim carries the data (partition or free tile), and +smaller dims are batch loop indices that should be dropped. + +Run with: python -m pytest tests/passes/linalg_to_nisa/test_multi_non_unit_collapse.py -v +""" + +import pytest + +from nkipy_kernelgen.transforms.nkipy_opt import run_nkipy_opt_passes +from nkipy_kernelgen.transforms.linalg_to_nisa_py import linalg_to_nisa +from passes.pass_utils import run_filecheck + + +# ============================================================================ +# Helpers +# ============================================================================ + +def run_linalg_to_nisa(mlir_input: str) -> str: + """Run linalg-to-nisa on the given MLIR. + + simplify-linalg still runs in C++ (`nkipy-opt`); the actual linalg→NISA + lowering was moved to Python as part of open-sourcing so we call the + Python implementation directly here instead of shelling out to an + `nkipy-opt --linalg-to-nisa` pass that no longer exists. + + Use ``print_generic=False`` so FileCheck patterns can reference the + pretty form (e.g. ``nisa.tensor_tensor_arith`` rather than + ``\"nisa.tensor_tensor_arith\"``). + """ + simplified = run_nkipy_opt_passes(mlir_input, ['simplify-linalg']) + return linalg_to_nisa(simplified, target='trn2', print_generic=False) + + +# ============================================================================ +# Test: 4D SBUF alloc collapsed to 2D with multi-non-unit groups +# ============================================================================ + +# This represents the WA3 (head-deconcat) scenario after legalize-layout: +# 4D SBUF alloc: memref<128x2x2x128xf32, #sbuf> +# collapse_shape [[0,1],[2,3]] → memref<256x256xf32, #sbuf> +# subview [128, 128] tiles in a nested loop +# linalg.add on the tiles +# +# getBaseAndOffsets must decompose the collapsed indices: +# Group [0,1]: srcShape=[128,2] → dim 0 (128) is primary, dim 1 (2) is dropped +# Group [2,3]: srcShape=[2,128] → dim 3 (128) is primary, dim 2 (2) is dropped + +MLIR_MULTI_NON_UNIT_COLLAPSE = ''' +module { + func.func @test_multi_non_unit_collapse( + %arg0: memref<256x256xf32, strided<[?, ?], offset: ?>, 4 : i32> + ) -> memref<256x256xf32, 4 : i32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + + // 4D SBUF alloc — physical layout from legalize-layout + %alloc = memref.alloc() {alignment = 64 : i64} : memref<128x2x2x128xf32, 3 : i32> + + // Load data from HBM into the 4D alloc tile by tile + scf.for %bm = %c0 to %c2 step %c1 { + scf.for %bn = %c0 to %c2 step %c1 { + %hbm_off_m = arith.muli %bm, %c128 : index + %hbm_off_n = arith.muli %bn, %c128 : index + %sv_hbm = memref.subview %arg0[%hbm_off_m, %hbm_off_n] [128, 128] [1, 1] + : memref<256x256xf32, strided<[?, ?], offset: ?>, 4 : i32> + to memref<128x128xf32, strided<[?, ?], offset: ?>, 4 : i32> + %sv_sbuf = memref.subview %alloc[0, %bm, %bn, 0] [128, 1, 1, 128] [1, 1, 1, 1] + : memref<128x2x2x128xf32, 3 : i32> + to memref<128x1x1x128xf32, strided<[512, 256, 128, 1], offset: ?>, 3 : i32> + %sv_sbuf_2d = memref.collapse_shape %sv_sbuf [[0], [1, 2, 3]] + : memref<128x1x1x128xf32, strided<[512, 256, 128, 1], offset: ?>, 3 : i32> + into memref<128x128xf32, strided<[512, 1], offset: ?>, 3 : i32> + memref.copy %sv_hbm, %sv_sbuf_2d + : memref<128x128xf32, strided<[?, ?], offset: ?>, 4 : i32> + to memref<128x128xf32, strided<[512, 1], offset: ?>, 3 : i32> + } + } + + // Collapse the 4D alloc to 2D — creates multi-non-unit groups + %collapsed = memref.collapse_shape %alloc [[0, 1], [2, 3]] + : memref<128x2x2x128xf32, 3 : i32> + into memref<256x256xf32, 3 : i32> + + // Output alloc + %alloc_out = memref.alloc() {alignment = 64 : i64} : memref<128x2x2x128xf32, 3 : i32> + %collapsed_out = memref.collapse_shape %alloc_out [[0, 1], [2, 3]] + : memref<128x2x2x128xf32, 3 : i32> + into memref<256x256xf32, 3 : i32> + + // Tiled computation on collapsed 2D view + scf.for %bm = %c0 to %c2 step %c1 { + scf.for %bn = %c0 to %c2 step %c1 { + %off_m = arith.muli %bm, %c128 : index + %off_n = arith.muli %bn, %c128 : index + + %sv_in = memref.subview %collapsed[%off_m, %off_n] [128, 128] [1, 1] + : memref<256x256xf32, 3 : i32> + to memref<128x128xf32, strided<[256, 1], offset: ?>, 3 : i32> + + %tile_out = memref.alloc() {alignment = 64 : i64} : memref<128x128xf32, 3 : i32> + + linalg.add ins(%sv_in, %sv_in + : memref<128x128xf32, strided<[256, 1], offset: ?>, 3 : i32>, + memref<128x128xf32, strided<[256, 1], offset: ?>, 3 : i32>) + outs(%tile_out : memref<128x128xf32, 3 : i32>) + + %sv_out = memref.subview %collapsed_out[%off_m, %off_n] [128, 128] [1, 1] + : memref<256x256xf32, 3 : i32> + to memref<128x128xf32, strided<[256, 1], offset: ?>, 3 : i32> + memref.copy %tile_out, %sv_out + : memref<128x128xf32, 3 : i32> + to memref<128x128xf32, strided<[256, 1], offset: ?>, 3 : i32> + } + } + + // Copy back to HBM + %alloc_hbm = memref.alloc() {alignment = 64 : i64} : memref<256x256xf32, 4 : i32> + %collapsed_hbm = memref.collapse_shape %alloc_out [[0, 1], [2, 3]] + : memref<128x2x2x128xf32, 3 : i32> + into memref<256x256xf32, 3 : i32> + memref.copy %collapsed_hbm, %alloc_hbm + : memref<256x256xf32, 3 : i32> + to memref<256x256xf32, 4 : i32> + return %alloc_hbm : memref<256x256xf32, 4 : i32> + } +} +''' + + +def test_multi_non_unit_collapse_nisa_lowering(): + """ + Verify that linalg-to-nisa correctly handles collapse_shape with multi- + non-unit groups in SBUF. + + Before fix: getBaseAndOffsets stops at the collapse (can't decompose + multi-non-unit groups), leaving stale memref as base → wrong NISA map. + + After fix: getBaseAndOffsets traces through the collapse, marking batch + dims as dropped. The linalg.add is lowered to nisa.tensor_tensor_arith + with the 4D alloc as base. + """ + result = run_linalg_to_nisa(MLIR_MULTI_NON_UNIT_COLLAPSE) + + # The linalg.add should be lowered to nisa.tensor_tensor_arith + # referencing the 4D SBUF alloc (not the collapsed 2D view). + check_patterns = ''' +CHECK: func.func @test_multi_non_unit_collapse +CHECK: nisa.alloc +CHECK: nisa.tensor_tensor_arith +CHECK-SAME: op=add +CHECK-NOT: linalg.add +CHECK: return +''' + run_filecheck(result, check_patterns) + + +def test_multi_non_unit_collapse_correct_indices(): + """ + Verify that the decomposed indices are correct: + - dim 0 (partition, 128): base 0, iterated by d0 + - dim 1 (batch, 2): batch index = collapsed_offset / 128 + - dim 2 (batch, 2): batch index = collapsed_offset / 128 + - dim 3 (free, 128): base 0, iterated by d1 + + The NISA map should show: + %mem[%c0 + d0, + 0, + 0, %c0 + d1] + where and are divui results. + """ + result = run_linalg_to_nisa(MLIR_MULTI_NON_UNIT_COLLAPSE) + + # Verify correct index structure in tensor_tensor_arith: + # - dim 0: constant_0 + d0 (partition tile) + # - dim 1: divui result + 0 (batch block, dropped) + # - dim 3: constant_0 + d1 (free tile) + # The divui decomposes collapsed_offset / primary_size. + # Constants get unique names when the module is re-serialized, so use + # a regex wildcard to match `%c0`, `%c0_9`, `%c0_11`, etc. + check_patterns = ''' +CHECK: arith.divui +CHECK: nisa.tensor_tensor_arith +CHECK-SAME: %c0{{.*}} + d0 +CHECK-SAME: + 0 +CHECK-SAME: + 0 +CHECK-SAME: %c0{{.*}} + d1 +''' + run_filecheck(result, check_patterns) + + +# ============================================================================ +# Test Runner +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/passes/pass_utils.py b/kernelgen/tests/passes/pass_utils.py new file mode 100644 index 0000000..858ff32 --- /dev/null +++ b/kernelgen/tests/passes/pass_utils.py @@ -0,0 +1,363 @@ +""" +Common utilities for testing MLIR passes. + +Includes FileCheck support, MLIR compilation helpers, and test infrastructure. +""" + +import os +import subprocess +import tempfile +from pathlib import Path +from typing import List, Optional + +import numpy as np + +from nkipy_kernelgen import apply_passes +from nkipy_kernelgen.transforms.nkipy_opt import ( + run_nkipy_opt_passes, + apply_complete_knob_pipeline, +) +from nkipy_kernelgen.llvm import LLVMModule, extract_and_clean_func_from_module + + +# ============================================================================ +# MLIR Pass Execution +# ============================================================================ + + +def run_passes( + mlir_module: str, passes: List[str], print_ir_after_all: bool = False +) -> str: + """ + Run a list of passes on an MLIR module. + + Args: + mlir_module: Input MLIR module as string + passes: List of pass names to run + print_ir_after_all: If True, print IR after each pass for debugging + + Returns: + Transformed MLIR module as string + """ + return run_nkipy_opt_passes(mlir_module, passes, print_ir_after_all) + + +def trace_to_mlir_with_preprocessing(traced_func) -> str: + """ + Convert a traced function to MLIR string. + + Args: + traced_func: A traced function with .to_mlir() method + + Returns: + MLIR module as string + """ + mlir_module = traced_func.to_mlir() + return str(mlir_module) + + +def compile_through_passes( + traced_func, + passes: List[str], + dump_dir: Optional[str] = None, + preprocessing: bool = True, +) -> str: + """ + Compile a traced function through a specified pass pipeline. + + Args: + traced_func: A traced function with .to_mlir() method + passes: List of pass names to run + dump_dir: Optional directory to save intermediate MLIR files + preprocessing: Whether to apply preprocessing (currently a no-op, kept for API compat) + + Returns: + Final MLIR module as string + """ + # Get MLIR from traced function + if preprocessing: + mlir_str = trace_to_mlir_with_preprocessing(traced_func) + else: + mlir_str = str(traced_func.to_mlir()) + + if dump_dir: + os.makedirs(dump_dir, exist_ok=True) + save_mlir_to_file(mlir_str, os.path.join(dump_dir, "00_input.mlir")) + + # Run each pass individually and save output + current_mlir = mlir_str + for i, pass_name in enumerate(passes): + try: + current_mlir = run_nkipy_opt_passes( + current_mlir, [pass_name], print_stderr=True + ) + if dump_dir: + output_filename = f"{i + 1:02d}_{pass_name.replace('-', '_')}.mlir" + save_mlir_to_file(current_mlir, os.path.join(dump_dir, output_filename)) + except RuntimeError as e: + # On error, save what we have and re-raise + if dump_dir: + error_file = os.path.join(dump_dir, f"ERROR_{pass_name}.txt") + with open(error_file, "w") as f: + f.write(str(e)) + raise + + return current_mlir + + +# ============================================================================ +# File Utilities +# ============================================================================ + + +def save_mlir_to_file(mlir_text: str, output_path: str) -> None: + """ + Save MLIR text to a file. + + Args: + mlir_text: The MLIR module text to save + output_path: Path to the output file + """ + output_dir = os.path.dirname(output_path) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + + with open(output_path, "w") as f: + f.write(mlir_text) + print(f"Saved MLIR to: {output_path}") + + +def get_test_output_dir(test_file: str) -> str: + """Get the output directory for a specific test file.""" + test_dir = Path(test_file).parent + output_dir = test_dir / "outputs" + os.makedirs(output_dir, exist_ok=True) + return str(output_dir) + + +# ============================================================================ +# FileCheck Support +# ============================================================================ + + +def get_filecheck_path() -> str: + """Get the path to the FileCheck executable.""" + # Check in LLVM build directory + package_dir = Path(__file__).parent.parent.parent + llvm_build = package_dir.parent / "llvm-project" / "build" / "bin" / "FileCheck" + + if llvm_build.exists(): + return str(llvm_build) + + # Check in PATH + import shutil + + filecheck = shutil.which("FileCheck") + if filecheck: + return filecheck + + raise FileNotFoundError( + "FileCheck not found. Please build LLVM or add FileCheck to PATH.\n" + f"Looked in: {llvm_build}" + ) + + +def assert_ir_unchanged( + before_file: str, after_file: str, pass_name: str = "pass" +) -> None: + """ + Assert that two MLIR files are identical. + + This is useful for testing that a pass doesn't modify IR when there's + nothing to transform (e.g., no SBUF outputs for legalize-sbuf-outputs). + + Args: + before_file: Path to the MLIR file before the pass + after_file: Path to the MLIR file after the pass + pass_name: Name of the pass (for error messages) + + Raises: + AssertionError: If the files don't exist + pytest.fail: If the files differ + """ + import pytest + + assert os.path.exists(before_file), f"Before file not found: {before_file}" + assert os.path.exists(after_file), f"After file not found: {after_file}" + + # Use diff to compare the files + result = subprocess.run( + ["diff", "-q", before_file, after_file], capture_output=True, text=True + ) + + if result.returncode != 0: + # Files differ - show the diff + diff_result = subprocess.run( + ["diff", "-u", before_file, after_file], capture_output=True, text=True + ) + pytest.fail( + f"{pass_name} pass modified IR when it should not have!\n" + f"Diff:\n{diff_result.stdout}" + ) + + +def run_filecheck(mlir_output: str, check_patterns: str) -> None: + """ + Run FileCheck to verify MLIR output against check patterns. + + Args: + mlir_output: The MLIR text to verify + check_patterns: String containing CHECK patterns (e.g., "CHECK: scf.for\\nCHECK: linalg.add") + + Raises: + AssertionError: If FileCheck fails, with detailed error message + + Example: + check_patterns = ''' + CHECK: func.func + CHECK: scf.for + CHECK-NOT: linalg.fill + CHECK: linalg.matmul + ''' + run_filecheck(mlir_output, check_patterns) + + FileCheck Regex Notes: + - Use {{.*}} to match any text (FileCheck regex) + - In Python f-strings, use {{{{.*}}}} to get {{.*}} + - Use %c0{{.*}} to match %c0, %c0_1, %c0_3, etc. + """ + filecheck = get_filecheck_path() + + # Write check patterns to temp file + with tempfile.NamedTemporaryFile(mode="w", suffix=".check", delete=False) as f: + f.write(check_patterns) + check_file = f.name + + try: + # Run FileCheck: reads patterns from check_file, reads input from stdin + result = subprocess.run( + [filecheck, check_file], input=mlir_output, capture_output=True, text=True + ) + + if result.returncode != 0: + # FileCheck failed - provide helpful error message + error_msg = "FileCheck verification failed!\n" + error_msg += f"FileCheck stderr:\n{result.stderr}\n" + error_msg += f"\n--- Check Patterns ---\n{check_patterns}\n" + error_msg += ( + f"\n--- MLIR Output (first 3000 chars) ---\n{mlir_output[:3000]}\n" + ) + raise AssertionError(error_msg) + + finally: + if os.path.exists(check_file): + os.unlink(check_file) + + +# ============================================================================ +# Knob Pipeline +# ============================================================================ + + +def compile_knob_pipeline(traced_func, stop_after=None, dump_dir=None, **kwargs): + """ + Trace a function and run it through the knob compilation pipeline. + + Args: + traced_func: A traced function with .to_mlir() method + stop_after: Pass name (str) or index (int) to stop after, or None for all passes + dump_dir: Optional directory to save intermediate MLIR files + **kwargs: Additional arguments passed to apply_complete_knob_pipeline + + Returns: + Transformed MLIR module as string + """ + mlir_str = trace_to_mlir_with_preprocessing(traced_func) + return apply_complete_knob_pipeline( + mlir_str, stop_after=stop_after, dump_dir=dump_dir, **kwargs + ) + + +# ============================================================================ +# LLVM CPU Execution Verification +# ============================================================================ + + +def verify_tiled_mlir_with_numpy( + tiled_mlir: str, + traced_func, + rtol: float = 1e-5, + atol: float = 1e-6, + func_name: str = "top", +) -> None: + """ + Verify that compiled MLIR produces the same results as the original function. + + This function: + 1. Extracts input specs from the traced function + 2. Generates random test inputs based on those specs + 3. Runs the original function (via __wrapped__) to get expected output + 4. Compiles and executes the MLIR using LLVM JIT + 5. Compares the MLIR output with the original function output + + Args: + tiled_mlir: The transformed MLIR code + traced_func: A traced function (decorated with @trace) with __wrapped__ attribute + rtol: Relative tolerance for np.allclose comparison + atol: Absolute tolerance for np.allclose comparison + func_name: Name of the top-level function in the MLIR module + + Raises: + AssertionError: If MLIR output doesn't match original function output + """ + original_func = traced_func.__wrapped__ + input_specs = traced_func.input_specs + + dtype_map = { + "f32": np.float32, + "f64": np.float64, + "f16": np.float16, + "i32": np.int32, + "i64": np.int64, + } + + inputs = [] + for shape, dtype_str in input_specs: + if dtype_str not in dtype_map: + raise ValueError( + f"Unsupported dtype: {dtype_str}. Supported: {list(dtype_map.keys())}" + ) + np_dtype = dtype_map[dtype_str] + if np_dtype in [np.float32, np.float64, np.float16]: + arr = np.random.rand(*shape).astype(np_dtype) + else: + arr = np.random.randint(0, 100, size=shape).astype(np_dtype) + inputs.append(arr) + + numpy_result = original_func(*inputs) + + clean_mlir, actual_func_name = extract_and_clean_func_from_module(tiled_mlir) + runner = LLVMModule(clean_mlir, actual_func_name) + mlir_result = runner(*[inp.copy() for inp in inputs]) + + # Normalize both to lists for uniform comparison + if not isinstance(mlir_result, list): + mlir_result = [mlir_result] + if isinstance(numpy_result, tuple): + numpy_result = list(numpy_result) + elif not isinstance(numpy_result, list): + numpy_result = [numpy_result] + + assert len(mlir_result) == len(numpy_result), ( + f"Output count mismatch: MLIR returned {len(mlir_result)}, " + f"NumPy returned {len(numpy_result)}" + ) + + for i, (mr, nr) in enumerate(zip(mlir_result, numpy_result)): + if not np.allclose(mr, nr, rtol=rtol, atol=atol): + max_diff = np.max(np.abs(mr - nr)) + raise AssertionError( + f"MLIR result does not match original function output (output {i}).\n" + f"Max difference: {max_diff}\n" + f"Relative tolerance: {rtol}, Absolute tolerance: {atol}" + ) diff --git a/kernelgen/tests/passes/prepare_arithmetic/__init__.py b/kernelgen/tests/passes/prepare_arithmetic/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kernelgen/tests/passes/prepare_arithmetic/test_basic.py b/kernelgen/tests/passes/prepare_arithmetic/test_basic.py new file mode 100644 index 0000000..a6dd720 --- /dev/null +++ b/kernelgen/tests/passes/prepare_arithmetic/test_basic.py @@ -0,0 +1,258 @@ +""" +Tests for the prepare-arithmetic pass. + +This pass converts division operations into multiplication by reciprocal +because NISA's tensor_tensor_arith doesn't support DIVIDE directly. + +Patterns tested: + - linalg.div(A, B) -> linalg.mul(A, linalg.reciprocal(B)) + - linalg.generic with scalar divf -> mulf with reciprocal constant + - linalg.generic with broadcast divf(block_arg, block_arg) + -> linalg.reciprocal(rhs) + linalg.generic with mulf + +Run with: python -m pytest tests/passes/prepare_arithmetic/test_basic.py -v +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import run_kernel_test, Mode + + +# ============================================================================ +# Named linalg.div (same-shape tensor / tensor) +# ============================================================================ + +def test_tensor_div_tensor_same_shape(): + """ + Same-shape division: linalg.div -> linalg.mul + linalg.reciprocal. + + After prepare-arithmetic, arith.divf should be gone, replaced by + linalg.reciprocal and linalg.mul. + """ + shape = (128, 256) + tile_size = [64, 128] + + @trace(input_specs=[(shape, "f32"), (shape, "f32")]) + def kernel(a, b): + result = np.divide(a, b) + knob.knob(result, tile_size=tile_size) + return result + + np.random.seed(42) + A = np.random.randn(*shape).astype(np.float32) + B = (np.abs(np.random.randn(*shape)) + 0.5).astype(np.float32) + + # After prepare-arithmetic: div replaced by reciprocal + mul + check_patterns = """ + CHECK: func.func + CHECK: linalg.reciprocal + CHECK: linalg.mul + CHECK-NOT: linalg.div + CHECK: return + """ + run_kernel_test( + kernel, + stop_after='prepare-arithmetic', + check_patterns=check_patterns, + inputs=[A, B], + modes=Mode.LLVM | Mode.FILECHECK, + ) + + +# ============================================================================ +# Scalar division: tensor / scalar constant +# ============================================================================ + +def test_tensor_div_scalar(): + """ + Tensor / scalar: linalg.generic { divf(%arg, %cst) } + -> linalg.generic { mulf(%arg, 1/cst) }. + + The divf in the body is replaced with mulf using the reciprocal constant. + """ + shape = (256, 256) + tile_size = [128, 128] + + @trace(input_specs=[(shape, "f32")]) + def kernel(x): + result = x / 2.0 + knob.knob(result, tile_size=tile_size) + return result + + # After prepare-arithmetic: divf replaced by mulf in body + check_patterns = """ + CHECK: func.func + CHECK: linalg.generic + CHECK: arith.mulf + CHECK-NOT: arith.divf + CHECK: return + """ + run_kernel_test( + kernel, + stop_after='prepare-arithmetic', + check_patterns=check_patterns, + modes=Mode.LLVM | Mode.FILECHECK, + ) + + +# ============================================================================ +# Scalar / tensor (reciprocal pattern) +# ============================================================================ + +def test_scalar_div_tensor(): + """ + Scalar / tensor: linalg.generic { divf(%cst, %arg) } + -> linalg.reciprocal(input). + + The entire generic is replaced with a reciprocal op. + """ + shape = (256, 256) + tile_size = [128, 128] + + @trace(input_specs=[(shape, "f32")]) + def kernel(x): + result = 1.0 / x + knob.knob(result, tile_size=tile_size) + return result + + # After prepare-arithmetic: replaced by linalg.reciprocal + check_patterns = """ + CHECK: func.func + CHECK: linalg.reciprocal + CHECK-NOT: arith.divf + CHECK: return + """ + run_kernel_test( + kernel, + stop_after='prepare-arithmetic', + check_patterns=check_patterns, + modes=Mode.LLVM | Mode.FILECHECK, + ) + + +# ============================================================================ +# Broadcast division: tensor / tensor +# ============================================================================ + +def test_broadcast_div_column(): + """ + Broadcast column division: tensor<256x256> / tensor<256x1>. + + The tracer emits linalg.generic with broadcast indexing maps and + arith.divf between block args. After prepare-arithmetic, the divf + should be replaced by linalg.reciprocal on the rhs + arith.mulf + in the body. + """ + shape_a = (256, 256) + shape_b = (256, 1) + tile_size = [128, 128] + + @trace(input_specs=[(shape_a, "f32"), (shape_b, "f32")]) + def kernel(a, b): + result = np.divide(a, b) + knob.knob(result, tile_size=tile_size) + return result + + np.random.seed(42) + A = np.random.randn(*shape_a).astype(np.float32) + B = (np.abs(np.random.randn(*shape_b)) + 0.5).astype(np.float32) + + # After prepare-arithmetic: reciprocal of rhs, mulf in generic body + check_patterns = """ + CHECK: func.func + CHECK: linalg.reciprocal + CHECK: linalg.generic + CHECK: arith.mulf + CHECK-NOT: arith.divf + CHECK: return + """ + run_kernel_test( + kernel, + stop_after='prepare-arithmetic', + check_patterns=check_patterns, + inputs=[A, B], + modes=Mode.LLVM | Mode.FILECHECK, + ) + + +def test_broadcast_div_row(): + """ + Broadcast row division: tensor<128x256> / tensor<1x256>. + """ + shape_a = (128, 256) + shape_b = (1, 256) + tile_size = [64, 128] + + @trace(input_specs=[(shape_a, "f32"), (shape_b, "f32")]) + def kernel(a, b): + result = np.divide(a, b) + knob.knob(result, tile_size=tile_size) + return result + + np.random.seed(42) + A = np.random.randn(*shape_a).astype(np.float32) + B = (np.abs(np.random.randn(*shape_b)) + 0.5).astype(np.float32) + + check_patterns = """ + CHECK: func.func + CHECK: linalg.reciprocal + CHECK: linalg.generic + CHECK: arith.mulf + CHECK-NOT: arith.divf + CHECK: return + """ + run_kernel_test( + kernel, + stop_after='prepare-arithmetic', + check_patterns=check_patterns, + inputs=[A, B], + modes=Mode.LLVM | Mode.FILECHECK, + ) + + +def test_broadcast_div_rmsnorm_pattern(): + """ + RMSNorm-like pattern: tensor<128x256> / tensor<128x1>. + + This is the most common real-world use case for broadcast division, + where each row is divided by its RMS norm value. + """ + shape_a = (128, 256) + shape_b = (128, 1) + tile_size = [64, 128] + + @trace(input_specs=[(shape_a, "f32"), (shape_b, "f32")]) + def kernel(values, rms): + result = np.divide(values, rms) + knob.knob(result, tile_size=tile_size) + return result + + np.random.seed(42) + A = np.random.randn(*shape_a).astype(np.float32) + B = (np.abs(np.random.randn(*shape_b)) + 0.5).astype(np.float32) + + check_patterns = """ + CHECK: func.func + CHECK: linalg.reciprocal + CHECK: linalg.generic + CHECK: arith.mulf + CHECK-NOT: arith.divf + CHECK: return + """ + run_kernel_test( + kernel, + stop_after='prepare-arithmetic', + check_patterns=check_patterns, + inputs=[A, B], + modes=Mode.LLVM | Mode.FILECHECK, + ) + + +# ============================================================================ +# Test Runner +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/passes/remove_linalg_zero_fill/__init__.py b/kernelgen/tests/passes/remove_linalg_zero_fill/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kernelgen/tests/passes/resolve_custom_ops/test_basic.py b/kernelgen/tests/passes/resolve_custom_ops/test_basic.py new file mode 100644 index 0000000..175e9c4 --- /dev/null +++ b/kernelgen/tests/passes/resolve_custom_ops/test_basic.py @@ -0,0 +1,193 @@ +""" +Tests for the custom-op resolution step. + +The standalone C++ `--resolve-custom-ops` pass was removed during +open-sourcing and re-implemented as `_resolve_custom_ops()` inside the +Python `linalg_to_nisa_py` phase, so these tests drive that function +directly against small hand-written fixtures. Each test: + +1. Parses an MLIR module (with a `nkipy.custom_op_bodies` dict, a + body-less `func.func private` decl, and a call site) into the + NKI-wheel MLIR context. +2. Runs `_resolve_custom_ops`. +3. Asserts the call is gone, the decl is gone, the module attribute is + gone, and the inlined ops are present. + +Run with: python -m pytest tests/passes/resolve_custom_ops/test_basic.py -v +""" + +import pytest + +from nki.compiler._internal import ir as nk_ir +from nki.compiler._internal._mlir_libs import _nki + +from nkipy_kernelgen.transforms.linalg_to_nisa_py import _resolve_custom_ops +from passes.pass_utils import run_filecheck + + +def _escape_mlir_string(s: str) -> str: + """Escape a string for embedding in an MLIR string attribute.""" + return s.replace("\\", "\\\\").replace('"', '\\"') + + +def _run_resolve(mlir_input: str) -> str: + """Parse in a fresh NKI context, run _resolve_custom_ops, return IR text.""" + ctx = nk_ir.Context() + _nki.register_all_dialects(ctx) + ctx.allow_unregistered_dialects = True + with ctx: + module = nk_ir.Module.parse(mlir_input, ctx) + _resolve_custom_ops(module, ctx) + return str(module) + + +# ============================================================================ +# Basic resolution: single custom op with output-as-argument convention +# ============================================================================ + + +def test_resolve_single_custom_op(): + """ + Resolve a single custom op declaration + call site. + + The NISA body uses output-as-argument convention: trailing args are outputs. + After resolution, the body is inlined at the call site with an alloc for output. + """ + nisa_body = _escape_mlir_string( + "module attributes {nisa.target = #nisa.target} {" + " func.func @my_op(" + " %arg0: memref<128x128xf32, #nisa.mem>," + " %arg1: memref<128x128xf32, #nisa.mem>" + ' ) attributes {nki.output_names = ["output"]} {' + " return" + " }" + "}" + ) + + mlir_input = f''' +module attributes {{ + nkipy.custom_op_bodies = {{ + "__custom_op__my_op" = "{nisa_body}" + }} +}} {{ + func.func @main_kernel( + %arg0: memref<128x128xf32, #nisa.mem> + ) -> memref<128x128xf32, #nisa.mem> {{ + %result = func.call @__custom_op__my_op(%arg0) + : (memref<128x128xf32, #nisa.mem>) -> memref<128x128xf32, #nisa.mem> + return %result : memref<128x128xf32, #nisa.mem> + }} + + func.func private @__custom_op__my_op( + memref<128x128xf32, #nisa.mem> + ) -> memref<128x128xf32, #nisa.mem> + attributes {{nkipy.custom_op}} +}} +''' + + result = _run_resolve(mlir_input) + + # Body is inlined: alloc for output, no call instruction, no declaration + check_patterns = """ + CHECK: func.func @main_kernel + CHECK: memref.alloc + CHECK: return + CHECK-NOT: call @__custom_op__my_op + CHECK-NOT: func.func private @__custom_op__my_op + CHECK-NOT: nkipy.custom_op_bodies + CHECK-NOT: nkipy.custom_op + """ + run_filecheck(result, check_patterns) + + +# ============================================================================ +# No custom ops: resolve should be a no-op +# ============================================================================ + + +def test_no_custom_ops_is_noop(): + """When there's no nkipy.custom_op_bodies attribute, resolve is a no-op.""" + mlir_input = """ +module { + func.func @main_kernel( + %arg0: memref<128x128xf32, #nisa.mem> + ) -> memref<128x128xf32, #nisa.mem> { + return %arg0 : memref<128x128xf32, #nisa.mem> + } +} +""" + result = _run_resolve(mlir_input) + + check_patterns = """ + CHECK: func.func @main_kernel + CHECK: return + CHECK-NOT: nkipy.custom_op_bodies + """ + run_filecheck(result, check_patterns) + + +# ============================================================================ +# Multiple call sites for the same custom op +# ============================================================================ + + +def test_multiple_call_sites(): + """ + When the same custom op is called multiple times, + each call site gets its own inlined body with separate allocs. + """ + nisa_body = _escape_mlir_string( + "module attributes {nisa.target = #nisa.target} {" + " func.func @my_op(" + " %arg0: memref<64x64xf32, #nisa.mem>," + " %arg1: memref<64x64xf32, #nisa.mem>" + ' ) attributes {nki.output_names = ["output"]} {' + " return" + " }" + "}" + ) + + mlir_input = f''' +module attributes {{ + nkipy.custom_op_bodies = {{ + "__custom_op__my_op" = "{nisa_body}" + }} +}} {{ + func.func @main_kernel( + %arg0: memref<64x64xf32, #nisa.mem>, + %arg1: memref<64x64xf32, #nisa.mem> + ) -> memref<64x64xf32, #nisa.mem> {{ + %r0 = func.call @__custom_op__my_op(%arg0) + : (memref<64x64xf32, #nisa.mem>) -> memref<64x64xf32, #nisa.mem> + %r1 = func.call @__custom_op__my_op(%arg1) + : (memref<64x64xf32, #nisa.mem>) -> memref<64x64xf32, #nisa.mem> + return %r1 : memref<64x64xf32, #nisa.mem> + }} + + func.func private @__custom_op__my_op( + memref<64x64xf32, #nisa.mem> + ) -> memref<64x64xf32, #nisa.mem> + attributes {{nkipy.custom_op}} +}} +''' + + result = _run_resolve(mlir_input) + + # Two inlined bodies = two allocs, no call instructions + check_patterns = """ + CHECK: func.func @main_kernel + CHECK: memref.alloc + CHECK: memref.alloc + CHECK: return + CHECK-NOT: call @__custom_op__my_op + CHECK-NOT: func.func private @__custom_op__my_op + """ + run_filecheck(result, check_patterns) + + +# ============================================================================ +# Test Runner +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/python/__init__.py b/kernelgen/tests/python/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kernelgen/tests/python/lit.cfg.py b/kernelgen/tests/python/lit.cfg.py new file mode 100644 index 0000000..f4608d4 --- /dev/null +++ b/kernelgen/tests/python/lit.cfg.py @@ -0,0 +1,38 @@ +# python/dialects/lit.cfg.py + +import os +import sys + +import lit.formats + +# A name for this test suite (purely cosmetic). +config.name = "NKIPy MLIR Passes Tests" + +# Use the 'shell test' format: interpret RUN lines with a shell. +config.test_format = lit.formats.ShTest(execute_external=True) + +# Treat .py files as tests. +config.suffixes = [".py"] + +# Don't treat the config file itself as a test. +config.excludes = ["lit.cfg.py", "lit.site.cfg.py", "__pycache__"] + +# Where the test sources live. +config.test_source_root = os.path.dirname(__file__) + +# Where tests are executed. Often same as source root for simple setups. +config.test_exec_root = config.test_source_root + +# --- Substitutions --- + +# %PYTHON -> the Python interpreter running lit (from the env you invoked `lit` in) +config.substitutions.append(("%PYTHON", sys.executable)) + +# Copy the current shell PYTHONPATH into lit's environment, without changing it. +# (Lit already starts from the process env, but this makes it explicit and avoids overrides.) +if "PYTHONPATH" in os.environ: + config.environment["PYTHONPATH"] = os.environ["PYTHONPATH"] + +# Use FileCheck via PATH (no hardcoded path). +filecheck_path = "FileCheck" +config.substitutions.append(("FileCheck", filecheck_path)) diff --git a/kernelgen/tests/python/passes/__init__.py b/kernelgen/tests/python/passes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kernelgen/tests/python/passes/test_knob_annotations.py b/kernelgen/tests/python/passes/test_knob_annotations.py new file mode 100644 index 0000000..1a5ad73 --- /dev/null +++ b/kernelgen/tests/python/passes/test_knob_annotations.py @@ -0,0 +1,187 @@ +# RUN: %PYTHON %s | FileCheck %s + +import numpy as np +from mlir.ir import Context, Location + +from nkipy_kernelgen import trace +from nkipy_kernelgen.apis import knob + + +def run(f): + """Simple test runner: prints a label, runs the test.""" + print(f"\nTEST: {f.__name__}") + f() + print(f"TEST_END: {f.__name__}") + return f + + +# CHECK-LABEL: TEST: test_knob_partition_dim_only +# CHECK: nkipy.annotate(%{{.*}} : tensor<4x4xf32>, 0) +# CHECK: TEST_END: test_knob_partition_dim_only +@run +def test_knob_partition_dim_only(): + """Test that knob with only partition_dim injects annotate op correctly.""" + + @trace(input_specs=[((4, 4), "f32"), ((4, 4), "f32")]) + def func_with_partition(A, B): + temp = np.add(A, B) + temp = knob(temp, partition_dim=0) + result = np.multiply(temp, 2.0) + return result + + module = func_with_partition.to_mlir() + print(module) + + +# CHECK-LABEL: TEST: test_knob_mem_space_only +# CHECK: nkipy.annotate(%{{.*}} : tensor<4x4xf32>, Hbm) +# CHECK: TEST_END: test_knob_mem_space_only +@run +def test_knob_mem_space_only(): + """Test that knob with only mem_space injects annotate op correctly.""" + + @trace(input_specs=[((4, 4), "f32"), ((4, 4), "f32")]) + def func_with_mem_space(A, B): + temp = np.add(A, B) + temp = knob(temp, mem_space="Hbm") + result = np.multiply(temp, 2.0) + return result + + module = func_with_mem_space.to_mlir() + print(module) + + +# CHECK-LABEL: TEST: test_knob_both_params +# CHECK: nkipy.annotate(%{{.*}} : tensor<8x8xf32>, Sbuf, 1) +# CHECK: TEST_END: test_knob_both_params +@run +def test_knob_both_params(): + """Test that knob with both parameters injects annotate op correctly.""" + + @trace(input_specs=[((8, 8), "f32"), ((8, 8), "f32")]) + def func_with_both(A, B): + temp = np.add(A, B) + temp = knob(temp, partition_dim=1, mem_space="Sbuf") + result = np.multiply(temp, 3.0) + return result + + module = func_with_both.to_mlir() + print(module) + + +# CHECK-LABEL: TEST: test_knob_no_params +# CHECK-NOT: nkipy.annotate +# CHECK: TEST_END: test_knob_no_params +@run +def test_knob_no_params(): + """Test that knob without parameters does not inject annotate op.""" + + @trace(input_specs=[((4, 4), "f32"), ((4, 4), "f32")]) + def func_no_knob_params(A, B): + temp = np.add(A, B) + temp = knob(temp) # No parameters, should be no-op + result = np.multiply(temp, 2.0) + return result + + module = func_no_knob_params.to_mlir() + print(module) + + +# CHECK-LABEL: TEST: test_knob_multiple_annotations +# CHECK: nkipy.annotate(%{{.*}} : tensor<8x8xf32>, Hbm) +# CHECK: nkipy.annotate(%{{.*}} : tensor<8x8xf32>, 0) +# CHECK: nkipy.annotate(%{{.*}} : tensor<8x8xf32>, Sbuf, 1) +# CHECK: TEST_END: test_knob_multiple_annotations +@run +def test_knob_multiple_annotations(): + """Test multiple knob annotations in a single function.""" + + @trace(input_specs=[((8, 8), "f32"), ((8, 8), "f32")]) + def func_multiple_knobs(A, B): + temp0 = np.add(A, B) + temp0 = knob(temp0, mem_space="Hbm") + + temp1 = np.multiply(temp0, 2.0) + temp1 = knob(temp1, partition_dim=0) + + temp2 = np.square(temp1) + temp2 = knob(temp2, partition_dim=1, mem_space="Sbuf") + + result = np.multiply(temp2, 3.0) + return result + + module = func_multiple_knobs.to_mlir() + print(module) + + +# CHECK-LABEL: TEST: test_knob_mem_space_values +# CHECK: TEST: Hbm +# CHECK: nkipy.annotate(%{{.*}} : tensor<4x4xf32>, Hbm) +# CHECK: TEST: Psum +# CHECK: nkipy.annotate(%{{.*}} : tensor<4x4xf32>, Psum) +# CHECK: TEST: Sbuf +# CHECK: nkipy.annotate(%{{.*}} : tensor<4x4xf32>, Sbuf) +# CHECK: TEST: SharedHbm +# CHECK: nkipy.annotate(%{{.*}} : tensor<4x4xf32>, SharedHbm) +# CHECK: TEST_END: test_knob_mem_space_values +@run +def test_knob_mem_space_values(): + """Test that different mem_space values map to correct enum values.""" + + # Test Hbm (0) + @trace(input_specs=[((4, 4), "f32"), ((4, 4), "f32")]) + def func_hbm(A, B): + result = np.add(A, B) + result = knob(result, mem_space="Hbm") + return result + + print("TEST: Hbm") + print(func_hbm.to_mlir()) + + # Test Psum (1) + @trace(input_specs=[((4, 4), "f32"), ((4, 4), "f32")]) + def func_psum(A, B): + result = np.add(A, B) + result = knob(result, mem_space="Psum") + return result + + print("TEST: Psum") + print(func_psum.to_mlir()) + + # Test Sbuf (2) + @trace(input_specs=[((4, 4), "f32"), ((4, 4), "f32")]) + def func_sbuf(A, B): + result = np.add(A, B) + result = knob(result, mem_space="Sbuf") + return result + + print("TEST: Sbuf") + print(func_sbuf.to_mlir()) + + # Test SharedHbm (3) + @trace(input_specs=[((4, 4), "f32"), ((4, 4), "f32")]) + def func_sharedhbm(A, B): + result = np.add(A, B) + result = knob(result, mem_space="SharedHbm") + return result + + print("TEST: SharedHbm") + print(func_sharedhbm.to_mlir()) + + +# CHECK-LABEL: TEST: test_knob_with_matmul +# CHECK: linalg.matmul +# CHECK: nkipy.annotate(%{{.*}} : tensor<4x5xf32>, Psum) +# CHECK: TEST_END: test_knob_with_matmul +@run +def test_knob_with_matmul(): + """Test knob annotation on matmul result.""" + + @trace(input_specs=[((4, 3), "f32"), ((3, 5), "f32")]) + def matmul_with_knob(A, B): + C = np.matmul(A, B) + C = knob(C, mem_space="Psum") + return C + + module = matmul_with_knob.to_mlir() + print(module) diff --git a/kernelgen/tests/python/rewrites/__init__.py b/kernelgen/tests/python/rewrites/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kernelgen/tests/unit/__init__.py b/kernelgen/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kernelgen/tests/unit/conftest.py b/kernelgen/tests/unit/conftest.py new file mode 100644 index 0000000..6a2d8f0 --- /dev/null +++ b/kernelgen/tests/unit/conftest.py @@ -0,0 +1,9 @@ +""" +Unit test conftest.py. + +Auto-applies the 'unit' marker to all tests in this directory. +""" + +import pytest + +pytestmark = [pytest.mark.unit] diff --git a/kernelgen/tests/unit/test_broadcast_ops.py b/kernelgen/tests/unit/test_broadcast_ops.py new file mode 100644 index 0000000..d08f2a1 --- /dev/null +++ b/kernelgen/tests/unit/test_broadcast_ops.py @@ -0,0 +1,232 @@ +""" +Tests for broadcasting operations. + +These tests verify that NumPy-style broadcasting works correctly with various +shape combinations. Broadcasting follows NumPy rules: +1. Align shapes from the right (trailing dimensions) +2. Dimensions with size 1 can broadcast to any size +3. Missing dimensions are treated as size 1 + +All broadcast ops emit linalg.generic (not named linalg ops). +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace +from harness import nkipy_kernelgen_test, run_kernel_test, Mode + + +# ============================================================================ +# Shape expansion (one operand has size-1 dims) +# ============================================================================ + +@pytest.mark.parametrize("op,shape_a,shape_b", [ + (np.add, (128, 256), (128, 1)), # column broadcast + (np.multiply, (128, 256), (1, 1)), # scalar-shaped broadcast + (np.subtract, (64, 128, 256), (64, 1, 256)), # middle dim broadcast +]) +def test_broadcast_expand(op, shape_a, shape_b): + @trace(input_specs=[(shape_a, "f32"), (shape_b, "f32")]) + def kernel(a, b): + return op(a, b) + + run_kernel_test( + kernel, stop_after="trace", + check_ir_contains=["linalg.generic"], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + +# ============================================================================ +# Dimension addition (fewer dims in one operand) +# ============================================================================ + +@pytest.mark.parametrize("op,shape_a,shape_b", [ + (np.add, (128, 256), (256,)), # 1D to 2D + (np.multiply, (64, 128, 256), (256,)), # 1D to 3D +]) +def test_broadcast_add_dims(op, shape_a, shape_b): + @trace(input_specs=[(shape_a, "f32"), (shape_b, "f32")]) + def kernel(a, b): + return op(a, b) + + run_kernel_test( + kernel, stop_after="trace", + check_ir_contains=["linalg.generic"], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + +# ============================================================================ +# Divide with broadcasting (custom inputs to avoid div-by-zero) +# ============================================================================ + +@pytest.mark.parametrize("shape_a,shape_b", [ + ((128, 256), (1, 256)), # row broadcast + ((64, 128, 256), (128, 256)), # 2D to 3D + ((64, 128, 256), (128, 1)), # complex 3D +]) +def test_broadcast_divide(shape_a, shape_b): + @trace(input_specs=[(shape_a, "f32"), (shape_b, "f32")]) + def kernel(a, b): + return np.divide(a, b) + + np.random.seed(42) + A = np.random.randn(*shape_a).astype(np.float32) + B = np.random.randn(*shape_b).astype(np.float32) + 1.0 + + run_kernel_test( + kernel, stop_after="trace", + check_ir_contains=["linalg.generic"], + inputs=[A, B], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + +# ============================================================================ +# Both operands broadcast +# ============================================================================ + +@pytest.mark.parametrize("op,shape_a,shape_b", [ + (np.multiply, (1, 256), (128, 1)), + (np.add, (256,), (128, 1)), +]) +def test_broadcast_both_operands(op, shape_a, shape_b): + @trace(input_specs=[(shape_a, "f32"), (shape_b, "f32")]) + def kernel(a, b): + return op(a, b) + + run_kernel_test( + kernel, stop_after="trace", + check_ir_contains=["linalg.generic"], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + +# ============================================================================ +# 1D broadcast with expansion +# ============================================================================ + +@nkipy_kernelgen_test( + input_specs=[((128, 256), "f32"), ((1,), "f32")], + stop_after="trace", + check_ir_contains=["linalg.generic"], + modes=Mode.LLVM | Mode.STRING_CHECK, +) +def test_broadcast_1d_to_2d_with_expansion(A, B): + """Broadcasting (1,) to (128, 256) - add dimension AND expand.""" + return np.add(A, B) + + +# ============================================================================ +# Float16 broadcast +# ============================================================================ + +@pytest.mark.parametrize("op,shape_a,shape_b", [ + (np.add, (128, 256), (128, 1)), + (np.multiply, (128, 256), (256,)), +]) +def test_broadcast_f16(op, shape_a, shape_b): + @trace(input_specs=[(shape_a, "f16"), (shape_b, "f16")]) + def kernel(a, b): + return op(a, b) + + run_kernel_test( + kernel, stop_after="trace", + check_ir_contains=["linalg.generic"], + rtol=0.01, atol=0.01, + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + +# ============================================================================ +# Incompatible shapes (error tests) +# ============================================================================ + +def test_incompatible_shapes_no_size_1(): + """(128, 256) and (128,) are incompatible (no size-1 to expand).""" + def add_func(A, B): + return np.add(A, B) + + traced = trace(input_specs=[((128, 256), "f32"), ((128,), "f32")])(add_func) + + try: + mlir_module = traced.to_mlir() + A = np.random.randn(128, 256).astype(np.float32) + B = np.random.randn(128).astype(np.float32) + try: + add_func(A, B) + assert True + except ValueError: + assert False, "Tracer should have raised ValueError for incompatible shapes" + except ValueError as e: + assert "Incompatible" in str(e) or "broadcast" in str(e).lower() + + +def test_incompatible_shapes_mismatch(): + """(128, 256) and (64, 256) are incompatible.""" + def mul_func(A, B): + return np.multiply(A, B) + + traced = trace(input_specs=[((128, 256), "f32"), ((64, 256), "f32")])(mul_func) + + try: + mlir_module = traced.to_mlir() + A = np.random.randn(128, 256).astype(np.float32) + B = np.random.randn(64, 256).astype(np.float32) + try: + mul_func(A, B) + assert True + except ValueError: + assert False, "Tracer should have raised ValueError for incompatible shapes" + except ValueError as e: + assert "Incompatible" in str(e) or "broadcast" in str(e).lower() + + +# ============================================================================ +# Real-world patterns +# ============================================================================ + +def test_rmsnorm_pattern(): + """RMS normalization pattern: (M, N) / (M, 1).""" + @trace(input_specs=[((128, 256), "f32"), ((128, 1), "f32")]) + def kernel(values, rms): + return np.divide(values, rms) + + np.random.seed(42) + values = np.random.randn(128, 256).astype(np.float32) + rms = np.random.randn(128, 1).astype(np.float32) + 1.0 + + run_kernel_test( + kernel, stop_after="trace", + check_ir_contains=["linalg.generic"], + inputs=[values, rms], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + +@nkipy_kernelgen_test( + input_specs=[((2, 64, 128, 128), "f32"), ((1, 64, 1, 1), "f32")], + stop_after="trace", + check_ir_contains=["linalg.generic"], + modes=Mode.LLVM | Mode.STRING_CHECK, +) +def test_batch_normalization_pattern(A, bias): + """Batch normalization pattern: (B, C, H, W) with (1, C, 1, 1).""" + return np.add(A, bias) + + +@nkipy_kernelgen_test( + input_specs=[((2, 4, 128, 128), "f32"), ((1, 1, 1, 1), "f32")], + stop_after="trace", + check_ir_contains=["linalg.generic"], + modes=Mode.LLVM | Mode.STRING_CHECK, +) +def test_attention_scale_pattern(qk, scale): + """Attention scaling pattern: (B, H, S, S) * scalar.""" + return np.multiply(qk, scale) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/unit/test_custom_op.py b/kernelgen/tests/unit/test_custom_op.py new file mode 100644 index 0000000..5a50b8b --- /dev/null +++ b/kernelgen/tests/unit/test_custom_op.py @@ -0,0 +1,355 @@ +""" +Unit tests for the CustomOp Python module. + +Tests CustomOp creation, tracing integration (func.call emission), +declaration generation, and NISA body stashing. + +Run with: python -m pytest tests/unit/test_custom_op.py -v +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace +from nkipy_kernelgen.custom_op import ( + CustomOp, + emit_custom_op_declaration, +) + + +# ============================================================================ +# CustomOp construction +# ============================================================================ + + +def test_custom_op_init(): + """Test CustomOp constructor.""" + op = CustomOp( + nisa_mlir="module {}", + func_name="test_func", + input_names=["x"], + output_names=["output"], + input_shapes=[(128, 128)], + output_shapes=[(128, 128)], + input_dtypes=["f32"], + output_dtypes=["f32"], + ) + assert op.func_name == "__custom_op__test_func" + assert op.input_shapes == [(128, 128)] + assert op.output_shapes == [(128, 128)] + + +def test_custom_op_name_prefixing(): + """Test that func_name gets __custom_op__ prefix.""" + op = CustomOp( + nisa_mlir="module {}", + func_name="my_kernel", + input_names=["x"], + output_names=["y"], + input_shapes=[(64, 64)], + output_shapes=[(64, 64)], + input_dtypes=["f32"], + output_dtypes=["f32"], + ) + assert op.func_name == "__custom_op__my_kernel" + + +# ============================================================================ +# Reference function fallback +# ============================================================================ + + +def test_reference_fn_called_outside_tracing(): + """When called with numpy arrays (not TracedArrays), use reference_fn.""" + ref_fn = lambda x: x * 2.0 + + op = CustomOp( + nisa_mlir="module {}", + func_name="double", + input_names=["x"], + output_names=["y"], + input_shapes=[(4, 4)], + output_shapes=[(4, 4)], + input_dtypes=["f32"], + output_dtypes=["f32"], + reference_fn=ref_fn, + ) + x = np.ones((4, 4), dtype=np.float32) + result = op(x) + np.testing.assert_allclose(result, x * 2.0) + + +def test_no_reference_fn_raises_outside_tracing(): + """Error when called outside tracing without reference_fn.""" + op = CustomOp( + nisa_mlir="module {}", + func_name="no_ref", + input_names=["x"], + output_names=["y"], + input_shapes=[(4, 4)], + output_shapes=[(4, 4)], + input_dtypes=["f32"], + output_dtypes=["f32"], + ) + with pytest.raises(RuntimeError, match="without a reference_fn"): + op(np.ones((4, 4), dtype=np.float32)) + + +# ============================================================================ +# Tracing integration +# ============================================================================ + + +def test_custom_op_emits_func_call_during_tracing(): + """Verify that calling a CustomOp during tracing emits func.call in IR.""" + custom_identity = CustomOp( + nisa_mlir="module {}", + func_name="identity_128x128_128x128", + input_names=["x"], + output_names=["output"], + input_shapes=[(128, 128)], + output_shapes=[(128, 128)], + input_dtypes=["f32"], + output_dtypes=["f32"], + ) + + @trace(input_specs=[((128, 128), "f32")]) + def kernel(x): + return custom_identity(x) + + module = kernel.to_mlir() + mlir_str = str(module) + + # Check that func.call is emitted + assert "call @__custom_op__identity_128x128_128x128" in mlir_str + # Check that declaration is emitted + assert "nkipy.custom_op" in mlir_str + # Check that NISA body is stashed + assert "nkipy.custom_op_bodies" in mlir_str + + +def test_custom_op_shape_mismatch_raises(): + """Error when shape doesn't match during tracing.""" + custom_op = CustomOp( + nisa_mlir="module {}", + func_name="expects_64x64", + input_names=["x"], + output_names=["y"], + input_shapes=[(64, 64)], + output_shapes=[(64, 64)], + input_dtypes=["f32"], + output_dtypes=["f32"], + ) + + @trace(input_specs=[((128, 128), "f32")]) + def kernel(x): + return custom_op(x) + + with pytest.raises(ValueError, match="Shape mismatch"): + kernel.to_mlir() + + +def test_custom_op_wrong_arg_count_raises(): + """Error when wrong number of args during tracing.""" + custom_op = CustomOp( + nisa_mlir="module {}", + func_name="expects_two", + input_names=["x", "y"], + output_names=["z"], + input_shapes=[(64, 64), (64, 64)], + output_shapes=[(64, 64)], + input_dtypes=["f32", "f32"], + output_dtypes=["f32"], + ) + + @trace(input_specs=[((64, 64), "f32")]) + def kernel(x): + return custom_op(x) + + with pytest.raises(ValueError, match="expects 2 inputs, got 1"): + kernel.to_mlir() + + +# ============================================================================ +# Registry +# ============================================================================ + + +def test_registry_deduplicates_by_func_name(): + """Same CustomOp called twice should only register once in module bodies.""" + custom_op = CustomOp( + nisa_mlir="module {}", + func_name="dedup_test", + input_names=["x"], + output_names=["y"], + input_shapes=[(128, 128)], + output_shapes=[(128, 128)], + input_dtypes=["f32"], + output_dtypes=["f32"], + ) + + @trace(input_specs=[((128, 128), "f32")]) + def kernel(x): + r1 = custom_op(x) + r2 = custom_op(r1) + return r2 + + module = kernel.to_mlir() + mlir_str = str(module) + + # The custom op body should be stashed exactly once despite two call sites + assert mlir_str.count("__custom_op__dedup_test") >= 2, ( + "expected at least 2 call sites" + ) + # Only one private declaration should exist + func_decl_count = mlir_str.count("func.func private @__custom_op__dedup_test") + assert func_decl_count == 1, f"expected 1 declaration, got {func_decl_count}" + + +# ============================================================================ +# emit_custom_op_declaration +# ============================================================================ + + +def test_emit_custom_op_declaration(): + """Verify emitted func.func has correct signature and attributes.""" + from mlir import ir + + op = CustomOp( + nisa_mlir="module {}", + func_name="my_activation_128x128_128x128", + input_names=["x"], + output_names=["y"], + input_shapes=[(128, 128)], + output_shapes=[(128, 128)], + input_dtypes=["f32"], + output_dtypes=["f32"], + ) + + with ir.Context(), ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + fn = emit_custom_op_declaration(op) + + mlir_str = str(module) + # Private visibility + assert "private" in mlir_str + # Correct name with prefix + assert "@__custom_op__my_activation_128x128_128x128" in mlir_str + # nkipy.custom_op marker + assert "nkipy.custom_op" in mlir_str + # Input and output types + assert "128x128xf32" in mlir_str + + +def test_emit_custom_op_declaration_multi_io(): + """Verify declaration for a multi-input multi-output custom op.""" + from mlir import ir + + op = CustomOp( + nisa_mlir="module {}", + func_name="multi_io", + input_names=["a", "b"], + output_names=["x", "y"], + input_shapes=[(64, 64), (32, 32)], + output_shapes=[(64, 64), (32, 32)], + input_dtypes=["f32", "f16"], + output_dtypes=["f32", "f16"], + ) + + with ir.Context(), ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + fn = emit_custom_op_declaration(op) + + mlir_str = str(module) + assert "64x64xf32" in mlir_str + assert "32x32xf16" in mlir_str + + +# ============================================================================ +# Multi-output custom op tracing +# ============================================================================ + + +def test_multi_output_custom_op_tracing(): + """Verify that a multi-output CustomOp returns a tuple of TracedArrays.""" + custom_split = CustomOp( + nisa_mlir="module {}", + func_name="split_128x64_128x64", + input_names=["x"], + output_names=["left", "right"], + input_shapes=[(128, 128)], + output_shapes=[(128, 64), (128, 64)], + input_dtypes=["f32"], + output_dtypes=["f32", "f32"], + ) + + @trace(input_specs=[((128, 128), "f32")]) + def kernel(x): + left, right = custom_split(x) + return left + + module = kernel.to_mlir() + mlir_str = str(module) + + assert "call @__custom_op__split_128x64_128x64" in mlir_str + # Two result types in the call + assert "128x64xf32" in mlir_str + + +# ============================================================================ +# from_kernel_builder +# ============================================================================ + + +def test_from_kernel_builder(): + """Verify from_kernel_builder produces a CustomOp with real NISA MLIR.""" + import nki.compiler.kernel_builder as nb + + def relu_kernel(x_hbm, out_hbm): + """ReLU activation: load from HBM, activate in SBUF, store back.""" + x_sbuf = nb.ndarray((128, 128), x_hbm.dtype, nb.sbuf) + nb.isa.dma_copy(dst=x_sbuf, src=x_hbm[0:128, 0:128]) + + out_sbuf = nb.ndarray((128, 128), x_hbm.dtype, nb.sbuf) + bias = nb.ndarray((128, 1), x_hbm.dtype, nb.sbuf) + nb.isa.memset(dst=bias, value=0.0) + scale = nb.ndarray((128, 1), x_hbm.dtype, nb.sbuf) + nb.isa.memset(dst=scale, value=1.0) + + nb.isa.activation( + dst=out_sbuf, + src=x_sbuf, + bias=bias, + scale=scale, + op=nb.isa.activation_function.relu, + ) + nb.isa.dma_copy(dst=out_hbm[0:128, 0:128], src=out_sbuf) + + op = CustomOp.from_kernel_builder( + kernel_func=relu_kernel, + input_specs={"x_hbm": nb.Tensor((128, 128), nb.float32, nb.shared_hbm)}, + output_specs={"out_hbm": nb.Tensor((128, 128), nb.float32, nb.shared_hbm)}, + reference_fn=lambda x: np.maximum(x, 0), + ) + + assert op.func_name.startswith("__custom_op__relu_kernel_") + assert op.input_shapes == [(128, 128)] + assert op.output_shapes == [(128, 128)] + # NISA MLIR should contain real ops from the kernel (generic form) + for op_name in ["nisa.dma_copy", "nisa.activation", "nisa.memset"]: + assert op_name in op.nisa_mlir or f'"{op_name}"' in op.nisa_mlir, ( + f"expected {op_name} in NISA MLIR" + ) + # Reference fn should work + x = np.array([[1, -2], [-3, 4]], dtype=np.float32) + np.testing.assert_allclose(op(x), np.array([[1, 0], [0, 4]], dtype=np.float32)) + + +# ============================================================================ +# Test Runner +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/unit/test_elementwise_ops.py b/kernelgen/tests/unit/test_elementwise_ops.py new file mode 100644 index 0000000..4c511d7 --- /dev/null +++ b/kernelgen/tests/unit/test_elementwise_ops.py @@ -0,0 +1,435 @@ +""" +Tests for elementwise operations: binary, unary, scalar, and chained. + +These tests verify that MLIR/LLVM execution matches NumPy CPU execution +and that the tracer emits the correct linalg ops. +Tests with knobs also verify KnobDrivenTiling + linalg-to-nisa produces +correct NISA dialect ops. +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import nkipy_kernelgen_test, run_kernel_test, Mode + + +# ============================================================================ +# Binary ops (identical shapes) → linalg. +# ============================================================================ + +@pytest.mark.parametrize("op,ir_op,shape,dtype,tile_size", [ + (np.add, "linalg.add", (128, 256), "f32", [64, 128]), + (np.add, "linalg.add", (256, 512), "f32", [128, 256]), + (np.add, "linalg.add", (128, 128), "f16", [64, 64]), + (np.subtract, "linalg.sub", (128, 256), "f32", [64, 128]), + (np.subtract, "linalg.sub", (256, 512), "f32", [128, 256]), + (np.multiply, "linalg.mul", (128, 256), "f32", [64, 128]), + (np.multiply, "linalg.mul", (256, 512), "f32", [128, 256]), +]) +def test_binary_op(op, ir_op, shape, dtype, tile_size): + rtol, atol = (0.01, 0.01) if dtype == "f16" else (1e-5, 1e-8) + + @trace(input_specs=[(shape, dtype), (shape, dtype)]) + def kernel(a, b): + result = op(a, b) + knob.knob(result, tile_size=tile_size) + return result + + run_kernel_test( + kernel, stop_after="trace", + check_ir_contains=[ir_op], + rtol=rtol, atol=atol, + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + run_kernel_test( + kernel, stop_after="apply-and-strip-transforms", + check_ir_contains=["scf.for", "memory_space = 3 : i32", ir_op], + rtol=rtol, atol=atol, + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + run_kernel_test( + kernel, stop_after="linalg-to-nisa", + check_ir_contains=["nisa.tensor_tensor_arith"], + modes=Mode.STRING_CHECK, + ) + + +# ============================================================================ +# Divide (custom inputs to avoid division by near-zero) +# ============================================================================ + +@pytest.mark.parametrize("shape,dtype,tile_size", [ + ((128, 256), "f32", [64, 128]), + ((256, 512), "f32", [128, 256]), + ((128, 128), "f16", [64, 64]), +]) +def test_divide(shape, dtype, tile_size): + rtol, atol = (0.01, 0.01) if dtype == "f16" else (1e-5, 1e-8) + np_dtype = np.float16 if dtype == "f16" else np.float32 + + @trace(input_specs=[(shape, dtype), (shape, dtype)]) + def kernel(a, b): + result = np.divide(a, b) + knob.knob(result, tile_size=tile_size) + return result + + np.random.seed(42) + A = np.random.randn(*shape).astype(np_dtype) + B = (np.abs(np.random.randn(*shape)) + 0.5).astype(np_dtype) + + run_kernel_test( + kernel, stop_after="trace", + check_ir_contains=["linalg.div"], + inputs=[A, B], rtol=rtol, atol=atol, + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + run_kernel_test( + kernel, stop_after="apply-and-strip-transforms", + check_ir_contains=["scf.for", "memory_space = 3 : i32"], + inputs=[A, B], rtol=rtol, atol=atol, + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + +# ============================================================================ +# Scalar ops +# ============================================================================ + +@pytest.mark.parametrize("op,scalar,tile_size", [ + (np.add, 2.5, [64, 128]), + (np.multiply, 3.0, [64, 128]), + (np.subtract, 1.0, [64, 128]), +]) +def test_scalar_op(op, scalar, tile_size): + @trace(input_specs=[((128, 256), "f32")]) + def kernel(a): + result = op(a, scalar) + knob.knob(result, tile_size=tile_size) + return result + + run_kernel_test(kernel, stop_after="trace", modes=Mode.LLVM) + + run_kernel_test( + kernel, stop_after="apply-and-strip-transforms", + check_ir_contains=["scf.for", "memory_space = 3 : i32"], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + run_kernel_test( + kernel, stop_after="linalg-to-nisa", + check_ir_contains=["nisa.tensor_scalar_arith"], + modes=Mode.STRING_CHECK, + ) + + +# ============================================================================ +# Unary ops → linalg. +# ============================================================================ + +@pytest.mark.parametrize("op,ir_op,nisa_op,shape,dtype,tile_size", [ + (np.square, "linalg.square", "nisa.activation", (128, 256), "f32", [64, 128]), + (np.square, "linalg.square", "nisa.activation", (256, 512), "f32", [128, 256]), + (np.square, "linalg.square", "nisa.activation", (128, 128), "f16", [64, 64]), + (np.abs, "linalg.abs", None, (128, 256), "f32", [64, 128]), + (np.abs, "linalg.abs", None, (128, 128), "f16", [64, 64]), +]) +def test_unary_op(op, ir_op, nisa_op, shape, dtype, tile_size): + rtol, atol = (0.01, 0.01) if dtype == "f16" else (1e-5, 1e-8) + + @trace(input_specs=[(shape, dtype)]) + def kernel(a): + result = op(a) + knob.knob(result, tile_size=tile_size) + return result + + run_kernel_test( + kernel, stop_after="trace", + check_ir_contains=[ir_op], + rtol=rtol, atol=atol, + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + run_kernel_test( + kernel, stop_after="apply-and-strip-transforms", + check_ir_contains=["scf.for", "memory_space = 3 : i32", ir_op], + rtol=rtol, atol=atol, + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + if nisa_op: + run_kernel_test( + kernel, stop_after="linalg-to-nisa", + check_ir_contains=[nisa_op], + modes=Mode.STRING_CHECK, + ) + + +# ============================================================================ +# Negative (implemented as 0 - x, emits linalg.generic) +# ============================================================================ + +@nkipy_kernelgen_test( + input_specs=[((128, 256), "f32")], + stop_after="trace", + modes=Mode.LLVM, +) +def test_negative(A): + return np.negative(A) + + +# ============================================================================ +# Sqrt (needs positive inputs) +# ============================================================================ + +@pytest.mark.parametrize("shape,dtype,tile_size", [ + ((128, 256), "f32", [64, 128]), + ((256, 512), "f32", [128, 256]), + ((128, 128), "f16", [64, 64]), +]) +def test_sqrt(shape, dtype, tile_size): + rtol, atol = (0.01, 0.01) if dtype == "f16" else (1e-5, 1e-8) + np_dtype = np.float16 if dtype == "f16" else np.float32 + + @trace(input_specs=[(shape, dtype)]) + def kernel(a): + result = np.sqrt(a) + knob.knob(result, tile_size=tile_size) + return result + + np.random.seed(42) + A = np.abs(np.random.randn(*shape)).astype(np_dtype) + + run_kernel_test( + kernel, stop_after="trace", + check_ir_contains=["linalg.sqrt"], + inputs=[A], rtol=rtol, atol=atol, + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + run_kernel_test( + kernel, stop_after="apply-and-strip-transforms", + check_ir_contains=["scf.for", "memory_space = 3 : i32", "linalg.sqrt"], + inputs=[A], rtol=rtol, atol=atol, + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + +# ============================================================================ +# Exp (needs small inputs to avoid overflow) +# ============================================================================ + +@pytest.mark.parametrize("shape,dtype,tile_size", [ + ((128, 256), "f32", [64, 128]), + ((256, 512), "f32", [128, 256]), + ((128, 128), "f16", [64, 64]), +]) +def test_exp(shape, dtype, tile_size): + rtol, atol = (0.01, 0.01) if dtype == "f16" else (1e-5, 1e-8) + np_dtype = np.float16 if dtype == "f16" else np.float32 + + @trace(input_specs=[(shape, dtype)]) + def kernel(a): + result = np.exp(a) + knob.knob(result, tile_size=tile_size) + return result + + np.random.seed(42) + A = (np.random.randn(*shape) * 0.5).astype(np_dtype) + + run_kernel_test( + kernel, stop_after="trace", + check_ir_contains=["linalg.exp"], + inputs=[A], rtol=rtol, atol=atol, + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + run_kernel_test( + kernel, stop_after="apply-and-strip-transforms", + check_ir_contains=["scf.for", "memory_space = 3 : i32", "linalg.exp"], + inputs=[A], rtol=rtol, atol=atol, + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + run_kernel_test( + kernel, stop_after="linalg-to-nisa", + check_ir_contains=["nisa.activation"], + inputs=[A], + modes=Mode.STRING_CHECK, + ) + + +# ============================================================================ +# Chained expressions +# ============================================================================ + +def test_add_then_multiply(): + @trace(input_specs=[((128, 256), "f32"), ((128, 256), "f32")]) + def kernel(A, B): + temp = np.add(A, B) + knob.knob(temp, tile_size=[64, 128]) + return np.multiply(temp, 2.0) + + run_kernel_test( + kernel, stop_after="trace", + modes=Mode.LLVM, + ) + + run_kernel_test( + kernel, stop_after="apply-and-strip-transforms", + check_ir_contains=["scf.for", "memory_space = 3 : i32"], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + +def test_add_then_square(): + @trace(input_specs=[((128, 256), "f32"), ((128, 256), "f32")]) + def kernel(A, B): + result = np.square(np.add(A, B)) + knob.knob(result, tile_size=[64, 128]) + return result + + run_kernel_test( + kernel, stop_after="trace", + check_ir_contains=["linalg.add", "linalg.square"], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + run_kernel_test( + kernel, stop_after="apply-and-strip-transforms", + check_ir_contains=["scf.for", "memory_space = 3 : i32"], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + +def test_square_then_divide(): + @trace(input_specs=[((128, 256), "f32"), ((128, 256), "f32")]) + def kernel(A, B): + squared = np.square(A) + knob.knob(squared, tile_size=[64, 128]) + return np.divide(squared, B) + + np.random.seed(42) + A = np.random.randn(128, 256).astype(np.float32) + B = np.random.randn(128, 256).astype(np.float32) + 1.0 + + run_kernel_test( + kernel, stop_after="trace", + check_ir_contains=["linalg.square", "linalg.div"], + inputs=[A, B], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + run_kernel_test( + kernel, stop_after="apply-and-strip-transforms", + check_ir_contains=["scf.for", "memory_space = 3 : i32"], + inputs=[A, B], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + +def test_complex_expression(): + @trace(input_specs=[ + ((256, 256), "f32"), ((256, 256), "f32"), ((256, 256), "f32") + ]) + def kernel(A, B, C): + squared = np.square(A) + knob.knob(squared, tile_size=[128, 128]) + sum_result = np.add(squared, B) + return np.divide(sum_result, C) + + np.random.seed(42) + A = np.random.randn(256, 256).astype(np.float32) + B = np.random.randn(256, 256).astype(np.float32) + C = np.random.randn(256, 256).astype(np.float32) + 1.0 + + run_kernel_test( + kernel, stop_after="trace", + check_ir_contains=["linalg.square", "linalg.add", "linalg.div"], + inputs=[A, B, C], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + run_kernel_test( + kernel, stop_after="apply-and-strip-transforms", + check_ir_contains=["scf.for", "memory_space = 3 : i32"], + inputs=[A, B, C], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + +def test_square_in_expression(): + @trace(input_specs=[((128, 256), "f32")]) + def kernel(A): + squared = np.square(A) + knob.knob(squared, tile_size=[64, 128]) + return np.add(squared, 1.0) + + run_kernel_test( + kernel, stop_after="trace", + check_ir_contains=["linalg.square"], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + run_kernel_test( + kernel, stop_after="apply-and-strip-transforms", + check_ir_contains=["scf.for", "memory_space = 3 : i32"], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + +def test_exp_in_expression(): + @trace(input_specs=[((128, 128), "f32")]) + def kernel(A): + squared = np.square(A) + knob.knob(squared, tile_size=[64, 64]) + return np.exp(squared * 0.1) + + np.random.seed(42) + A = np.random.randn(128, 128).astype(np.float32) + + run_kernel_test( + kernel, stop_after="trace", + check_ir_contains=["linalg.exp", "linalg.square"], + inputs=[A], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + run_kernel_test( + kernel, stop_after="apply-and-strip-transforms", + check_ir_contains=["scf.for", "memory_space = 3 : i32"], + inputs=[A], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + +def test_sqrt_in_expression(): + @trace(input_specs=[((128, 128), "f32")]) + def kernel(A): + squared = np.square(A) + knob.knob(squared, tile_size=[64, 64]) + return np.sqrt(squared) + + np.random.seed(42) + A = np.abs(np.random.randn(128, 128)).astype(np.float32) + + run_kernel_test( + kernel, stop_after="trace", + check_ir_contains=["linalg.sqrt", "linalg.square"], + inputs=[A], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + run_kernel_test( + kernel, stop_after="apply-and-strip-transforms", + check_ir_contains=["scf.for", "memory_space = 3 : i32"], + inputs=[A], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/unit/test_execution_engine.py b/kernelgen/tests/unit/test_execution_engine.py new file mode 100644 index 0000000..6f935b7 --- /dev/null +++ b/kernelgen/tests/unit/test_execution_engine.py @@ -0,0 +1,79 @@ +""" +Tests for LLVM execution engine. + +These tests verify that MLIR modules can be executed directly using the LLVM +execution engine, testing low-level MLIR operations. +""" + +import pytest + +import numpy as np +from mlir.ir import Context, Module +from nkipy_kernelgen.llvm import LLVMModule + + +class TestExecutionEngine: + """Test LLVM execution engine with MLIR modules.""" + + def test_tensor_subtract(self): + """Test element-wise tensor subtraction using linalg.elementwise.""" + with Context(): + module = Module.parse( + """ + module { + func.func @main(%arg0: tensor<8x32xf32>, %arg1: tensor<8x32xf32>) -> tensor<8x32xf32> attributes { llvm.emit_c_interface } { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<8x32xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<8x32xf32>) -> tensor<8x32xf32> + %4 = linalg.elementwise kind=#linalg.elementwise_kind ins(%arg0, %arg1 : tensor<8x32xf32>, tensor<8x32xf32>) outs(%1 : tensor<8x32xf32>) -> tensor<8x32xf32> + return %4 : tensor<8x32xf32> + } + } """ + ) + + runner = LLVMModule(module, "main") + + # Inputs and NumPy reference + A = np.random.rand(8, 32).astype(np.float32) + B = np.random.rand(8, 32).astype(np.float32) + ref = A - B + + # Run and compare + out = runner(A.copy(), B.copy()) + + assert np.allclose(out, ref, rtol=1e-5, atol=1e-6), \ + f"MLIR result does not match NumPy result for tensor subtraction" + + def test_tensor_reduce_all(self): + """Test tensor reduction over all dimensions.""" + with Context(): + module = Module.parse( + r""" +module { + func.func @main(%arg0: tensor<3x4xf32>) -> f32 { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor) -> tensor + %t = linalg.reduce { arith.addf } ins(%arg0 : tensor<3x4xf32>) outs(%1 : tensor) dimensions = [0, 1] + %s = tensor.extract %t[] : tensor + return %s : f32 + } +} +""" + ) + + runner = LLVMModule(module, "main") + + # Input and NumPy reference + A = np.random.rand(3, 4).astype(np.float32) + ref = np.array(A.sum().astype(np.float32)) + + # Run and compare + out = runner(A.copy()) + + assert np.allclose(out, ref, rtol=1e-5, atol=1e-6), \ + f"MLIR result does not match NumPy result for tensor reduction" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/unit/test_for_loops.py b/kernelgen/tests/unit/test_for_loops.py new file mode 100644 index 0000000..c59bf93 --- /dev/null +++ b/kernelgen/tests/unit/test_for_loops.py @@ -0,0 +1,268 @@ +""" +Tests for control flow operations: fori_loop. + +These tests verify that MLIR/LLVM execution matches NumPy CPU execution +for control flow operations like fori_loop. +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace +from nkipy_kernelgen.apis import fori_loop +from harness import run_kernel_test, Mode + + +# ============================================================================ +# Basic fori_loop Tests +# ============================================================================ + +def test_simple_accumulation(): + """Test simple accumulation in fori_loop.""" + @trace(input_specs=[((10,), "f32")]) + def sum_with_loop(x): + def body(i, acc): + return acc + x[i] + + init = np.zeros((1,), dtype=np.float32) + result = fori_loop(0, 10, body, init) + return result + + run_kernel_test(sum_with_loop, stop_after="trace", modes=Mode.LLVM) + + +def test_single_accumulator(): + """Test fori_loop with a single tensor accumulator.""" + @trace(input_specs=[((8,), "f32")]) + def scale_accumulate(x): + def body(i, acc): + return acc + x[i] * 2.0 + + init = np.zeros((1,), dtype=np.float32) + result = fori_loop(0, 8, body, init) + return result + + run_kernel_test(scale_accumulate, stop_after="trace", modes=Mode.LLVM) + + +# ============================================================================ +# Multiple Accumulator Tests +# ============================================================================ + +def test_two_accumulators(): + """Test fori_loop with two tensor accumulators.""" + @trace(input_specs=[((8,), "f32")]) + def dual_accumulate(x): + def body(i, accs): + acc1, acc2 = accs + return (acc1 + x[i], acc2 + x[i] * x[i]) + + init1 = np.zeros((1,), dtype=np.float32) + init2 = np.zeros((1,), dtype=np.float32) + result1, result2 = fori_loop(0, 8, body, (init1, init2)) + # Return sum of both for testing + return result1 + result2 + + run_kernel_test( + dual_accumulate, stop_after="trace", + rtol=1e-4, atol=1e-4, modes=Mode.LLVM, + ) + + +def test_multiple_tensor_operations(): + """Test fori_loop with multiple tensor operations.""" + @trace(input_specs=[((4, 6), "f32")]) + def multi_tensor_loop(x): + def body(i, accs): + acc1, acc2 = accs + # Simple operations that don't require advanced slicing + row_sum = np.sum(x[i:i+1, :]) + return (acc1 + row_sum, acc2 + row_sum * 2.0) + + init1 = np.zeros((1,), dtype=np.float32) + init2 = np.zeros((1,), dtype=np.float32) + result1, result2 = fori_loop(0, 4, body, (init1, init2)) + return result1 + result2 + + run_kernel_test( + multi_tensor_loop, stop_after="trace", + rtol=1e-4, atol=1e-4, modes=Mode.LLVM, + ) + + +# ============================================================================ +# Dynamic Slicing Tests +# ============================================================================ + +def test_tiled_accumulation(): + """Test tiled accumulation with dynamic slicing.""" + @trace(input_specs=[((8,), "f32")]) + def tiled_sum(x, TILE_SIZE=2): + def body(i, acc): + # Dynamic slicing with loop index + chunk = x[i * TILE_SIZE : (i + 1) * TILE_SIZE] + chunk_sum = np.sum(chunk) + return acc + chunk_sum + + init = np.zeros((1,), dtype=np.float32) + result = fori_loop(0, 4, body, init) + return result + + run_kernel_test( + tiled_sum, stop_after="trace", + rtol=1e-4, atol=1e-4, modes=Mode.LLVM, + ) + + +def test_2d_tiled_operations(): + """Test 2D tiled operations with dynamic slicing.""" + @trace(input_specs=[((4, 8), "f16")]) + def tiled_2d_sum(input_tensor, TILING_FACTOR=2): + M, K = input_tensor.shape + TILED_CHUNK = K // TILING_FACTOR + + sum_buffer = np.zeros((M, TILED_CHUNK), dtype=np.float16) + + def body(i, acc): + input_chunk = input_tensor[:, i * TILED_CHUNK : (i + 1) * TILED_CHUNK] + return np.add(acc, input_chunk) + + result = fori_loop(0, TILING_FACTOR, body, sum_buffer) + return np.sum(result) + + run_kernel_test( + tiled_2d_sum, stop_after="trace", + rtol=0.01, atol=0.01, modes=Mode.LLVM, + ) + + +# ============================================================================ +# Complex Scenario Tests +# ============================================================================ + +def test_rmsnorm_pattern(): + """Test RMSNorm-like pattern with fori_loop.""" + @trace(input_specs=[((4, 8), "f16")]) + def simple_rmsnorm(input_tensor, TILING_FACTOR=2): + M, K = input_tensor.shape + TILED_CHUNK = K // TILING_FACTOR + + square_sum_buffer = np.zeros((M, TILED_CHUNK), dtype=np.float16) + + def body(i, acc): + input_chunk = input_tensor[:, i * TILED_CHUNK : (i + 1) * TILED_CHUNK] + squared_input = np.square(input_chunk) + scaled_square = np.divide(squared_input, K) + return np.add(scaled_square, acc) + + square_sum_buffer = fori_loop(0, TILING_FACTOR, body, square_sum_buffer) + rms_sum = np.sum(square_sum_buffer, axis=1, keepdims=True) + return rms_sum + + run_kernel_test( + simple_rmsnorm, stop_after="trace", + rtol=0.01, atol=0.01, modes=Mode.LLVM, + ) + + +def test_matmul_with_loop(): + """Test matrix multiplication accumulation with fori_loop.""" + @trace(input_specs=[((4, 8), "f16"), ((8, 5), "f16")]) + def tiled_matmul(input_tensor, weight_matrix, TILING_FACTOR=2): + M, K = input_tensor.shape + K_, N = weight_matrix.shape + assert K == K_ + TILED_CHUNK = K // TILING_FACTOR + + matmul_buffer = np.zeros((M, N), dtype=np.float16) + + def body(i, acc): + input_chunk = input_tensor[:, i * TILED_CHUNK : (i + 1) * TILED_CHUNK] + weight_chunk = weight_matrix[i * TILED_CHUNK : (i + 1) * TILED_CHUNK, :] + matmul_result = np.matmul(input_chunk, weight_chunk) + return np.add(matmul_result, acc) + + result = fori_loop(0, TILING_FACTOR, body, matmul_buffer) + return result + + run_kernel_test( + tiled_matmul, stop_after="trace", + rtol=0.01, atol=0.01, modes=Mode.LLVM, + ) + + +# ============================================================================ +# MLIR Generation Tests +# ============================================================================ + +def test_mlir_contains_scf_for(): + """Verify that MLIR contains scf.for operation.""" + @trace(input_specs=[((8,), "f32")]) + def loop_func(x): + def body(i, acc): + return acc + x[i] + + init = np.zeros((1,), dtype=np.float32) + return fori_loop(0, 8, body, init) + + run_kernel_test( + loop_func, stop_after="trace", + check_ir_contains=["scf.for", "scf.yield"], + modes=Mode.STRING_CHECK, + ) + + +def test_mlir_with_dynamic_slicing(): + """Verify MLIR generation with dynamic slicing.""" + @trace(input_specs=[((4, 8), "f16")]) + def dynamic_slice_func(x, TILE=4): + def body(i, acc): + chunk = x[:, i * TILE : (i + 1) * TILE] + return acc + np.sum(chunk) + + init = np.zeros((1,), dtype=np.float16) + return fori_loop(0, 2, body, init) + + run_kernel_test( + dynamic_slice_func, stop_after="trace", + check_ir_contains=["scf.for", "tensor.extract_slice"], + modes=Mode.STRING_CHECK, + ) + + +# ============================================================================ +# Edge Case Tests +# ============================================================================ + +def test_single_iteration(): + """Test loop with single iteration.""" + @trace(input_specs=[((4,), "f32")]) + def single_iter_loop(x): + def body(i, acc): + return acc + x[0:4] + + init = np.zeros((4,), dtype=np.float32) + return fori_loop(0, 1, body, init) + + run_kernel_test(single_iter_loop, stop_after="trace", modes=Mode.LLVM) + + +def test_small_tiling(): + """Test with small tile sizes.""" + @trace(input_specs=[((4, 4), "f32")]) + def small_tile_loop(x): + def body(i, acc): + chunk = x[:, i * 1 : (i + 1) * 1] + return acc + np.sum(chunk) + + init = np.zeros((1,), dtype=np.float32) + return fori_loop(0, 4, body, init) + + run_kernel_test( + small_tile_loop, stop_after="trace", + rtol=1e-4, atol=1e-4, modes=Mode.LLVM, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/unit/test_gather_ops.py b/kernelgen/tests/unit/test_gather_ops.py new file mode 100644 index 0000000..43845ad --- /dev/null +++ b/kernelgen/tests/unit/test_gather_ops.py @@ -0,0 +1,92 @@ +""" +Tests for gather / np.take operations. + +These tests verify that: + 1. nkipy.gather carries a linalg reference_impl region + 2. The reference region is correctly inlined for LLVM CPU simulation + 3. LLVM JIT execution matches NumPy +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace +from harness import run_kernel_test, Mode + + +# ============================================================================ +# np.take (axis=0 gather) +# ============================================================================ + +@pytest.mark.parametrize("vocab,embed,n_idx", [ + (128, 128, 8), + (256, 64, 16), + (128, 256, 4), +]) +def test_take_axis0(vocab, embed, n_idx): + """np.take along axis 0 — basic embedding lookup.""" + @trace(input_specs=[((vocab, embed), "f32"), ((n_idx,), "i32")]) + def kernel(table, indices): + return np.take(table, indices, axis=0) + + run_kernel_test( + kernel, stop_after="trace", + check_ir_contains=["nkipy.gather"], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + +def test_take_single_index(): + """np.take with a single-element index tensor.""" + @trace(input_specs=[((128, 64), "f32"), ((1,), "i32")]) + def kernel(table, indices): + return np.take(table, indices, axis=0) + + run_kernel_test( + kernel, stop_after="trace", + modes=Mode.LLVM, + ) + + +def test_take_large_embedding(): + """np.take with a larger table.""" + @trace(input_specs=[((512, 128), "f32"), ((32,), "i32")]) + def kernel(table, indices): + return np.take(table, indices, axis=0) + + run_kernel_test( + kernel, stop_after="trace", + modes=Mode.LLVM, + ) + + +# ============================================================================ +# np.take used in a larger computation +# ============================================================================ + +def test_take_then_add(): + """np.take followed by elementwise add — verifies gather result feeds + into downstream ops correctly after inlining.""" + vocab, embed, n_idx = 128, 128, 8 + + @trace(input_specs=[ + ((vocab, embed), "f32"), + ((n_idx,), "i32"), + ((n_idx, embed), "f32"), + ]) + def kernel(table, indices, bias): + gathered = np.take(table, indices, axis=0) + return np.add(gathered, bias) + + run_kernel_test( + kernel, stop_after="trace", + modes=Mode.LLVM, + ) + + +# ============================================================================ +# Test runner +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/unit/test_import_compatibility.py b/kernelgen/tests/unit/test_import_compatibility.py new file mode 100644 index 0000000..aad5f53 --- /dev/null +++ b/kernelgen/tests/unit/test_import_compatibility.py @@ -0,0 +1,11 @@ +"""Test that both MLIR IR modules can be imported without conflicts.""" + +import pytest + + +def test_dual_mlir_import(): + """Test importing both nkipy_kernelgen._mlir.ir and nki.compiler._internal.ir.""" + from nkipy_kernelgen._mlir import ir as nkipy_ir + + from nki.compiler._internal import ir as nki_ir + from nki.compiler._internal import register_all_dialects diff --git a/kernelgen/tests/unit/test_matrix_ops.py b/kernelgen/tests/unit/test_matrix_ops.py new file mode 100644 index 0000000..6ba8d8b --- /dev/null +++ b/kernelgen/tests/unit/test_matrix_ops.py @@ -0,0 +1,128 @@ +""" +Tests for matrix operations: matmul, matmul chains, and batched matmul. + +These tests verify that MLIR/LLVM execution matches NumPy CPU execution +for matrix operations. +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import nkipy_kernelgen_test, run_kernel_test, Mode + + +# ============================================================================ +# Matmul +# ============================================================================ + +@pytest.mark.parametrize("shape_a,shape_b", [ + ((128, 64), (64, 256)), # rectangular + ((256, 256), (256, 256)), # square + ((64, 128), (128, 64)), # different rectangular +]) +def test_matmul(shape_a, shape_b): + @trace(input_specs=[(shape_a, "f32"), (shape_b, "f32")]) + def kernel(A, B): + return np.matmul(A, B) + + run_kernel_test(kernel, stop_after="trace", modes=Mode.LLVM) + + +# ============================================================================ +# Chained ops +# ============================================================================ + +@nkipy_kernelgen_test( + input_specs=[((128, 64), "f32"), ((64, 256), "f32"), ((128, 256), "f32")], + stop_after="trace", + modes=Mode.LLVM, +) +def test_matmul_then_add(A, B, C): + """Matrix multiplication followed by addition.""" + temp = np.matmul(A, B) + return np.add(temp, C) + + +@nkipy_kernelgen_test( + input_specs=[((128, 64), "f32"), ((128, 64), "f32"), ((64, 256), "f32")], + stop_after="trace", + modes=Mode.LLVM, +) +def test_add_then_matmul(A, B, C): + """Addition followed by matrix multiplication.""" + temp = np.add(A, B) + return np.matmul(temp, C) + + +@nkipy_kernelgen_test( + input_specs=[((128, 64), "f32"), ((64, 256), "f32"), ((256, 128), "f32")], + stop_after="trace", + modes=Mode.LLVM, +) +def test_matmul_chain(A, B, C): + """Chained matrix multiplications.""" + temp = np.matmul(A, B) + return np.matmul(temp, C) + + +# ============================================================================ +# Batched matmul +# ============================================================================ + +@pytest.mark.parametrize("shape_a,shape_b", [ + ((4, 128, 64), (4, 64, 256)), # 3D batched + ((2, 4, 128, 64), (2, 4, 64, 256)), # 4D batched +]) +def test_batched_matmul(shape_a, shape_b): + @trace(input_specs=[(shape_a, "f32"), (shape_b, "f32")]) + def kernel(A, B): + return np.matmul(A, B) + + run_kernel_test(kernel, stop_after="trace", modes=Mode.LLVM) + + +# ============================================================================ +# Batched matmul (end-to-end with knobs) +# ============================================================================ + +BMM_CONFIGS = [ + (2, 256, 256, 256, [1, 128, 128], [128]), + (8, 256, 256, 256, [1, 128, 128], [128]), +] + + +def _bmm_config_id(val): + B, M, N, K, ts, rt = val + return f"b{B}_{M}x{N}x{K}_tile{ts[-1]}" + + +@pytest.mark.parametrize( + "B, M, N, K, tile_size, reduction_tile", + BMM_CONFIGS, + ids=[_bmm_config_id(c) for c in BMM_CONFIGS], +) +def test_bmm_e2e(B, M, N, K, tile_size, reduction_tile): + """End-to-end batched matmul: trace → knob-driven tiling → NISA → BIR sim.""" + @trace(input_specs=[((B, M, K), "f32"), ((B, K, N), "f32")]) + def bmm_kernel(a, b): + result = np.matmul(a, b) + knob.knob(result, mem_space="SharedHbm", + tile_size=tile_size, reduction_tile=reduction_tile) + return result + + run_kernel_test( + bmm_kernel, + stop_after="apply-and-strip-transforms", + modes=Mode.LLVM, + ) + run_kernel_test( + bmm_kernel, + check_ir_contains=["nisa.alloc", "nisa.matmul", "nisa.dma_transpose", + "nisa.dma_copy", "nisa.target"], + modes=Mode.BIR_SIM | Mode.STRING_CHECK, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/kernelgen/tests/unit/test_reduction_ops.py b/kernelgen/tests/unit/test_reduction_ops.py new file mode 100644 index 0000000..a27d12e --- /dev/null +++ b/kernelgen/tests/unit/test_reduction_ops.py @@ -0,0 +1,370 @@ +""" +Tests for reduction operations: sum, mean, max, min. + +These tests verify that MLIR/LLVM execution matches NumPy CPU execution +and that the tracer emits linalg.generic with reduction iterator types. +Tests with knobs also verify KnobDrivenTiling + apply-and-strip-transforms +produces tiled scf.for loops with SBUF promotion, and linalg-to-nisa +converts reduction generics to nisa.tensor_reduce_arith. +""" + +import pytest +import numpy as np + +from nkipy_kernelgen import trace, knob +from harness import nkipy_kernelgen_test, run_kernel_test, Mode + + +# ============================================================================ +# np.sum → linalg.generic (arith.addf) +# ============================================================================ + +@pytest.mark.parametrize("shape,axis,keepdims,tile_size,reduction_tile", [ + ((128, 256), -1, True, [64], [128]), # last axis, keepdims + ((128, 256), 0, False, [128], [64]), # first axis + ((128, 256), 1, False, [64], [128]), # last axis + ((128, 256), 1, True, [64], [128]), # last axis, keepdims + ((64, 128, 64), -1, True, [32, 64], [32]), # 3D, last axis, keepdims + ((64, 128, 64), 0, False, [64, 32], [32]), # 3D, first axis + ((64, 128, 64), 1, False, [32, 32], [64]), # 3D, middle axis +]) +def test_sum_axis(shape, axis, keepdims, tile_size, reduction_tile): + @trace(input_specs=[(shape, "f32")]) + def kernel(a): + result = np.sum(a, axis=axis, keepdims=keepdims) + knob.knob(result, tile_size=tile_size, reduction_tile=reduction_tile) + return result + + run_kernel_test( + kernel, stop_after="trace", + check_ir_contains=["linalg.generic"], + check_ir_not_contains=["linalg.reduce"], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + run_kernel_test( + kernel, + stop_after='apply-and-strip-transforms', + check_ir_contains=["scf.for", "memory_space = 3 : i32", "linalg.generic"], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + # TODO: enable linalg-to-nisa stage verification once we fix: + # - 1D output shapes (keepdims=False) crash nisa.dma_copy (needs 2D tiles) + # - non-rightmost reductions (axis=0) not supported by LinalgGenericReductionToNisaPattern + # - linalg.fill on HBM not yet lowered to NISA + # run_kernel_test( + # kernel, + # stop_after='linalg-to-nisa', + # check_ir_contains=["nisa.tensor_reduce_arith"], + # modes=Mode.STRING_CHECK, + # ) + + +@nkipy_kernelgen_test( + input_specs=[((128, 256), "f32")], + stop_after="trace", + check_ir_contains=["linalg.generic"], + check_ir_not_contains=["linalg.reduce"], + modes=Mode.LLVM | Mode.STRING_CHECK, +) +def test_sum_full_reduction(a): + """Sum all elements to a scalar (no parallel dims, not tileable).""" + return np.sum(a) + + +# ============================================================================ +# np.mean → linalg.generic (sum) + scalar divide +# ============================================================================ + +@pytest.mark.parametrize("shape,axis,keepdims,tile_size,reduction_tile", [ + ((128, 256), -1, True, [64], [128]), # last axis, keepdims + ((128, 256), 0, False, [128], [64]), # first axis + ((128, 256), 1, False, [64], [128]), # last axis + ((128, 256), 1, True, [64], [128]), # last axis, keepdims + ((64, 128, 64), -1, True, [32, 64], [32]), # 3D, last axis, keepdims + ((64, 128, 64), 0, False, [64, 32], [32]), # 3D, first axis + ((64, 128, 64), 1, False, [32, 32], [64]), # 3D, middle axis +]) +def test_mean_axis(shape, axis, keepdims, tile_size, reduction_tile): + norm_axis = axis % len(shape) + N = shape[norm_axis] + + @trace(input_specs=[(shape, "f32")]) + def kernel(a): + # Decompose mean as sum * (1/N) so we can annotate the reduction + sm = np.sum(a, axis=axis, keepdims=keepdims) + knob.knob(sm, tile_size=tile_size, reduction_tile=reduction_tile) + return sm * np.float32(1.0 / N) + + run_kernel_test( + kernel, stop_after="trace", + check_ir_contains=["linalg.generic"], + check_ir_not_contains=["linalg.reduce"], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + run_kernel_test( + kernel, + stop_after='apply-and-strip-transforms', + check_ir_contains=["scf.for", "memory_space = 3 : i32", "linalg.generic"], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + # TODO: enable linalg-to-nisa stage verification (same blockers as test_sum_axis, + # plus scalar multiply after reduction needs all intermediate allocs annotated) + # run_kernel_test( + # kernel, + # stop_after='linalg-to-nisa', + # check_ir_contains=["nisa.tensor_reduce_arith", "nisa.tensor_scalar_arith"], + # modes=Mode.STRING_CHECK, + # ) + + +@nkipy_kernelgen_test( + input_specs=[((128, 256), "f32")], + stop_after="trace", + check_ir_contains=["linalg.generic"], + check_ir_not_contains=["linalg.reduce"], + modes=Mode.LLVM | Mode.STRING_CHECK, +) +def test_mean_full_reduction(a): + """Mean of all elements to a scalar (no parallel dims, not tileable).""" + return np.mean(a) + + +# ============================================================================ +# np.max → linalg.generic (arith.maximumf) +# ============================================================================ + +@pytest.mark.parametrize("shape,axis,keepdims,tile_size,reduction_tile", [ + ((128, 256), -1, True, [64], [128]), # last axis, keepdims + ((128, 256), 0, False, [128], [64]), # first axis + ((128, 256), 1, False, [64], [128]), # last axis + ((128, 256), 1, True, [64], [128]), # last axis, keepdims + ((64, 128, 64), -1, True, [32, 64], [32]), # 3D, last axis, keepdims + ((64, 128, 64), 0, False, [64, 32], [32]), # 3D, first axis + ((64, 128, 64), 1, False, [32, 32], [64]), # 3D, middle axis +]) +def test_max_axis(shape, axis, keepdims, tile_size, reduction_tile): + @trace(input_specs=[(shape, "f32")]) + def kernel(a): + result = np.max(a, axis=axis, keepdims=keepdims) + knob.knob(result, tile_size=tile_size, reduction_tile=reduction_tile) + return result + + run_kernel_test( + kernel, stop_after="trace", + check_ir_contains=["linalg.generic"], + check_ir_not_contains=["linalg.reduce"], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + run_kernel_test( + kernel, + stop_after='apply-and-strip-transforms', + check_ir_contains=["scf.for", "memory_space = 3 : i32", "linalg.generic"], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + # TODO: enable linalg-to-nisa stage verification (same blockers as test_sum_axis) + # run_kernel_test( + # kernel, + # stop_after='linalg-to-nisa', + # check_ir_contains=["nisa.tensor_reduce_arith"], + # modes=Mode.STRING_CHECK, + # ) + + +@nkipy_kernelgen_test( + input_specs=[((128, 256), "f32")], + stop_after="trace", + check_ir_contains=["linalg.generic"], + check_ir_not_contains=["linalg.reduce"], + modes=Mode.LLVM | Mode.STRING_CHECK, +) +def test_max_full_reduction(a): + """Max of all elements to a scalar (no parallel dims, not tileable).""" + return np.max(a) + + +# ============================================================================ +# np.min → linalg.generic (arith.minimumf) +# ============================================================================ + +@pytest.mark.parametrize("shape,axis,keepdims,tile_size,reduction_tile", [ + ((128, 256), -1, True, [64], [128]), # last axis, keepdims + ((128, 256), 0, False, [128], [64]), # first axis + ((128, 256), 1, False, [64], [128]), # last axis + ((128, 256), 1, True, [64], [128]), # last axis, keepdims + ((64, 128, 64), -1, True, [32, 64], [32]), # 3D, last axis, keepdims + ((64, 128, 64), 0, False, [64, 32], [32]), # 3D, first axis + ((64, 128, 64), 1, False, [32, 32], [64]), # 3D, middle axis +]) +def test_min_axis(shape, axis, keepdims, tile_size, reduction_tile): + @trace(input_specs=[(shape, "f32")]) + def kernel(a): + result = np.min(a, axis=axis, keepdims=keepdims) + knob.knob(result, tile_size=tile_size, reduction_tile=reduction_tile) + return result + + run_kernel_test( + kernel, stop_after="trace", + check_ir_contains=["linalg.generic"], + check_ir_not_contains=["linalg.reduce"], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + run_kernel_test( + kernel, + stop_after='apply-and-strip-transforms', + check_ir_contains=["scf.for", "memory_space = 3 : i32", "linalg.generic"], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + # TODO: enable linalg-to-nisa stage verification (same blockers as test_sum_axis) + # run_kernel_test( + # kernel, + # stop_after='linalg-to-nisa', + # check_ir_contains=["nisa.tensor_reduce_arith"], + # modes=Mode.STRING_CHECK, + # ) + + +@nkipy_kernelgen_test( + input_specs=[((128, 256), "f32")], + stop_after="trace", + check_ir_contains=["linalg.generic"], + check_ir_not_contains=["linalg.reduce"], + modes=Mode.LLVM | Mode.STRING_CHECK, +) +def test_min_full_reduction(a): + """Min of all elements to a scalar (no parallel dims, not tileable).""" + return np.min(a) + + +# ============================================================================ +# Chained reductions (real-world patterns) +# ============================================================================ + +def test_sum_of_squares(): + """Pattern: sum(x^2) - used in RMSNorm variance computation.""" + @trace(input_specs=[((128, 256), "f32")]) + def kernel(a): + sq = np.square(a) + knob.knob(sq, tile_size=[64, 128]) + + result = np.sum(sq, axis=-1, keepdims=True) + knob.knob(result, tile_size=[64], reduction_tile=[128]) + return result + + run_kernel_test( + kernel, stop_after="trace", + check_ir_contains=["linalg.square", "linalg.generic"], + check_ir_not_contains=["linalg.reduce"], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + run_kernel_test( + kernel, + stop_after='apply-and-strip-transforms', + check_ir_contains=["scf.for", "memory_space = 3 : i32"], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + # TODO: enable linalg-to-nisa stage — crashes because intermediate + # tensor.empty() (from chaining square→reduce) gets memref.alloc without + # NISA memory space annotation + # run_kernel_test( + # kernel, + # stop_after='linalg-to-nisa', + # check_ir_contains=["nisa.activation", "nisa.tensor_reduce_arith"], + # modes=Mode.STRING_CHECK, + # ) + + +def test_mean_of_squares(): + """Pattern: mean(x^2) = sum(x^2) * (1/N) - variance without centering.""" + N = 256 + + @trace(input_specs=[((128, N), "f32")]) + def kernel(a): + sq = np.square(a) + knob.knob(sq, tile_size=[64, 128]) + + sm = np.sum(sq, axis=-1, keepdims=True) + knob.knob(sm, tile_size=[64], reduction_tile=[128]) + return sm * np.float32(1.0 / N) + + run_kernel_test( + kernel, stop_after="trace", + check_ir_contains=["linalg.square", "linalg.generic"], + check_ir_not_contains=["linalg.reduce"], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + run_kernel_test( + kernel, + stop_after='apply-and-strip-transforms', + check_ir_contains=["scf.for", "memory_space = 3 : i32"], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + # TODO: enable linalg-to-nisa stage — same intermediate alloc issue as + # test_sum_of_squares, plus scalar multiply chain + # run_kernel_test( + # kernel, + # stop_after='linalg-to-nisa', + # check_ir_contains=[ + # "nisa.activation", + # "nisa.tensor_reduce_arith", + # "nisa.tensor_scalar_arith", + # ], + # modes=Mode.STRING_CHECK, + # ) + + +def test_softmax_reductions(): + """Pattern: exp(x - max(x)) / sum(exp(x - max(x))) - softmax.""" + @trace(input_specs=[((128, 256), "f32")]) + def kernel(a): + a_max = np.max(a, axis=-1, keepdims=True) + knob.knob(a_max, tile_size=[64], reduction_tile=[128]) + + exp_a = np.exp(a - a_max) + + exp_sum = np.sum(exp_a, axis=-1, keepdims=True) + knob.knob(exp_sum, tile_size=[64], reduction_tile=[128]) + + return exp_a / exp_sum + + np.random.seed(42) + A = (np.random.randn(128, 256) * 0.5).astype(np.float32) + + run_kernel_test( + kernel, stop_after="trace", + inputs=[A], + modes=Mode.LLVM, + ) + + run_kernel_test( + kernel, + stop_after='apply-and-strip-transforms', + inputs=[A], + check_ir_contains=["scf.for", "memory_space = 3 : i32", "linalg.generic"], + modes=Mode.LLVM | Mode.STRING_CHECK, + ) + + # TODO: enable linalg-to-nisa stage — multi-output chained kernel with + # intermediate allocs that lack NISA memory space annotation + # run_kernel_test( + # kernel, + # stop_after='linalg-to-nisa', + # inputs=[A], + # check_ir_contains=["nisa.tensor_reduce_arith", "nisa.activation"], + # modes=Mode.STRING_CHECK, + # ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])