File tree Expand file tree Collapse file tree 1 file changed +8
-1
lines changed Expand file tree Collapse file tree 1 file changed +8
-1
lines changed Original file line number Diff line number Diff line change 11from __future__ import annotations
22
33import contextvars
4+ import os
45from typing import TYPE_CHECKING
56
67import torch
@@ -47,7 +48,13 @@ def get_num_sm(device: torch.device) -> int:
4748 Returns:
4849 Grid size to use for a persistent kernel on the device.
4950 """
50- assert device .type in ["cuda" , "xpu" ], "TODO: implement for other devices"
51+ assert device .type in ["cuda" , "xpu" , "cpu" ], "TODO: implement for other devices"
52+ if device .type == "cpu" :
53+ try :
54+ num_threads = int (torch .get_num_threads ())
55+ except Exception :
56+ num_threads = 0
57+ return num_threads if num_threads > 0 else int (os .cpu_count () or 1 )
5158 if device .type == "cuda" :
5259 return torch .cuda .get_device_properties (device .index ).multi_processor_count
5360 # TODO(EikanWang): gpu_subslice_count is an out-of-date term. we change update it to XeCore number.
You can’t perform that action at this time.
0 commit comments