@@ -170,78 +170,58 @@ def default_dtypes(self, *, device=None):
170170            "indexing" : default_integral ,
171171        }
172172
173- 
174173    def  _dtypes (self , kind ):
175-         bool  =  torch .bool 
176-         int8  =  torch .int8 
177-         int16  =  torch .int16 
178-         int32  =  torch .int32 
179-         int64  =  torch .int64 
180-         uint8  =  torch .uint8 
181-         # uint16, uint32, and uint64 are present in newer versions of pytorch, 
182-         # but they aren't generally supported by the array API functions, so 
183-         # we omit them from this function. 
184-         float32  =  torch .float32 
185-         float64  =  torch .float64 
186-         complex64  =  torch .complex64 
187-         complex128  =  torch .complex128 
188- 
189174        if  kind  is  None :
190-             return  {
191-                 "bool" : bool ,
192-                 "int8" : int8 ,
193-                 "int16" : int16 ,
194-                 "int32" : int32 ,
195-                 "int64" : int64 ,
196-                 "uint8" : uint8 ,
197-                 "float32" : float32 ,
198-                 "float64" : float64 ,
199-                 "complex64" : complex64 ,
200-                 "complex128" : complex128 ,
201-             }
175+             return  self ._dtypes (
176+                 (
177+                     "bool" ,
178+                     "signed integer" ,
179+                     "unsigned integer" ,
180+                     "real floating" ,
181+                     "complex floating" ,
182+                 )
183+             )
202184        if  kind  ==  "bool" :
203-             return  {"bool" : bool }
185+             return  {"bool" : torch . bool }
204186        if  kind  ==  "signed integer" :
205187            return  {
206-                 "int8" : int8 ,
207-                 "int16" : int16 ,
208-                 "int32" : int32 ,
209-                 "int64" : int64 ,
188+                 "int8" : torch . int8 ,
189+                 "int16" : torch . int16 ,
190+                 "int32" : torch . int32 ,
191+                 "int64" : torch . int64 ,
210192            }
211193        if  kind  ==  "unsigned integer" :
212-             return  {
213-                 "uint8" : uint8 ,
214-             }
194+             try :
195+                 # torch >=2.3 
196+                 return  {
197+                     "uint8" : torch .uint8 ,
198+                     "uint16" : torch .uint16 ,
199+                     "uint32" : torch .uint32 ,
200+                     "uint64" : torch .uint32 ,
201+                 }
202+             except  AttributeError :
203+                 return  {"uint8" : torch .uint8 }
215204        if  kind  ==  "integral" :
216-             return  {
217-                 "int8" : int8 ,
218-                 "int16" : int16 ,
219-                 "int32" : int32 ,
220-                 "int64" : int64 ,
221-                 "uint8" : uint8 ,
222-             }
205+             return  self ._dtypes (("signed integer" , "unsigned integer" ))
223206        if  kind  ==  "real floating" :
224207            return  {
225-                 "float32" : float32 ,
226-                 "float64" : float64 ,
208+                 "float32" : torch . float32 ,
209+                 "float64" : torch . float64 ,
227210            }
228211        if  kind  ==  "complex floating" :
229212            return  {
230-                 "complex64" : complex64 ,
231-                 "complex128" : complex128 ,
213+                 "complex64" : torch . complex64 ,
214+                 "complex128" : torch . complex128 ,
232215            }
233216        if  kind  ==  "numeric" :
234-             return  {
235-                 "int8" : int8 ,
236-                 "int16" : int16 ,
237-                 "int32" : int32 ,
238-                 "int64" : int64 ,
239-                 "uint8" : uint8 ,
240-                 "float32" : float32 ,
241-                 "float64" : float64 ,
242-                 "complex64" : complex64 ,
243-                 "complex128" : complex128 ,
244-             }
217+             return  self ._dtypes (
218+                 (
219+                     "signed integer" ,
220+                     "unsigned integer" ,
221+                     "real floating" ,
222+                     "complex floating" ,
223+                 )
224+             )
245225        if  isinstance (kind , tuple ):
246226            res  =  {}
247227            for  k  in  kind :
@@ -261,7 +241,6 @@ def dtypes(self, *, device=None, kind=None):
261241        ---------- 
262242        device : Device, optional 
263243            The device to get the data types for. 
264-             Unused for PyTorch, as all devices use the same dtypes. 
265244        kind : str or tuple of str, optional 
266245            The kind of data types to return. If ``None``, all data types are 
267246            returned. If a string, only data types of that kind are returned. 
0 commit comments