|
6 | 6 | import logging |
7 | 7 | import os |
8 | 8 | from pathlib import Path |
| 9 | +import platform |
9 | 10 | import textwrap |
10 | 11 | from typing import TYPE_CHECKING |
11 | 12 | import uuid |
@@ -58,17 +59,34 @@ def _generate_key(self) -> LooseAutotuneCacheKey: |
58 | 59 |
|
59 | 60 | for arg in self.args: |
60 | 61 | if isinstance(arg, torch.Tensor): |
61 | | - nms = torch.xpu if torch.xpu.is_available() else torch.cuda |
62 | | - device_properties = nms.get_device_properties(arg.device) |
63 | | - if torch.version.cuda is not None: # pyright: ignore[reportAttributeAccessIssue] |
64 | | - hardware = device_properties.name |
65 | | - runtime_name = str(torch.version.cuda) |
66 | | - elif torch.version.hip is not None: # pyright: ignore[reportAttributeAccessIssue] |
67 | | - hardware = device_properties.gcnArchName |
68 | | - runtime_name = torch.version.hip # pyright: ignore[reportAttributeAccessIssue] |
69 | | - else: |
| 62 | + dev = arg.device |
| 63 | + # CPU support |
| 64 | + if dev.type == "cpu": |
| 65 | + hardware = "cpu" |
| 66 | + runtime_name = platform.machine().lower() |
| 67 | + break |
| 68 | + |
| 69 | + # XPU (Intel) path |
| 70 | + if ( |
| 71 | + dev.type == "xpu" |
| 72 | + and getattr(torch, "xpu", None) is not None |
| 73 | + and torch.xpu.is_available() |
| 74 | + ): # pyright: ignore[reportAttributeAccessIssue] |
| 75 | + device_properties = torch.xpu.get_device_properties(dev) |
70 | 76 | hardware = device_properties.name |
71 | 77 | runtime_name = device_properties.driver_version # pyright: ignore[reportAttributeAccessIssue] |
| 78 | + break |
| 79 | + |
| 80 | + # CUDA/ROCm path |
| 81 | + if dev.type == "cuda" and torch.cuda.is_available(): |
| 82 | + device_properties = torch.cuda.get_device_properties(dev) |
| 83 | + if torch.version.cuda is not None: # pyright: ignore[reportAttributeAccessIssue] |
| 84 | + hardware = device_properties.name |
| 85 | + runtime_name = str(torch.version.cuda) |
| 86 | + elif torch.version.hip is not None: # pyright: ignore[reportAttributeAccessIssue] |
| 87 | + hardware = device_properties.gcnArchName |
| 88 | + runtime_name = torch.version.hip # pyright: ignore[reportAttributeAccessIssue] |
| 89 | + break |
72 | 90 |
|
73 | 91 | assert hardware is not None and runtime_name is not None |
74 | 92 | return LooseAutotuneCacheKey( |
|
0 commit comments