Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,9 @@ bash compile_on_mlu.sh
cd build/triton/tutorials
python 01-vector-add.py
```

## 基于沐曦芯片

```bash
bash compile_on_maca.sh
```
64 changes: 47 additions & 17 deletions backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(self, target: str) -> None:
assert isinstance(self.capability, int)
self.binary_ext = "cnbin"
elif self.driver.target == "maca":
self.capability = 80
self.binary_ext = "mcfatbin"
elif self.driver.target == "ascend":
self.binary_ext = "npubin"
Expand Down Expand Up @@ -163,7 +164,7 @@ def get_attrs_descriptor(self, params, args):
f"backend {self.driver.target} not supported for get_attrs_descriptor."
)

def add_stages(self, stages, options, language):
def add_stages(self, stages, options, language=None):
if self.driver.target not in ["ascend", "mlu"]:
stages["ttir"] = lambda src, metadata: self.make_ttir(
src, metadata, options
Expand Down Expand Up @@ -210,24 +211,25 @@ def add_stages(self, stages, options, language):
# stages["cnbin"] = lambda src, metadata: ttir_to_cnfatbin(src, metadata, get_architecture_descriptor(self.driver, options), False, True)
elif self.driver.target == "maca":
from triton.backends.dicp_triton.maca import (
ttir_to_ttgir,
optimize_ttgir,
ttgir_to_llir,
llir_to_mcfatbin,
get_architecture_descriptor,
make_ttir,
make_ttgir,
make_mlir,
make_llir,
make_mcfatbin,
)

arch = get_architecture_descriptor()
extern_libs = dict()
stages["ttgir"] = lambda src, metadata: optimize_ttgir(
ttir_to_ttgir(src, 4), options.num_stages, arch
stages["ttir"] = lambda src, metadata: make_ttir(src, metadata, options)
stages["ttgir"] = lambda src, metadata: make_ttgir(
src, metadata, options, self.capability
)
stages["llir"] = lambda src, metadata: ttgir_to_llir(src, arch)
mxcc_arch = os.environ.get("MACA_PATH") + "/mxgpu_llvm/bin/mxcc"
if mxcc_arch is None:
raise RuntimeError("mxcc_arch is None (not specified)")
stages["mcfatbin"] = lambda src, metadata: llir_to_mcfatbin(
src, mxcc_arch, os.environ.get("MACA_PATH")
stages["mlir"] = lambda src, metadata: make_mlir(
src, metadata, options, self.capability
)
stages["llir"] = lambda src, metadata: make_llir(
src, metadata, options, self.capability
)
stages["mcfatbin"] = lambda src, metadata: make_mcfatbin(
src, metadata, options, self.capability
)
elif self.driver.target == "ascend":
from triton.backends.dicp_triton.npu import (
Expand Down Expand Up @@ -329,6 +331,24 @@ def parse_options(self, options: dict) -> Any:
os.getenv("TRITON_ENABLE_MLU_BOUND_CHECK", "0") == "1"
)
return MLUOptions(**args)
elif self.target.backend == "maca":
from triton.backends.dicp_triton.maca import MACAOptions

# args = {k: options[k] for k in MACAOptions.__dataclass_fields__.keys() if k in options}
# return MACAOptions(**args)
args = {
k: options[k]
for k in MACAOptions.__dataclass_fields__.keys()
if k in options
}
# USE_MACA: support allow_fp8e4nv(i.e. float8_e4m3fn)
args["allow_fp8e4nv"] = True
# args["allow_fp8e4nv"] = False
args["allow_fp8e4b15"] = False
args["max_num_imprecise_acc_default"] = (
2**30 if self.capability == 90 else 0
)
return MACAOptions(**args)
else:
args = {"arch": self.target}
args.update(
Expand All @@ -340,7 +360,7 @@ def parse_options(self, options: dict) -> Any:
)
return DICPOptions(**args)

def get_codegen_implementation(self, options):
def get_codegen_implementation(self, options=None):
codegen_fns = dict()
if self.target.backend == "ascend":
from triton.backends.dicp_triton.npu import min_dot_size
Expand All @@ -353,6 +373,16 @@ def get_codegen_implementation(self, options):
"convert_custom_types": lambda arg, dst_ty: arg,
"min_dot_size": min_dot_size(self.target),
}
elif self.target.backend == "maca":
import triton.language.extra.cuda as cuda

codegen_fns = {
"convert_custom_types": (
cuda.convert_custom_float8_sm80
if self.capability >= 80
else cuda.convert_custom_float8_sm70
)
}
return codegen_fns

def pack_metadata(self, metadata):
Expand Down
11 changes: 10 additions & 1 deletion backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,17 @@ def test_npucompiler():
reset = "\x1b[0m"
warnings.warn(red + str(e_npucompiler) + reset)
return False
elif self.target == "muxi":
import torch

return True
except Exception as e:
raise RuntimeError(f"dicp triton exception:{e}")
try:
import torch

return True
except Exception as e:
raise RuntimeError(f"dicp triton exception:{e}")
return True

def launch_as_union_task(self, device, grid):
Expand Down
Loading