1212from typing import NamedTuple
1313import inspect
1414
15- from ._helpers import array_namespace , _check_device , device , is_cupy_namespace
15+ from ._helpers import (
16+ array_namespace ,
17+ _check_device ,
18+ device as _get_device ,
19+ is_cupy_namespace as _is_cupy_namespace
20+ )
1621
1722# These functions are modified from the NumPy versions.
1823
@@ -287,7 +292,7 @@ def cumulative_sum(
287292 initial_shape = list (x .shape )
288293 initial_shape [axis ] = 1
289294 res = xp .concatenate (
290- [wrapped_xp .zeros (shape = initial_shape , dtype = res .dtype , device = device (res )), res ],
295+ [wrapped_xp .zeros (shape = initial_shape , dtype = res .dtype , device = _get_device (res )), res ],
291296 axis = axis ,
292297 )
293298 return res
@@ -317,7 +322,7 @@ def cumulative_prod(
317322 initial_shape = list (x .shape )
318323 initial_shape [axis ] = 1
319324 res = xp .concatenate (
320- [wrapped_xp .ones (shape = initial_shape , dtype = res .dtype , device = device (res )), res ],
325+ [wrapped_xp .ones (shape = initial_shape , dtype = res .dtype , device = _get_device (res )), res ],
321326 axis = axis ,
322327 )
323328 return res
@@ -369,7 +374,7 @@ def _isscalar(a):
369374 if type (max ) is int and max >= wrapped_xp .iinfo (x .dtype ).max :
370375 max = None
371376
372- dev = device (x )
377+ dev = _get_device (x )
373378 if out is None :
374379 out = wrapped_xp .empty (result_shape , dtype = x .dtype , device = dev )
375380 out [()] = x
@@ -567,7 +572,7 @@ def sign(x: ndarray, /, xp, **kwargs) -> ndarray:
567572 out = xp .sign (x , ** kwargs )
568573 # CuPy sign() does not propagate nans. See
569574 # https://github.com/data-apis/array-api-compat/issues/136
570- if is_cupy_namespace (xp ) and isdtype (x .dtype , 'real floating' , xp = xp ):
575+ if _is_cupy_namespace (xp ) and isdtype (x .dtype , 'real floating' , xp = xp ):
571576 out [xp .isnan (x )] = xp .nan
572577 return out [()]
573578
@@ -579,3 +584,5 @@ def sign(x: ndarray, /, xp, **kwargs) -> ndarray:
579584 'reshape' , 'argsort' , 'sort' , 'nonzero' , 'ceil' , 'floor' , 'trunc' ,
580585 'matmul' , 'matrix_transpose' , 'tensordot' , 'vecdot' , 'isdtype' ,
581586 'unstack' , 'sign' ]
587+
588+ _all_ignore = ['inspect' , 'array_namespace' , 'NamedTuple' ]
0 commit comments