2323    SupportsIndex ,
2424    TypeAlias ,
2525    TypeGuard ,
26-     TypeVar ,
2726    cast ,
2827    overload ,
2928)
3029
3130from  ._typing  import  Array , Device , HasShape , Namespace , SupportsArrayNamespace 
3231
3332if  TYPE_CHECKING :
34- 
33+      import   cupy   as   cp 
3534    import  dask .array  as  da 
3635    import  jax 
3736    import  ndonnx  as  ndx 
3837    import  numpy  as  np 
3938    import  numpy .typing  as  npt 
40-     import  sparse    # pyright: ignore[reportMissingTypeStubs] 
39+     import  sparse 
4140    import  torch 
4241
4342    # TODO: import from typing (requires Python >=3.13) 
44-     from  typing_extensions  import  TypeIs , TypeVar 
45- 
46-     _SizeT  =  TypeVar ("_SizeT" , bound  =  int  |  None )
43+     from  typing_extensions  import  TypeIs 
4744
4845    _ZeroGradientArray : TypeAlias  =  npt .NDArray [np .void ]
49-     _CupyArray : TypeAlias  =  Any   # cupy has no py.typed 
5046
5147    _ArrayApiObj : TypeAlias  =  (
5248        npt .NDArray [Any ]
49+         |  cp .ndarray 
5350        |  da .Array 
5451        |  jax .Array 
5552        |  ndx .Array 
5653        |  sparse .SparseArray 
5754        |  torch .Tensor 
5855        |  SupportsArrayNamespace [Any ]
59-         |  _CupyArray 
6056    )
6157
6258_API_VERSIONS_OLD : Final  =  frozenset ({"2021.12" , "2022.12" , "2023.12" })
@@ -96,7 +92,7 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
9692    return  dtype  ==  jax .float0 
9793
9894
99- def  is_numpy_array (x : object ) ->  TypeGuard [npt .NDArray [Any ]]:
95+ def  is_numpy_array (x : object ) ->  TypeIs [npt .NDArray [Any ]]:
10096    """ 
10197    Return True if `x` is a NumPy array. 
10298
@@ -267,7 +263,7 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
267263    return  _issubclass_fast (cls , "sparse" , "SparseArray" )
268264
269265
270- def  is_array_api_obj (x : object ) ->  TypeIs [_ArrayApiObj ]:   # pyright: ignore[reportUnknownParameterType] 
266+ def  is_array_api_obj (x : object ) ->  TypeGuard [_ArrayApiObj ]:
271267    """ 
272268    Return True if `x` is an array API compatible array object. 
273269
@@ -748,7 +744,7 @@ def device(x: _ArrayApiObj, /) -> Device:
748744        return  "cpu" 
749745    elif  is_dask_array (x ):
750746        # Peek at the metadata of the Dask array to determine type 
751-         if  is_numpy_array (x ._meta ):   # pyright: ignore 
747+         if  is_numpy_array (x ._meta ):
752748            # Must be on CPU since backed by numpy 
753749            return  "cpu" 
754750        return  _DASK_DEVICE 
@@ -777,7 +773,7 @@ def device(x: _ArrayApiObj, /) -> Device:
777773            return  "cpu" 
778774        # Return the device of the constituent array 
779775        return  device (inner )  # pyright: ignore 
780-     return  x .device   # pyright: ignore 
776+     return  x .device   # type: ignore  #  pyright: ignore 
781777
782778
783779# Prevent shadowing, used below 
@@ -786,11 +782,11 @@ def device(x: _ArrayApiObj, /) -> Device:
786782
787783# Based on cupy.array_api.Array.to_device 
788784def  _cupy_to_device (
789-     x : _CupyArray ,
785+     x : cp . ndarray ,
790786    device : Device ,
791787    / ,
792788    stream : int  |  Any  |  None  =  None ,
793- ) ->  _CupyArray :
789+ ) ->  cp . ndarray :
794790    import  cupy  as  cp 
795791
796792    if  device  ==  "cpu" :
@@ -819,7 +815,7 @@ def _torch_to_device(
819815    x : torch .Tensor ,
820816    device : torch .device  |  str  |  int ,
821817    / ,
822-     stream : None  =  None ,
818+     stream : int   |   Any   |   None  =  None ,
823819) ->  torch .Tensor :
824820    if  stream  is  not   None :
825821        raise  NotImplementedError 
@@ -885,7 +881,7 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -
885881        # cupy does not yet have to_device 
886882        return  _cupy_to_device (x , device , stream = stream )
887883    elif  is_torch_array (x ):
888-         return  _torch_to_device (x , device , stream = stream )   # pyright: ignore[reportArgumentType] 
884+         return  _torch_to_device (x , device , stream = stream )
889885    elif  is_dask_array (x ):
890886        if  stream  is  not   None :
891887            raise  ValueError ("The stream argument to to_device() is not supported" )
@@ -912,8 +908,6 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -
912908@overload  
913909def  size (x : HasShape [Collection [SupportsIndex ]]) ->  int : ...
914910@overload  
915- def  size (x : HasShape [Collection [None ]]) ->  None : ...
916- @overload  
917911def  size (x : HasShape [Collection [SupportsIndex  |  None ]]) ->  int  |  None : ...
918912def  size (x : HasShape [Collection [SupportsIndex  |  None ]]) ->  int  |  None :
919913    """ 
@@ -948,7 +942,7 @@ def _is_writeable_cls(cls: type) -> bool | None:
948942    return  None 
949943
950944
951- def  is_writeable_array (x : object ) ->  bool :
945+ def  is_writeable_array (x : object ) ->  TypeGuard [ _ArrayApiObj ] :
952946    """ 
953947    Return False if ``x.__setitem__`` is expected to raise; True otherwise. 
954948    Return False if `x` is not an array API compatible object. 
@@ -986,7 +980,7 @@ def _is_lazy_cls(cls: type) -> bool | None:
986980    return   None 
987981
988982
989- def  is_lazy_array (x : object ) ->  bool :
983+ def  is_lazy_array (x : object ) ->  TypeGuard [ _ArrayApiObj ] :
990984    """Return True if x is potentially a future or it may be otherwise impossible or 
991985    expensive to eagerly read its contents, regardless of their size, e.g. by 
992986    calling ``bool(x)`` or ``float(x)``. 
0 commit comments