Skip to content

Commit 1d06772

Browse files
committed
Fix caching for CPUs
stack-info: PR: #1055, branch: oulgen/stack/164
1 parent 0bafd91 commit 1d06772

File tree

1 file changed

+27
-9
lines changed

1 file changed

+27
-9
lines changed

helion/autotuner/local_cache.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77
import os
88
from pathlib import Path
9+
import platform
910
import textwrap
1011
from typing import TYPE_CHECKING
1112
import uuid
@@ -58,17 +59,34 @@ def _generate_key(self) -> LooseAutotuneCacheKey:
5859

5960
for arg in self.args:
6061
if isinstance(arg, torch.Tensor):
61-
nms = torch.xpu if torch.xpu.is_available() else torch.cuda
62-
device_properties = nms.get_device_properties(arg.device)
63-
if torch.version.cuda is not None: # pyright: ignore[reportAttributeAccessIssue]
64-
hardware = device_properties.name
65-
runtime_name = str(torch.version.cuda)
66-
elif torch.version.hip is not None: # pyright: ignore[reportAttributeAccessIssue]
67-
hardware = device_properties.gcnArchName
68-
runtime_name = torch.version.hip # pyright: ignore[reportAttributeAccessIssue]
69-
else:
62+
dev = arg.device
63+
# CPU support
64+
if dev.type == "cpu":
65+
hardware = "cpu"
66+
runtime_name = platform.machine().lower()
67+
break
68+
69+
# XPU (Intel) path
70+
if (
71+
dev.type == "xpu"
72+
and getattr(torch, "xpu", None) is not None
73+
and torch.xpu.is_available()
74+
): # pyright: ignore[reportAttributeAccessIssue]
75+
device_properties = torch.xpu.get_device_properties(dev)
7076
hardware = device_properties.name
7177
runtime_name = device_properties.driver_version # pyright: ignore[reportAttributeAccessIssue]
78+
break
79+
80+
# CUDA/ROCm path
81+
if dev.type == "cuda" and torch.cuda.is_available():
82+
device_properties = torch.cuda.get_device_properties(dev)
83+
if torch.version.cuda is not None: # pyright: ignore[reportAttributeAccessIssue]
84+
hardware = device_properties.name
85+
runtime_name = str(torch.version.cuda)
86+
elif torch.version.hip is not None: # pyright: ignore[reportAttributeAccessIssue]
87+
hardware = device_properties.gcnArchName
88+
runtime_name = torch.version.hip # pyright: ignore[reportAttributeAccessIssue]
89+
break
7290

7391
assert hardware is not None and runtime_name is not None
7492
return LooseAutotuneCacheKey(

0 commit comments

Comments
 (0)