From 1d067728abb36d40ab8687772917dca66678784a Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Thu, 30 Oct 2025 14:40:01 -0700 Subject: [PATCH] Fix caching for CPUs stack-info: PR: https://github.com/pytorch/helion/pull/1055, branch: oulgen/stack/164 --- helion/autotuner/local_cache.py | 36 ++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/helion/autotuner/local_cache.py b/helion/autotuner/local_cache.py index 1113dffc9..bb4110ab9 100644 --- a/helion/autotuner/local_cache.py +++ b/helion/autotuner/local_cache.py @@ -6,6 +6,7 @@ import logging import os from pathlib import Path +import platform import textwrap from typing import TYPE_CHECKING import uuid @@ -58,17 +59,34 @@ def _generate_key(self) -> LooseAutotuneCacheKey: for arg in self.args: if isinstance(arg, torch.Tensor): - nms = torch.xpu if torch.xpu.is_available() else torch.cuda - device_properties = nms.get_device_properties(arg.device) - if torch.version.cuda is not None: # pyright: ignore[reportAttributeAccessIssue] - hardware = device_properties.name - runtime_name = str(torch.version.cuda) - elif torch.version.hip is not None: # pyright: ignore[reportAttributeAccessIssue] - hardware = device_properties.gcnArchName - runtime_name = torch.version.hip # pyright: ignore[reportAttributeAccessIssue] - else: + dev = arg.device + # CPU support + if dev.type == "cpu": + hardware = "cpu" + runtime_name = platform.machine().lower() + break + + # XPU (Intel) path + if ( + dev.type == "xpu" + and getattr(torch, "xpu", None) is not None + and torch.xpu.is_available() + ): # pyright: ignore[reportAttributeAccessIssue] + device_properties = torch.xpu.get_device_properties(dev) hardware = device_properties.name runtime_name = device_properties.driver_version # pyright: ignore[reportAttributeAccessIssue] + break + + # CUDA/ROCm path + if dev.type == "cuda" and torch.cuda.is_available(): + device_properties = torch.cuda.get_device_properties(dev) + if torch.version.cuda is not None: # pyright: ignore[reportAttributeAccessIssue] + hardware = device_properties.name + runtime_name = str(torch.version.cuda) + elif torch.version.hip is not None: # pyright: ignore[reportAttributeAccessIssue] + hardware = device_properties.gcnArchName + runtime_name = torch.version.hip # pyright: ignore[reportAttributeAccessIssue] + break assert hardware is not None and runtime_name is not None return LooseAutotuneCacheKey(