Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/xtc/backends/mlir/MlirCompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from xtc.backends.mlir.MlirCompilerPasses import (
MlirProgramInsertTransformPass,
MlirProgramApplyTransformPass,
apply_bufferization_passes,
)

from xtc.backends.mlir.MlirTarget import (
Expand Down Expand Up @@ -149,6 +150,15 @@ def mlir_apply_transform_pass(self) -> None:
if self._config.print_transformed_ir:
self.dump_ir("IR Dump After transform")

def mlir_apply_tensor_lowering_pass(self) -> None:
if self._config.print_bufferization_ir:
self.dump_ir("IR Dump Before Tensor Lowering")

apply_bufferization_passes(self._mlir_program)

if self._config.print_bufferization_ir:
self.dump_ir("IR Dump After Tensor Lowering")

def _save_temp(self, fname: str, content: Any) -> None:
if not self._config.save_temps:
return
Expand Down Expand Up @@ -196,4 +206,6 @@ def compile(self) -> None:
self.mlir_apply_transform_pass()
save_temp(mlir_atrn_dump_file, self._mlir_program.mlir_module)

self.mlir_apply_tensor_lowering_pass()

self._target.generate_code_for_target(self._mlir_program, dump_file=dump_file)
41 changes: 41 additions & 0 deletions src/xtc/backends/mlir/MlirCompilerPasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,3 +534,44 @@ def run(self) -> None:
transform_op.erase()
else:
break


class MlirProgramApplyPasses:
def __init__(
self,
mlir_program: RawMlirProgram,
) -> None:
self._mlir_program = mlir_program

def run(self, pass_names: list[str]) -> None:
ctx = self._mlir_program.mlir_context
pm = PassManager(context=ctx)
for name in pass_names:
pm.add(name) # type: ignore # no attribute add
pm.run(self._mlir_program.mlir_module.operation)


def apply_bufferization_passes(mlir_program: RawMlirProgram):
apply_passes = MlirProgramApplyPasses(mlir_program)
bufferize_options = [
"bufferize-function-boundaries=1",
"function-boundary-type-conversion=identity-layout-map",
"buffer-alignment=256",
]
apply_passes.run(
[
"canonicalize",
"cse",
"eliminate-empty-tensors", # causes ops to write directly to out buffer
f"one-shot-bufferize{{{' '.join(bufferize_options)}}}",
"func.func(buffer-hoisting)",
"func.func(buffer-loop-hoisting)",
"drop-equivalent-buffer-results",
"func.func(promote-buffers-to-stack)",
]
)


def pre_transform_tensor_passes(mlir_program: RawMlirProgram):
apply_passes = MlirProgramApplyPasses(mlir_program)
# apply_passes.run(["eliminate-empty-tensors"])
1 change: 1 addition & 0 deletions src/xtc/backends/mlir/MlirConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class MlirConfig:
print_assembly: bool = False
visualize_jumps: bool = True
print_lowered_ir: bool = False
print_bufferization_ir: bool = False
debug: bool = False
color: bool = False
concluding_passes: list[str] = field(default_factory=list)
Expand Down
61 changes: 51 additions & 10 deletions src/xtc/backends/mlir/MlirGraphBackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2024-2026 The XTC Project Authors
#
from typing import cast, Any
from typing import cast, Any, Type
from typing_extensions import override

from xdsl.dialects.func import FuncOp as xdslFuncOp
from xdsl.dialects import func, memref
from xdsl.dialects.builtin import MemRefType, f32, f64
from xdsl.dialects import func, memref, tensor, bufferization
from xdsl.dialects.builtin import MemRefType, TensorType, f32, f64, UnitAttr
from xdsl.ir import Region, Block, Operation
from xdsl.builder import ImplicitBuilder

Expand All @@ -28,7 +28,11 @@ def __init__(
concluding_passes: list[str] = [],
always_vectorize: bool = False,
no_alias: bool = True,
use_tensor_dialect: bool = False,
):
self.xdsl_type: Type[TensorType] | Type[MemRefType] = (
TensorType if use_tensor_dialect else MemRefType
)
if isinstance(xdsl_func, XTCGraph):
assert nodes is None
graph = xdsl_func
Expand Down Expand Up @@ -62,13 +66,24 @@ def _init_from_xdsl(
def _xdsl_generate_node(
self, node: XTCNode, block: Block, variables: dict[str, Any]
):
operation = MlirOperation.from_operation(node.operation, name=node.name)
operation = MlirOperation.from_operation(
node.operation,
name=node.name,
op_type=self.xdsl_type, # type: ignore
)
names = [*node.inputs, *node.outputs]
assert node.inputs_types is not None and node.outputs_types is not None
types = [*node.inputs_types, *node.outputs_types]
for name, type in zip(names, types):
if name in node.outputs and self.xdsl_type == TensorType:
with ImplicitBuilder(block):
variables[name] = tensor.EmptyOp(
dynamic_sizes=[],
tensor_type=self._xdsl_type_from_tensortype(type),
).results[0]
if name in variables:
continue
assert self.xdsl_type != TensorType
with ImplicitBuilder(block):
elt_type, shape = self._xdsl_elt_shape_from_tensortype(type)
alloca = memref.AllocaOp.get(
Expand All @@ -79,6 +94,11 @@ def _xdsl_generate_node(
variables[name] = alloca.results[0]
args = [variables[name] for name in names]
_, attrs = operation.generate(block=block, args=args)
# the tensor dialect needs the result of the op, not the alloca
if self.xdsl_type == TensorType:
assert len(node.outputs) == len(attrs["output_nodes"])
for name, output in zip(node.outputs, attrs["output_nodes"]):
variables[name] = output.results[0]
return attrs

def _init_from_graph(
Expand All @@ -95,18 +115,34 @@ def _init_from_graph(
)
params_types = [
self._xdsl_type_from_tensortype(cast(XTCTensorType, tensor_type))
for tensor_type in [*inputs_types, *outputs_types]
for tensor_type in inputs_types
]
# graph output types are always memrefs
params_types.extend(
self._memref_type_from_tensortype(cast(XTCTensorType, tensor_type))
for tensor_type in outputs_types
)
inlined_block = Block(arg_types=params_types)
variables = {
name: arg
for name, arg in zip([*graph.inputs, *graph.outputs], inlined_block.args)
}
block_attrs = []

for node in graph.nodes.values():
node_attrs = self._xdsl_generate_node(node, inlined_block, variables)
block_attrs.append(node_attrs)
with ImplicitBuilder(inlined_block):
if self.xdsl_type == TensorType:
# write the final tensor values to the output buffers
for name, out_arg in zip(
graph.outputs, inlined_block.args[-len(graph.outputs) :]
):
bufferization.MaterializeInDestinationOp(
operands=((variables[name],), (out_arg,)),
result_types=((),),
attributes={"writable": UnitAttr(), "restrict": UnitAttr()},
)
func.ReturnOp()
region = Region([inlined_block]) # type: ignore # issue with mypy
payload = xdslFuncOp.from_region(
Expand All @@ -128,6 +164,7 @@ def _init_from_graph(
always_vectorize=always_vectorize,
concluding_passes=concluding_passes,
id=f"__xtc_id_{node_id}_",
xdsl_type=self.xdsl_type,
)
return payload, nodes_dict

Expand All @@ -136,11 +173,15 @@ def _xdsl_elt_shape_from_tensortype(self, type: XTCTensorType) -> tuple[Any, Any
return (elt_type, type.constant_shape)

def _xdsl_type_from_tensortype(self, type: XTCTensorType) -> Any:
elt_type, shape = self._xdsl_elt_shape_from_tensortype(type)
return self.xdsl_type(elt_type, shape)

def _memref_type_from_tensortype(self, type: XTCTensorType) -> Any:
elt_type, shape = self._xdsl_elt_shape_from_tensortype(type)
return MemRefType(elt_type, shape)

def _np_types_spec(
self, types: list[MemRefType]
self, types: list[MemRefType] | list[TensorType]
) -> list[dict[str, tuple[int, ...] | str]]:
types_map = {"f32": "float32", "f64": "float64"}
types_spec: list[dict[str, tuple[int, ...] | str]] = [
Expand All @@ -156,12 +197,12 @@ def _np_types_spec(
def np_inputs_spec(self) -> list[dict[str, Any]]:
# Assume inputs are first, and output is single last param
inputs_args_types = [arg.type for arg in self.xdsl_func.args[:-1]]
list_memref_tys = cast(list[MemRefType], inputs_args_types)
return self._np_types_spec(list_memref_tys)
list_xdsl_tys = cast(list[self.xdsl_type], inputs_args_types) # type: ignore
return self._np_types_spec(list_xdsl_tys)

@override
def np_outputs_spec(self) -> list[dict[str, Any]]:
# Assume inputs are first, and output is single last param
outputs_args_types = [arg.type for arg in self.xdsl_func.args[-1:]]
list_memref_tys = cast(list[MemRefType], outputs_args_types)
return self._np_types_spec(list_memref_tys)
list_xdsl_tys = cast(list[MemRefType], outputs_args_types)
return self._np_types_spec(list_xdsl_tys)
16 changes: 9 additions & 7 deletions src/xtc/backends/mlir/MlirNodeBackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2024-2026 The XTC Project Authors
#
from typing import cast, Any
from typing import cast, Any, Type
from typing_extensions import override

from xdsl.ir import Operation as xdslOperation
from xdsl.dialects.builtin import MemRefType as xdslAnyMemRefType
from xdsl.dialects.builtin import MemRefType, TensorType
from xdsl.dialects.builtin import UnitAttr as xdslUnitAttr
from xtc.utils.xdsl_aux import xdsl_operator_to_function

Expand All @@ -26,8 +26,10 @@ def __init__(
always_vectorize: bool = False,
no_alias: bool = True,
id: str | None = None,
xdsl_type: Type[TensorType] | Type[MemRefType] = MemRefType,
):
self._graph = None
self.xdsl_type = xdsl_type
if id is None:
self.op_id_attribute = f"__id{MlirNodeBackend.count}__"
MlirNodeBackend.count += 1
Expand All @@ -48,7 +50,7 @@ def __init__(
self.loop_stamps = loop_stamps

def _np_types_spec(
self, types: list[xdslAnyMemRefType]
self, types: list[MemRefType | TensorType]
) -> list[dict[str, tuple[int, ...] | str]]:
types_map = {"f32": "float32", "f64": "float64"}
types_spec: list[dict[str, tuple[int, ...] | str]] = [
Expand All @@ -63,11 +65,11 @@ def _np_types_spec(
@override
def np_inputs_spec(self) -> list[dict[str, Any]]:
list_attr_tys = [i.type for i in self.source_op.inputs] # type: ignore
list_memref_tys = cast(list[xdslAnyMemRefType], list_attr_tys)
return self._np_types_spec(list_memref_tys)
list_xdsl_tys = cast(list[self.xdsl_type], list_attr_tys) # type: ignore
return self._np_types_spec(list_xdsl_tys)

@override
def np_outputs_spec(self) -> list[dict[str, Any]]:
list_attr_tys = [i.type for i in self.source_op.outputs] # type: ignore
list_memref_tys = cast(list[xdslAnyMemRefType], list_attr_tys)
return self._np_types_spec(list_memref_tys)
list_xdsl_tys = cast(list[self.xdsl_type], list_attr_tys) # type: ignore
return self._np_types_spec(list_xdsl_tys)
Loading
Loading