Skip to content

Commit d4ce200

Browse files
committed
Add best effort triton-cpu support
Fixes #163 stack-info: PR: #1037, branch: oulgen/stack/163
1 parent 2d7c237 commit d4ce200

18 files changed

+89
-10
lines changed

.github/matrix.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,16 @@
7171
"container-options": "--device=/dev/kfd --device=/dev/dri",
7272
"pytorch-version": "pytorch-nightly",
7373
"alias": "mi325x"
74+
},
75+
{
76+
"runner": "linux.g5.4xlarge.nvidia.gpu",
77+
"python-version": "3.12",
78+
"ref-eager": false,
79+
"image": "nvidia/cuda:12.8.1-devel-ubuntu24.04",
80+
"runtime-version": "cpu",
81+
"container-options": "--gpus all",
82+
"pytorch-version": "pytorch-nightly",
83+
"alias": "cpu"
7484
}
7585
]
7686
}

.github/workflows/test.yml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ jobs:
9797
fi
9898
9999
- name: Install Triton
100-
if: steps.cache.outputs.cache-hit != 'true' && matrix.pytorch-version != 'pytorch-2.9'
100+
if: steps.cache.outputs.cache-hit != 'true' && (matrix.pytorch-version != 'pytorch-2.9' || contains(matrix.alias, 'cpu'))
101101
run: |
102102
set -x
103103
source .venv/bin/activate
@@ -110,7 +110,11 @@ jobs:
110110
cd /tmp/$USER
111111
uv pip uninstall triton pytorch-triton || true
112112
rm -rf triton/ || true
113-
git clone https://github.com/triton-lang/triton.git
113+
if [[ "${{ matrix.alias }}" == *cpu* ]]; then
114+
git clone --recursive -b main-merged https://github.com/triton-lang/triton-cpu.git triton
115+
else
116+
git clone https://github.com/triton-lang/triton.git triton
117+
fi
114118
cd triton/
115119
uv pip install -r python/requirements.txt
116120
MAX_JOBS=$(nproc) TRITON_PARALLEL_LINK_JOBS=2 uv pip install .
@@ -131,9 +135,10 @@ jobs:
131135
if [[ "${{ matrix.dtype-asserts }}" == "true" ]]; then export HELION_DEBUG_DTYPE_ASSERTS=1; fi
132136
if [[ "${{ matrix.expecttest-accept }}" == "true" ]]; then export EXPECTTEST_ACCEPT=1; fi
133137
if [[ "${{ matrix.ref-eager }}" == "true" ]]; then export HELION_INTERPRET=1; fi
138+
if [[ "${{ contains(matrix.alias, 'cpu') }}" == "true" ]]; then export TRITON_CPU_BACKEND=1; fi
134139
# -rf: print failed tests
135140
# --timeout: max allowed time for each test
136-
pytest -rf --timeout=60
141+
pytest -rf -vv --timeout=60
137142
138143
test-notebooks:
139144
name: test-notebooks-cu128-py3.12-pytorch-2.9-a10g

helion/_testing.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,37 @@
3434
from .runtime.kernel import Kernel
3535

3636

37-
DEVICE = torch.device("xpu") if torch.xpu.is_available() else torch.device("cuda")
38-
PROJECT_ROOT: Path = Path(__file__).parent.parent
39-
EXAMPLES_DIR: Path = PROJECT_ROOT / "examples"
37+
def _get_triton_backend() -> str | None:
38+
try:
39+
return triton.runtime.driver.active.get_current_target().backend # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
40+
except Exception:
41+
return None
4042

4143

42-
def is_cuda() -> bool:
43-
"""Return True if running on CUDA (NVIDIA GPU)."""
44+
def is_cpu() -> bool:
45+
"""Return True if running on Triton CPU backend."""
4446
return (
45-
triton.runtime.driver.active.get_current_target().backend == "cuda" # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
46-
and DEVICE.type == "cuda"
47+
os.environ.get("TRITON_CPU_BACKEND", "0") == "1"
48+
or _get_triton_backend() == "cpu"
4749
)
4850

4951

52+
def is_cuda() -> bool:
53+
"""Return True if running on CUDA (NVIDIA GPU)."""
54+
return _get_triton_backend() == "cuda" and torch.cuda.is_available()
55+
56+
57+
PROJECT_ROOT: Path = Path(__file__).parent.parent
58+
EXAMPLES_DIR: Path = PROJECT_ROOT / "examples"
59+
60+
if is_cpu():
61+
DEVICE = torch.device("cpu")
62+
elif torch.xpu.is_available():
63+
DEVICE = torch.device("xpu")
64+
else:
65+
DEVICE = torch.device("cuda")
66+
67+
5068
def get_nvidia_gpu_model() -> str:
5169
"""
5270
Retrieves the model of the NVIDIA GPU being used.
@@ -80,6 +98,11 @@ def skipIfXPU(reason: str) -> Callable[[Callable], Callable]:
8098
return unittest.skipIf(torch.xpu.is_available(), reason) # pyright: ignore[reportAttributeAccessIssue]
8199

82100

101+
def skipIfCpu(reason: str) -> Callable[[Callable], Callable]:
102+
"""Skip test if running on Triton CPU backend."""
103+
return unittest.skipIf(is_cpu(), reason)
104+
105+
83106
def skipIfA10G(reason: str) -> Callable[[Callable], Callable]:
84107
"""Skip test if running on A10G GPU"""
85108
gpu_model = get_nvidia_gpu_model()

test/test_autotuner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from helion._testing import RefEagerTestDisabled
2929
from helion._testing import TestCase
3030
from helion._testing import import_path
31+
from helion._testing import skipIfCpu
3132
from helion._testing import skipIfRocm
3233
from helion.autotuner import DifferentialEvolutionSearch
3334
from helion.autotuner import PatternSearch
@@ -347,6 +348,7 @@ def add(a, b):
347348
torch.testing.assert_close(add(*args), sum(args))
348349

349350
@skipIfRocm("too slow on rocm")
351+
@skipIfCpu("TritonError: Error from Triton code")
350352
def test_random_search(self):
351353
args = (
352354
torch.randn([512, 512], device=DEVICE),

test/test_dot.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from helion._testing import TestCase
1515
from helion._testing import code_and_output
1616
from helion._testing import is_cuda
17+
from helion._testing import skipIfCpu
1718
from helion._testing import skipIfRefEager
1819
from helion._testing import skipIfRocm
1920
from helion._testing import skipIfXPU
@@ -293,6 +294,7 @@ def bmm(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
293294

294295
@skipIfRefEager("Debug dtype codegen checks rely on compiled code")
295296
@skipIfXPU("Failed on XPU - https://github.com/pytorch/helion/issues/772")
297+
@skipIfCpu("Failed: Timeout (>10.0s) from pytest-timeout.")
296298
def test_baddbmm_pipeline_debug_dtype_asserts(self):
297299
# Reproduces scripts/repro512.py within the test suite and asserts
298300
# the kernel compiles and runs with debug dtype asserts enabled.
@@ -981,6 +983,12 @@ def test_matmul_reshape_n_2(self):
981983
"float16 accumulator not supported for bf16/f32 in ref eager mode"
982984
)(_test_func)
983985

986+
# CPU backend skip for specific failing dynamic-shape case
987+
if test_name == "test_input_float16_acc_float16_dynamic_shape":
988+
_test_func = skipIfCpu("AssertionError: Tensor-likes are not close!")(
989+
_test_func
990+
)
991+
984992
setattr(TestDot, test_name, _test_func)
985993

986994

test/test_examples.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from helion._testing import check_example
1717
from helion._testing import import_path
1818
from helion._testing import skipIfA10G
19+
from helion._testing import skipIfCpu
1920
from helion._testing import skipIfRefEager
2021
from helion._testing import skipIfRocm
2122
from helion._testing import skipIfXPU
@@ -24,6 +25,7 @@
2425
torch.backends.cudnn.conv.fp32_precision = "tf32"
2526

2627

28+
@skipIfCpu("needs to be debugged")
2729
class TestExamples(RefEagerTestBase, TestCase):
2830
def test_add(self):
2931
args = (

test/test_generate_ast.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from helion._testing import TestCase
1212
from helion._testing import code_and_output
1313
from helion._testing import import_path
14+
from helion._testing import skipIfCpu
1415
from helion._testing import skipIfRefEager
1516
import helion.language as hl
1617

@@ -213,6 +214,7 @@ def test_final_cast_enforced_for_to_dtype(self):
213214
# Ensure codegen emits a final tl.cast(..., tl.bfloat16)
214215
assert "tl.cast" in code and "tl.bfloat16" in code
215216

217+
@skipIfCpu("Failed: Timeout (>10.0s) from pytest-timeout.")
216218
def test_sigmoid_scalar_autocast(self):
217219
@helion.kernel(
218220
config=helion.Config(

test/test_inline_asm_elementwise.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from helion._testing import RefEagerTestDisabled
1111
from helion._testing import TestCase
1212
from helion._testing import code_and_output
13+
from helion._testing import skipIfCpu
1314
from helion._testing import skipIfRocm
1415
import helion.language as hl
1516

@@ -221,6 +222,7 @@ def kernel_empty_args(x: torch.Tensor) -> torch.Tensor:
221222
torch.testing.assert_close(result, expected)
222223

223224
@skipIfRocm("only works on cuda")
225+
@skipIfCpu("RuntimeError: failed to translate module to LLVM IR")
224226
def test_inline_asm_basic_compilation(self):
225227
"""Test that inline_asm_elementwise compiles without errors (no CUDA requirement)"""
226228

test/test_loops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from helion._testing import TestCase
1515
from helion._testing import code_and_output
1616
from helion._testing import import_path
17+
from helion._testing import skipIfCpu
1718
from helion._testing import skipIfLowVRAM
1819
from helion._testing import skipIfRefEager
1920
import helion.language as hl
@@ -372,6 +373,7 @@ def fn(x: torch.Tensor) -> torch.Tensor:
372373
self.assertEqual(spec.min_size, 32)
373374
self.assertEqual(spec.max_size, 256)
374375

376+
@skipIfCpu("Failed: Timeout (>10.0s) from pytest-timeout.")
375377
def test_register_block_size_codegen_size_hint(self):
376378
@helion.kernel(static_shapes=True)
377379
def kernel_fixed_block_size(
@@ -1153,6 +1155,7 @@ def kernel_with_dynamic_fill(
11531155
expected = x + fill_value[0]
11541156
torch.testing.assert_close(result, expected)
11551157

1158+
@skipIfCpu("codegen mismatch on CPU")
11561159
def test_nested_loop_accumulator(self):
11571160
"""Test variable scoping with nested loops and accumulator pattern."""
11581161

@@ -1202,6 +1205,7 @@ def nested_loop_accumulator(x: torch.Tensor) -> torch.Tensor:
12021205
torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5)
12031206
self.assertExpectedJournal(code)
12041207

1208+
@skipIfCpu("codegen mismatch on CPU")
12051209
def test_three_pass_kernel(self):
12061210
"""Test variable scoping with three-pass pattern like layer norm."""
12071211

test/test_masking.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from helion._testing import RefEagerTestBase
1212
from helion._testing import TestCase
1313
from helion._testing import code_and_output
14+
from helion._testing import skipIfCpu
1415
from helion._testing import skipIfRefEager
1516
import helion.language as hl
1617

@@ -42,6 +43,7 @@ def add1mm(x, y):
4243
result, (args[0] + 1) @ (args[1] + 1), rtol=1e-2, atol=1e-1
4344
)
4445

46+
@skipIfCpu("AssertionError: Tensor-likes are not close!")
4547
def test_no_mask_views0(self):
4648
@helion.kernel(config={"block_sizes": [32]})
4749
def fn(x):
@@ -59,6 +61,7 @@ def fn(x):
5961
torch.testing.assert_close(result, args[0].sum(dim=1, keepdim=True))
6062
self.assertNotIn("tl.where", code)
6163

64+
@skipIfCpu("AssertionError: Tensor-likes are not close!")
6265
def test_no_mask_views1(self):
6366
@helion.kernel(config={"block_sizes": [32]})
6467
def fn(x):
@@ -130,6 +133,7 @@ def fn(x):
130133
torch.testing.assert_close(result, (args[0] + 1).sum(dim=1))
131134
self.assertIn("tl.where", code)
132135

136+
@skipIfCpu("AssertionError: Tensor-likes are not close!")
133137
def test_no_mask_inductor_ops(self):
134138
@helion.kernel(config={"block_sizes": [32]})
135139
def fn(x):

0 commit comments

Comments
 (0)