|
34 | 34 | from .runtime.kernel import Kernel |
35 | 35 |
|
36 | 36 |
|
37 | | -DEVICE = torch.device("xpu") if torch.xpu.is_available() else torch.device("cuda") |
38 | | -PROJECT_ROOT: Path = Path(__file__).parent.parent |
39 | | -EXAMPLES_DIR: Path = PROJECT_ROOT / "examples" |
| 37 | +def _get_triton_backend() -> str | None: |
| 38 | + try: |
| 39 | + return triton.runtime.driver.active.get_current_target().backend # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] |
| 40 | + except Exception: |
| 41 | + return None |
40 | 42 |
|
41 | 43 |
|
42 | | -def is_cuda() -> bool: |
43 | | - """Return True if running on CUDA (NVIDIA GPU).""" |
| 44 | +def is_cpu() -> bool: |
| 45 | + """Return True if running on Triton CPU backend.""" |
44 | 46 | return ( |
45 | | - triton.runtime.driver.active.get_current_target().backend == "cuda" # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] |
46 | | - and DEVICE.type == "cuda" |
| 47 | + os.environ.get("TRITON_CPU_BACKEND", "0") == "1" |
| 48 | + or _get_triton_backend() == "cpu" |
47 | 49 | ) |
48 | 50 |
|
49 | 51 |
|
| 52 | +def is_cuda() -> bool: |
| 53 | + """Return True if running on CUDA (NVIDIA GPU).""" |
| 54 | + return _get_triton_backend() == "cuda" and torch.cuda.is_available() |
| 55 | + |
| 56 | + |
| 57 | +PROJECT_ROOT: Path = Path(__file__).parent.parent |
| 58 | +EXAMPLES_DIR: Path = PROJECT_ROOT / "examples" |
| 59 | + |
| 60 | +if is_cpu(): |
| 61 | + DEVICE = torch.device("cpu") |
| 62 | +elif torch.xpu.is_available(): |
| 63 | + DEVICE = torch.device("xpu") |
| 64 | +else: |
| 65 | + DEVICE = torch.device("cuda") |
| 66 | + |
| 67 | + |
50 | 68 | def get_nvidia_gpu_model() -> str: |
51 | 69 | """ |
52 | 70 | Retrieves the model of the NVIDIA GPU being used. |
@@ -80,6 +98,11 @@ def skipIfXPU(reason: str) -> Callable[[Callable], Callable]: |
80 | 98 | return unittest.skipIf(torch.xpu.is_available(), reason) # pyright: ignore[reportAttributeAccessIssue] |
81 | 99 |
|
82 | 100 |
|
| 101 | +def skipIfCpu(reason: str) -> Callable[[Callable], Callable]: |
| 102 | + """Skip test if running on Triton CPU backend.""" |
| 103 | + return unittest.skipIf(is_cpu(), reason) |
| 104 | + |
| 105 | + |
83 | 106 | def skipIfA10G(reason: str) -> Callable[[Callable], Callable]: |
84 | 107 | """Skip test if running on A10G GPU""" |
85 | 108 | gpu_model = get_nvidia_gpu_model() |
|
0 commit comments