From 1b4b48fa2d8c8cb844bf910146c1b8dead4c0021 Mon Sep 17 00:00:00 2001 From: hellozmz <407190054@qq.com> Date: Wed, 3 Dec 2025 10:47:43 +0800 Subject: [PATCH] support muxi --- backend/compiler.py | 37 +- backend/driver.py | 4 + backend/maca.c | 251 +++++++--- backend/maca.py | 1093 +++++++++++++++++++++---------------------- setup_on_maca.py | 260 ++++++++++ 5 files changed, 1007 insertions(+), 638 deletions(-) create mode 100644 setup_on_maca.py diff --git a/backend/compiler.py b/backend/compiler.py index 4ae18dde..1fa392fe 100644 --- a/backend/compiler.py +++ b/backend/compiler.py @@ -114,6 +114,8 @@ def __init__(self, target:str) -> None: assert isinstance(self.capability, int) self.binary_ext = "cnbin" elif self.driver.target == 'maca': + self.capability = target.arch + self.capability = 80 self.binary_ext = "mcfatbin" elif self.driver.target == 'ascend': self.binary_ext = "npubin" @@ -173,15 +175,21 @@ def add_stages(self, stages, options): # from triton.backends.dicp_triton.mlu import ttir_to_cnfatbin, get_architecture_descriptor # stages["cnbin"] = lambda src, metadata: ttir_to_cnfatbin(src, metadata, get_architecture_descriptor(self.driver, options), False, True) elif self.driver.target == 'maca': - from triton.backends.dicp_triton.maca import ttir_to_ttgir, optimize_ttgir, ttgir_to_llir, llir_to_mcfatbin, get_architecture_descriptor - arch = get_architecture_descriptor() - extern_libs = dict() - stages["ttgir"] = lambda src, metadata: optimize_ttgir(ttir_to_ttgir(src, 4), options.num_stages, arch) - stages["llir"] = lambda src, metadata: ttgir_to_llir(src, arch) - mxcc_arch = os.environ.get('MACA_PATH') + "/mxgpu_llvm/bin/mxcc" - if mxcc_arch is None: - raise RuntimeError('mxcc_arch is None (not specified)') - stages["mcfatbin"] = lambda src, metadata: llir_to_mcfatbin(src, mxcc_arch, os.environ.get('MACA_PATH')) + from triton.backends.dicp_triton.maca import make_ttir, make_ttgir, make_mlir, make_llir, make_mcfatbin + stages["ttir"] = lambda src, metadata: make_ttir(src, metadata, options) + stages["ttgir"] = lambda src, metadata: make_ttgir(src, metadata, options, self.capability) + stages["mlir"] = lambda src, metadata: make_mlir(src, metadata, options, self.capability) + stages["llir"] = lambda src, metadata: make_llir(src, metadata, options, self.capability) + stages["mcfatbin"] = lambda src, metadata: make_mcfatbin(src, metadata, options, self.capability) + # from triton.backends.dicp_triton.maca import ttir_to_ttgir, optimize_ttgir, ttgir_to_llir, llir_to_mcfatbin, get_architecture_descriptor + # arch = get_architecture_descriptor() + # extern_libs = dict() + # stages["ttgir"] = lambda src, metadata: optimize_ttgir(ttir_to_ttgir(src, 4), options.num_stages, arch) + # stages["llir"] = lambda src, metadata: ttgir_to_llir(src, arch) + # mxcc_arch = os.environ.get('MACA_PATH') + "/mxgpu_llvm/bin/mxcc" + # if mxcc_arch is None: + # raise RuntimeError('mxcc_arch is None (not specified)') + # stages["mcfatbin"] = lambda src, metadata: llir_to_mcfatbin(src, mxcc_arch, os.environ.get('MACA_PATH')) elif self.driver.target =='ascend': from triton.backends.dicp_triton.npu import make_ttir, ttir_to_linalg, linalg_to_bin_enable_npu_compile stages["ttir"] = lambda src, metadata: make_ttir(src, metadata, options) @@ -235,6 +243,17 @@ def parse_options(self, options: dict) -> Any: args["enable_mlu_bound_check"] = os.getenv("TRITON_ENABLE_MLU_BOUND_CHECK", "0") == "1" return MLUOptions(**args) + elif self.target.backend == 'maca': + from triton.backends.dicp_triton.maca import MACAOptions + # args = {k: options[k] for k in MACAOptions.__dataclass_fields__.keys() if k in options} + # return MACAOptions(**args) + args = {k: options[k] for k in MACAOptions.__dataclass_fields__.keys() if k in options} + # USE_MACA: support allow_fp8e4nv(i.e. float8_e4m3fn) + args["allow_fp8e4nv"] = True + # args["allow_fp8e4nv"] = False + args["allow_fp8e4b15"] = False + args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0 + return MACAOptions(**args) else: args = {'arch': self.target} args.update({k: options[k] for k in DICPOptions.__dataclass_fields__.keys() if k in options}) diff --git a/backend/driver.py b/backend/driver.py index 529d0349..5b440992 100644 --- a/backend/driver.py +++ b/backend/driver.py @@ -161,8 +161,12 @@ def test_npucompiler(): reset = "\x1b[0m" warnings.warn(red + str(e_npucompiler) + reset) return False + elif self.target == "muxi": + import torch + return True except Exception as e: import torch + return True try: if torch.mlu: return True diff --git a/backend/maca.c b/backend/maca.c index 539ccfb2..bb9e42b8 100644 --- a/backend/maca.c +++ b/backend/maca.c @@ -1,95 +1,196 @@ #include +#include +#include #define PY_SSIZE_T_CLEAN #include -#include -#include -static inline void gpuAssert(mcError_t code, const char *file, int line) -{ - if (code != mcSuccess) - { - const char* prefix = "Triton Error [MACA]: "; - const char* str = mcGetErrorString(code); - char err[1024] = {0}; - strcat(err, prefix); - strcat(err, str); - PyErr_SetString(PyExc_RuntimeError, err); - } +#include + +// Raises a Python exception and returns false if code is not MC_SUCCESS. +static bool gpuAssert(mcError_t code, const char *file, int line) { + if (code == mcSuccess) + return true; + + const char *prefix = "Triton Error [MACA]: "; + const char *str = mcGetErrorString(code); + char err[1024] = {0}; + strcat(err, prefix); + strcat(err, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); + PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); + return false; } -#define MACA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); if(PyErr_Occurred()) return NULL; } - -static PyObject* getDeviceProperties(PyObject* self, PyObject* args){ - int device_id; - if(!PyArg_ParseTuple(args, "i", &device_id)) - return NULL; - // Get device handle - MCdevice device; - mcDeviceGet(&device, device_id); - - // create a struct to hold device properties - int max_shared_mem; - int multiprocessor_count; - int sm_clock_rate; - int mem_clock_rate; - int mem_bus_width; - MACA_CHECK(mcDeviceGetAttribute(&max_shared_mem, mcDeviceAttributeMaxSharedMemoryPerBlock, device)); - MACA_CHECK(mcDeviceGetAttribute(&multiprocessor_count, mcDeviceAttributeMultiProcessorCount, device)); - MACA_CHECK(mcDeviceGetAttribute(&sm_clock_rate, mcDeviceAttributeClockRate, device)); - MACA_CHECK(mcDeviceGetAttribute(&mem_clock_rate, mcDeviceAttributeMemoryClockRate, device)); - MACA_CHECK(mcDeviceGetAttribute(&mem_bus_width, mcDeviceAttributeMemoryBusWidth, device)); - - - return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", max_shared_mem, - "multiprocessor_count", multiprocessor_count, - "sm_clock_rate", sm_clock_rate, - "mem_clock_rate", mem_clock_rate, - "mem_bus_width", mem_bus_width); +// To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block. +#define MACA_CHECK_AND_RETURN_NULL(ans) \ + do { \ + if (!gpuAssert((ans), __FILE__, __LINE__)) \ + return NULL; \ + } while (0) + +// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block. +#define MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \ + do { \ + if (!gpuAssert((ans), __FILE__, __LINE__)) { \ + PyEval_RestoreThread(_save); \ + return NULL; \ + } \ + } while (0) + +// Used to check if functions exist in old CUDA driver versions. +#define INITIALIZE_FUNCTION_POINTER_IF_NULL(funcPointer, initializerFunction) \ + do { \ + if ((funcPointer) == NULL) { \ + (funcPointer) = (initializerFunction)(); \ + if ((funcPointer) == NULL) { \ + return NULL; \ + } \ + } \ + } while (0) + +static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { + int device_id; + if (!PyArg_ParseTuple(args, "i", &device_id)) + return NULL; + // Get device handle + MCdevice device; + mcDeviceGet(&device, device_id); + + // create a struct to hold device properties + int max_shared_mem = 64 * 1024; // 64KB, no CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN + int max_num_regs; + int multiprocessor_count; + int warp_size = 64; + int sm_clock_rate; + int mem_clock_rate; + int mem_bus_width; + MACA_CHECK_AND_RETURN_NULL(mcDeviceGetAttribute( + &max_num_regs, mcDeviceAttributeMaxSharedMemoryPerBlock, device)); + MACA_CHECK_AND_RETURN_NULL(mcDeviceGetAttribute( + &multiprocessor_count, mcDeviceAttributeMultiProcessorCount, device)); + MACA_CHECK_AND_RETURN_NULL(mcDeviceGetAttribute( + &sm_clock_rate, mcDeviceAttributeClockRate, device)); + MACA_CHECK_AND_RETURN_NULL(mcDeviceGetAttribute( + &mem_clock_rate, mcDeviceAttributeMemoryClockRate, device)); + MACA_CHECK_AND_RETURN_NULL(mcDeviceGetAttribute( + &mem_bus_width, mcDeviceAttributeMemoryBusWidth, device)); + + return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", + max_shared_mem, "max_num_regs", max_num_regs, + "multiprocessor_count", multiprocessor_count, "warpSize", + warp_size, "sm_clock_rate", sm_clock_rate, + "mem_clock_rate", mem_clock_rate, "mem_bus_width", + mem_bus_width); } -static PyObject* loadBinary(PyObject* self, PyObject* args) { - const char* name; - const char* data; - Py_ssize_t data_size; - int shared; - int device; - if(!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, &device)) { - return NULL; - } - mcFunction_t fun; - mcModule_t mod; - // create driver handles - MACA_CHECK(mcModuleLoadData(&mod, data)); - MACA_CHECK(mcModuleGetFunction(&fun, mod, name)); - - // get allocated registers and spilled registers from the function - int n_regs = 0; - int n_spills = 0; - - if(PyErr_Occurred()) { - return NULL; - } - return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, n_spills); +static PyObject *loadBinary(PyObject *self, PyObject *args) { + const char *name; + const char *data; + Py_ssize_t data_size; + int shared; + int device; + if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, + &device)) { + return NULL; + } + mcFunction_t fun; + mcModule_t mod; + int32_t n_regs = 0; + int32_t n_spills = 0; + // create driver handles + MCcontext pctx = 0; + + Py_BEGIN_ALLOW_THREADS; + // TODO: MCcontext implement not found + MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(mcCtxGetCurrent(&pctx)); + if (!pctx) { + MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + mcDevicePrimaryCtxRetain(&pctx, device)); + MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(mcCtxSetCurrent(pctx)); + } + MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(mcModuleLoadData(&mod, data)); + MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(mcModuleGetFunction(&fun, mod, name)); + // get allocated registers and spilled registers from the function + MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + mcFuncGetAttribute(&n_regs, MC_FUNC_ATTRIBUTE_NUM_REGS, fun)); + MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + mcFuncGetAttribute(&n_spills, MC_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); + n_spills /= 4; + Py_END_ALLOW_THREADS; + + if (PyErr_Occurred()) { + return NULL; + } + return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, n_spills); +} + +static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) { + long size; + if (!PyArg_ParseTuple(args, "l", &size)) { + return NULL; + } + if (size < 0) { + PyErr_SetString(PyExc_ValueError, "fifo size must be non-negative"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS; + + // Ensure we have an active context. + // MCcontext ctx = NULL; + // TODO: CU_LIMIT_PRINTF_FIFO_SIZE implement not found + // MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(mcCtxGetCurrent(&ctx)); + // if (!ctx) { + // MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + // mcDevicePrimaryCtxRetain(&ctx, /*device=*/0)); + // MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(mcCtxSetCurrent(ctx)); + // } + + // // We can't set the fifo size after running a kernel that calls printf. This + // // is true even if the set() call is a nop and the new size is the same as the + // // old size. + // // + // // This is unfriendly, so check if the old size matches the new size, and skip + // // the set() call if so. + // size_t oldSize = 0; + // MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + // mcCtxGetLimit(&oldSize, CU_LIMIT_PRINTF_FIFO_SIZE)); + // if (oldSize != size) { + // MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + // mcCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, size)); + // } + + Py_END_ALLOW_THREADS; + return Py_None; } static PyMethodDef ModuleMethods[] = { - {"load_binary", loadBinary, METH_VARARGS, "Load provided mcfatbin into MACA driver"}, - {"get_device_properties", getDeviceProperties, METH_VARARGS, "Get the properties for a given device"}, - {NULL, NULL, 0, NULL} // sentinel + {"load_binary", loadBinary, METH_VARARGS, + "Load provided cubin into CUDA driver"}, + {"get_device_properties", getDeviceProperties, METH_VARARGS, + "Get the properties for a given device"}, + {"set_printf_fifo_size", setPrintfFifoSize, METH_VARARGS, + "Python interface for cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, x), which " + "controls how many bytes can be streamed from kernels before data starts " + "being dropped. This inherits all the limitations of this call; in " + "particular it's an error to change this value after launching any kernel " + "that calls printf()."}, + {NULL, NULL, 0, NULL} // sentinel }; -static struct PyModuleDef ModuleDef = { - PyModuleDef_HEAD_INIT, - "maca_utils", - NULL, //documentation - -1, //size - ModuleMethods -}; +static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "maca_utils", + NULL, // documentation + -1, // size + ModuleMethods}; PyMODINIT_FUNC PyInit_maca_utils(void) { PyObject *m = PyModule_Create(&ModuleDef); - if(m == NULL) { + if (m == NULL) { return NULL; } + PyModule_AddFunctions(m, ModuleMethods); + return m; } diff --git a/backend/maca.py b/backend/maca.py index 6d10c9ab..d541b30f 100644 --- a/backend/maca.py +++ b/backend/maca.py @@ -1,129 +1,477 @@ +from triton.backends.compiler import BaseBackend, GPUTarget +from triton._C.libtriton import ir, passes, llvm, metax + +from dataclasses import dataclass +import functools +from typing import Any, Tuple, Optional import hashlib -import os +import re import tempfile -import shutil +import signal +import os import subprocess -import sysconfig -import contextlib -import sys -import io -import functools -import importlib -import setuptools from pathlib import Path -from triton.runtime.cache import get_cache_manager -from triton.runtime import JITFunction -from .utils import quiet - -from triton.backends.dicp_triton.libtriton.triton import ( - add_external_libs, compile_ptx_to_cubin, - get_shared_memory_size, ir, - translate_llvmir_to_hsaco, translate_llvmir_to_ptx, - translate_triton_gpu_to_llvmir, translate_llvmir_to_mcfatbin) - -if os.environ.get('MACA_PATH') is not None: - USE_MACA = True -else: - USE_MACA = False -assert USE_MACA, "Please set MACA_PATH!" - -def get_architecture_descriptor(capability=None): - try: - import torch - except ImportError: - raise ImportError("Triton requires PyTorch to be installed") - if capability is None: - if torch.version.hip is None: - device = torch.cuda.current_device() - capability = torch.cuda.get_device_capability(device) - capability = capability[0] * 10 + capability[1] - else: - raise ImportError("HIP not supported") - return capability - -def parse_mlir_module(mod): - context = ir.context() - ttir_code = str(mod) - with tempfile.TemporaryDirectory() as tmpdir: - src_path = os.path.join(tmpdir, "tt.mlir") - Path(src_path).write_text(ttir_code) - src_path = "/home/pujiang/tangding/Triton/python/op/add_kernel.ttir.mx" - module = ir.parse_mlir_module(src_path, context) - # module takes ownership of the context - module.context = context - return module - -def ttir_to_ttgir(mod, num_warps): - mod = parse_mlir_module(mod) + + +@functools.lru_cache() +def _path_to_binary(binary: str): + paths = [ + os.environ.get(f"TRITON_{binary.upper()}_PATH", ""), + os.path.join(os.path.dirname(__file__), "bin", binary), + ] + + for bin in paths: + if os.path.exists(bin) and os.path.isfile(bin): + result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT) + if result is not None: + version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE) + if version is not None: + return bin, version.group(1) + raise RuntimeError(f"Cannot find {binary}") + + +@functools.lru_cache() +def get_ptxas_version(): + version = subprocess.check_output([_path_to_binary("ptxas")[0], "--version"]).decode("utf-8") + return version + + +@functools.lru_cache() +def ptx_get_version(cuda_version) -> int: + ''' + Get the highest PTX version supported by the current CUDA driver. + ''' + assert isinstance(cuda_version, str) + major, minor = map(int, cuda_version.split('.')) + if major == 12: + return 80 + minor + if major == 11: + return 70 + minor + if major == 10: + return 63 + minor + raise RuntimeError("Triton only support CUDA 10.0 or higher") + + +@functools.lru_cache(None) +def file_hash(path): + with open(path, "rb") as f: + return hashlib.sha256(f.read()).hexdigest() + + +def maca_get_kernel_name(src: str) -> str: + ''' + Get kernel name from llvm ir. + This Kernel name is required when launching the kernel. + ''' + assert src + import re + for line in src.split('\n'): + line = line.strip() + if line.startswith('define metaxgpu_kernel void @'): + return re.match(r"define metaxgpu_kernel void @(.+?)\(", line).groups()[0] + +def parse_option(string): + return [item for item in string.split(';') if item] + + +@dataclass(frozen=True) +class MACAOptions: + num_warps: int = 4 + num_ctas: int = 1 + num_stages: int = 3 + # maxnreg corresponds to the ptx parameter .maxnreg, which controls the + # maximum number of 32-bit registers used by one thread. + maxnreg: Optional[int] = None + cluster_dims: tuple = (1, 1, 1) + ptx_version: int = None + enable_fp_fusion: bool = True + allow_fp8e4nv: bool = False + allow_fp8e4b15: bool = False + default_dot_input_precision: str = "tf32" + allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee") + max_num_imprecise_acc_default: bool = None + extern_libs: dict = None + debug: bool = False + backend_name: str = 'maca' + # MACA: new args + pipeline: str = "basic" + scenario: str = "" + extra_options: str = "" + pipeline_load_num: int = -1 + + def __post_init__(self): + default_libdir = os.getenv("MACA_PATH") + '/lib' + ext_default_libdir = Path(__file__).parent / 'lib' + extern_libs = {} if self.extern_libs is None else dict(self.extern_libs) + if not extern_libs.get('libdevice', None): + # ext_maca_mathlib.bc + env_ext_libdevice_path = os.getenv("TRITON_EXT_LIBDEVICE_PATH", None) + ext_libdevice_path = env_ext_libdevice_path if env_ext_libdevice_path is not None else str(ext_default_libdir) + '/ext_maca_mathlib.bc' + assert os.path.exists(ext_libdevice_path), "ext_maca_mathlib.bc do not exit, please check!" + extern_libs['ext_libdevice'] = ext_libdevice_path + # maca_kernellib.bc + env_kernel_libdevice_path = os.getenv("TRITON_KERNEL_LIBDEVICE_PATH", None) + kernel_libdevice_path = env_kernel_libdevice_path if env_kernel_libdevice_path is not None else default_libdir + '/maca_kernellib.bc' + extern_libs['kernel_libdevice'] = kernel_libdevice_path + # maca_mathlib.bc + env_libdevice_path = os.getenv("TRITON_LIBDEVICE_PATH", None) + libdevice_path = env_libdevice_path if env_libdevice_path is not None else default_libdir + '/maca_mathlib.bc' + extern_libs['libdevice'] = libdevice_path + object.__setattr__(self, 'extern_libs', tuple(extern_libs.items())) + assert self.num_warps > 0 and self.num_warps <= 16 and (self.num_warps & (self.num_warps - 1)) == 0, \ + "num_warps must be a power of 2 or greater than 0 and less than or equal to 16" + + def hash(self): + hash_dict = dict(self.__dict__) + # hash_dict["extern_libs"] = tuple((k, file_hash(v)) for k, v in sorted(hash_dict["extern_libs"])) + key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())]) + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + +class MACABackend(BaseBackend): + + @staticmethod + def supports_target(target: GPUTarget): + return target.backend == 'maca' + + def __init__(self, target: GPUTarget) -> None: + super().__init__(target) + self.capability = target.arch + print(f"zmz debug self.capability: {self.capability}") + assert isinstance(self.capability, int) + self.binary_ext = "mcfatbin" + + def parse_options(self, opts) -> Any: + args = {k: opts[k] for k in MACAOptions.__dataclass_fields__.keys() if k in opts} + # USE_MACA: support allow_fp8e4nv(i.e. float8_e4m3fn) + args["allow_fp8e4nv"] = True + # args["allow_fp8e4nv"] = False + args["allow_fp8e4b15"] = False + args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0 + return MACAOptions(**args) + + def pack_metadata(self, metadata): + return ( + metadata.num_warps, + metadata.num_ctas, + metadata.shared, + metadata.cluster_dims[0], + metadata.cluster_dims[1], + metadata.cluster_dims[2], + ) + + def get_codegen_implementation(self): + import triton.language.extra.cuda as cuda + codegen_fns = { + "convert_custom_types": + cuda.convert_custom_float8_sm80 if self.capability >= 80 else cuda.convert_custom_float8_sm70 + } + return codegen_fns + + def load_dialects(self, ctx): + metax.load_dialects(ctx) + + +@staticmethod +def make_ttir(mod, metadata, opt): pm = ir.pass_manager(mod.context) pm.enable_debug() - pm.add_convert_triton_to_tritongpu_pass(num_warps) + passes.common.add_inliner(pm) + passes.ttir.add_rewrite_tensor_pointer(pm) + passes.ttir.add_combine(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_reorder_broadcast(pm) + passes.common.add_cse(pm) + passes.common.add_licm(pm) + passes.common.add_symbol_dce(pm) pm.run(mod) return mod -def optimize_ttgir(mod, num_stages, arch): - """ - MACA supported pass combination: - 1. tritongpu_pipeline_maca4_pass - 2. tritongpu_pipeline_maca4_pass + tritongpu_prefetch_maca2_pass - 1. tritongpu_pipeline_maca5_pass - tritongpu_optimize_dot_operands_pass - 2. tritongpu_pipeline_maca5_pass + tritongpu_prefetch_maca2_pass - tritongpu_optimize_dot_operands_pass - """ - use_opt_maca_mma = (USE_MACA and os.getenv("TRITON_ENABLE_MACA_OPT_MMA") and num_stages == 2) +@staticmethod +def make_ttgir(mod, metadata, opt, capability): + assert opt.pipeline_load_num >= -1, "invalid pipeline_load_num value!" + scenarios = parse_option(opt.scenario) + disable_prefetch = "unprefetch" in scenarios + fullstage = "fullstage" in scenarios + store_coalesce = "storeCoalesce" in scenarios + mla = "mla" in scenarios + single_shm = "singleshm" in scenarios + use_opt_maca_mma = True + use_opt_maca_mma = (opt.pipeline != "" and not os.getenv("TRITON_DISABLE_MACA_OPT_MMA")) + # TTIR -> TTGIR pm = ir.pass_manager(mod.context) pm.enable_debug() - pm.add_tritongpu_coalesce_pass() - pm.add_tritongpu_remove_layout_conversions_pass() - if isinstance(arch, int): - pm.add_tritongpu_accelerate_matmul_pass(arch, num_stages) - pm.add_tritongpu_remove_layout_conversions_pass() - pm.add_tritongpu_optimize_dot_operands_pass() - if use_opt_maca_mma: - pm.add_tritongpu_pipeline_maca5_pass(num_stages) - pm.add_tritongpu_prefetch_maca2_pass() - else: - pm.add_tritongpu_pipeline_pass(num_stages) - pm.add_tritongpu_prefetch_pass() - # TODO(fix): add_tritongpu_pipeline_maca5_pass induces endless loop - if not use_opt_maca_mma: - pm.add_tritongpu_optimize_dot_operands_pass() - pm.add_tritongpu_remove_layout_conversions_pass() - pm.add_tritongpu_decompose_conversions_pass() - if not use_opt_maca_mma: - pm.add_tritongpu_reorder_instructions_pass() - pm.add_cse_pass() - pm.add_symbol_dce_pass() + passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 64, opt.num_ctas) + # optimize TTGIR + passes.ttgpuir.add_coalesce(pm) + if capability // 10 >= 8: + passes.ttgpuir.add_f32_dot_tc(pm) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_optimize_thread_locality(pm) + + if opt.pipeline == "cpasync" : + disable_prefetch = True + metax.passes.ttgpuir.add_accelerate_matmul(pm, opt.num_stages, disable_prefetch, store_coalesce, "c500") + passes.ttgpuir.add_remove_layout_conversions(pm) + if store_coalesce: + metax.passes.ttgpuir.add_tritonmetaxgpu_change_layout_from_repn_to_elemn_pass(pm) + metax.passes.ttgpuir.add_tritonmetaxgpu_optimize_cstore_pass(pm, opt.num_stages) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) + passes.common.add_cse(pm) + if capability // 10 >= 8: + passes.ttgpuir.add_combine_tensor_select_and_if(pm) + if use_opt_maca_mma: + if opt.pipeline == "basic": + if mla and single_shm: + # only mla=True and single_shm=True + metax.passes.ttgpuir.add_pipeline_maca_4(pm, opt.num_stages, opt.pipeline_load_num, fullstage, True) + else: + metax.passes.ttgpuir.add_pipeline_maca_4(pm, opt.num_stages, opt.pipeline_load_num, fullstage, False) + elif opt.pipeline == "cpasync" and not mla: + metax.passes.ttgpuir.add_pipeline_async_tn(pm, opt.num_stages) + metax.passes.ttgpuir.add_pipeline_async_tt(pm, opt.num_stages) + metax.passes.ttgpuir.add_pipeline_async_base(pm, opt.num_stages, fullstage) + elif mla and opt.num_stages == 2 and opt.pipeline == "cpasync": + metax.passes.ttgpuir.add_pipeline_async_multidot_mla_mixed(pm, opt.num_stages, fullstage, opt.pipeline_load_num, single_shm, True) + elif mla and opt.num_stages == 2 and opt.pipeline == "mixed": + metax.passes.ttgpuir.add_pipeline_async_multidot_mla_mixed(pm, opt.num_stages, fullstage, opt.pipeline_load_num, single_shm, False) + else: + print("no avalilable pipeline for maca") + else: + passes.ttgpuir.add_pipeline(pm, opt.num_stages) + if use_opt_maca_mma and opt.pipeline == "basic" and "unprefetch" not in scenarios: + metax.passes.ttgpuir.add_prefetch_maca_2(pm) + elif not use_opt_maca_mma: + passes.ttgpuir.add_prefetch(pm) + passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_reduce_data_duplication(pm) + passes.ttgpuir.add_reorder_instructions(pm) + if os.getenv("TRITON_ENABLE_MACA_OPT_MOVE_DOT_OPERANDS_OUT_LOOP"): + metax.passes.ttgpuir.add_tritonmetaxgpu_move_dot_operands_out_loop_pass(pm) + if os.getenv("TRITON_ENABLE_MACA_MERGE_EQUAL_SHARED_LAYOUT"): + metax.passes.ttgpuir.add_tritonmetaxgpu_merge_equal_shared_layout_pass(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + passes.common.add_canonicalizer(pm) pm.run(mod) return mod -def ttgir_to_llir(mod, arch): - return translate_triton_gpu_to_llvmir(mod, arch, False) +@staticmethod +def make_mlir(src, metadata, options, capability): + # warp-specialization mutates num_warps + num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta") + if num_warp_groups is not None: + metadata["num_warps"] *= num_warp_groups + mod = src -def llir_to_mcfatbin(mod, mxcc_arch: str, maca_path: str): - ''' - Translate TritonGPU module to mcfatbin code. - :param mod: a TritonGPU dialect module - :return: - - Path to mcfatbin object - ''' - return translate_llvmir_to_mcfatbin(mod, mxcc_arch, maca_path) + # TritonGPU -> LLVM-IR (MLIR) + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.ttgpuir.add_combine_tensor_select_and_if(pm) + passes.convert.add_scf_to_cf(pm) + passes.convert.add_index_to_llvmir(pm) + passes.ttgpuir.add_allocate_shared_memory(pm) + metax.passes.ttgpuir.add_to_llvmir(pm, capability) + passes.convert.add_arith_to_llvmir(pm) + passes.common.add_canonicalizer(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": + passes.llvmir.add_di_scope(pm) + pm.run(mod) + + # Get some metadata + metadata["shared"] = src.get_int_attr("triton_gpu.shared") + ret = str(mod) + return ret + +@staticmethod +def make_llir(src, metadata, options, capability): + mlir_opt_path = os.path.dirname(__file__) + "/bin/mlir-opt" + opted_mlir = metax.mlir_opt(src, mlir_opt_path) + mlir_translate_path = os.path.dirname(__file__) + "/bin/mlir-translate" + maca_path = os.environ.get('MACA_PATH') + assert maca_path, "Not found MACA_PATH" + llir = metax.translate_mlir_to_llir(opted_mlir, maca_path) + if options.extern_libs: + paths = [path for (name, path) in options.extern_libs] + llir = metax.link_extern_libs(llir, paths, maca_path) + metadata["name"] = maca_get_kernel_name(llir) + return llir + + +@staticmethod +def make_mcfatbin(src, metadata, opt, capability): + scenarios = parse_option(opt.scenario) + opt_mxcc = os.environ.get("TRITON_COMPILER_OPT_PATH") + mxcc_arch = os.environ.get('MACA_PATH') + "/mxgpu_llvm/bin/mxcc" + if opt_mxcc: + mxcc_arch = opt_mxcc + "/mxgpu_llvm/bin/mxcc" + if mxcc_arch is None: + raise RuntimeError('mxcc_arch is None (not specified)') + compile_options = "" + if (opt.pipeline == "basic" or opt.pipeline == "basic-prefetch") and "mla" not in scenarios: + compile_options = " -mllvm -metaxgpu-sched-regpressure=false -mllvm -metaxgpu-PostRA-Scheduler=false -mllvm -metaxgpu-mma-sched=true " + if "fullstage" in scenarios: + compile_options += " -mllvm -metaxgpu-vectorize-slp=true -mllvm -metaxgpu-igroup " + else: + compile_options += " -mllvm -metaxgpu-vectorize-slp=true -mllvm -metaxgpu-sched-mma-maxnum=3 " + if "roll" not in scenarios: + compile_options += " -mllvm -metaxgpu-mma-unroll-count=" + str(opt.num_stages) + " " + elif opt.pipeline == "cpasync" and "mla" not in scenarios: + compile_options = " -mllvm -metaxgpu-sched-regpressure=true " + compile_options += " -mllvm -metaxgpu-sinkload=false -mllvm -metaxgpu-vectorize-slp=true -mllvm -metaxgpu-igroup -mllvm -metaxgpu-aggressive-4g-addr-opt=true \ + -mllvm -metaxgpu-shl-add-combine=false -mllvm -misched-postra=true -mllvm -enable-post-misched=true " + + if os.getenv("TRITON_ENABLE_MACA_COMPILER_INT8_OPT"): + compile_options += " -mllvm -metaxgpu-slp-vectorize-i8=true" + + if "unroll" in scenarios: + compile_options += " -mllvm -metaxgpu-mma-unroll-count=" + str(opt.num_stages) + " " + if "flashattn-fwd" in scenarios: + compile_options = " -mllvm -metaxgpu-mma-sched=true -mllvm -metaxgpu-sched-select=metaxgpu-minreg -mllvm -map-use-pk-fma=1 " + elif "flashattn-bwd" in scenarios: + compile_options = " -mllvm -metaxgpu-sched-regpressure=true " + compile_options += " -mllvm -metaxgpu-sinkload=false -mllvm -metaxgpu-vectorize-slp=true " + if "mla" in scenarios: + # maybe will change the compile options in mla later + if opt.num_stages == 2: + if opt.pipeline == "cpasync": + compile_options = " -mllvm -metaxgpu-sched-regpressure=true " + compile_options += " -mllvm -metaxgpu-sinkload=false -mllvm -metaxgpu-vectorize-slp=true -mllvm -metaxgpu-igroup -mllvm -metaxgpu-aggressive-4g-addr-opt=true \ + -mllvm -metaxgpu-shl-add-combine=false -mllvm -misched-postra=true -mllvm -enable-post-misched=true " + if "unroll" in scenarios: + compile_options += " -mllvm -metaxgpu-mma-unroll-count=" + str(opt.num_stages) + " " + elif opt.pipeline == "basic" or opt.pipeline == "mixed": + compile_options = " -mllvm -metaxgpu-mma-sched=true -mllvm -map-use-pk-fma=1 -mllvm -metaxgpu-split-regalloc=true -mllvm -metaxgpu-aggressive-fold=true \ + -mllvm -metaxgpu-disable-licm=true " + else: + assert False, "Please set pipeline for mla!" + else: + compile_options = " -mllvm -metaxgpu-mma-sched=true -mllvm -map-use-pk-fma=1 -mllvm -metaxgpu-split-regalloc=true -mllvm -metaxgpu-aggressive-fold=true " + if opt.extra_options != "": + compile_options = opt.extra_options + return metax.translate_llvmir_to_mcfatbin(src, mxcc_arch, os.environ.get('MACA_PATH'), compile_options) + + + # def add_stages(self, stages, options): + # stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) + # stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability) + # stages["mlir"] = lambda src, metadata: self.make_mlir(src, metadata, options, self.capability) + # stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability) + # stages["mcfatbin"] = lambda src, metadata: self.make_mcfatbin(src, metadata, options, self.capability) + + # @functools.lru_cache() + # def hash(self): + # mxcc_arch = os.environ.get('MACA_PATH') + "/mxgpu_llvm/bin/mxcc" + # if mxcc_arch is None: + # raise RuntimeError('mxcc_arch is None (not specified)') + # version = subprocess.check_output([mxcc_arch, "--version"]).decode("utf-8").split('\n', 1)[0] + # return f'{version}-{self.capability}' -# ----- source code generation -------- +################################################################################################## + + + +import functools +import os +import hashlib +import subprocess +import tempfile +from pathlib import Path +from triton.runtime.build import _build +from triton.runtime.cache import get_cache_manager +from triton.backends.compiler import GPUTarget +from triton.backends.driver import GPUDriver + +dirname = os.path.dirname(os.path.realpath(__file__)) +include_dir = [os.path.join(dirname, "include")] +libdevice_dir = os.path.join(dirname, "lib") +# libraries = ['cuda'] +libraries = [] + +@functools.lru_cache() +def maca_home_dirs(): + return os.getenv("MACA_PATH") + +@functools.lru_cache() +def libmaca_dirs(): + maca_path = maca_home_dirs() + return ["{}/lib/".format(maca_path)] + +maca_lib_dir = libmaca_dirs() +maca_include_dir = [os.path.join(maca_home_dirs(), "include")] + + +@functools.lru_cache() +def library_dirs(): + return [libdevice_dir, *libmaca_dirs()] + + +def compile_module_from_src(src, name): + # print(f"zmz debug compile_module_from_src: {src}") + print(f"zmz debug name, {name}") + key = hashlib.sha256(src.encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + cache_path = cache.get_file(f"{name}.so") + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "main.c") + with open(src_path, "w") as f: + f.write(src) + # TODO(MACA): fix it + so = _build(name, src_path, tmpdir, library_dirs(), maca_include_dir, libraries) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), f"{name}.so", binary=True) + import importlib.util + spec = importlib.util.spec_from_file_location(name, cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +# ------------------------ +# Utils +# ------------------------ + + +class MacaUtils(object): + + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(MacaUtils, cls).__new__(cls) + return cls.instance + + def __init__(self): + mod = compile_module_from_src(Path(os.path.join(dirname, "maca.c")).read_text(), "maca_utils") + self.load_binary = mod.load_binary + self.get_device_properties = mod.get_device_properties + # self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters + self.set_printf_fifo_size = mod.set_printf_fifo_size + # self.fill_1d_tma_descriptor = mod.fill_1d_tma_descriptor + # self.fill_2d_tma_descriptor = mod.fill_2d_tma_descriptor + + +# ------------------------ +# Launcher +# ------------------------ def ty_to_cpp(ty): if ty[0] == '*': - if USE_MACA: - return "mcDeviceptr_t" - else: - return "CUdeviceptr" + return "mcDeviceptr_t" return { "i1": "int32_t", "i8": "int8_t", "i16": "int16_t", "i32": "int32_t", "i64": "int64_t", + "u1": "uint32_t", + "u8": "uint8_t", + "u16": "uint16_t", "u32": "uint32_t", "u64": "uint64_t", "fp16": "float", @@ -134,24 +482,15 @@ def ty_to_cpp(ty): }[ty] -def generate_launcher(constants, signature): +def make_launcher(constants, signature, ids): + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) def _extracted_type(ty): if ty[0] == '*': return "PyObject*" - return { - 'i1': 'int32_t', - 'i32': 'int32_t', - 'i64': 'int64_t', - 'u32': 'uint32_t', - 'u64': 'uint64_t', - 'fp16': 'float', - 'bf16': 'float', - 'fp32': 'float', - 'f32': 'float', - 'fp64': 'double', - }[ty] + return ty_to_cpp(ty) def format_of(ty): return { @@ -159,328 +498,56 @@ def format_of(ty): "float": "f", "double": "d", "long": "l", - "uint32_t": "I", + "int8_t": "b", + "int16_t": "h", "int32_t": "i", + "int64_t": "l", + "uint8_t": "B", + "uint16_t": "H", + "uint32_t": "I", "uint64_t": "K", - "int64_t": "L", }[ty] - format = "iiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + format = "iiiKKOOOO" + args_format + args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' # generate glue code - if USE_MACA: - src = f""" - #include - #include - - static inline void gpuAssert(mcError_t code, const char *file, int line) - {{ - if (code != mcSuccess) - {{ - const char* prefix = "Triton Error [MACA]: "; - const char* str = mcGetErrorString(code); - char err[1024] = {{0}}; - strcat(err, prefix); - strcat(err, str); - PyErr_SetString(PyExc_RuntimeError, err); - }} - }} - - #define MACA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} - - void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, mcStream_t stream, mcFunction_t function, {arg_decls}) {{ - void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; - if(gridX*gridY*gridZ > 0){{ - MACA_CHECK(mcModuleLaunchKernel(function, gridX, gridY, gridZ, 64*num_warps, 1, 1, shared_memory, stream, params, 0)); - }} - }} - - typedef struct _DevicePtrInfo {{ - mcDeviceptr_t dev_ptr; - bool valid; - }} DevicePtrInfo; - - static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ - DevicePtrInfo ptr_info; - ptr_info.dev_ptr = (mcDeviceptr_t)0; - ptr_info.valid = true; - if (PyLong_Check(obj)) {{ - ptr_info.dev_ptr = (mcDeviceptr_t)PyLong_AsUnsignedLongLong(obj); - return ptr_info; - }} - if (obj == Py_None) {{ - // valid nullptr - return ptr_info; - }} - PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); - if(ptr){{ - PyObject *empty_tuple = PyTuple_New(0); - PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); - Py_DECREF(empty_tuple); - Py_DECREF(ptr); - if (!PyLong_Check(ret)) {{ - PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); - ptr_info.valid = false; - return ptr_info; - }} - ptr_info.dev_ptr = (mcDeviceptr_t)PyLong_AsUnsignedLongLong(ret); - if(!ptr_info.dev_ptr) - return ptr_info; - uint64_t dev_ptr; - int status = mcPointerGetAttribute(&dev_ptr, mcPointerAttributeDevicePointer, ptr_info.dev_ptr); - if (status == mcErrorInvalidValue) {{ - PyErr_Format(PyExc_ValueError, - "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); - ptr_info.valid = false; - }} - ptr_info.dev_ptr = (mcDeviceptr_t)dev_ptr; - Py_DECREF(ret); - return ptr_info; - }} - PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); - return ptr_info; - }} - - static PyObject* launch(PyObject* self, PyObject* args) {{ - int gridX, gridY, gridZ; - uint64_t _stream; - uint64_t _function; - int num_warps; - int shared_memory; - PyObject *launch_enter_hook = NULL; - PyObject *launch_exit_hook = NULL; - PyObject *compiled_kernel = NULL; - PyObject *hook_ret = NULL; - {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} - if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{ - return NULL; - }} - - if (launch_enter_hook != Py_None) {{ - PyObject *new_args = PyTuple_Pack(1, compiled_kernel); - hook_ret = PyObject_CallObject(launch_enter_hook, new_args); - Py_DECREF(new_args); - }} - - // raise exception asap - {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; - _launch(gridX, gridY, gridZ, num_warps, shared_memory, (mcStream_t)_stream, (mcFunction_t)_function, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())}); - - if (launch_exit_hook != Py_None) {{ - PyObject *new_args = NULL; - if (hook_ret) {{ - new_args = PyTuple_Pack(2, compiled_kernel, hook_ret); - }} else {{ - new_args = PyTuple_Pack(1, compiled_kernel); - }} - hook_ret = PyObject_CallObject(launch_exit_hook, new_args); - Py_DECREF(new_args); - }} - - if (hook_ret) {{ - Py_DECREF(hook_ret); - }} - if(PyErr_Occurred()) {{ - return NULL; - }} - // return None - Py_INCREF(Py_None); - return Py_None; - }} - - static PyMethodDef ModuleMethods[] = {{ - {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, - {{NULL, NULL, 0, NULL}} // sentinel - }}; - - static struct PyModuleDef ModuleDef = {{ - PyModuleDef_HEAD_INIT, - \"__triton_launcher\", - NULL, //documentation - -1, //size - ModuleMethods - }}; - - PyMODINIT_FUNC PyInit___triton_launcher(void) {{ - PyObject *m = PyModule_Create(&ModuleDef); - if(m == NULL) {{ - return NULL; - }} - PyModule_AddFunctions(m, ModuleMethods); - return m; - }} - """ - elif is_hip(): - src = f""" - #define __HIP_PLATFORM_AMD__ - #include - #include - #include - - static inline void gpuAssert(hipError_t code, const char *file, int line) - {{ - if (code != HIP_SUCCESS) - {{ - const char* prefix = "Triton Error [HIP]: "; - const char* str = hipGetErrorString(code); - char err[1024] = {{0}}; - snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str ); - PyErr_SetString(PyExc_RuntimeError, err); - }} - }} - - #define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} - - static void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, hipStream_t stream, hipFunction_t function, {arg_decls}) {{ - void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; - if (gridX*gridY*gridZ > 0) {{ - HIP_CHECK(hipModuleLaunchKernel(function, gridX, gridY, gridZ, 64*num_warps, 1, 1, shared_memory, stream, params, 0)); - }} - }} - - typedef struct _DevicePtrInfo {{ - hipDeviceptr_t dev_ptr; - bool valid; - }} DevicePtrInfo; - - static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ - DevicePtrInfo ptr_info; - ptr_info.dev_ptr = 0; - ptr_info.valid = true; - - if (PyLong_Check(obj)) {{ - ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj); - return ptr_info; - }} - - if (obj == Py_None) {{ - // valid nullptr - return ptr_info; - }} - - PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); - - if (ptr) {{ - PyObject *empty_tuple = PyTuple_New(0); - PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); - Py_DECREF(empty_tuple); - Py_DECREF(ptr); - - if (!PyLong_Check(ret)) {{ - PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); - ptr_info.valid = false; - return ptr_info; - }} - - ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret); - - if (!ptr_info.dev_ptr) - return ptr_info; - - uint64_t dev_ptr; - hipError_t status = hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); - if (status == hipErrorInvalidValue) {{ - PyErr_Format(PyExc_ValueError, - "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); - ptr_info.valid = false; - }} - - ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr; - return ptr_info; - }} - - PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); - return ptr_info; - }} - - static PyObject* launch(PyObject* self, PyObject* args) {{ - - int gridX, gridY, gridZ; - uint64_t _stream; - uint64_t _function; - int num_warps; - int shared_memory; - PyObject *launch_enter_hook = NULL; - PyObject *launch_exit_hook = NULL; - PyObject *compiled_kernel = NULL; - - {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} - if (!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{ - return NULL; - }} - - if (launch_enter_hook != Py_None) {{ - PyObject_CallObject(launch_enter_hook, args); - }} - - // raise exception asap - {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; - _launch(gridX, gridY, gridZ, num_warps, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items())}); - if (launch_exit_hook != Py_None) {{ - PyObject_CallObject(launch_exit_hook, args); - }} - if (PyErr_Occurred()) {{ - return NULL; - }} - - // return None - Py_INCREF(Py_None); - return Py_None; - }} - - static PyMethodDef ModuleMethods[] = {{ - {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, - {{NULL, NULL, 0, NULL}} // sentinel - }}; - - static struct PyModuleDef ModuleDef = {{ - PyModuleDef_HEAD_INIT, - \"__triton_launcher\", - NULL, //documentation - -1, //size - ModuleMethods - }}; - - PyMODINIT_FUNC PyInit___triton_launcher(void) {{ - PyObject *m = PyModule_Create(&ModuleDef); - if(m == NULL) {{ - return NULL; - }} - PyModule_AddFunctions(m, ModuleMethods); - return m; - }} - """ - else: - src = f""" -#include \"cuda.h\" + params = [i for i in signature.keys() if i not in constants] + src = f""" +#include #include #include +#include -static inline void gpuAssert(CUresult code, const char *file, int line) +static inline void gpuAssert(mcError_t code, const char *file, int line) {{ - if (code != CUDA_SUCCESS) + if (code != mcSuccess) {{ - const char* prefix = "Triton Error [CUDA]: "; - const char* str; - cuGetErrorString(code, &str); + const char* prefix = "Triton Error [MACA]: "; + const char* str = mcGetErrorString(code); char err[1024] = {{0}}; strcat(err, prefix); strcat(err, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); }} }} -#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} +#define MACA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} -static void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{ - void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; - if(gridX*gridY*gridZ > 0){{ - CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0)); +static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, mcStream_t stream, mcFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }}; + if (gridX*gridY*gridZ > 0) {{ + assert(num_ctas == 1); + MACA_CHECK(mcModuleLaunchKernel(function, gridX, gridY, gridZ, 64*num_warps, 1, 1, shared_memory, stream, params, 0)); }} }} typedef struct _DevicePtrInfo {{ - CUdeviceptr dev_ptr; + mcDeviceptr_t dev_ptr; bool valid; }} DevicePtrInfo; @@ -489,7 +556,7 @@ def format_of(ty): ptr_info.dev_ptr = 0; ptr_info.valid = true; if (PyLong_Check(obj)) {{ - ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj); + ptr_info.dev_ptr = (mcDeviceptr_t)PyLong_AsUnsignedLongLong(obj); return ptr_info; }} if (obj == Py_None) {{ @@ -507,21 +574,22 @@ def format_of(ty): ptr_info.valid = false; return ptr_info; }} - ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret); + ptr_info.dev_ptr = (mcDeviceptr_t)PyLong_AsUnsignedLongLong(ret); if(!ptr_info.dev_ptr) return ptr_info; uint64_t dev_ptr; - int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); - if (status == CUDA_ERROR_INVALID_VALUE) {{ + int status = mcPointerGetAttribute(&dev_ptr, mcPointerAttributeDevicePointer, ptr_info.dev_ptr); + if (status == mcErrorInvalidValue) {{ PyErr_Format(PyExc_ValueError, "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); ptr_info.valid = false; }} - ptr_info.dev_ptr = dev_ptr; + ptr_info.dev_ptr = (mcDeviceptr_t)dev_ptr; Py_DECREF(ret); // Thanks ChatGPT! return ptr_info; }} PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + ptr_info.valid = false; return ptr_info; }} @@ -529,32 +597,50 @@ def format_of(ty): int gridX, gridY, gridZ; uint64_t _stream; uint64_t _function; - int num_warps; - int shared_memory; PyObject *launch_enter_hook = NULL; PyObject *launch_exit_hook = NULL; - PyObject *compiled_kernel = NULL; + PyObject *kernel_metadata = NULL; + PyObject *launch_metadata = NULL; {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} - if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{ + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &_stream, &_function, + &kernel_metadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook {args_list})) {{ return NULL; }} - if (launch_enter_hook != Py_None) {{ - PyObject_CallObject(launch_enter_hook, args); + int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; + if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{ + PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple"); + return NULL; }} + // extract launch metadata + if (launch_enter_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} // raise exception asap {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; - _launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())}); - - if (launch_exit_hook != Py_None) {{ - PyObject_CallObject(launch_exit_hook, args); + Py_BEGIN_ALLOW_THREADS; + _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (mcStream_t)_stream, (mcFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''}); + Py_END_ALLOW_THREADS; + if (PyErr_Occurred()) {{ + return NULL; }} - if(PyErr_Occurred()) {{ - return NULL; + if(launch_exit_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + // return None Py_INCREF(Py_None); return Py_None; @@ -584,139 +670,38 @@ def format_of(ty): """ return src -@functools.lru_cache() -def maca_home_dirs(): - return os.getenv("MACA_PATH") - -@functools.lru_cache() -def libmaca_dir(): - return os.path.join(maca_home_dirs(), "lib") - -def _build(name, src, srcdir): - maca_lib_dir = libmaca_dir() - maca_include_dir = os.path.join(maca_home_dirs(), "include") - suffix = sysconfig.get_config_var('EXT_SUFFIX') - so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) - # try to avoid setuptools if possible - cc = os.environ.get("CC") - if cc is None: - # TODO: support more things here. - clang = shutil.which("clang") - gcc = shutil.which("gcc") - cc = gcc if gcc is not None else clang - if cc is None: - raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.") - # This function was renamed and made public in Python 3.10 - if hasattr(sysconfig, 'get_default_scheme'): - scheme = sysconfig.get_default_scheme() - else: - scheme = sysconfig._get_default_scheme() - # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install - # path changes to include 'local'. This change is required to use triton with system-wide python. - if scheme == 'posix_local': - scheme = 'posix_prefix' - py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] - - ret = subprocess.check_call(["/usr/bin/g++", src, f"-I{maca_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", f"-D__MACA__", "-shared", "-fPIC", f"-L{maca_lib_dir}", "-lmcruntime", "-o", so]) - - if ret == 0: - return so - # fallback on setuptools - extra_compile_args = [] - include_dirs = [srcdir] - libraries = ['cuda'] - # extra arguments - extra_link_args = [] - # create extension module - ext = setuptools.Extension( - name=name, - language='c', - sources=[src], - include_dirs=include_dirs, - extra_compile_args=extra_compile_args + ['-O3'], - extra_link_args=extra_link_args, - libraries=libraries, - ) - # build extension module - args = ['build_ext'] - args.append('--build-temp=' + srcdir) - args.append('--build-lib=' + srcdir) - args.append('-q') - args = dict( - name=name, - ext_modules=[ext], - script_args=args, - ) - with quiet(): - setuptools.setup(**args) - return so - - -class MacaUtils(object): - - def __new__(cls): - if not hasattr(cls, 'instance'): - cls.instance = super(MacaUtils, cls).__new__(cls) - return cls.instance - - def __init__(self): - dirname = os.path.dirname(os.path.realpath(__file__)) - src = Path(os.path.join(dirname, "maca.c")).read_text() - key = hashlib.md5(src.encode("utf-8")).hexdigest() - cache = get_cache_manager(key) - fname = "maca_utils.so" - cache_path = cache.get_file(fname) - if cache_path is None: - with tempfile.TemporaryDirectory() as tmpdir: - src_path = os.path.join(tmpdir, "main.c") - with open(src_path, "w") as f: - f.write(src) - so = _build("maca_utils", src_path, tmpdir) - with open(so, "rb") as f: - cache_path = cache.put(f.read(), fname, binary=True) - import importlib.util - spec = importlib.util.spec_from_file_location("maca_utils", cache_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - self.load_binary = mod.load_binary - self.get_device_properties = mod.get_device_properties - class MacaLauncher(object): def __init__(self, src, metadata): - if isinstance(src.fn, JITFunction): - name, _ = src.fn.__name__, "ast" - else: - name, _ = os.path.basename(src.fn).split(".") + ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} constants = src.constants if hasattr(src, "constants") else dict() cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i constants = {cst_key(key): value for key, value in constants.items()} signature = {cst_key(key): value for key, value in src.signature.items()} - so_cache_key = src.hash() - self.so_path = self.make_launcher_stub(name, so_cache_key, signature, constants) - spec = importlib.util.spec_from_file_location("__triton_launcher", self.so_path) - mod = importlib.util.module_from_spec(spec) + src = make_launcher(constants, signature, ids) + mod = compile_module_from_src(src, "__triton_launcher") self.launch = mod.launch - def __call__(self, grid_0, grid_1, grid_2, stream, kernel_function, kernel_packed_metadata, launch_metadata, launch_enter_hook,launch_exit_hook, *args, **kwargs): - self.launch(grid_0, grid_1, grid_2, 4, 0, stream, kernel_function, launch_enter_hook, launch_exit_hook, self, *args) - - def make_launcher_stub(self, name, so_cache_key, signature, constants): - # name of files that are cached - so_cache_manager = get_cache_manager(so_cache_key) - so_name = f"{name}.so" - # retrieve stub from cache if it exists - cache_path = so_cache_manager.get_file(so_name) - if cache_path is None: - with tempfile.TemporaryDirectory() as tmpdir: - src = generate_launcher(constants, signature) - src_path = os.path.join(tmpdir, "main.c") - with open(src_path, "w") as f: - f.write(src) - - so = _build(name, src_path, tmpdir) - with open(so, "rb") as f: - return so_cache_manager.put(f.read(), so_name, binary=True) - else: - return cache_path \ No newline at end of file + def __call__(self, *args, **kwargs): + self.launch(*args, **kwargs) + + +class MacaDriver(GPUDriver): + + def __init__(self): + self.utils = MacaUtils() # TODO: make static + self.launcher_cls = MacaLauncher + super().__init__() + + def get_current_target(self): + device = self.get_current_device() + capability = self.get_device_capability(device) + capability = capability[0] * 10 + capability[1] + warp_size = 64 + return GPUTarget("maca", capability, warp_size) + + @staticmethod + def is_active(): + import torch + return torch.cuda.is_available() and (torch.version.hip is None) diff --git a/setup_on_maca.py b/setup_on_maca.py new file mode 100644 index 00000000..0524002e --- /dev/null +++ b/setup_on_maca.py @@ -0,0 +1,260 @@ +import setuptools +import os +import shutil +import sys +import glob +from setuptools.command.build_py import build_py +from setuptools.command.install import install +import importlib.util +import subprocess +import site +setuptools._distutils_hack = lambda *args, **kwargs: None +from setuptools.dist import Distribution + +# Monkey-patch: 让 bdist_egg 认为包永远不安全(从而生成目录) +_original_zip_safe = getattr(Distribution, 'zip_safe', None) +def _always_false_zip_safe(self): + return False +Distribution.zip_safe = property(_always_false_zip_safe) + +ORIGIN_TRITON_PATH=None +# check triton +def check_triton_package(): + try: + spec = importlib.util.find_spec("triton") + if spec is not None: + print("Triton package found.") + global ORIGIN_TRITON_PATH + ORIGIN_TRITON_PATH = spec.origin + print(f"ORIGIN_TRITON_PATH: {ORIGIN_TRITON_PATH}") + if ORIGIN_TRITON_PATH.endswith("__init__.py"): + ORIGIN_TRITON_PATH = ORIGIN_TRITON_PATH[:-12] + print(f"ORIGIN_TRITON_PATH: {ORIGIN_TRITON_PATH}") + + else: + print("Triton package not found.") + assert False, "Triton package not found, please choose env with triton." + except ImportError: + print("Triton package not found.") + assert False, "Triton package not found, please choose env with triton." + +def copy_triton_package(): + global ORIGIN_TRITON_PATH + source_dir = ORIGIN_TRITON_PATH + backup_dir = os.path.join(site.getsitepackages()[0], "triton-ori") + if os.path.exists(source_dir): + if os.path.exists(backup_dir): + print(f"{backup_dir} already exists, use it {backup_dir}.") + else: + shutil.copytree(source_dir, backup_dir, ignore=shutil.ignore_patterns('__pycache__', '*.pyc')) + + source_dir = backup_dir + + target_dir = "./triton" + if not os.path.exists(target_dir): + print(f"Copying {source_dir} to {target_dir}") + shutil.copytree(source_dir, target_dir, ignore=shutil.ignore_patterns('__pycache__', '*.pyc')) + else: + print(f"{target_dir} already exists, skipping.") + + if not os.listdir(target_dir): + assert False, f"{target_dir} is empty, please check." + +def copy_backend_files(): + source_dir = "./backend" + target_dir = "./triton/backends/metax" + print(f"zmz debug source_dir: {os.path.realpath(source_dir)}") + + if not os.path.exists(target_dir): + print(f"zmz debug realpath: {os.path.realpath(target_dir)}") + # 创建对应目录 + os.makedirs(target_dir) + # assert False, f"Target directory {target_dir} does not exist, please check the path." + + for filename in ["compiler.py", "driver.py", "maca.py", "utils.py", "maca.c"]: + src_path = os.path.join(source_dir, filename) + dest_path = os.path.join(target_dir, filename) + + if os.path.exists(src_path): + print(f"Copying {src_path} to {dest_path}") + shutil.copy2(src_path, dest_path) + else: + realpath = os.path.realpath(src_path) + assert False, f"Source file {realpath} does not exist, please check the path." + + # 拷贝lib目录文件,到dest_path目录 + lib_dir = os.path.join(source_dir, "lib") + dest_lib_dir = os.path.join(target_dir, "lib") + if os.path.exists(lib_dir): + print(f"Copying {os.path.realpath(lib_dir)} to {os.path.realpath(dest_lib_dir)}") + shutil.copytree(lib_dir, dest_lib_dir, ignore=shutil.ignore_patterns('__pycache__', '*.pyc')) + + # 拷贝bin目录 + bin_dir = os.path.join(source_dir, "bin") + dest_bin_dir = os.path.join(target_dir, "bin") + if os.path.exists(bin_dir): + print(f"Copying {os.path.realpath(bin_dir)} to {os.path.realpath(dest_bin_dir)}") + shutil.copytree(bin_dir, dest_bin_dir, ignore=shutil.ignore_patterns('__pycache__', '*.pyc')) + +def modify_backend_name(): + maca_dir = "./triton/backends/metax" + dicp_triton_dir = "./triton/backends/dicp_triton" + if os.path.exists(dicp_triton_dir): + shutil.rmtree(dicp_triton_dir) + + if os.path.exists(maca_dir): + print(f"Renaming {maca_dir} to {dicp_triton_dir}") + shutil.move(maca_dir, dicp_triton_dir) + else: + assert False, f"Source directory {maca_dir} does not exist, please check the path." + + +# pip uninstall triton -y +def uninstall_triton_package(): + try: + subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "triton", "-y"]) + print("Triton package uninstalled successfully.") + except subprocess.CalledProcessError: + print("Failed to uninstall Triton package.") + assert False, "Failed to uninstall Triton package." + +# 清理旧的 egg 文件 +def clean_old_egg(): + site_packages = site.getsitepackages()[0] + egg_pattern = os.path.join(site_packages, "triton-*.egg") + egg_files = glob.glob(egg_pattern) + print(f"Found old egg files: {egg_files}") + + for egg_file in egg_files: + print(f"Removing old egg file: {egg_file}") + if os.path.isdir(egg_file): + shutil.rmtree(egg_file) + else: + os.remove(egg_file) + +# 清理旧的triton目录 +def clean_old_triton(): + site_packages = site.getsitepackages()[0] + triton_pattern = os.path.join(site_packages, "triton") + triton_files = glob.glob(triton_pattern) + print(f"Found old triton files: {triton_files}") + + for triton_file in triton_files: + print(f"Removing old triton file: {triton_file}") + if os.path.isdir(triton_file): + shutil.rmtree(triton_file) + else: + os.remove(triton_file) + +def clean_build_artifacts(): + # 清理构建目录 + build_dirs = ["build", "dist"] + for dir_name in build_dirs: + if os.path.exists(dir_name): + print(f"Removing {dir_name} directory") + shutil.rmtree(dir_name) + + # 清理 egg-info 目录 + egg_info_dirs = glob.glob("*.egg-info") + for dir_name in egg_info_dirs: + print(f"Removing {dir_name} directory") + shutil.rmtree(dir_name) + +clean_build_artifacts() +clean_old_egg() + +check_triton_package() +copy_triton_package() +copy_backend_files() +modify_backend_name() +uninstall_triton_package() +clean_old_egg() +clean_old_triton() + +def clean_build_dir(): + build_dir = os.path.join(os.getcwd(), "build") + if os.path.exists(build_dir): + print(f"Cleaning build directory: {build_dir}") + shutil.rmtree(build_dir) + +class CustomBuildPy(build_py): + def run(self): + super().run() + + self.copy_files("libtriton*.so", "triton/_C") + self.copy_files("mlir-opt", "triton/backends/dicp_triton/bin") + self.copy_files("ext_maca_mathlib.bc", "triton/backends/dicp_triton/lib") + # self.copy_files("*.h", "triton/_C/include") + # self.copy_files("*.hpp", "triton/_C/include") + # self.copy_files("*.c", "triton/backends/dicp_triton") + + def copy_files(self, pattern, dest_subdir): + dest_dir = os.path.join(self.build_lib, dest_subdir) + os.makedirs(dest_dir, exist_ok=True) + + source_files = glob.glob(os.path.join("**", pattern), recursive=True) + + print(f"[DEBUG] Searching for files matching pattern: {pattern}") + print(f"[DEBUG] Found source files: {source_files}") + + for src_path in source_files: + if not os.path.isfile(src_path): + continue + + dest_path = os.path.join(dest_dir, os.path.basename(src_path)) + + if os.path.abspath(src_path) != os.path.abspath(dest_path): + print(f"Copying {src_path} to {dest_path}") + shutil.copy2(src_path, dest_path) + else: + print(f"Skipping copy (source and destination same): {src_path}") + +packages = setuptools.find_packages(where='.') + +package_data = { + 'triton': [ + '_C/*.so', # 包含所有共享库 + '_C/include/*', # 包含头文件 + 'backends/**/*', # 包含后端文件(包括 .bc) + 'backends/dicp_triton/bin/*', + 'backends/dicp_triton/lib/*.bc', # 显式包含 bitcode 文件 + 'compiler/**/*', + 'language/**/*', + 'ops/**/*', + 'runtime/**/*', + 'tools/**/*', + # 'tutorials/*', + # 'tutorials/**/*', + # '**/*.bc', # 包含所有位置的 bitcode 文件 + # '**/*.h', # 包含所有头文件 + # '**/*.hpp' # 包含所有 C++ 头文件 + ] +} + +clean_build_dir() + +setuptools.setup( + name="triton", + version="0.0.1", + description="A language and compiler for custom Deep Learning operations on MACA backend", + long_description="A language and compiler for custom Deep Learning operations on MACA backend", + long_description_content_type="text/markdown", + url="https://github.com/DeepLink-org/DLCompiler.git", + packages=packages, + package_data=package_data, + include_package_data=True, + # zip_safe=False, + # options={ + # 'bdist_wheel': {'universal': False}, + # # 'bdist_egg': {'zip_safe': lambda: False}, + # }, + python_requires='>=3.10', + cmdclass={ + 'build_py': CustomBuildPy + } +) + +print(f"Python executable: {sys.executable}") +print(f"Install prefix: {sys.prefix}") + +print(f"finish.....")