|
6 | 6 |
|
7 | 7 | # https://github.com/pylint-dev/pylint/issues/10112 |
8 | 8 | from collections.abc import Callable # pylint: disable=import-error |
9 | | -from typing import ClassVar |
| 9 | +from typing import Literal |
10 | 10 |
|
11 | 11 | from ._lib import _utils |
12 | 12 | from ._lib._compat import ( |
@@ -657,13 +657,13 @@ class at: # pylint: disable=invalid-name |
657 | 657 |
|
658 | 658 | x: Array |
659 | 659 | idx: Index |
660 | | - __slots__: ClassVar[tuple[str, str]] = ("idx", "x") |
| 660 | + __slots__ = ("idx", "x") |
661 | 661 |
|
662 | | - def __init__(self, x: Array, idx: Index = _undef, /): |
| 662 | + def __init__(self, x: Array, idx: Index = _undef, /) -> None: |
663 | 663 | self.x = x |
664 | 664 | self.idx = idx |
665 | 665 |
|
666 | | - def __getitem__(self, idx: Index) -> at: |
| 666 | + def __getitem__(self, idx: Index, /) -> at: |
667 | 667 | """Allow for the alternate syntax ``at(x)[start:stop:step]``, |
668 | 668 | which looks prettier than ``at(x, slice(start, stop, step))`` |
669 | 669 | and feels more intuitive coming from the JAX documentation. |
@@ -704,19 +704,16 @@ def _common( |
704 | 704 |
|
705 | 705 | x = self.x |
706 | 706 |
|
707 | | - if copy is True: |
| 707 | + if copy is None: |
| 708 | + writeable = is_writeable_array(x) |
| 709 | + copy = _is_update and not writeable |
| 710 | + elif copy: |
708 | 711 | writeable = None |
709 | | - elif copy is False: |
| 712 | + else: |
710 | 713 | writeable = is_writeable_array(x) |
711 | 714 | if not writeable: |
712 | 715 | msg = "Cannot modify parameter in place" |
713 | 716 | raise ValueError(msg) |
714 | | - elif copy is None: # type: ignore[redundant-expr] |
715 | | - writeable = is_writeable_array(x) |
716 | | - copy = _is_update and not writeable |
717 | | - else: |
718 | | - msg = f"Invalid value for copy: {copy!r}" # type: ignore[unreachable] # pyright: ignore[reportUnreachable] |
719 | | - raise ValueError(msg) |
720 | 717 |
|
721 | 718 | if copy: |
722 | 719 | try: |
@@ -782,7 +779,9 @@ def set(self, y: Array, /, **kwargs: Untyped) -> Array: |
782 | 779 |
|
783 | 780 | def _iop( |
784 | 781 | self, |
785 | | - at_op: str, |
| 782 | + at_op: Literal[ |
| 783 | + "set", "add", "subtract", "multiply", "divide", "power", "min", "max" |
| 784 | + ], |
786 | 785 | elwise_op: Callable[[Array, Array], Array], |
787 | 786 | y: Array, |
788 | 787 | /, |
|
0 commit comments