Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions jir_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
--index-url https://gitlab.inria.fr/api/v4/groups/corse/-/packages/pypi/simple
mlir==21.1.2.2025091603
mlir-python-bindings==21.1.2.2025091603
xtc-llvm-tools==21.1.2.6
xtc-mlir-tools==21.1.2.7
xtc-mlir-python-bindings==21.1.2.7
polygeist==18.0.0.2024042201
jir==0.3.4
jir==0.3.5
8 changes: 4 additions & 4 deletions macos_mlir_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
--index-url https://gitlab.inria.fr/api/v4/groups/corse/-/packages/pypi/simple
llvm==21.1.2.2025091603+b708aea0
mlir-python-bindings==21.1.2.2025091604+b708aea0
mlir==21.1.2.2025091604+b708aea0
xtc-llvm-tools==21.1.2.6
xtc-mlir-tools==21.1.2.7
xtc-mlir-python-bindings==21.1.2.7
xtc-mlir-extra-tools==21.1.2.7
5 changes: 2 additions & 3 deletions macos_tvm_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
--index-url https://gitlab.inria.fr/api/v4/groups/corse/-/packages/pypi/simple
llvm==21.1.2.2025091603+b708aea0
tvm==0.19.0.2025010906+c4dc0c29
xtc-llvm-tools==21.1.2.6
xtc-tvm-python-bindings==0.19.0.9
8 changes: 4 additions & 4 deletions mlir_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
--index-url https://gitlab.inria.fr/api/v4/groups/corse/-/packages/pypi/simple
mlir==21.1.2.2025091603
mlir-python-bindings==21.1.2.2025091603
xtc-mlir==21.1.2.2
xtc-llvm-tools==21.1.2.6
xtc-mlir-tools==21.1.2.7
xtc-mlir-python-bindings==21.1.2.7
xtc-mlir-extra-tools==21.1.2.7
10 changes: 6 additions & 4 deletions src/xtc/backends/jir/JIRCompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from xtc.utils.tools import (
get_mlir_prefix,
get_llvm_prefix,
)
from xtc.utils.ext_tools import (
cc_bin,
Expand Down Expand Up @@ -98,7 +99,8 @@ def __init__(
assert not self.executable, f"executable generation not supported yet for TVM"
assert self.shared_lib, f"shared_lib generation is mandatory for TVM"
self.mlir_install_dir = get_mlir_prefix()
self._jir_llvm_config = f"{self.mlir_install_dir}/bin/llvm-config"
self.llvm_install_dir = get_llvm_prefix()
self._jir_llvm_config = f"{self.llvm_install_dir}/bin/llvm-config"
self._target_triple = kwargs.get(
"target_triple", get_host_target_triple(self._jir_llvm_config)
)
Expand Down Expand Up @@ -132,7 +134,7 @@ def compile(self, schedule: itf.schd.Schedule) -> itf.comp.Module:
mlir_lowering = MLIRLowering(f"{self.mlir_install_dir}/bin/mlir-opt")
mlir2llvm = MLIR2LLVMConversion(f"{self.mlir_install_dir}/bin/mlir-translate")
llvm_compiler = LLVMSharedLibraryCompiler(
f"{self.mlir_install_dir}/bin/clang",
f"{self.llvm_install_dir}/bin/clang",
f"{self.mlir_install_dir}/lib",
None,
self._target_triple,
Expand Down Expand Up @@ -239,13 +241,13 @@ def _shared_path(self):

@property
def _cmd_opt(self):
opt = [f"{self.mlir_install_dir}/bin/opt"]
opt = [f"{self.llvm_install_dir}/bin/opt"]
arch_opts = [f"-march={self._target_arch}", f"--mcpu={self._target_cpu}"]
return opt + opt_opts + arch_opts

@property
def _cmd_llc(self):
llc = [f"{self.mlir_install_dir}/bin/llc"]
llc = [f"{self.llvm_install_dir}/bin/llc"]
if self._target_arch == "native":
arch_opts = [f"--mcpu={self._target_cpu}"]
else:
Expand Down
7 changes: 7 additions & 0 deletions src/xtc/backends/mlir/MlirConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass, field
from xtc.utils.tools import (
get_mlir_prefix,
get_llvm_prefix,
)


Expand All @@ -13,6 +14,7 @@ class MlirConfig:
shared_lib: bool = False
executable: bool = False
mlir_install_dir: str | None = None
llvm_install_dir: str | None = None
to_disassemble: str = ""
save_temps: bool = False
save_temps_dir: str = "./save_temps_dir"
Expand All @@ -32,6 +34,11 @@ class MlirConfig:
selected_device: int | None = None

def __post_init__(self):
object.__setattr__(
self,
"llvm_install_dir",
get_llvm_prefix(self.llvm_install_dir or self.mlir_install_dir),
)
object.__setattr__(
self, "mlir_install_dir", get_mlir_prefix(self.mlir_install_dir)
)
4 changes: 2 additions & 2 deletions src/xtc/backends/mlir/MlirTarget/MlirLLVMTarget.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def cmd_cc(self):

@property
def cmd_opt(self):
opt = [f"{self._config.mlir_install_dir}/bin/opt"]
opt = [f"{self._config.llvm_install_dir}/bin/opt"]
return (
opt
+ opt_opts
Expand All @@ -203,7 +203,7 @@ def cmd_opt(self):

@property
def cmd_llc(self):
llc = [f"{self._config.mlir_install_dir}/bin/llc"]
llc = [f"{self._config.llvm_install_dir}/bin/llc"]
if self._config.arch == "native":
llc_arch = [f"--mcpu={self._config.cpu}"]
else:
Expand Down
4 changes: 2 additions & 2 deletions src/xtc/backends/mlir/MlirTarget/MlirNVGPUTarget.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def cmd_cc(self):

@property
def cmd_opt(self):
opt = [f"{self._config.mlir_install_dir}/bin/opt"]
opt = [f"{self._config.llvm_install_dir}/bin/opt"]
return (
opt
+ opt_opts
Expand All @@ -365,7 +365,7 @@ def cmd_opt(self):

@property
def cmd_llc(self):
llc = [f"{self._config.mlir_install_dir}/bin/llc"]
llc = [f"{self._config.llvm_install_dir}/bin/llc"]
if self._config.arch == "native":
llc_arch = [f"--mcpu={self._config.cpu}"]
else:
Expand Down
25 changes: 21 additions & 4 deletions src/xtc/cli/mlir_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@

import argparse
import os

from xtc.backends.mlir.MlirTarget import (
MlirTarget,
get_default_target,
)
from xtc.backends.mlir.MlirConfig import MlirConfig
from xtc.backends.mlir.MlirProgram import RawMlirProgram
from xtc.backends.mlir.MlirCompiler import MlirProgramCompiler

Expand All @@ -18,10 +24,15 @@ def main():
type=str,
help="The source file.",
)
parser.add_argument(
"--mlir-dir",
type=str,
help="The prefix for MLIR tools, or autodetected.",
)
parser.add_argument(
"--llvm-dir",
type=str,
help="The prefix for LLVM/MLIR tools, or autodetected.",
help="The prefix for LLVM tools, or --mlir-dir or autodetected.",
)
parser.add_argument(
"--print-source-ir",
Expand Down Expand Up @@ -71,14 +82,20 @@ def main():
args.print_assembly,
]
)
compiler = MlirProgramCompiler(
mlir_program=mlir_program,
mlir_install_dir=args.llvm_dir,
config = MlirConfig(
mlir_install_dir=args.mlir_dir,
llvm_install_dir=args.llvm_dir,
print_source_ir=print_source,
print_transformed_ir=args.print_transformed_ir,
print_lowered_ir=args.print_lowered_ir,
print_assembly=args.print_assembly,
color=args.color,
debug=args.debug,
)
target = get_default_target()(config)
compiler = MlirProgramCompiler(
mlir_program=mlir_program,
target=target,
config=config,
)
compiler.compile()
10 changes: 8 additions & 2 deletions src/xtc/cli/mlir_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def main():
dump_file = str(Path(tdir) / Path(args.filename).stem)

compiler_args = {
"mlir_install_dir": args.llvm_dir,
"mlir_install_dir": args.mlir_dir,
"llvm_install_dir": args.llvm_dir,
"print_source_ir": print_source,
"print_transformed_ir": args.print_transformed_ir,
"print_lowered_ir": args.print_lowered_ir,
Expand Down Expand Up @@ -225,10 +226,15 @@ def parse_args() -> argparse.Namespace:
type=str,
help="The source file.",
)
parser.add_argument(
"--mlir-dir",
type=str,
help="The prefix for MLIR tools, or autodetected.",
)
parser.add_argument(
"--llvm-dir",
type=str,
help="The prefix for LLVM/MLIR tools, or autodetected.",
help="The prefix for LLVM tools, or --mlir-dir or autodetected.",
)
parser.add_argument(
"--arch",
Expand Down
3 changes: 2 additions & 1 deletion src/xtc/graphs/xtc/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ def forward_types(

@override
def forward(self, inputs: Sequence[Tensor]) -> Sequence[XTCTensor]:
matmul = XTCTensor(np.matmul(inputs[0].numpy(), inputs[1].numpy()))
# Note, use np.dot instead of np.matmul which may be buggy on Mac accelerators
matmul = XTCTensor(np.dot(inputs[0].numpy(), inputs[1].numpy()))
expected_type = self.forward_types([inp.type for inp in inputs])[0]
assert matmul.type == expected_type, (
f"output type mismatch expect: {matmul.type} != {expected_type}"
Expand Down
74 changes: 62 additions & 12 deletions src/xtc/utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pathlib import Path


def get_mlir_prefix(prefix: Path | str | None = None):
def get_mlir_prefix(prefix: Path | str | None = None) -> Path:
"""
Tentatively return the mlir compiler prefix where
{prefix}/bin/mlir-opt can be found.
Expand All @@ -18,35 +18,85 @@ def get_mlir_prefix(prefix: Path | str | None = None):
- mlir python package prefix if installed
- mlir-opt binary prefix in PATH
"""
how = None
if prefix is None:
prefix_var = os.environ.get("XTC_MLIR_PREFIX")
if prefix_var:
prefix = Path(prefix_var)
how, prefix = "XTC_MLIR_PREFIX envvar", Path(prefix_var)
else:
try:
import mlir

prefix = Path(mlir.__path__[0])
how, prefix = "mlir package", Path(mlir.__path__[0])
except:
mlir_exe = shutil.which("mlir-opt")
if mlir_exe:
prefix = Path(mlir_exe).parents[1]
how, prefix = "mlir-opt PATH", Path(mlir_exe).parents[1]
else:
prefix = Path(prefix)
how, prefix = "explicit prefix", Path(prefix)
if prefix is None:
raise RuntimeError("could not find MLIR installation")
if not prefix.exists():
raise RuntimeError(f"could not find MLIR prefix at: {prefix}")
raise RuntimeError(f"could not find MLIR prefix at: {prefix}, method; {how}")
mlir_opt = prefix / "bin" / "mlir-opt"
if not mlir_opt.exists():
prefix = prefix.parents[2].resolve()
mlir_opt2 = prefix / "bin" / "mlir-opt"
if not mlir_opt2.exists():
raise RuntimeError(f"could not find mlir-opt at: {mlir_opt}")
if how == "mlir package":
# Try to find prefix from MLIR standard python package install
prefix2 = prefix.parents[2].resolve()
mlir_opt2 = prefix2 / "bin" / "mlir-opt"
raise RuntimeError(
f"could not find mlir-opt at: {mlir_opt}, nor: {mlir_opt2}, method: {how}"
)
else:
raise RuntimeError(f"could not find mlir-opt at: {mlir_opt}, method: {how}")
return prefix


def get_llvm_prefix(prefix: Path | str | None = None) -> Path:
"""
Tentatively return the llvm compiler prefix where
{prefix}/bin/opt can be found.
Raise on error.
Defined in order as:
- passed prefix if not None
- env var XTC_LLVM_PREFIX
- get_mlir_prefix() if successfull
- llvm python package prefix if installed
- opt binary prefix in PATH
"""
how = None
if prefix is None:
prefix_var = os.environ.get("XTC_LLVM_PREFIX")
if prefix_var:
how, prefix = "XTC_LLVM_PREFIX envvar", Path(prefix_var)
else:
try:
mlir_prefix = get_mlir_prefix()
if not (mlir_prefix / "bin" / "opt").exists():
raise RuntimeError()
how, prefix = "mlir prefix", mlir_prefix
except RuntimeError:
try:
import llvm

how, prefix = "llvm package", Path(llvm.__path__[0])
except:
opt_exe = shutil.which("opt")
if opt_exe:
how, prefix = "opt PATH", Path(opt_exe).parents[1]
else:
how, prefix = "explicit prefix", Path(prefix)
if prefix is None:
raise RuntimeError("could not find LLVM installation")
if not prefix.exists():
raise RuntimeError(f"could not find LLVM prefix at: {prefix}, method; {how}")
llvm_opt = prefix / "bin" / "opt"
if not llvm_opt.exists():
raise RuntimeError(f"could not find opt at: {llvm_opt}, method: {how}")
return prefix


def get_geist_prefix(prefix: Path | str | None = None):
def get_geist_prefix(prefix: Path | str | None = None) -> Path:
"""
Tentatively return the mlir polygeist prefix where
{prefix}/bin/cgeist can be found.
Expand Down Expand Up @@ -82,7 +132,7 @@ def get_geist_prefix(prefix: Path | str | None = None):
return prefix


def get_cuda_prefix(prefix: Path | str | None = None):
def get_cuda_prefix(prefix: Path | str | None = None) -> Path:
"""
Tentatively return the cuda installation dir
Raise on error.
Expand Down
4 changes: 2 additions & 2 deletions tvm_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
--index-url https://gitlab.inria.fr/api/v4/groups/corse/-/packages/pypi/simple
tvm==0.19.0.2025010905
xtc-llvm-tools==21.1.2.6
xtc-tvm-python-bindings==0.19.0.9