22
33import operator
44import warnings
5- from collections .abc import Callable
6- from typing import Any
5+
6+ # https://github.com/pylint-dev/pylint/issues/10112
7+ from collections .abc import Callable # pylint: disable=import-error
8+ from typing import ClassVar
79
810from ._lib import _utils
911from ._lib ._compat import (
1214 is_dask_array ,
1315 is_writeable_array ,
1416)
15- from ._lib ._typing import Array , ModuleType
17+ from ._lib ._typing import Array , Index , ModuleType , Untyped
1618
1719__all__ = [
1820 "at" ,
@@ -559,7 +561,7 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
559561_undef = object ()
560562
561563
562- class at :
564+ class at : # pylint: disable=invalid-name
563565 """
564566 Update operations for read-only arrays.
565567
@@ -651,14 +653,14 @@ class at:
651653 """
652654
653655 x : Array
654- idx : Any
655- __slots__ = ("idx" , "x" )
656+ idx : Index
657+ __slots__ : ClassVar [ tuple [ str , str ]] = ("idx" , "x" )
656658
657- def __init__ (self , x : Array , idx : Any = _undef , / ):
659+ def __init__ (self , x : Array , idx : Index = _undef , / ):
658660 self .x = x
659661 self .idx = idx
660662
661- def __getitem__ (self , idx : Any ) -> Any :
663+ def __getitem__ (self , idx : Index ) -> at :
662664 """Allow for the alternate syntax ``at(x)[start:stop:step]``,
663665 which looks prettier than ``at(x, slice(start, stop, step))``
664666 and feels more intuitive coming from the JAX documentation.
@@ -677,8 +679,8 @@ def _common(
677679 copy : bool | None = True ,
678680 xp : ModuleType | None = None ,
679681 _is_update : bool = True ,
680- ** kwargs : Any ,
681- ) -> tuple [Any , None ] | tuple [None , Array ]:
682+ ** kwargs : Untyped ,
683+ ) -> tuple [Untyped , None ] | tuple [None , Array ]:
682684 """Perform common prepocessing.
683685
684686 Returns
@@ -706,11 +708,11 @@ def _common(
706708 if not writeable :
707709 msg = "Cannot modify parameter in place"
708710 raise ValueError (msg )
709- elif copy is None :
711+ elif copy is None : # type: ignore[redundant-expr]
710712 writeable = is_writeable_array (x )
711713 copy = _is_update and not writeable
712714 else :
713- msg = f"Invalid value for copy: { copy !r} " # type: ignore[unreachable]
715+ msg = f"Invalid value for copy: { copy !r} " # type: ignore[unreachable] # pyright: ignore[reportUnreachable]
714716 raise ValueError (msg )
715717
716718 if copy :
@@ -741,7 +743,7 @@ def _common(
741743
742744 return None , x
743745
744- def get (self , ** kwargs : Any ) -> Any :
746+ def get (self , ** kwargs : Untyped ) -> Untyped :
745747 """Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
746748 that the output is either a copy or a view; it also allows passing
747749 keyword arguments to the backend.
@@ -766,7 +768,7 @@ def get(self, **kwargs: Any) -> Any:
766768 assert x is not None
767769 return x [self .idx ]
768770
769- def set (self , y : Array , / , ** kwargs : Any ) -> Array :
771+ def set (self , y : Array , / , ** kwargs : Untyped ) -> Array :
770772 """Apply ``x[idx] = y`` and return the update array"""
771773 res , x = self ._common ("set" , y , ** kwargs )
772774 if res is not None :
@@ -781,7 +783,7 @@ def _iop(
781783 elwise_op : Callable [[Array , Array ], Array ],
782784 y : Array ,
783785 / ,
784- ** kwargs : Any ,
786+ ** kwargs : Untyped ,
785787 ) -> Array :
786788 """x[idx] += y or equivalent in-place operation on a subset of x
787789
@@ -799,33 +801,33 @@ def _iop(
799801 x [self .idx ] = elwise_op (x [self .idx ], y )
800802 return x
801803
802- def add (self , y : Array , / , ** kwargs : Any ) -> Array :
804+ def add (self , y : Array , / , ** kwargs : Untyped ) -> Array :
803805 """Apply ``x[idx] += y`` and return the updated array"""
804806 return self ._iop ("add" , operator .add , y , ** kwargs )
805807
806- def subtract (self , y : Array , / , ** kwargs : Any ) -> Array :
808+ def subtract (self , y : Array , / , ** kwargs : Untyped ) -> Array :
807809 """Apply ``x[idx] -= y`` and return the updated array"""
808810 return self ._iop ("subtract" , operator .sub , y , ** kwargs )
809811
810- def multiply (self , y : Array , / , ** kwargs : Any ) -> Array :
812+ def multiply (self , y : Array , / , ** kwargs : Untyped ) -> Array :
811813 """Apply ``x[idx] *= y`` and return the updated array"""
812814 return self ._iop ("multiply" , operator .mul , y , ** kwargs )
813815
814- def divide (self , y : Array , / , ** kwargs : Any ) -> Array :
816+ def divide (self , y : Array , / , ** kwargs : Untyped ) -> Array :
815817 """Apply ``x[idx] /= y`` and return the updated array"""
816818 return self ._iop ("divide" , operator .truediv , y , ** kwargs )
817819
818- def power (self , y : Array , / , ** kwargs : Any ) -> Array :
820+ def power (self , y : Array , / , ** kwargs : Untyped ) -> Array :
819821 """Apply ``x[idx] **= y`` and return the updated array"""
820822 return self ._iop ("power" , operator .pow , y , ** kwargs )
821823
822- def min (self , y : Array , / , ** kwargs : Any ) -> Array :
824+ def min (self , y : Array , / , ** kwargs : Untyped ) -> Array :
823825 """Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
824826 xp = array_namespace (self .x )
825827 y = xp .asarray (y )
826828 return self ._iop ("min" , xp .minimum , y , ** kwargs )
827829
828- def max (self , y : Array , / , ** kwargs : Any ) -> Array :
830+ def max (self , y : Array , / , ** kwargs : Untyped ) -> Array :
829831 """Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
830832 xp = array_namespace (self .x )
831833 y = xp .asarray (y )
0 commit comments