Skip to content

Commit e719cd7

Browse files
committed
Add get_num_sm for cpu
stack-info: PR: #1056, branch: oulgen/stack/165
1 parent 1d06772 commit e719cd7

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

helion/runtime/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import contextvars
4+
import os
45
from typing import TYPE_CHECKING
56

67
import torch
@@ -47,7 +48,13 @@ def get_num_sm(device: torch.device) -> int:
4748
Returns:
4849
Grid size to use for a persistent kernel on the device.
4950
"""
50-
assert device.type in ["cuda", "xpu"], "TODO: implement for other devices"
51+
assert device.type in ["cuda", "xpu", "cpu"], "TODO: implement for other devices"
52+
if device.type == "cpu":
53+
try:
54+
num_threads = int(torch.get_num_threads())
55+
except Exception:
56+
num_threads = 0
57+
return num_threads if num_threads > 0 else int(os.cpu_count() or 1)
5158
if device.type == "cuda":
5259
return torch.cuda.get_device_properties(device.index).multi_processor_count
5360
# TODO(EikanWang): gpu_subslice_count is an out-of-date term. we change update it to XeCore number.

0 commit comments

Comments
 (0)