2121try :
2222    # torch >=2.3 
2323    _int_dtypes  |=  {torch .uint16 , torch .uint32 , torch .uint64 }
24+     _HAS_LARGE_UINT  =  True 
2425except  AttributeError :
25-     pass 
26- 
26+     _HAS_LARGE_UINT  =  False 
2727
2828_array_api_dtypes  =  {
2929    torch .bool ,
3434    torch .complex128 ,
3535}
3636
37- _promotion_table   =  {
38-     # bool 
39-     (torch .bool , torch .bool ): torch .bool ,
37+ _promotion_table  =  {
4038    # ints 
41-     (torch .int8 , torch .int8 ): torch .int8 ,
4239    (torch .int8 , torch .int16 ): torch .int16 ,
4340    (torch .int8 , torch .int32 ): torch .int32 ,
4441    (torch .int8 , torch .int64 ): torch .int64 ,
45-     (torch .int16 , torch .int8 ): torch .int16 ,
46-     (torch .int16 , torch .int16 ): torch .int16 ,
4742    (torch .int16 , torch .int32 ): torch .int32 ,
4843    (torch .int16 , torch .int64 ): torch .int64 ,
49-     (torch .int32 , torch .int8 ): torch .int32 ,
50-     (torch .int32 , torch .int16 ): torch .int32 ,
51-     (torch .int32 , torch .int32 ): torch .int32 ,
5244    (torch .int32 , torch .int64 ): torch .int64 ,
53-     (torch .int64 , torch .int8 ): torch .int64 ,
54-     (torch .int64 , torch .int16 ): torch .int64 ,
55-     (torch .int64 , torch .int32 ): torch .int64 ,
56-     (torch .int64 , torch .int64 ): torch .int64 ,
57-     # uints 
58-     (torch .uint8 , torch .uint8 ): torch .uint8 ,
5945    # ints and uints (mixed sign) 
60-     (torch .int8 , torch .uint8 ): torch .int16 ,
61-     (torch .int16 , torch .uint8 ): torch .int16 ,
62-     (torch .int32 , torch .uint8 ): torch .int32 ,
63-     (torch .int64 , torch .uint8 ): torch .int64 ,
6446    (torch .uint8 , torch .int8 ): torch .int16 ,
6547    (torch .uint8 , torch .int16 ): torch .int16 ,
6648    (torch .uint8 , torch .int32 ): torch .int32 ,
6749    (torch .uint8 , torch .int64 ): torch .int64 ,
6850    # floats 
69-     (torch .float32 , torch .float32 ): torch .float32 ,
7051    (torch .float32 , torch .float64 ): torch .float64 ,
71-     (torch .float64 , torch .float32 ): torch .float64 ,
72-     (torch .float64 , torch .float64 ): torch .float64 ,
7352    # complexes 
74-     (torch .complex64 , torch .complex64 ): torch .complex64 ,
7553    (torch .complex64 , torch .complex128 ): torch .complex128 ,
76-     (torch .complex128 , torch .complex64 ): torch .complex128 ,
77-     (torch .complex128 , torch .complex128 ): torch .complex128 ,
7854    # Mixed float and complex 
7955    (torch .float32 , torch .complex64 ): torch .complex64 ,
8056    (torch .float32 , torch .complex128 ): torch .complex128 ,
8157    (torch .float64 , torch .complex64 ): torch .complex128 ,
8258    (torch .float64 , torch .complex128 ): torch .complex128 ,
8359}
8460
61+ if  _HAS_LARGE_UINT :  # torch >=2.3 
62+     _promotion_table .update (
63+         {
64+             # uints 
65+             (torch .uint8 , torch .uint16 ): torch .uint16 ,
66+             (torch .uint8 , torch .uint32 ): torch .uint32 ,
67+             (torch .uint8 , torch .uint64 ): torch .uint64 ,
68+             (torch .uint16 , torch .uint32 ): torch .uint32 ,
69+             (torch .uint16 , torch .uint64 ): torch .uint64 ,
70+             (torch .uint32 , torch .uint64 ): torch .uint64 ,
71+             # ints and uints (mixed sign) 
72+             (torch .uint16 , torch .int8 ): torch .int32 ,
73+             (torch .uint16 , torch .int16 ): torch .int32 ,
74+             (torch .uint16 , torch .int32 ): torch .int32 ,
75+             (torch .uint16 , torch .int64 ): torch .int64 ,
76+             (torch .uint32 , torch .int8 ): torch .int64 ,
77+             (torch .uint32 , torch .int16 ): torch .int64 ,
78+             (torch .uint32 , torch .int32 ): torch .int64 ,
79+             (torch .uint32 , torch .int64 ): torch .int64 ,
80+         }
81+     )
82+ 
83+ _promotion_table .update ({(b , a ): c  for  (a , b ), c  in  _promotion_table .items ()})
84+ _promotion_table .update ({(a , a ): a  for  a  in  _array_api_dtypes })
85+ 
8586
8687def  _two_arg (f ):
8788    @_wraps (f ) 
@@ -275,6 +276,31 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
275276            out  =  torch .unsqueeze (out , a )
276277    return  out 
277278
279+ 
280+ def  _sum_prod_no_axis (x : Array , dtype : DType  |  None ) ->  Array :
281+     """ 
282+     Implements `sum(..., axis=())` and `prod(..., axis=())`. 
283+      
284+     Works around https://github.com/pytorch/pytorch/issues/29137 
285+     """ 
286+     if  dtype  is  not None :
287+         return  x .clone () if  dtype  ==  x .dtype  else  x .to (dtype )
288+ 
289+     if  x .dtype  in  (torch .int8 , torch .int16 , torch .int32 ):
290+         return  x .to (torch .int64 )
291+ 
292+     if  _HAS_LARGE_UINT  and  x .dtype  in  (torch .uint8 , torch .uint16 , torch .uint32 ):
293+         return  x .to (torch .uint64 )
294+ 
295+     if  x .dtype  ==  torch .uint8 :
296+         # We can't upcast uint8 according to the spec because there is no 
297+         # torch.uint64, so at least upcast to int64 which is what prod does 
298+         # when axis=None. 
299+         return  x .to (torch .int64 )
300+ 
301+     return  x .clone ()
302+ 
303+ 
278304def  prod (x : Array ,
279305         / ,
280306         * ,
@@ -283,20 +309,9 @@ def prod(x: Array,
283309         keepdims : bool  =  False ,
284310         ** kwargs ) ->  Array :
285311    x  =  torch .asarray (x )
286-     ndim  =  x .ndim 
287312
288-     # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic 
289-     # below because it still needs to upcast. 
290313    if  axis  ==  ():
291-         if  dtype  is  None :
292-             # We can't upcast uint8 according to the spec because there is no 
293-             # torch.uint64, so at least upcast to int64 which is what sum does 
294-             # when axis=None. 
295-             if  x .dtype  in  [torch .int8 , torch .int16 , torch .int32 , torch .uint8 ]:
296-                 return  x .to (torch .int64 )
297-             return  x .clone ()
298-         return  x .to (dtype )
299- 
314+         return  _sum_prod_no_axis (x , dtype )
300315    # torch.prod doesn't support multiple axes 
301316    # (https://github.com/pytorch/pytorch/issues/56586). 
302317    if  isinstance (axis , tuple ):
@@ -305,7 +320,7 @@ def prod(x: Array,
305320        # torch doesn't support keepdims with axis=None 
306321        # (https://github.com/pytorch/pytorch/issues/71209) 
307322        res  =  torch .prod (x , dtype = dtype , ** kwargs )
308-         res  =  _axis_none_keepdims (res , ndim , keepdims )
323+         res  =  _axis_none_keepdims (res , x . ndim , keepdims )
309324        return  res 
310325
311326    return  torch .prod (x , axis , dtype = dtype , keepdims = keepdims , ** kwargs )
@@ -319,25 +334,14 @@ def sum(x: Array,
319334         keepdims : bool  =  False ,
320335         ** kwargs ) ->  Array :
321336    x  =  torch .asarray (x )
322-     ndim  =  x .ndim 
323337
324-     # https://github.com/pytorch/pytorch/issues/29137. 
325-     # Make sure it upcasts. 
326338    if  axis  ==  ():
327-         if  dtype  is  None :
328-             # We can't upcast uint8 according to the spec because there is no 
329-             # torch.uint64, so at least upcast to int64 which is what sum does 
330-             # when axis=None. 
331-             if  x .dtype  in  [torch .int8 , torch .int16 , torch .int32 , torch .uint8 ]:
332-                 return  x .to (torch .int64 )
333-             return  x .clone ()
334-         return  x .to (dtype )
335- 
339+         return  _sum_prod_no_axis (x , dtype )
336340    if  axis  is  None :
337341        # torch doesn't support keepdims with axis=None 
338342        # (https://github.com/pytorch/pytorch/issues/71209) 
339343        res  =  torch .sum (x , dtype = dtype , ** kwargs )
340-         res  =  _axis_none_keepdims (res , ndim , keepdims )
344+         res  =  _axis_none_keepdims (res , x . ndim , keepdims )
341345        return  res 
342346
343347    return  torch .sum (x , axis , dtype = dtype , keepdims = keepdims , ** kwargs )
0 commit comments