File tree Expand file tree Collapse file tree 1 file changed +8
-1
lines changed Expand file tree Collapse file tree 1 file changed +8
-1
lines changed Original file line number Diff line number Diff line change 11from  __future__ import  annotations 
22
33import  contextvars 
4+ import  os 
45from  typing  import  TYPE_CHECKING 
56
67import  torch 
@@ -47,7 +48,13 @@ def get_num_sm(device: torch.device) -> int:
4748    Returns: 
4849        Grid size to use for a persistent kernel on the device. 
4950    """ 
50-     assert  device .type  in  ["cuda" , "xpu" ], "TODO: implement for other devices" 
51+     assert  device .type  in  ["cuda" , "xpu" , "cpu" ], "TODO: implement for other devices" 
52+     if  device .type  ==  "cpu" :
53+         try :
54+             num_threads  =  int (torch .get_num_threads ())
55+         except  Exception :
56+             num_threads  =  0 
57+         return  num_threads  if  num_threads  >  0  else  int (os .cpu_count () or  1 )
5158    if  device .type  ==  "cuda" :
5259        return  torch .cuda .get_device_properties (device .index ).multi_processor_count 
5360    # TODO(EikanWang): gpu_subslice_count is an out-of-date term. we change update it to XeCore number. 
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments