File tree Expand file tree Collapse file tree 3 files changed +41
-9
lines changed Expand file tree Collapse file tree 3 files changed +41
-9
lines changed Original file line number Diff line number Diff line change 7171 "container-options" : " --device=/dev/kfd --device=/dev/dri" ,
7272 "pytorch-version" : " pytorch-nightly" ,
7373 "alias" : " mi325x"
74+ },
75+ {
76+ "runner" : " linux.g5.4xlarge.nvidia.gpu" ,
77+ "python-version" : " 3.12" ,
78+ "ref-eager" : false ,
79+ "image" : " nvidia/cuda:12.8.1-devel-ubuntu24.04" ,
80+ "runtime-version" : " cpu" ,
81+ "container-options" : " --gpus all" ,
82+ "pytorch-version" : " pytorch-nightly" ,
83+ "alias" : " cpu"
7484 }
7585 ]
7686}
Original file line number Diff line number Diff line change 9797 fi
9898
9999 - name : Install Triton
100- if : steps.cache.outputs.cache-hit != 'true' && matrix.pytorch-version != 'pytorch-2.9'
100+ if : steps.cache.outputs.cache-hit != 'true' && ( matrix.pytorch-version != 'pytorch-2.9' || contains(matrix.alias, 'cpu'))
101101 run : |
102102 set -x
103103 source .venv/bin/activate
@@ -110,7 +110,14 @@ jobs:
110110 cd /tmp/$USER
111111 uv pip uninstall triton pytorch-triton || true
112112 rm -rf triton/ || true
113- git clone https://github.com/triton-lang/triton.git
113+ if [[ "${{ matrix.alias }}" == *cpu* ]]; then
114+ REPO_URL="https://github.com/triton-lang/triton-cpu.git"
115+ REPO_BRANCH="main-merged"
116+ else
117+ REPO_URL="https://github.com/triton-lang/triton.git"
118+ REPO_BRANCH="main"
119+ fi
120+ git clone -b "$REPO_BRANCH" "$REPO_URL" triton
114121 cd triton/
115122 uv pip install -r python/requirements.txt
116123 MAX_JOBS=$(nproc) TRITON_PARALLEL_LINK_JOBS=2 uv pip install .
Original file line number Diff line number Diff line change 3434 from .runtime .kernel import Kernel
3535
3636
37- DEVICE = torch .device ("xpu" ) if torch .xpu .is_available () else torch .device ("cuda" )
38- PROJECT_ROOT : Path = Path (__file__ ).parent .parent
39- EXAMPLES_DIR : Path = PROJECT_ROOT / "examples"
37+ def _get_triton_backend () -> str | None :
38+ try :
39+ return triton .runtime .driver .active .get_current_target ().backend # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
40+ except Exception :
41+ return None
42+
43+
44+ def is_cpu () -> bool :
45+ """Return True if running on Triton CPU backend."""
46+ return _get_triton_backend () == "cpu"
4047
4148
4249def is_cuda () -> bool :
4350 """Return True if running on CUDA (NVIDIA GPU)."""
44- return (
45- triton .runtime .driver .active .get_current_target ().backend == "cuda" # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
46- and DEVICE .type == "cuda"
47- )
51+ return _get_triton_backend () == "cuda" and torch .cuda .is_available ()
52+
53+
54+ PROJECT_ROOT : Path = Path (__file__ ).parent .parent
55+ EXAMPLES_DIR : Path = PROJECT_ROOT / "examples"
56+
57+ if is_cpu ():
58+ DEVICE = torch .device ("cpu" )
59+ elif torch .xpu .is_available ():
60+ DEVICE = torch .device ("xpu" )
61+ else :
62+ DEVICE = torch .device ("cuda" )
4863
4964
5065def get_nvidia_gpu_model () -> str :
You can’t perform that action at this time.
0 commit comments