From 0a25818f542e6335131739eeb12dedb0e70af8a6 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Mon, 3 Nov 2025 14:25:20 +0100 Subject: [PATCH 1/7] [examples][mlir] Basic MLIR compilation and execution example Adds a simple end-to-end example demonstrating programatic transform schedule creation, MLIR JIT compilation, execution, and numerical verification of the result. Additionally, 'utils' submodule is added with basic tools to simplify creation of ctype arguments in format accepted by jitted function. --- python/examples/mlir/compile_and_run.py | 121 ++++++++---------------- python/lighthouse/utils/runtime_args.py | 3 + 2 files changed, 45 insertions(+), 79 deletions(-) diff --git a/python/examples/mlir/compile_and_run.py b/python/examples/mlir/compile_and_run.py index d0529f6..52ac45c 100644 --- a/python/examples/mlir/compile_and_run.py +++ b/python/examples/mlir/compile_and_run.py @@ -1,12 +1,22 @@ import torch +<<<<<<< HEAD import argparse +======= +import os +>>>>>>> 9647a7f ([examples][mlir] Basic MLIR compilation and execution example) from mlir import ir from mlir.dialects import transform from mlir.dialects.transform import structured from mlir.dialects.transform import interpreter from mlir.execution_engine import ExecutionEngine +<<<<<<< HEAD from mlir.passmanager import PassManager +======= +from mlir.runtime.np_to_memref import ( + get_ranked_memref_descriptor, +) +>>>>>>> 9647a7f ([examples][mlir] Basic MLIR compilation and execution example) from lighthouse import utils as lh_utils @@ -35,7 +45,7 @@ def create_kernel(ctx: ir.Context) -> ir.Module: def create_schedule(ctx: ir.Context) -> ir.Module: """ Create an MLIR module containing transformation schedule. - The schedule provides partial lowering to scalar operations. + The schedule provides necessary steps to lower the kernel to LLVM IR. Args: ctx: MLIR context. @@ -47,26 +57,25 @@ def create_schedule(ctx: ir.Context) -> ir.Module: ir.UnitAttr.get() ) - # For simplicity, use generic matchers without requiring specific types. - anytype = transform.any_op_t() - # Create entry point transformation sequence. with ir.InsertionPoint(schedule.body): named_seq = transform.NamedSequenceOp( - sym_name="__transform_main", - input_types=[anytype], - result_types=[], + "__transform_main", + [transform.AnyOpType.get()], + [], arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}], ) # Create the schedule. with ir.InsertionPoint(named_seq.body): + # For simplicity, use generic transform matchers. + anytype = transform.AnyOpType.get() + # Find the kernel's function op. func = structured.MatchOp.match_op_names( named_seq.bodyTarget, ["func.func"] ) - # Use C interface wrappers - required to make function executable - # after jitting. + # Use C interface wrappers - required to make function executable after jitting. func = transform.apply_registered_pass( anytype, func, "llvm-request-c-wrappers" ) @@ -80,16 +89,22 @@ def create_schedule(ctx: ir.Context) -> ir.Module: anytype, mod, "convert-linalg-to-loops" ) # Cleanup. - transform.apply_cse(mod) + transform.ApplyCommonSubexpressionEliminationOp(mod) with ir.InsertionPoint(transform.ApplyPatternsOp(mod).patterns): - transform.apply_patterns_canonicalization() + transform.ApplyCanonicalizationPatternsOp() + # Lower to LLVM. + mod = transform.apply_registered_pass(anytype, mod, "convert-scf-to-cf") + mod = transform.apply_registered_pass(anytype, mod, "convert-to-llvm") + mod = transform.apply_registered_pass( + anytype, mod, "reconcile-unrealized-casts" + ) # Terminate the schedule. - transform.yield_([]) + transform.YieldOp() return schedule -def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None: +def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> ir.Module: """ Apply transformation schedule to a kernel module. The kernel is modified in-place. @@ -105,29 +120,8 @@ def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None: ) -def create_pass_pipeline(ctx: ir.Context) -> PassManager: - """ - Create an MLIR pass pipeline. - The pipeline lowers operations further down to LLVM dialect. - - Args: - ctx: MLIR context. - """ - with ctx: - # Create a pass manager that applies passes to the whole module. - pm = PassManager("builtin.module") - # Lower to LLVM. - pm.add("convert-scf-to-cf") - pm.add("convert-to-llvm") - pm.add("reconcile-unrealized-casts") - # Cleanup - pm.add("cse") - pm.add("canonicalize") - return pm - - # The example's entry point. -def main(args): +def main(): ### Baseline computation ### # Create inputs. a = torch.randn(16, 32, dtype=torch.float32) @@ -137,50 +131,36 @@ def main(args): out_ref = torch.add(a, b) ### MLIR payload preparation ### - # Create payload kernel. + # Create payload kernel and lowering schedule. ctx = ir.Context() kernel = create_kernel(ctx) - - # Create a transform schedule and apply initial lowering. schedule = create_schedule(ctx) + # Lower the kernel to LLVM dialect. apply_schedule(kernel, schedule) - # Create a pass pipeline and lower the kernel to LLVM dialect. - pm = create_pass_pipeline(ctx) - pm.run(kernel.operation) - ### Compilation ### - # Parse additional libraries if present. + # External shared libraries, containing MLIR runner utilities, are are generally + # required to execute the compiled module. # - # External shared libraries, runtime utilities, might be needed to execute - # the compiled module. - # The execution engine requires full paths to the libraries. - mlir_libs = [] - if args.shared_libs: - mlir_libs += args.shared_libs.split(",") + # Get paths to MLIR runner shared libraries through an environment variable. + mlir_libs = os.environ.get("LIGHTHOUSE_SHARED_LIBS").split(":") # JIT the kernel. eng = ExecutionEngine(kernel, opt_level=2, shared_libs=mlir_libs) - - # Initialize the JIT engine. - # - # The deferred initialization executes global constructors that might - # have been created by the module during engine creation (for example, - # when `gpu.module` is present) or registered afterwards. - # - # Initialization is not strictly necessary in this case. - # However, it is a good practice to perform it regardless. - eng.initialize() - # Get the kernel function. add_func = eng.lookup("add") ### Execution ### + # Create corresponding memref descriptors containing input data. + a_mem = get_ranked_memref_descriptor(a.numpy()) + b_mem = get_ranked_memref_descriptor(b.numpy()) + # Create an empty buffer to hold results. out = torch.empty_like(out_ref) + out_mem = get_ranked_memref_descriptor(out.numpy()) # Execute the kernel. - args = lh_utils.torch_to_packed_args([a, b, out]) + args = lh_utils.memrefs_to_packed_args([a_mem, b_mem, out_mem]) add_func(args) ### Verification ### @@ -192,21 +172,4 @@ def main(args): if __name__ == "__main__": - parser = argparse.ArgumentParser() - - # External shared libraries, runtime utilities, might be needed to - # execute the compiled module. - # For example, MLIR runner utils libraries such as: - # - libmlir_runner_utils.so - # - libmlir_c_runner_utils.so - # - # Full paths to the libraries should be provided. - # For example: - # --shared-libs=$LLVM_BUILD/lib/lib1.so,$LLVM_BUILD/lib/lib2.so - parser.add_argument( - "--shared-libs", - type=str, - help="Comma-separated list of libraries to link dynamically", - ) - args = parser.parse_args() - main(args) + main() diff --git a/python/lighthouse/utils/runtime_args.py b/python/lighthouse/utils/runtime_args.py index 6719896..3d88242 100644 --- a/python/lighthouse/utils/runtime_args.py +++ b/python/lighthouse/utils/runtime_args.py @@ -1,9 +1,12 @@ import ctypes +<<<<<<< HEAD import torch from mlir.runtime.np_to_memref import ( get_ranked_memref_descriptor, ) +======= +>>>>>>> 9647a7f ([examples][mlir] Basic MLIR compilation and execution example) def get_packed_arg(ctypes_args) -> list[ctypes.c_void_p]: From 4fb023f07f5a79236521732f2ef7207df4aa75a8 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Mon, 3 Nov 2025 14:36:23 +0100 Subject: [PATCH 2/7] Fix return type --- python/examples/mlir/compile_and_run.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/examples/mlir/compile_and_run.py b/python/examples/mlir/compile_and_run.py index 52ac45c..350da64 100644 --- a/python/examples/mlir/compile_and_run.py +++ b/python/examples/mlir/compile_and_run.py @@ -104,7 +104,7 @@ def create_schedule(ctx: ir.Context) -> ir.Module: return schedule -def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> ir.Module: +def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None: """ Apply transformation schedule to a kernel module. The kernel is modified in-place. @@ -135,6 +135,7 @@ def main(): ctx = ir.Context() kernel = create_kernel(ctx) schedule = create_schedule(ctx) + # Lower the kernel to LLVM dialect. apply_schedule(kernel, schedule) From 7416cbf6b37bb7b9150466a1b26748331b1d65fd Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Tue, 4 Nov 2025 10:23:41 +0100 Subject: [PATCH 3/7] Split lowering and add pass pipeline example --- python/examples/mlir/compile_and_run.py | 48 +++++++++++++++---------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/python/examples/mlir/compile_and_run.py b/python/examples/mlir/compile_and_run.py index 350da64..325bdf4 100644 --- a/python/examples/mlir/compile_and_run.py +++ b/python/examples/mlir/compile_and_run.py @@ -1,22 +1,15 @@ import torch -<<<<<<< HEAD -import argparse -======= import os ->>>>>>> 9647a7f ([examples][mlir] Basic MLIR compilation and execution example) from mlir import ir from mlir.dialects import transform from mlir.dialects.transform import structured from mlir.dialects.transform import interpreter from mlir.execution_engine import ExecutionEngine -<<<<<<< HEAD from mlir.passmanager import PassManager -======= from mlir.runtime.np_to_memref import ( get_ranked_memref_descriptor, ) ->>>>>>> 9647a7f ([examples][mlir] Basic MLIR compilation and execution example) from lighthouse import utils as lh_utils @@ -45,7 +38,7 @@ def create_kernel(ctx: ir.Context) -> ir.Module: def create_schedule(ctx: ir.Context) -> ir.Module: """ Create an MLIR module containing transformation schedule. - The schedule provides necessary steps to lower the kernel to LLVM IR. + The schedule provides partial lowering to scalar operations. Args: ctx: MLIR context. @@ -92,12 +85,6 @@ def create_schedule(ctx: ir.Context) -> ir.Module: transform.ApplyCommonSubexpressionEliminationOp(mod) with ir.InsertionPoint(transform.ApplyPatternsOp(mod).patterns): transform.ApplyCanonicalizationPatternsOp() - # Lower to LLVM. - mod = transform.apply_registered_pass(anytype, mod, "convert-scf-to-cf") - mod = transform.apply_registered_pass(anytype, mod, "convert-to-llvm") - mod = transform.apply_registered_pass( - anytype, mod, "reconcile-unrealized-casts" - ) # Terminate the schedule. transform.YieldOp() @@ -120,6 +107,27 @@ def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None: ) +def create_pass_pipeline(ctx: ir.Context) -> PassManager: + """ + Create an MLIR pass pipeline. + The pipeline lowers operations further down to LLVM dialect. + + Args: + ctx: MLIR context. + """ + with ctx: + # Create a pass manager that applies passes to the whole module. + pm = PassManager("builtin.module") + # Lower to LLVM. + pm.add("convert-scf-to-cf") + pm.add("convert-to-llvm") + pm.add("reconcile-unrealized-casts") + # Cleanup + pm.add("cse") + pm.add("canonicalize") + return pm + + # The example's entry point. def main(): ### Baseline computation ### @@ -131,20 +139,24 @@ def main(): out_ref = torch.add(a, b) ### MLIR payload preparation ### - # Create payload kernel and lowering schedule. + # Create payload kernel. ctx = ir.Context() kernel = create_kernel(ctx) - schedule = create_schedule(ctx) - # Lower the kernel to LLVM dialect. + # Create a transform schedule and apply initial lowering. + schedule = create_schedule(ctx) apply_schedule(kernel, schedule) + # Create a pass pipeline and lower the kernel to LLVM dialect. + pm = create_pass_pipeline(ctx) + pm.run(kernel.operation) + ### Compilation ### # External shared libraries, containing MLIR runner utilities, are are generally # required to execute the compiled module. # # Get paths to MLIR runner shared libraries through an environment variable. - mlir_libs = os.environ.get("LIGHTHOUSE_SHARED_LIBS").split(":") + mlir_libs = os.environ.get("LIGHTHOUSE_SHARED_LIBS", default="").split(":") # JIT the kernel. eng = ExecutionEngine(kernel, opt_level=2, shared_libs=mlir_libs) From f5c65dda93e46b7adafba45ffc3b077bcb1b4aca Mon Sep 17 00:00:00 2001 From: Petr Kurapov Date: Mon, 10 Nov 2025 16:28:32 +0100 Subject: [PATCH 4/7] Add the layout for testing --- ingress/mlir-gen/mlir_gen/test/__init__.py | 1 + ingress/mlir-gen/mlir_gen/test/conftest.py | 56 ++ ingress/mlir-gen/mlir_gen/test/model.py | 302 +++++++++ ingress/mlir-gen/mlir_gen/test/test_core.py | 651 ++++++++++++++++++++ pyproject.toml | 19 + python/lighthouse/utils/__init__.py | 1 + python/lighthouse/utils/runtime_args.py | 40 +- 7 files changed, 1067 insertions(+), 3 deletions(-) create mode 100644 ingress/mlir-gen/mlir_gen/test/__init__.py create mode 100644 ingress/mlir-gen/mlir_gen/test/conftest.py create mode 100644 ingress/mlir-gen/mlir_gen/test/model.py create mode 100644 ingress/mlir-gen/mlir_gen/test/test_core.py diff --git a/ingress/mlir-gen/mlir_gen/test/__init__.py b/ingress/mlir-gen/mlir_gen/test/__init__.py new file mode 100644 index 0000000..417887a --- /dev/null +++ b/ingress/mlir-gen/mlir_gen/test/__init__.py @@ -0,0 +1 @@ +"""Tests for MLIR generation.""" diff --git a/ingress/mlir-gen/mlir_gen/test/conftest.py b/ingress/mlir-gen/mlir_gen/test/conftest.py new file mode 100644 index 0000000..ec9d8eb --- /dev/null +++ b/ingress/mlir-gen/mlir_gen/test/conftest.py @@ -0,0 +1,56 @@ +""" +Pytest configuration for MLIR generation tests. + +This file sets up fixtures and configuration for tests in this directory. +""" + +import os +import pytest + + +@pytest.fixture(scope="session", autouse=True) +def setup_mlir_environment(): + """ + Set up MLIR environment variables for testing. + This runs once per test session. + """ + # Set environment variable for MLIR shared libraries if needed + # The default empty string is fine for most cases + if "LIGHTHOUSE_SHARED_LIBS" not in os.environ: + os.environ["LIGHTHOUSE_SHARED_LIBS"] = "" + + yield + + # Cleanup after all tests (if needed) + pass + + +@pytest.fixture +def mlir_context(): + """ + Provide a fresh MLIR context for each test. + """ + from mlir import ir + + return ir.Context() + + +@pytest.fixture +def sample_shapes(): + """ + Provide common tensor shapes for testing. + """ + return [ + (4, 16), + (8, 8), + (16, 32), + (1, 64), + ] + + +@pytest.fixture +def sample_types(): + """ + Provide common element types for testing. + """ + return ["f32", "f64"] diff --git a/ingress/mlir-gen/mlir_gen/test/model.py b/ingress/mlir-gen/mlir_gen/test/model.py new file mode 100644 index 0000000..e388c03 --- /dev/null +++ b/ingress/mlir-gen/mlir_gen/test/model.py @@ -0,0 +1,302 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import fairscale.nn.model_parallel.initialize as fs_init +import torch +import torch.nn.functional as F +from fairscale.nn.model_parallel.layers import ( + ColumnParallelLinear, + RowParallelLinear, + VocabParallelEmbedding, +) +from torch import nn + + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + rope_theta: float = 500000 + + max_batch_size: int = 32 + max_seq_len: int = 2048 + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + model_parallel_size = fs_init.get_model_parallel_world_size() + self.n_local_heads = args.n_heads // model_parallel_size + self.n_local_kv_heads = self.n_kv_heads // model_parallel_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.dim // args.n_heads + + self.wq = ColumnParallelLinear( + args.dim, + args.n_heads * self.head_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.wk = ColumnParallelLinear( + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.wv = ColumnParallelLinear( + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.wo = RowParallelLinear( + args.n_heads * self.head_dim, + args.dim, + bias=False, + input_is_parallel=True, + init_method=lambda x: x, + ) + + self.cache_k = torch.zeros( + ( + args.max_batch_size, + args.max_seq_len, + self.n_local_kv_heads, + self.head_dim, + ) + ).cuda() + self.cache_v = torch.zeros( + ( + args.max_batch_size, + args.max_seq_len, + self.n_local_kv_heads, + self.head_dim, + ) + ).cuda() + + def forward( + self, + x: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ): + bsz, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + self.cache_k = self.cache_k.to(xq) + self.cache_v = self.cache_v.to(xq) + + self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk + self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv + + keys = self.cache_k[:bsz, : start_pos + seqlen] + values = self.cache_v[:bsz, : start_pos + seqlen] + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv( + keys, self.n_rep + ) # (bs, cache_len + seqlen, n_local_heads, head_dim) + values = repeat_kv( + values, self.n_rep + ) # (bs, cache_len + seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) + values = values.transpose( + 1, 2 + ) # (bs, n_local_heads, cache_len + seqlen, head_dim) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + if mask is not None: + scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = ColumnParallelLinear( + dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x + ) + self.w2 = RowParallelLinear( + hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x + ) + self.w3 = ColumnParallelLinear( + dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x + ) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, layer_id: int, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.dim // args.n_heads + self.attention = Attention(args) + self.feed_forward = FeedForward( + dim=args.dim, + hidden_dim=4 * args.dim, + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + ) + self.layer_id = layer_id + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + + def forward( + self, + x: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ): + h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +class Transformer(nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.params = params + self.vocab_size = params.vocab_size + self.n_layers = params.n_layers + + self.tok_embeddings = VocabParallelEmbedding( + params.vocab_size, params.dim, init_method=lambda x: x + ) + + self.layers = torch.nn.ModuleList() + for layer_id in range(params.n_layers): + self.layers.append(TransformerBlock(layer_id, params)) + + self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.output = ColumnParallelLinear( + params.dim, params.vocab_size, bias=False, init_method=lambda x: x + ) + + self.freqs_cis = precompute_freqs_cis( + params.dim // params.n_heads, + params.max_seq_len * 2, + params.rope_theta, + ) + + @torch.inference_mode() + def forward(self, tokens: torch.Tensor, start_pos: int): + _bsz, seqlen = tokens.shape + h = self.tok_embeddings(tokens) + self.freqs_cis = self.freqs_cis.to(h.device) + freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] + + mask = None + if seqlen > 1: + mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device) + + mask = torch.triu(mask, diagonal=1) + + # When performing key-value caching, we compute the attention scores + # only for the new sequence. Thus, the matrix of scores is of size + # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for + # j > cache_len + i, since row i corresponds to token cache_len + i. + mask = torch.hstack( + [torch.zeros((seqlen, start_pos), device=tokens.device), mask] + ).type_as(h) + + for layer in self.layers: + h = layer(h, start_pos, freqs_cis, mask) + h = self.norm(h) + output = self.output(h).float() + return output diff --git a/ingress/mlir-gen/mlir_gen/test/test_core.py b/ingress/mlir-gen/mlir_gen/test/test_core.py new file mode 100644 index 0000000..d638dc9 --- /dev/null +++ b/ingress/mlir-gen/mlir_gen/test/test_core.py @@ -0,0 +1,651 @@ +import pytest +import torch +from typing import Tuple + +from mlir import ir +from mlir.dialects import transform, func, linalg, tensor, arith, complex +from mlir.dialects.transform import structured +from mlir.dialects.transform import interpreter +from mlir.passmanager import PassManager +from mlir.runtime.np_to_memref import ( + get_ranked_memref_descriptor, +) +from mlir.execution_engine import ExecutionEngine + + +from lighthouse import utils as lh_utils + + +def affine_map(dim_count, exprs, *, symb_count=0): + return ir.AffineMap.get(dim_count, symb_count, exprs) + + +parallel = linalg.IteratorType.parallel +reduction = linalg.IteratorType.reduction + + +def create_pass_pipeline(ctx: ir.Context) -> PassManager: + with ctx: + pm = PassManager("builtin.module") + pm.add("convert-scf-to-cf") + pm.add("finalize-memref-to-llvm") + pm.add("convert-func-to-llvm") + pm.add("convert-to-llvm") + pm.add("reconcile-unrealized-casts") + pm.add("cse") + pm.add("canonicalize") + return pm + + +def create_schedule(ctx: ir.Context) -> ir.Module: + """ + Create an MLIR module containing transformation schedule. + The schedule provides partial lowering to scalar operations. + + Args: + ctx: MLIR context. + """ + with ctx, ir.Location.unknown(context=ctx): + # Create transform module. + schedule = ir.Module.create() + schedule.operation.attributes["transform.with_named_sequence"] = ( + ir.UnitAttr.get() + ) + + # Create entry point transformation sequence. + with ir.InsertionPoint(schedule.body): + named_seq = transform.NamedSequenceOp( + "__transform_main", + [transform.AnyOpType.get()], + [], + arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}], + ) + + # Create the schedule. + with ir.InsertionPoint(named_seq.body): + # For simplicity, use generic transform matchers. + anytype = transform.AnyOpType.get() + + # Find the kernel's function op. + func = structured.MatchOp.match_op_names( + named_seq.bodyTarget, ["func.func"] + ) + + # Use C interface wrappers - required to make function executable after jitting. + func = transform.apply_registered_pass( + anytype, func, "llvm-request-c-wrappers" + ) + + # Find the kernel's module op. + mod = transform.get_parent_op( + anytype, func, op_name="builtin.module", deduplicate=True + ) + + # Naive lowering to loops. + mod = transform.apply_registered_pass( + anytype, mod, "convert-linalg-to-loops" + ) + # Cleanup. + transform.ApplyCommonSubexpressionEliminationOp(mod) + with ir.InsertionPoint(transform.ApplyPatternsOp(mod).patterns): + transform.ApplyCanonicalizationPatternsOp() + + # Terminate the schedule. + transform.YieldOp() + return schedule + + +def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None: + interpreter.apply_named_sequence( + payload_root=kernel, + transform_root=schedule.body.operations[0], + transform_module=schedule, + ) + + +def bufferize_module(ctx: ir.Context, kernel: ir.Module) -> None: + with ctx: + pm = PassManager("builtin.module") + pm.add("one-shot-bufferize{bufferize-function-boundaries}") + pm.run(kernel.operation) + + +#### IR builders ##### +# TODO: Move to mlir_gen module + + +def get_add(a: ir.Value, b: ir.Value, out: ir.Value) -> ir.Value: + return linalg.add(a, b, outs=(out,)) + + +def get_rsqrt(a: ir.Value, out: ir.Value) -> ir.Value: + return linalg.rsqrt(a, outs=(out,)) + + +def get_powf(a: ir.Value, out: ir.Value) -> ir.Value: + return linalg.powf(a, outs=(out,)) + + +def get_sqr(a: ir.Value, out: ir.Value) -> ir.Value: + return linalg.square(a, outs=(out,)) + + +def get_mul(a: ir.Value, b: ir.Value, out: ir.Value) -> ir.Value: + return linalg.mul(a, b, outs=(out,)) + + +# equvialent to torch.mean(-1, keepdim=True) +def get_mean(a: ir.Value, out: ir.Value) -> ir.Value: + # Need to initialize the output with zeros for accumulation + zero = arith.ConstantOp(ir.F32Type.get(), 0.0) + out_filled = linalg.fill(zero, outs=[out]) + + # Input map: (d0, d1) -> (d0, d1) + input_map = affine_map( + a.type.rank, + [ir.AffineDimExpr.get(i) for i in range(a.type.rank)], + ) + # Output map: (d0, d1) -> (d0, 0) + output_map = affine_map( + a.type.rank, + [ir.AffineDimExpr.get(i) for i in range(a.type.rank - 1)] + + [ir.AffineConstantExpr.get(0)], + ) + iterator_types = [parallel] * (a.type.rank - 1) + [reduction] + + scale = arith.ConstantOp(ir.F32Type.get(), 1.0 / a.type.shape[-1]) + + @linalg.generic( + [a], + [out_filled], + [input_map, output_map], + iterator_types, + ) + def mean_op(a_val, acc): + # Multiply input by scale factor and add to accumulator + scaled = arith.mulf(a_val, scale) + return arith.addf(scaled, acc) + + return mean_op + + +def get_l2_norm(a: ir.Value, out: ir.Value, eps: float = 1e-5) -> ir.Value: + """ + Compute x * rsqrt(mean(x^2, dim=-1, keepdim=True) + eps) + + Args: + a: Input tensor + eps: Epsilon value as a tensor with reduced shape [..., 1] + out: Output tensor + """ + elty = a.type.element_type + # Broadcast epsilon scalar to tensor with reduced shape + reduced_shape = list(a.type.shape) + reduced_shape[-1] = 1 + eps_const = arith.ConstantOp(elty, eps) + eps_tensor_uninit = tensor.EmptyOp(reduced_shape, elty) + eps_tensor = linalg.fill(eps_const, outs=[eps_tensor_uninit]) + # Square the input + squared_input = tensor.EmptyOp(a.type.shape, elty) + sqr = get_sqr(a, squared_input) + + # Compute mean along last dimension + reduced_shape = list(a.type.shape) + reduced_shape[-1] = 1 + mean_uninit = tensor.EmptyOp(reduced_shape, elty) + + mean = get_mean(sqr, mean_uninit) + mean_plus_eps = get_add(mean, eps_tensor, mean_uninit) + rsqrt_reduced = get_rsqrt(mean_plus_eps, mean_uninit) + + # (d0, d1) -> (d0, 0) for input, (d0, d1) -> (d0, d1) for output + input_map = affine_map( + a.type.rank, + [ir.AffineDimExpr.get(i) for i in range(a.type.rank - 1)] + + [ir.AffineConstantExpr.get(0)], + ) + output_map = affine_map( + a.type.rank, + [ir.AffineDimExpr.get(i) for i in range(a.type.rank)], + ) + iterator_types = [parallel] * a.type.rank + + @linalg.generic( + [rsqrt_reduced], + [out], + [input_map, output_map], + iterator_types, + ) + def broadcast_rsqrt(val, _out): + return val + + return get_mul(a, broadcast_rsqrt, out) + + +def get_rotary_emb( + xq: ir.Value, xk: ir.Value, freqs_cis: ir.Value, xq_out: ir.Value, xk_out: ir.Value +): + """ + Apply rotary embeddings to query and key tensors. + + This implements the transformation: + 1. View xq, xk as complex: [B, S, H, D] -> [B, S, H, D//2] complex + 2. Broadcast freqs_cis: [S, D//2] -> [1, S, 1, D//2] + 3. Complex multiply: xq_ * freqs_cis, xk_ * freqs_cis + 4. View back as real: [B, S, H, D//2] complex -> [B, S, H, D] real + + Args: + xq: Query tensor of shape [B, S, H, D] + xk: Key tensor of shape [B, S, H_kv, D] + freqs_cis: Rotary embeddings of shape [S, D//2] + Note: In PyTorch this is complex64, but here it's f32 + We need to interpret as pairs or pass cos/sin separately + xq_out: Output tensor for queries [B, S, H, D] + xk_out: Output tensor for keys [B, S, H_kv, D] + + TODO: Properly implement rotary embeddings + Current implementation is just a placeholder passthrough. + + For a correct implementation, we need to: + 1. Either: + a) Pass freqs_cis as [S, D] with interleaved cos/sin values, OR + b) Pass separate cos and sin tensors of shape [S, D//2] + 2. Extract pairs of elements from xq/xk to treat as complex (real, imag) + 3. Apply complex rotation: + (real', imag') = (real*cos - imag*sin, real*sin + imag*cos) + 4. Interleave results back into output + """ + elty = xq.type.element_type + + # Get shapes + xq_shape = list(xq.type.shape) # [B, S, H, D] + xk_shape = list(xk.type.shape) # [B, S, H_kv, D] + + # Placeholder implementation: just copy inputs to outputs + # This allows the test infrastructure to work but doesn't compute correct results + + b, s, h, d = [ir.AffineDimExpr.get(i) for i in range(4)] + + @linalg.generic( + [xq], + [xq_out], + [affine_map(4, [b, s, h, d]), affine_map(4, [b, s, h, d])], + [parallel] * 4, + ) + def copy_xq(x, _out): + return x + + @linalg.generic( + [xk], + [xk_out], + [ + affine_map(4, [b, s, ir.AffineDimExpr.get(2), d]), + affine_map(4, [b, s, ir.AffineDimExpr.get(2), d]), + ], + [parallel] * 4, + ) + def copy_xk(x, _out): + return x + + return (copy_xq, copy_xk) + + +def get_as_complex(x: ir.Value, out: ir.Value) -> ir.Value: + """ + Interpret the input tensor as complex numbers by grouping pairs of elements. + + Args: + x: Input tensor of shape [..., 2] representing complex numbers as pairs (real, imag) + out: Output tensor of shape [...] with complex type + """ + elty = x.type.element_type + rank = x.type.rank + shape = list(x.type.shape) + assert shape[-1] == 2, "Last dimension must be of size 2 to form complex numbers" + complex_shape = shape[:-1] + + dim_exprs_in = [ir.AffineDimExpr.get(i) for i in range(rank)] + dim_exprs_out = [ir.AffineDimExpr.get(i) for i in range(rank - 1)] + + input_map = affine_map( + rank, + dim_exprs_in, + ) + output_map = affine_map( + rank - 1, + dim_exprs_out, + ) + iterator_types = [parallel] * (rank - 1) + + @linalg.generic( + [x], + [out], + [input_map, output_map], + iterator_types, + ) + def as_complex_op(a, _out): + real_part = a[0] + imag_part = a[1] + cplx = complex.CreateOp( + complex.ComplexType.get(elty), real_part, imag_part + ).result + return cplx + + return as_complex_op + + +#### Test cases ##### + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def rotary_emb_ref( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +references = { + get_add: torch.add, + get_mul: torch.mul, + get_rsqrt: torch.rsqrt, + get_sqr: torch.square, + get_mean: lambda x: torch.mean(x, dim=-1, keepdim=True), + get_l2_norm: lambda x, eps: x + * torch.rsqrt(torch.mean(x.pow(2), dim=-1, keepdim=True) + eps), + get_rotary_emb: rotary_emb_ref, +} + + +# TODO: torch_dtype_to_mlir_type +def to_ir_type(type_str, ctx): + if type_str == "f32": + return ir.F32Type.get(context=ctx) + elif type_str == "f64": + return ir.F64Type.get(context=ctx) + else: + raise ValueError(f"Unsupported type: {type_str}") + + +@pytest.mark.parametrize( + "op,shape,elem_type", [(get_add, (4, 16), "f32"), (get_mul, (4, 16), "f32")] +) +def test_bin_op(op, shape, elem_type): + def generate_module(ctx, elty): + with ctx, ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + tensor_type = ir.RankedTensorType.get(shape, elty) + + @func.FuncOp.from_py_func( + tensor_type, tensor_type, tensor_type, name="bin_op" + ) + def bin_op(a, b, out): + op(a, b, out) + + return module + + ctx = ir.Context() + ir_type = to_ir_type(elem_type, ctx) + module = generate_module(ctx, ir_type) + bufferize_module(ctx, module) + schedule = create_schedule(ctx) + apply_schedule(module, schedule) + pm = create_pass_pipeline(ctx) + pm.run(module.operation) + + eng = ExecutionEngine(module, opt_level=2) + func_ptr = eng.lookup("bin_op") + + torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + a = torch.randn(*shape, dtype=torch_dtype) + b = torch.randn(*shape, dtype=torch_dtype) + out_ref = references[op](a, b) + out = torch.empty_like(out_ref) + + a_mem = get_ranked_memref_descriptor(a.numpy()) + b_mem = get_ranked_memref_descriptor(b.numpy()) + out_mem = get_ranked_memref_descriptor(out.numpy()) + args = lh_utils.memrefs_to_packed_args([a_mem, b_mem, out_mem]) + func_ptr(args) + + assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) + + +@pytest.mark.parametrize( + "op,shape,elem_type", + [ + (get_rsqrt, (4, 16), "f32"), + (get_mean, (4, 16), "f32"), + (get_sqr, (4, 16), "f32"), + ], +) +def test_unary_op(op, shape, elem_type): + def generate_module(ctx, elty): + with ctx, ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + tensor_type = ir.RankedTensorType.get(shape, elty) + + # For mean operation, output has different shape (reduction on last dim) + if op == get_mean: + out_shape = list(shape) + out_shape[-1] = 1 + out_tensor_type = ir.RankedTensorType.get(out_shape, elty) + else: + out_tensor_type = tensor_type + + @func.FuncOp.from_py_func(tensor_type, out_tensor_type, name="unary_op") + def unary_op(a, out): + op(a, out) + + return module + + ctx = ir.Context() + ir_type = to_ir_type(elem_type, ctx) + module = generate_module(ctx, ir_type) + bufferize_module(ctx, module) + schedule = create_schedule(ctx) + apply_schedule(module, schedule) + pm = create_pass_pipeline(ctx) + pm.run(module.operation) + + eng = ExecutionEngine(module, opt_level=2) + func_ptr = eng.lookup("unary_op") + + torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + a = torch.randn(*shape, dtype=torch_dtype) + out_ref = references[op](a) + out = torch.empty_like(out_ref) + + a_mem = get_ranked_memref_descriptor(a.numpy()) + out_mem = get_ranked_memref_descriptor(out.numpy()) + args = lh_utils.memrefs_to_packed_args([a_mem, out_mem]) + func_ptr(args) + + assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) + + +@pytest.mark.parametrize("shape,elem_type", [((4, 16), "f32")]) +def test_rms_norm(shape, elem_type): + eps = 1e-5 + + def generate_module(ctx, elty): + with ctx, ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + input_type = ir.RankedTensorType.get(shape, elty) + + @func.FuncOp.from_py_func(input_type, input_type, name="rms_norm") + def rms_norm(a, out): + get_l2_norm(a, out, eps) + + return module + + ctx = ir.Context() + ir_type = to_ir_type(elem_type, ctx) + module = generate_module(ctx, ir_type) + print(module) + bufferize_module(ctx, module) + schedule = create_schedule(ctx) + apply_schedule(module, schedule) + pm = create_pass_pipeline(ctx) + pm.run(module.operation) + eng = ExecutionEngine(module, opt_level=2) + func_ptr = eng.lookup("rms_norm") + torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + a = torch.randn(*shape, dtype=torch_dtype) + out_ref = references[get_l2_norm](a, eps) + out = torch.empty_like(out_ref) + a_mem = get_ranked_memref_descriptor(a.numpy()) + out_mem = get_ranked_memref_descriptor(out.numpy()) + args = lh_utils.memrefs_to_packed_args([a_mem, out_mem]) + func_ptr(args) + + assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) + + +@pytest.mark.parametrize( + "batch_size,seq_len,n_heads,head_dim,n_kv_heads,elem_type", + [(2, 512, 32, 128, 8, "f32")], +) +def test_rotary_emb(batch_size, seq_len, n_heads, head_dim, n_kv_heads, elem_type): + def generate_module(ctx, elty, xq_shape, xk_shape, freqs_cis_shape): + with ctx, ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + xq_type = ir.RankedTensorType.get(xq_shape, elty) + xk_type = ir.RankedTensorType.get(xk_shape, elty) + freqs_cis_type = ir.RankedTensorType.get(freqs_cis_shape, elty) + + @func.FuncOp.from_py_func( + xq_type, + xk_type, + freqs_cis_type, + xq_type, + xk_type, + name="rotary_emb", + ) + def rotary_emb(xq, xk, freqs_cis, xq_out, xk_out): + get_rotary_emb(xq, xk, freqs_cis, xq_out, xk_out) + + return module + + ctx = ir.Context() + ir_type = to_ir_type(elem_type, ctx) + torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + xq_shape = (batch_size, seq_len, n_heads, head_dim) + xk_shape = (batch_size, seq_len, n_kv_heads, head_dim) + freqs_cis_shape = (seq_len, head_dim // 2) + xq = torch.randn(*xq_shape, dtype=torch_dtype) + xk = torch.randn(*xk_shape, dtype=torch_dtype) + freqs_cis = torch.randn(*freqs_cis_shape, dtype=torch_dtype) + xq_out, xk_out = references[get_rotary_emb](xq, xk, freqs_cis) + + module = generate_module( + ctx, + xq_shape=xq_shape, + xk_shape=xk_shape, + freqs_cis_shape=freqs_cis_shape, + elty=ir_type, + ) + bufferize_module(ctx, module) + schedule = create_schedule(ctx) + apply_schedule(module, schedule) + pm = create_pass_pipeline(ctx) + pm.run(module.operation) + + eng = ExecutionEngine(module, opt_level=2) + func_ptr = eng.lookup("rotary_emb") + + out1 = torch.empty_like(xq_out) + out2 = torch.empty_like(xk_out) + + a_mem = get_ranked_memref_descriptor(xq.numpy()) + b_mem = get_ranked_memref_descriptor(xk.numpy()) + freqs_cis_mem = get_ranked_memref_descriptor(freqs_cis.numpy()) + out1_mem = get_ranked_memref_descriptor(out1.numpy()) + out2_mem = get_ranked_memref_descriptor(out2.numpy()) + args = lh_utils.memrefs_to_packed_args( + [a_mem, b_mem, freqs_cis_mem, out1_mem, out2_mem] + ) + func_ptr(args) + + assert torch.allclose(out1, xq_out, rtol=0.01, atol=0.01, equal_nan=True) + assert torch.allclose(out2, xk_out, rtol=0.01, atol=0.01, equal_nan=True) + + +def test_to_complex(): + def generate_module(ctx, elty): + with ctx, ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + a_type = ir.RankedTensorType.get((2, 2), elty) + b_type = ir.RankedTensorType.get((2, 2), elty) + out_type = ir.RankedTensorType.get((2, 2), elty) + + @func.FuncOp.from_py_func( + a_type, b_type, out_type, name="mul_as_complex" + ) + def mul_as_complex(a, b, out): + # Convert both inputs to complex + # (d0, d1) -> (d0, d1//2) complex + # multiply with linalg.mul + # Convert back to real + # (d0, d1//2) complex -> (d0, d1) + + complex_shape = list(a.type.shape) + complex_shape[-1] = complex_shape[-1] // 2 + a_complex_uninit = ir.RankedTensorType.get( + complex_shape, complex.ComplexType.get(elty) + ) + b_complex_uninit = ir.RankedTensorType.get( + complex_shape, complex.ComplexType.get(elty) + ) + mul_out = ir.RankedTensorType.get( + complex_shape, complex.ComplexType.get(elty) + ) + mul = linalg.mul( + a_complex_uninit, b_complex_uninit, outs=(mul_out,) + ) + + return module + + a = torch.randn(2, 2, dtype=torch.float32) + b = torch.randn(2, 2, dtype=torch.float32) + x_complex = torch.view_as_complex(a) + y_complex = torch.view_as_complex(b) + res = torch.view_as_real(x_complex * y_complex).flatten(1) + + ctx = ir.Context() + ir_type = to_ir_type("f32", ctx) + module = generate_module(ctx, ir_type) + bufferize_module(ctx, module) + schedule = create_schedule(ctx) + apply_schedule(module, schedule) + pm = create_pass_pipeline(ctx) + pm.run(module.operation) + + eng = ExecutionEngine(module, opt_level=2) + func_ptr = eng.lookup("mul_as_complex") + out = torch.empty_like(a) + a_mem = get_ranked_memref_descriptor(a.numpy()) + b_mem = get_ranked_memref_descriptor(b.numpy()) + out_mem = get_ranked_memref_descriptor(out.numpy()) + args = lh_utils.memrefs_to_packed_args([a_mem, b_mem, out_mem]) + func_ptr(args) + + assert torch.allclose(out, res, rtol=0.01, atol=0.01, equal_nan=True) diff --git a/pyproject.toml b/pyproject.toml index ba0cba7..d96ef65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,6 +4,7 @@ dynamic = ["version"] requires-python = ">=3.10,<3.13" # Bounds are due to torch-mlir's packaging dependencies = [ "mlir-python-bindings==20251105+4c2a9c4ba", + "pytest>=8.2.0", ] [project.optional-dependencies] @@ -73,3 +74,21 @@ include = ["lighthouse*"] [tool.setuptools.dynamic] version = {attr = "lighthouse.__version__"} + +[tool.pytest.ini_options] +# Pytest configuration +testpaths = ["python/tests", "ingress/mlir-gen/mlir_gen/test"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = [ + "-v", + "--tb=short", + "--strict-markers", +] +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", + "integration: marks tests as integration tests", +] +# Set PYTHONPATH to include project directories +pythonpath = ["python", "ingress/mlir-gen"] diff --git a/python/lighthouse/utils/__init__.py b/python/lighthouse/utils/__init__.py index 22799cc..d411b17 100644 --- a/python/lighthouse/utils/__init__.py +++ b/python/lighthouse/utils/__init__.py @@ -6,4 +6,5 @@ memrefs_to_packed_args, torch_to_memref, torch_to_packed_args, + mlir_type_to_torch_dtype, ) diff --git a/python/lighthouse/utils/runtime_args.py b/python/lighthouse/utils/runtime_args.py index 3d88242..877cf9f 100644 --- a/python/lighthouse/utils/runtime_args.py +++ b/python/lighthouse/utils/runtime_args.py @@ -1,12 +1,10 @@ import ctypes -<<<<<<< HEAD import torch from mlir.runtime.np_to_memref import ( get_ranked_memref_descriptor, ) -======= ->>>>>>> 9647a7f ([examples][mlir] Basic MLIR compilation and execution example) +from mlir import ir def get_packed_arg(ctypes_args) -> list[ctypes.c_void_p]: @@ -63,3 +61,39 @@ def torch_to_packed_args(inputs: list[torch.Tensor]) -> list[ctypes.c_void_p]: """ memrefs = [torch_to_memref(input) for input in inputs] return memrefs_to_packed_args(memrefs) + + +def mlir_type_to_torch_dtype(mlir_type: ir.Type): + """ + Convert an MLIR type to a PyTorch dtype. + + Args: + mlir_type: An MLIR type (e.g., ir.F32Type, ir.F64Type) + + Returns: + Corresponding PyTorch dtype + """ + import torch + + if isinstance(mlir_type, ir.F32Type): + return torch.float32 + elif isinstance(mlir_type, ir.F64Type): + return torch.float64 + elif isinstance(mlir_type, ir.F16Type): + return torch.float16 + elif isinstance(mlir_type, ir.BF16Type): + return torch.bfloat16 + elif isinstance(mlir_type, ir.IntegerType): + width = mlir_type.width + if width == 64: + return torch.int64 + elif width == 32: + return torch.int32 + elif width == 16: + return torch.int16 + elif width == 8: + return torch.int8 + elif width == 1: + return torch.bool + + raise ValueError(f"Unsupported MLIR type: {mlir_type}") From 2f80f1792f99f9e5dc544801fb3cd9ad4055cf36 Mon Sep 17 00:00:00 2001 From: Petr Kurapov Date: Fri, 14 Nov 2025 13:29:34 +0100 Subject: [PATCH 5/7] Add helpers and tests for rotary embeddings --- ingress/mlir-gen/mlir_gen/test/test_core.py | 886 ++++++++++++++++---- 1 file changed, 740 insertions(+), 146 deletions(-) diff --git a/ingress/mlir-gen/mlir_gen/test/test_core.py b/ingress/mlir-gen/mlir_gen/test/test_core.py index d638dc9..43bb044 100644 --- a/ingress/mlir-gen/mlir_gen/test/test_core.py +++ b/ingress/mlir-gen/mlir_gen/test/test_core.py @@ -3,7 +3,7 @@ from typing import Tuple from mlir import ir -from mlir.dialects import transform, func, linalg, tensor, arith, complex +from mlir.dialects import transform, func, linalg, tensor, arith, complex, math from mlir.dialects.transform import structured from mlir.dialects.transform import interpreter from mlir.passmanager import PassManager @@ -28,6 +28,8 @@ def create_pass_pipeline(ctx: ir.Context) -> PassManager: with ctx: pm = PassManager("builtin.module") pm.add("convert-scf-to-cf") + pm.add("expand-strided-metadata") + pm.add("lower-affine") pm.add("finalize-memref-to-llvm") pm.add("convert-func-to-llvm") pm.add("convert-to-llvm") @@ -110,6 +112,18 @@ def bufferize_module(ctx: ir.Context, kernel: ir.Module) -> None: pm.run(kernel.operation) +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + #### IR builders ##### # TODO: Move to mlir_gen module @@ -169,6 +183,184 @@ def mean_op(a_val, acc): return mean_op +# repeat_kv +def get_repeat_kv(x: ir.Value, n_rep: int, out: ir.Value) -> ir.Value: + bs, slen, n_kv_heads, head_dim = x.type.shape + if n_rep == 1: + return x + + b, s, h_out, d = [ir.AffineDimExpr.get(i) for i in range(4)] + + # For output head h_out, we read from input head h_out // n_rep + # This is equivalent to: x[:, :, :, None, :].expand(...).reshape(...) + h_in = ir.AffineExpr.get_floor_div(h_out, ir.AffineConstantExpr.get(n_rep)) + + # Affine maps + x_map = affine_map(4, [b, s, h_in, d]) + out_map = affine_map(4, [b, s, h_out, d]) + + @linalg.generic( + [x], + [out], + [x_map, out_map], + [parallel] * 4, + ) + def repeat_kv_op(a, _out): + return a + + return repeat_kv_op + + +# equivalent to torch.nn.functional.silu +def get_silu(inputs: ir.Value, out: ir.Value) -> ir.Value: + elty = inputs.type.element_type + one = arith.constant(elty, 1.0) + + dims = [ir.AffineDimExpr.get(i) for i in range(inputs.type.rank)] + par_affine_map = affine_map(inputs.type.rank, dims) + par_iterator_types = [parallel] * inputs.type.rank + + @linalg.generic( + [inputs], + [out], + [par_affine_map, par_affine_map], + par_iterator_types, + ) + def silu_op(a, _out): + sigmoid = arith.DivFOp( + one, + arith.AddFOp( + one, + math.exp(arith.NegFOp(a).result), + ).result, + ).result + return arith.MulFOp(a, sigmoid).result + + return silu_op + + +# equivalent to torch.softmax(a, dim=-1) +# this should be just linalg.softmax, but there's no decomposition +def get_softmax(a: ir.Value, out: ir.Value) -> ir.Value: + elty = a.type.element_type + + reduced_shape = list(a.type.shape) + reduced_shape[-1] = 1 + max_uninit = tensor.EmptyOp(reduced_shape, elty) + + neg_inf = arith.ConstantOp(elty, float("-inf")) + max_init = linalg.fill(neg_inf, outs=[max_uninit.result]) + + reduce_map = affine_map( + a.type.rank, + [ir.AffineDimExpr.get(i) for i in range(a.type.rank - 1)] + + [ir.AffineConstantExpr.get(0)], + ) + identity_map = affine_map( + a.type.rank, + [ir.AffineDimExpr.get(i) for i in range(a.type.rank)], + ) + + iterator_types = [parallel] * (a.type.rank - 1) + [reduction] + + @linalg.generic( + [a], + [max_init], + [identity_map, reduce_map], + iterator_types, + ) + def compute_max(val, acc): + return arith.MaximumFOp(val, acc).result + + shifted_uninit = tensor.EmptyOp(a.type.shape, elty) + + @linalg.generic( + [a, compute_max], + [shifted_uninit.result], + [identity_map, reduce_map, identity_map], + [parallel] * a.type.rank, + ) + def subtract_max(val, max_val, _out): + return arith.SubFOp(val, max_val).result + + exp_uninit = tensor.EmptyOp(a.type.shape, elty) + + @linalg.generic( + [subtract_max], + [exp_uninit.result], + [identity_map, identity_map], + [parallel] * a.type.rank, + ) + def compute_exp(val, _out): + return math.exp(val) + + sum_uninit = tensor.EmptyOp(reduced_shape, elty) + zero = arith.ConstantOp(elty, 0.0) + sum_init = linalg.fill(zero, outs=[sum_uninit.result]) + + @linalg.generic( + [compute_exp], + [sum_init], + [identity_map, reduce_map], + iterator_types, + ) + def compute_sum(val, acc): + return arith.AddFOp(val, acc).result + + @linalg.generic( + [compute_exp, compute_sum], + [out], + [identity_map, reduce_map, identity_map], + [parallel] * a.type.rank, + ) + def divide_by_sum(exp_val, sum_val, _out): + return arith.DivFOp(exp_val, sum_val).result + + return divide_by_sum + + +# torch.matmul +def get_matmul(a: ir.Value, b: ir.Value, out: ir.Value) -> ir.Value: + return linalg.matmul(a, b, outs=[out]) + + +# torch.nn.functional.linear +def get_linear(a: ir.Value, w: ir.Value, b: ir.Value, out: ir.Value) -> ir.Value: + # a[i, k] * w[j, k] -> out[i, j] + i, j, k = [ir.AffineDimExpr.get(d) for d in range(3)] + a_map = affine_map(3, [i, k]) # (batch, in_feat) + w_map = affine_map(3, [j, k]) # (out_feat, in_feat) - note: we use j for first dim + out_map = affine_map(3, [i, j]) # (batch, out_feat) + + # First compute the matmul into out (which will accumulate) + @linalg.generic( + [a, w], + [out], + [a_map, w_map, out_map], + [parallel, parallel, reduction], + ) + def matmul_op(a_elem, w_elem, out_elem): + prod = arith.MulFOp(a_elem, w_elem).result + return arith.AddFOp(out_elem, prod).result + + # Step 2: Add bias using broadcasting + # b[j] -> out[i, j] + i2, j2 = [ir.AffineDimExpr.get(d) for d in range(2)] + b_map = affine_map(2, [j2]) # (out_feat,) + out_map2 = affine_map(2, [i2, j2]) # (batch, out_feat) + + @linalg.generic( + [matmul_op, b], + [out], + [out_map2, b_map, out_map2], + [parallel, parallel], + ) + def add_bias_op(matmul_elem, b_elem, _out): + return arith.AddFOp(matmul_elem, b_elem).result + + return add_bias_op + + def get_l2_norm(a: ir.Value, out: ir.Value, eps: float = 1e-5) -> ir.Value: """ Compute x * rsqrt(mean(x^2, dim=-1, keepdim=True) + eps) @@ -222,116 +414,299 @@ def broadcast_rsqrt(val, _out): return get_mul(a, broadcast_rsqrt, out) -def get_rotary_emb( - xq: ir.Value, xk: ir.Value, freqs_cis: ir.Value, xq_out: ir.Value, xk_out: ir.Value -): +# equivalent to torch.polar +def get_polar(abs: ir.Value, angle: ir.Value, out: ir.Value) -> ir.Value: + """ + Convert magnitude and angle to complex number: out = abs * (cos(angle) + i*sin(angle)) """ - Apply rotary embeddings to query and key tensors. + elty = abs.type.element_type + shape = abs.type.shape + rank = len(shape) - This implements the transformation: - 1. View xq, xk as complex: [B, S, H, D] -> [B, S, H, D//2] complex - 2. Broadcast freqs_cis: [S, D//2] -> [1, S, 1, D//2] - 3. Complex multiply: xq_ * freqs_cis, xk_ * freqs_cis - 4. View back as real: [B, S, H, D//2] complex -> [B, S, H, D] real + # Identity map for element-wise operations + id_map = affine_map(rank, [ir.AffineDimExpr.get(i) for i in range(rank)]) - Args: - xq: Query tensor of shape [B, S, H, D] - xk: Key tensor of shape [B, S, H_kv, D] - freqs_cis: Rotary embeddings of shape [S, D//2] - Note: In PyTorch this is complex64, but here it's f32 - We need to interpret as pairs or pass cos/sin separately - xq_out: Output tensor for queries [B, S, H, D] - xk_out: Output tensor for keys [B, S, H_kv, D] - - TODO: Properly implement rotary embeddings - Current implementation is just a placeholder passthrough. - - For a correct implementation, we need to: - 1. Either: - a) Pass freqs_cis as [S, D] with interleaved cos/sin values, OR - b) Pass separate cos and sin tensors of shape [S, D//2] - 2. Extract pairs of elements from xq/xk to treat as complex (real, imag) - 3. Apply complex rotation: - (real', imag') = (real*cos - imag*sin, real*sin + imag*cos) - 4. Interleave results back into output + # Compute cos(angle) and sin(angle), then multiply by abs to get real and imag parts + @linalg.generic( + [abs, angle], + [out], + [id_map, id_map, id_map], + [parallel] * rank, + ) + def polar_convert(abs_val, angle_val, _out): + cos_val = math.CosOp(angle_val).result + sin_val = math.SinOp(angle_val).result + real_part = arith.MulFOp(abs_val, cos_val).result + imag_part = arith.MulFOp(abs_val, sin_val).result + return complex.CreateOp(ir.ComplexType.get(elty), real_part, imag_part).result + + return polar_convert + + +# equivalent to torch.outer +def get_outer(a: ir.Value, b: ir.Value, out: ir.Value) -> ir.Value: """ - elty = xq.type.element_type + Compute outer product: out[i,j] = a[i] * b[j] - # Get shapes - xq_shape = list(xq.type.shape) # [B, S, H, D] - xk_shape = list(xk.type.shape) # [B, S, H_kv, D] + Assumes inputs are 1-D tensors. + """ + # Affine maps for outer product: a[i] broadcasts to (i,j), b[j] broadcasts to (i,j) + a_map = affine_map(2, [ir.AffineDimExpr.get(0)]) + b_map = affine_map(2, [ir.AffineDimExpr.get(1)]) + out_map = affine_map(2, [ir.AffineDimExpr.get(0), ir.AffineDimExpr.get(1)]) - # Placeholder implementation: just copy inputs to outputs - # This allows the test infrastructure to work but doesn't compute correct results + @linalg.generic( + [a, b], + [out], + [a_map, b_map, out_map], + [parallel, parallel], + ) + def outer_product(a_val, b_val, _out): + return arith.MulFOp(a_val, b_val).result + + return outer_product + + +# with b broadcasting, assuming it has smaller rank +def get_complex_mul(a: ir.Value, b: ir.Value, out: ir.Value) -> ir.Value: + rank_b = b.type.rank + rank_out = out.type.rank - b, s, h, d = [ir.AffineDimExpr.get(i) for i in range(4)] + dim_exprs_a = [ir.AffineDimExpr.get(i) for i in range(rank_out)] + + if rank_b < rank_out: + offset = rank_out - rank_b + dim_exprs_b = [ir.AffineConstantExpr.get(0)] * offset + [ + ir.AffineDimExpr.get(i) for i in range(offset, rank_out) + ] + else: + b_shape = list(b.type.shape) + dim_exprs_b = [] + for i in range(rank_out): + if i < len(b_shape) and b_shape[i] == 1: + dim_exprs_b.append(ir.AffineConstantExpr.get(0)) + else: + dim_exprs_b.append(ir.AffineDimExpr.get(i)) + + dim_exprs_out = [ir.AffineDimExpr.get(i) for i in range(rank_out)] + + map_a = affine_map(rank_out, dim_exprs_a) + map_b = affine_map(rank_out, dim_exprs_b) + map_out = affine_map(rank_out, dim_exprs_out) @linalg.generic( - [xq], - [xq_out], - [affine_map(4, [b, s, h, d]), affine_map(4, [b, s, h, d])], - [parallel] * 4, + [a, b], + [out], + [map_a, map_b, map_out], + [parallel] * rank_out, + ) + def complex_mul_op(a_val, b_val, _out): + result = complex.MulOp(a_val, b_val).result + return result + + return complex_mul_op + + +def get_rotary_emb( + xq: ir.Value, xk: ir.Value, freqs_cis: ir.Value, xq_out: ir.Value, xk_out: ir.Value +): + elty = xq.type.element_type + + xq_shape = list(xq.type.shape) + xk_shape = list(xk.type.shape) + batch, seq_len, n_heads, head_dim = xq_shape + n_kv_heads = xk_shape[2] + + # Reshape xq to (batch, seq_len, n_heads, head_dim//2, 2) + xq_reshaped_shape = [batch, seq_len, n_heads, head_dim // 2, 2] + xq_reshaped_type = ir.RankedTensorType.get(xq_reshaped_shape, elty) + xq_reshaped = tensor.expand_shape( + xq_reshaped_type, + xq, + reassociation=[[0], [1], [2], [3, 4]], + output_shape=[], + static_output_shape=xq_reshaped_shape, + ) + + # View xq as complex: (batch, seq_len, n_heads, head_dim//2, 2) -> (batch, seq_len, n_heads, head_dim//2) complex + xq_complex_shape = [batch, seq_len, n_heads, head_dim // 2] + xq_complex_uninit = tensor.EmptyOp( + xq_complex_shape, ir.ComplexType.get(elty) + ).result + xq_complex = get_view_as_complex(xq_reshaped, xq_complex_uninit) + + # same for xk + xk_reshaped_shape = [batch, seq_len, n_kv_heads, head_dim // 2, 2] + xk_reshaped_type = ir.RankedTensorType.get(xk_reshaped_shape, elty) + xk_reshaped = tensor.expand_shape( + xk_reshaped_type, + xk, + reassociation=[[0], [1], [2], [3, 4]], + output_shape=[], + static_output_shape=xk_reshaped_shape, + ) + + xk_complex_shape = [batch, seq_len, n_kv_heads, head_dim // 2] + xk_complex_uninit = tensor.EmptyOp( + xk_complex_shape, ir.ComplexType.get(elty) + ).result + xk_complex = get_view_as_complex(xk_reshaped, xk_complex_uninit) + + # Reshape freqs_cis for broadcasting: (seq_len, head_dim//2) -> (1, seq_len, 1, head_dim//2) + freqs_broadcast_shape = [1, seq_len, 1, head_dim // 2] + freqs_broadcast_uninit = tensor.EmptyOp(freqs_broadcast_shape, elty).result + freqs_broadcast = get_reshape_for_broadcast( + freqs_cis, xq_complex, freqs_broadcast_uninit ) - def copy_xq(x, _out): - return x + + # cast freqs_broadcast to complex + freqs_broadcast_complex_uninit = tensor.EmptyOp( + freqs_broadcast_shape, ir.ComplexType.get(elty) + ).result + + d0, d1, d2, d3 = [ir.AffineDimExpr.get(i) for i in range(4)] + indexing_maps = [ + ir.AffineMap.get(4, 0, [d0, d1, d2, d3]), + ir.AffineMap.get(4, 0, [d0, d1, d2, d3]), + ] @linalg.generic( - [xk], - [xk_out], - [ - affine_map(4, [b, s, ir.AffineDimExpr.get(2), d]), - affine_map(4, [b, s, ir.AffineDimExpr.get(2), d]), - ], - [parallel] * 4, + inputs=[freqs_broadcast], + outputs=[freqs_broadcast_complex_uninit], + indexing_maps=indexing_maps, + iterator_types=["parallel", "parallel", "parallel", "parallel"], + ) + def real_to_complex(r, out): + zero = arith.constant(elty, 0.0) + return complex.CreateOp(ir.ComplexType.get(elty), r, zero).result + + freqs_broadcast_complex = real_to_complex + + # Multiply xq_complex with freqs_broadcast_complex + xq_rotated_uninit = tensor.EmptyOp( + xq_complex_shape, ir.ComplexType.get(elty) + ).result + xq_rotated = get_complex_mul(xq_complex, freqs_broadcast_complex, xq_rotated_uninit) + + xk_rotated_uninit = tensor.EmptyOp( + xk_complex_shape, ir.ComplexType.get(elty) + ).result + xk_rotated = get_complex_mul(xk_complex, freqs_broadcast_complex, xk_rotated_uninit) + + # view as real + xq_real_shape = [batch, seq_len, n_heads, head_dim // 2, 2] + xq_real_uninit = tensor.EmptyOp(xq_real_shape, elty).result + xq_real = get_view_as_real(xq_rotated, xq_real_uninit) + + xk_real_shape = [batch, seq_len, n_kv_heads, head_dim // 2, 2] + xk_real_uninit = tensor.EmptyOp(xk_real_shape, elty).result + xk_real = get_view_as_real(xk_rotated, xk_real_uninit) + + # flatten back to original shape + xq_final = tensor.collapse_shape( + xq.type, + xq_real, + reassociation=[[0], [1], [2], [3, 4]], ) - def copy_xk(x, _out): - return x - return (copy_xq, copy_xk) + xk_final = tensor.collapse_shape( + xk.type, + xk_real, + reassociation=[[0], [1], [2], [3, 4]], + ) + linalg.copy(xq_final, outs=[xq_out]) + linalg.copy(xk_final, outs=[xk_out]) -def get_as_complex(x: ir.Value, out: ir.Value) -> ir.Value: - """ - Interpret the input tensor as complex numbers by grouping pairs of elements. - Args: - x: Input tensor of shape [..., 2] representing complex numbers as pairs (real, imag) - out: Output tensor of shape [...] with complex type - """ +def get_reshape_for_broadcast(freqs_cis: ir.Value, x: ir.Value, out: ir.Value): + # broadcast freqs_cis[seq, head] -> out[0, seq, 0, head] + d0, d1, d2, d3 = [ir.AffineDimExpr.get(i) for i in range(4)] + + in_map = affine_map(4, [d1, d3]) + out_map = affine_map(4, [d0, d1, d2, d3]) + + @linalg.generic( + [freqs_cis], + [out], + [in_map, out_map], + [parallel, parallel, parallel, parallel], + ) + def reshape_op(val, _out): + return val + + return reshape_op + + +# torch.view_as_complex +def get_view_as_complex(x: ir.Value, out: ir.Value) -> ir.Value: elty = x.type.element_type rank = x.type.rank shape = list(x.type.shape) assert shape[-1] == 2, "Last dimension must be of size 2 to form complex numbers" - complex_shape = shape[:-1] - dim_exprs_in = [ir.AffineDimExpr.get(i) for i in range(rank)] - dim_exprs_out = [ir.AffineDimExpr.get(i) for i in range(rank - 1)] + rank_out = rank - 1 + dim_exprs_out = [ir.AffineDimExpr.get(i) for i in range(rank_out)] - input_map = affine_map( - rank, - dim_exprs_in, - ) - output_map = affine_map( - rank - 1, - dim_exprs_out, + # real part: access input[d0, d1, ..., d_{rank-2}, 0] + dim_exprs_real = dim_exprs_out + [ir.AffineConstantExpr.get(0)] + # imag part: access input[d0, d1, ..., d_{rank-2}, 1] + dim_exprs_imag = dim_exprs_out + [ir.AffineConstantExpr.get(1)] + + input_map_real = affine_map(rank_out, dim_exprs_real) + input_map_imag = affine_map(rank_out, dim_exprs_imag) + output_map = affine_map(rank_out, dim_exprs_out) + + @linalg.generic( + [x, x], # Same input tensor accessed twice with different maps + [out], + [input_map_real, input_map_imag, output_map], + [parallel] * rank_out, ) - iterator_types = [parallel] * (rank - 1) + def view_as_complex_op(r, i, _out): + cplx = complex.CreateOp(ir.ComplexType.get(elty), r, i).result + return cplx + + return view_as_complex_op + + +# torch.view_as_real +def get_view_as_real(x: ir.Value, out: ir.Value) -> ir.Value: + rank = x.type.rank + + # Output has shape [..., 2] + # extract real part to [..., 0] and imag part to [..., 1] + + dim_exprs_in = [ir.AffineDimExpr.get(i) for i in range(rank)] + + # For real part: write to output[..., 0] + dim_exprs_real = dim_exprs_in + [ir.AffineConstantExpr.get(0)] + # For imag part: write to output[..., 1] + dim_exprs_imag = dim_exprs_in + [ir.AffineConstantExpr.get(1)] + + input_map = affine_map(rank, dim_exprs_in) + output_map_real = affine_map(rank, dim_exprs_real) + output_map_imag = affine_map(rank, dim_exprs_imag) @linalg.generic( [x], [out], - [input_map, output_map], - iterator_types, + [input_map, output_map_real], + [parallel] * rank, ) - def as_complex_op(a, _out): - real_part = a[0] - imag_part = a[1] - cplx = complex.CreateOp( - complex.ComplexType.get(elty), real_part, imag_part - ).result - return cplx + def write_real(cplx, _out): + return complex.ReOp(cplx).result + + @linalg.generic( + [x], + [write_real], + [input_map, output_map_imag], + [parallel] * rank, + ) + def write_imag(cplx, _out): + return complex.ImOp(cplx).result - return as_complex_op + return write_imag #### Test cases ##### @@ -361,9 +736,16 @@ def rotary_emb_ref( references = { get_add: torch.add, get_mul: torch.mul, + get_matmul: torch.matmul, get_rsqrt: torch.rsqrt, get_sqr: torch.square, get_mean: lambda x: torch.mean(x, dim=-1, keepdim=True), + get_silu: lambda x: torch.nn.functional.silu(x), + get_softmax: lambda x: torch.softmax(x, dim=-1), + get_polar: torch.polar, + get_outer: torch.outer, + get_linear: torch.nn.functional.linear, + get_repeat_kv: repeat_kv, get_l2_norm: lambda x, eps: x * torch.rsqrt(torch.mean(x.pow(2), dim=-1, keepdim=True) + eps), get_rotary_emb: rotary_emb_ref, @@ -381,7 +763,13 @@ def to_ir_type(type_str, ctx): @pytest.mark.parametrize( - "op,shape,elem_type", [(get_add, (4, 16), "f32"), (get_mul, (4, 16), "f32")] + "op,shape,elem_type", + [ + (get_add, (4, 16), "f32"), + (get_mul, (4, 16), "f32"), + (get_matmul, (16, 16), "f32"), + (get_outer, (16,), "f32"), + ], ) def test_bin_op(op, shape, elem_type): def generate_module(ctx, elty): @@ -390,8 +778,15 @@ def generate_module(ctx, elty): with ir.InsertionPoint(module.body): tensor_type = ir.RankedTensorType.get(shape, elty) + # Outer product produces [M, M] output for 1-D input of size M + if op == get_outer: + out_shape = (shape[0], shape[0]) + out_tensor_type = ir.RankedTensorType.get(out_shape, elty) + else: + out_tensor_type = tensor_type + @func.FuncOp.from_py_func( - tensor_type, tensor_type, tensor_type, name="bin_op" + tensor_type, tensor_type, out_tensor_type, name="bin_op" ) def bin_op(a, b, out): op(a, b, out) @@ -415,6 +810,7 @@ def bin_op(a, b, out): b = torch.randn(*shape, dtype=torch_dtype) out_ref = references[op](a, b) out = torch.empty_like(out_ref) + out.zero_() a_mem = get_ranked_memref_descriptor(a.numpy()) b_mem = get_ranked_memref_descriptor(b.numpy()) @@ -431,6 +827,8 @@ def bin_op(a, b, out): (get_rsqrt, (4, 16), "f32"), (get_mean, (4, 16), "f32"), (get_sqr, (4, 16), "f32"), + (get_silu, (4, 16), "f32"), + (get_softmax, (4, 16), "f32"), ], ) def test_unary_op(op, shape, elem_type): @@ -518,6 +916,265 @@ def rms_norm(a, out): assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) +def test_linear(): + def generate_module(ctx, elty): + with ctx, ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + input_type = ir.RankedTensorType.get((4, 16), elty) + weight_type = ir.RankedTensorType.get((32, 16), elty) + bias_type = ir.RankedTensorType.get((32,), elty) + output_type = ir.RankedTensorType.get((4, 32), elty) + + @func.FuncOp.from_py_func( + input_type, weight_type, bias_type, output_type, name="linear_op" + ) + def linear_op(x, w, b, out): + get_linear(x, w, b, out) + + return module + + ctx = ir.Context() + ir_type = to_ir_type("f32", ctx) + module = generate_module(ctx, ir_type) + bufferize_module(ctx, module) + schedule = create_schedule(ctx) + apply_schedule(module, schedule) + pm = create_pass_pipeline(ctx) + pm.run(module.operation) + + eng = ExecutionEngine(module, opt_level=2) + func_ptr = eng.lookup("linear_op") + torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + x = torch.randn(4, 16, dtype=torch_dtype) + w = torch.randn(32, 16, dtype=torch_dtype) + b = torch.randn(32, dtype=torch_dtype) + out_ref = references[get_linear](x, w, b) + out = torch.empty_like(out_ref) + out.zero_() + x_mem = get_ranked_memref_descriptor(x.numpy()) + w_mem = get_ranked_memref_descriptor(w.numpy()) + b_mem = get_ranked_memref_descriptor(b.numpy()) + out_mem = get_ranked_memref_descriptor(out.numpy()) + args = lh_utils.memrefs_to_packed_args([x_mem, w_mem, b_mem, out_mem]) + func_ptr(args) + assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) + + +def test_polar(): + def generate_module(ctx, elty): + with ctx, ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + tensor_type = ir.RankedTensorType.get((4, 16), elty) + complex_tensor_type = ir.RankedTensorType.get( + (4, 16), ir.ComplexType.get(elty) + ) + + @func.FuncOp.from_py_func( + tensor_type, tensor_type, complex_tensor_type, name="polar_op" + ) + def polar_op(magnitude, angle, out): + get_polar(magnitude, angle, out) + + return module + + ctx = ir.Context() + ir_type = to_ir_type("f32", ctx) + module = generate_module(ctx, ir_type) + bufferize_module(ctx, module) + schedule = create_schedule(ctx) + apply_schedule(module, schedule) + pm = create_pass_pipeline(ctx) + pm.run(module.operation) + + eng = ExecutionEngine(module, opt_level=2) + func_ptr = eng.lookup("polar_op") + torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + magnitude = torch.randn(4, 16, dtype=torch_dtype) + angle = torch.randn(4, 16, dtype=torch_dtype) + out_ref = references[get_polar](magnitude, angle) + out = torch.empty_like(out_ref) + magnitude_mem = get_ranked_memref_descriptor(magnitude.numpy()) + angle_mem = get_ranked_memref_descriptor(angle.numpy()) + out_mem = get_ranked_memref_descriptor(out.numpy()) + args = lh_utils.memrefs_to_packed_args([magnitude_mem, angle_mem, out_mem]) + func_ptr(args) + assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) + + +def test_repeat_kv(): + def generate_module(ctx, elty, n_rep): + with ctx, ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + x_type = ir.RankedTensorType.get((2, 512, 8, 64), elty) + out_type = ir.RankedTensorType.get((2, 512, 8 * n_rep, 64), elty) + + @func.FuncOp.from_py_func(x_type, out_type, name="repeat_kv_op") + def repeat_kv_op(x, out): + get_repeat_kv(x, n_rep, out) + + return module + + n_rep = 4 + ctx = ir.Context() + ir_type = to_ir_type("f32", ctx) + module = generate_module(ctx, ir_type, n_rep) + bufferize_module(ctx, module) + schedule = create_schedule(ctx) + apply_schedule(module, schedule) + pm = create_pass_pipeline(ctx) + pm.run(module.operation) + + eng = ExecutionEngine(module, opt_level=2) + func_ptr = eng.lookup("repeat_kv_op") + + torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + x = torch.randn(2, 512, 8, 64, dtype=torch_dtype) + out_ref = references[get_repeat_kv](x, n_rep) + out = torch.empty_like(out_ref) + + x_mem = get_ranked_memref_descriptor(x.numpy()) + out_mem = get_ranked_memref_descriptor(out.numpy()) + args = lh_utils.memrefs_to_packed_args([x_mem, out_mem]) + func_ptr(args) + + assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) + + +def test_reshape_for_broadcast(): + def generate_module(ctx, elty): + with ctx, ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + freqs_cis_type = ir.RankedTensorType.get((512, 64), elty) + x_type = ir.RankedTensorType.get((2, 512, 32, 128), elty) + out_type = ir.RankedTensorType.get((1, 512, 1, 64), elty) + + @func.FuncOp.from_py_func( + freqs_cis_type, x_type, out_type, name="reshape_for_broadcast" + ) + def reshape_for_broadcast_op(freqs_cis, x, out): + get_reshape_for_broadcast(freqs_cis, x, out) + + return module + + ctx = ir.Context() + ir_type = to_ir_type("f32", ctx) + module = generate_module(ctx, ir_type) + bufferize_module(ctx, module) + schedule = create_schedule(ctx) + apply_schedule(module, schedule) + pm = create_pass_pipeline(ctx) + pm.run(module.operation) + + eng = ExecutionEngine(module, opt_level=2) + func_ptr = eng.lookup("reshape_for_broadcast") + + torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + freqs_cis = torch.randn(512, 64, dtype=torch_dtype) + x = torch.randn(2, 512, 32, 128, dtype=torch_dtype) + # Convert x to complex view as expected by reshape_for_broadcast + x_complex = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + out_ref = reshape_for_broadcast(freqs_cis, x_complex) + out = torch.empty_like(out_ref) + + freqs_cis_mem = get_ranked_memref_descriptor(freqs_cis.numpy()) + x_mem = get_ranked_memref_descriptor(x.numpy()) + out_mem = get_ranked_memref_descriptor(out.numpy()) + args = lh_utils.memrefs_to_packed_args([freqs_cis_mem, x_mem, out_mem]) + func_ptr(args) + + assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) + + +def test_view_as_complex(): + def generate_module(ctx, elty): + with ctx, ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + # Input should be reshaped to have last dim = 2 + x_type = ir.RankedTensorType.get((2, 512, 32, 64, 2), elty) + out_type = ir.RankedTensorType.get( + (2, 512, 32, 64), ir.ComplexType.get(elty) + ) + + @func.FuncOp.from_py_func(x_type, out_type, name="view_as_complex_op") + def view_as_complex_op(x, out): + get_view_as_complex(x, out) + + return module + + ctx = ir.Context() + ir_type = to_ir_type("f32", ctx) + module = generate_module(ctx, ir_type) + bufferize_module(ctx, module) + schedule = create_schedule(ctx) + apply_schedule(module, schedule) + pm = create_pass_pipeline(ctx) + pm.run(module.operation) + + eng = ExecutionEngine(module, opt_level=2) + func_ptr = eng.lookup("view_as_complex_op") + + torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + x = torch.randn(2, 512, 32, 128, dtype=torch_dtype) + # Reshape to (2, 512, 32, 64, 2) before passing to the function + x_reshaped = x.reshape(2, 512, 32, 64, 2) + out_ref = torch.view_as_complex(x_reshaped) + out = torch.empty_like(out_ref) + + x_mem = get_ranked_memref_descriptor(x_reshaped.numpy()) + out_mem = get_ranked_memref_descriptor(out.numpy()) + args = lh_utils.memrefs_to_packed_args([x_mem, out_mem]) + func_ptr(args) + + assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) + + +def test_view_as_real(): + def generate_module(ctx, elty): + with ctx, ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + x_type = ir.RankedTensorType.get( + (2, 512, 32, 64), ir.ComplexType.get(elty) + ) + out_type = ir.RankedTensorType.get((2, 512, 32, 64, 2), elty) + + @func.FuncOp.from_py_func(x_type, out_type, name="as_real_op") + def as_real_op(x, out): + get_view_as_real(x, out) + + return module + + ctx = ir.Context() + ir_type = to_ir_type("f32", ctx) + module = generate_module(ctx, ir_type) + bufferize_module(ctx, module) + schedule = create_schedule(ctx) + apply_schedule(module, schedule) + pm = create_pass_pipeline(ctx) + pm.run(module.operation) + + eng = ExecutionEngine(module, opt_level=2) + func_ptr = eng.lookup("as_real_op") + + torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + x = torch.randn(2, 512, 32, 64, 2, dtype=torch_dtype) + x_complex = torch.view_as_complex(x) + out_ref = torch.view_as_real(x_complex) + out = torch.empty_like(out_ref) + + x_mem = get_ranked_memref_descriptor(x_complex.numpy()) + out_mem = get_ranked_memref_descriptor(out.numpy()) + args = lh_utils.memrefs_to_packed_args([x_mem, out_mem]) + func_ptr(args) + + assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) + + @pytest.mark.parametrize( "batch_size,seq_len,n_heads,head_dim,n_kv_heads,elem_type", [(2, 512, 32, 128, 8, "f32")], @@ -586,66 +1243,3 @@ def rotary_emb(xq, xk, freqs_cis, xq_out, xk_out): assert torch.allclose(out1, xq_out, rtol=0.01, atol=0.01, equal_nan=True) assert torch.allclose(out2, xk_out, rtol=0.01, atol=0.01, equal_nan=True) - - -def test_to_complex(): - def generate_module(ctx, elty): - with ctx, ir.Location.unknown(): - module = ir.Module.create() - with ir.InsertionPoint(module.body): - a_type = ir.RankedTensorType.get((2, 2), elty) - b_type = ir.RankedTensorType.get((2, 2), elty) - out_type = ir.RankedTensorType.get((2, 2), elty) - - @func.FuncOp.from_py_func( - a_type, b_type, out_type, name="mul_as_complex" - ) - def mul_as_complex(a, b, out): - # Convert both inputs to complex - # (d0, d1) -> (d0, d1//2) complex - # multiply with linalg.mul - # Convert back to real - # (d0, d1//2) complex -> (d0, d1) - - complex_shape = list(a.type.shape) - complex_shape[-1] = complex_shape[-1] // 2 - a_complex_uninit = ir.RankedTensorType.get( - complex_shape, complex.ComplexType.get(elty) - ) - b_complex_uninit = ir.RankedTensorType.get( - complex_shape, complex.ComplexType.get(elty) - ) - mul_out = ir.RankedTensorType.get( - complex_shape, complex.ComplexType.get(elty) - ) - mul = linalg.mul( - a_complex_uninit, b_complex_uninit, outs=(mul_out,) - ) - - return module - - a = torch.randn(2, 2, dtype=torch.float32) - b = torch.randn(2, 2, dtype=torch.float32) - x_complex = torch.view_as_complex(a) - y_complex = torch.view_as_complex(b) - res = torch.view_as_real(x_complex * y_complex).flatten(1) - - ctx = ir.Context() - ir_type = to_ir_type("f32", ctx) - module = generate_module(ctx, ir_type) - bufferize_module(ctx, module) - schedule = create_schedule(ctx) - apply_schedule(module, schedule) - pm = create_pass_pipeline(ctx) - pm.run(module.operation) - - eng = ExecutionEngine(module, opt_level=2) - func_ptr = eng.lookup("mul_as_complex") - out = torch.empty_like(a) - a_mem = get_ranked_memref_descriptor(a.numpy()) - b_mem = get_ranked_memref_descriptor(b.numpy()) - out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args([a_mem, b_mem, out_mem]) - func_ptr(args) - - assert torch.allclose(out, res, rtol=0.01, atol=0.01, equal_nan=True) From bd23099b965f8c9e3deb9176792ec51d09e20f59 Mon Sep 17 00:00:00 2001 From: Petr Kurapov Date: Fri, 14 Nov 2025 14:18:56 +0100 Subject: [PATCH 6/7] Add feed forward --- ingress/mlir-gen/mlir_gen/test/test_core.py | 102 +++++++++++++++++++- 1 file changed, 97 insertions(+), 5 deletions(-) diff --git a/ingress/mlir-gen/mlir_gen/test/test_core.py b/ingress/mlir-gen/mlir_gen/test/test_core.py index 43bb044..db85892 100644 --- a/ingress/mlir-gen/mlir_gen/test/test_core.py +++ b/ingress/mlir-gen/mlir_gen/test/test_core.py @@ -326,16 +326,19 @@ def get_matmul(a: ir.Value, b: ir.Value, out: ir.Value) -> ir.Value: # torch.nn.functional.linear def get_linear(a: ir.Value, w: ir.Value, b: ir.Value, out: ir.Value) -> ir.Value: + elty = out.type.element_type + zero = arith.constant(elty, 0.0) + out_zeroed = linalg.fill(zero, outs=[out]) + # a[i, k] * w[j, k] -> out[i, j] i, j, k = [ir.AffineDimExpr.get(d) for d in range(3)] a_map = affine_map(3, [i, k]) # (batch, in_feat) - w_map = affine_map(3, [j, k]) # (out_feat, in_feat) - note: we use j for first dim + w_map = affine_map(3, [j, k]) # (out_feat, in_feat) out_map = affine_map(3, [i, j]) # (batch, out_feat) - # First compute the matmul into out (which will accumulate) @linalg.generic( [a, w], - [out], + [out_zeroed], [a_map, w_map, out_map], [parallel, parallel, reduction], ) @@ -343,7 +346,6 @@ def matmul_op(a_elem, w_elem, out_elem): prod = arith.MulFOp(a_elem, w_elem).result return arith.AddFOp(out_elem, prod).result - # Step 2: Add bias using broadcasting # b[j] -> out[i, j] i2, j2 = [ir.AffineDimExpr.get(d) for d in range(2)] b_map = affine_map(2, [j2]) # (out_feat,) @@ -351,7 +353,7 @@ def matmul_op(a_elem, w_elem, out_elem): @linalg.generic( [matmul_op, b], - [out], + [out_zeroed], [out_map2, b_map, out_map2], [parallel, parallel], ) @@ -1243,3 +1245,93 @@ def rotary_emb(xq, xk, freqs_cis, xq_out, xk_out): assert torch.allclose(out1, xq_out, rtol=0.01, atol=0.01, equal_nan=True) assert torch.allclose(out2, xk_out, rtol=0.01, atol=0.01, equal_nan=True) + + +def test_feed_forward(): + def generate_module(ctx, elty): + with ctx, ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + input_type = ir.RankedTensorType.get((4, 16), elty) + hidden_type = ir.RankedTensorType.get((4, 64), elty) + output_type = ir.RankedTensorType.get((4, 16), elty) + weight1_type = ir.RankedTensorType.get((64, 16), elty) + bias1_type = ir.RankedTensorType.get((64,), elty) + weight2_type = ir.RankedTensorType.get((16, 64), elty) + bias2_type = ir.RankedTensorType.get((16,), elty) + weight3_type = ir.RankedTensorType.get((64, 16), elty) + bias3_type = ir.RankedTensorType.get((64,), elty) + + @func.FuncOp.from_py_func( + input_type, + weight1_type, + bias1_type, + weight2_type, + bias2_type, + weight3_type, + bias3_type, + output_type, + name="feed_forward", + ) + def feed_forward(x, w1, b1, w2, b2, w3, b3, out): + # Compute hidden = linear(x, w1, b1) + hidden_uninit = tensor.EmptyOp(hidden_type.shape, elty).result + hidden = get_linear(x, w1, b1, hidden_uninit) + + # Compute hidden_silu = silu(hidden) + hidden_silu_uninit = tensor.EmptyOp(hidden_type.shape, elty).result + hidden_silu = get_silu(hidden, hidden_silu_uninit) + + # Compute gate = linear(x, w3, b3) + gate_uninit = tensor.EmptyOp(hidden_type.shape, elty).result + gate = get_linear(x, w3, b3, gate_uninit) + + # Compute activated = hidden_silu * gate + activated_uninit = tensor.EmptyOp(hidden_type.shape, elty).result + activated = get_mul(hidden_silu, gate, activated_uninit) + + # Compute out = linear(activated, w2, b2) + get_linear(activated, w2, b2, out) + + return module + + ctx = ir.Context() + ir_type = to_ir_type("f32", ctx) + module = generate_module(ctx, ir_type) + bufferize_module(ctx, module) + schedule = create_schedule(ctx) + apply_schedule(module, schedule) + pm = create_pass_pipeline(ctx) + pm.run(module.operation) + + eng = ExecutionEngine(module, opt_level=2) + func_ptr = eng.lookup("feed_forward") + + torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + x = torch.randn(4, 16, dtype=torch_dtype) + w1 = torch.randn(64, 16, dtype=torch_dtype) + b1 = torch.randn(64, dtype=torch_dtype) + w2 = torch.randn(16, 64, dtype=torch_dtype) + b2 = torch.randn(16, dtype=torch_dtype) + w3 = torch.randn(64, 16, dtype=torch_dtype) + b3 = torch.randn(64, dtype=torch_dtype) + + hidden_ref = torch.nn.functional.linear(x, w1, b1) + activated_ref = torch.nn.functional.silu(hidden_ref) + activated_ref *= torch.nn.functional.linear(x, w3, b3) + out_ref = torch.nn.functional.linear(activated_ref, w2, b2) + out = torch.empty_like(out_ref) + out.zero_() + x_mem = get_ranked_memref_descriptor(x.numpy()) + w1_mem = get_ranked_memref_descriptor(w1.numpy()) + b1_mem = get_ranked_memref_descriptor(b1.numpy()) + w2_mem = get_ranked_memref_descriptor(w2.numpy()) + b2_mem = get_ranked_memref_descriptor(b2.numpy()) + w3_mem = get_ranked_memref_descriptor(w3.numpy()) + b3_mem = get_ranked_memref_descriptor(b3.numpy()) + out_mem = get_ranked_memref_descriptor(out.numpy()) + args = lh_utils.memrefs_to_packed_args( + [x_mem, w1_mem, b1_mem, w2_mem, b2_mem, w3_mem, b3_mem, out_mem] + ) + func_ptr(args) + assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) From f7db6c4272eac75c9f6a18f81de3b941a104a0c8 Mon Sep 17 00:00:00 2001 From: Petr Kurapov Date: Fri, 14 Nov 2025 14:40:46 +0100 Subject: [PATCH 7/7] fix rebase mess --- ingress/mlir-gen/mlir_gen/test/__init__.py | 1 - ingress/mlir-gen/mlir_gen/test/conftest.py | 33 +--------- python/examples/mlir/compile_and_run.py | 76 ++++++++++++++-------- 3 files changed, 51 insertions(+), 59 deletions(-) diff --git a/ingress/mlir-gen/mlir_gen/test/__init__.py b/ingress/mlir-gen/mlir_gen/test/__init__.py index 417887a..e69de29 100644 --- a/ingress/mlir-gen/mlir_gen/test/__init__.py +++ b/ingress/mlir-gen/mlir_gen/test/__init__.py @@ -1 +0,0 @@ -"""Tests for MLIR generation.""" diff --git a/ingress/mlir-gen/mlir_gen/test/conftest.py b/ingress/mlir-gen/mlir_gen/test/conftest.py index ec9d8eb..2a44d9f 100644 --- a/ingress/mlir-gen/mlir_gen/test/conftest.py +++ b/ingress/mlir-gen/mlir_gen/test/conftest.py @@ -21,36 +21,5 @@ def setup_mlir_environment(): yield - # Cleanup after all tests (if needed) + # todo: cleanup pass - - -@pytest.fixture -def mlir_context(): - """ - Provide a fresh MLIR context for each test. - """ - from mlir import ir - - return ir.Context() - - -@pytest.fixture -def sample_shapes(): - """ - Provide common tensor shapes for testing. - """ - return [ - (4, 16), - (8, 8), - (16, 32), - (1, 64), - ] - - -@pytest.fixture -def sample_types(): - """ - Provide common element types for testing. - """ - return ["f32", "f64"] diff --git a/python/examples/mlir/compile_and_run.py b/python/examples/mlir/compile_and_run.py index 325bdf4..d0529f6 100644 --- a/python/examples/mlir/compile_and_run.py +++ b/python/examples/mlir/compile_and_run.py @@ -1,5 +1,5 @@ import torch -import os +import argparse from mlir import ir from mlir.dialects import transform @@ -7,9 +7,6 @@ from mlir.dialects.transform import interpreter from mlir.execution_engine import ExecutionEngine from mlir.passmanager import PassManager -from mlir.runtime.np_to_memref import ( - get_ranked_memref_descriptor, -) from lighthouse import utils as lh_utils @@ -50,25 +47,26 @@ def create_schedule(ctx: ir.Context) -> ir.Module: ir.UnitAttr.get() ) + # For simplicity, use generic matchers without requiring specific types. + anytype = transform.any_op_t() + # Create entry point transformation sequence. with ir.InsertionPoint(schedule.body): named_seq = transform.NamedSequenceOp( - "__transform_main", - [transform.AnyOpType.get()], - [], + sym_name="__transform_main", + input_types=[anytype], + result_types=[], arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}], ) # Create the schedule. with ir.InsertionPoint(named_seq.body): - # For simplicity, use generic transform matchers. - anytype = transform.AnyOpType.get() - # Find the kernel's function op. func = structured.MatchOp.match_op_names( named_seq.bodyTarget, ["func.func"] ) - # Use C interface wrappers - required to make function executable after jitting. + # Use C interface wrappers - required to make function executable + # after jitting. func = transform.apply_registered_pass( anytype, func, "llvm-request-c-wrappers" ) @@ -82,12 +80,12 @@ def create_schedule(ctx: ir.Context) -> ir.Module: anytype, mod, "convert-linalg-to-loops" ) # Cleanup. - transform.ApplyCommonSubexpressionEliminationOp(mod) + transform.apply_cse(mod) with ir.InsertionPoint(transform.ApplyPatternsOp(mod).patterns): - transform.ApplyCanonicalizationPatternsOp() + transform.apply_patterns_canonicalization() # Terminate the schedule. - transform.YieldOp() + transform.yield_([]) return schedule @@ -129,7 +127,7 @@ def create_pass_pipeline(ctx: ir.Context) -> PassManager: # The example's entry point. -def main(): +def main(args): ### Baseline computation ### # Create inputs. a = torch.randn(16, 32, dtype=torch.float32) @@ -152,28 +150,37 @@ def main(): pm.run(kernel.operation) ### Compilation ### - # External shared libraries, containing MLIR runner utilities, are are generally - # required to execute the compiled module. + # Parse additional libraries if present. # - # Get paths to MLIR runner shared libraries through an environment variable. - mlir_libs = os.environ.get("LIGHTHOUSE_SHARED_LIBS", default="").split(":") + # External shared libraries, runtime utilities, might be needed to execute + # the compiled module. + # The execution engine requires full paths to the libraries. + mlir_libs = [] + if args.shared_libs: + mlir_libs += args.shared_libs.split(",") # JIT the kernel. eng = ExecutionEngine(kernel, opt_level=2, shared_libs=mlir_libs) + + # Initialize the JIT engine. + # + # The deferred initialization executes global constructors that might + # have been created by the module during engine creation (for example, + # when `gpu.module` is present) or registered afterwards. + # + # Initialization is not strictly necessary in this case. + # However, it is a good practice to perform it regardless. + eng.initialize() + # Get the kernel function. add_func = eng.lookup("add") ### Execution ### - # Create corresponding memref descriptors containing input data. - a_mem = get_ranked_memref_descriptor(a.numpy()) - b_mem = get_ranked_memref_descriptor(b.numpy()) - # Create an empty buffer to hold results. out = torch.empty_like(out_ref) - out_mem = get_ranked_memref_descriptor(out.numpy()) # Execute the kernel. - args = lh_utils.memrefs_to_packed_args([a_mem, b_mem, out_mem]) + args = lh_utils.torch_to_packed_args([a, b, out]) add_func(args) ### Verification ### @@ -185,4 +192,21 @@ def main(): if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + + # External shared libraries, runtime utilities, might be needed to + # execute the compiled module. + # For example, MLIR runner utils libraries such as: + # - libmlir_runner_utils.so + # - libmlir_c_runner_utils.so + # + # Full paths to the libraries should be provided. + # For example: + # --shared-libs=$LLVM_BUILD/lib/lib1.so,$LLVM_BUILD/lib/lib2.so + parser.add_argument( + "--shared-libs", + type=str, + help="Comma-separated list of libraries to link dynamically", + ) + args = parser.parse_args() + main(args)