@@ -623,8 +623,10 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
623623
624624 You may use two alternate syntaxes::
625625
626- at(x, idx).set(value) # or add(value), etc.
627- at(x)[idx].set(value)
626+ >>> import array_api_extra as xpx
627+ >>> xpx.at(x, idx).set(value) # or add(value), etc.
628+ >>> xpx.at(x)[idx].set(value)
629+
628630 copy : bool, optional
629631 True (default)
630632 Ensure that the inputs are not modified.
@@ -647,14 +649,15 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
647649 (a) When you use ``copy=None``, you should always immediately overwrite
648650 the parameter array::
649651
650- x = at(x, 0).set(2, copy=None)
652+ >>> import array_api_extra as xpx
653+ >>> x = xpx.at(x, 0).set(2, copy=None)
651654
652655 The anti-pattern below must be avoided, as it will result in different
653656 behaviour on read-only versus writeable arrays::
654657
655- x = xp.asarray([0, 0, 0])
656- y = at(x, 0).set(2, copy=None)
657- z = at(x, 1).set(3, copy=None)
658+ >>> x = xp.asarray([0, 0, 0])
659+ >>> y = xpx. at(x, 0).set(2, copy=None)
660+ >>> z = xpx. at(x, 1).set(3, copy=None)
658661
659662 In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
660663 when ``x`` is read-only, whereas ``x == y == z == [2, 3, 0]`` when ``x`` is
@@ -667,9 +670,10 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
667670
668671 >>> import numpy as np
669672 >>> import jax.numpy as jnp
670- >>> at(np.asarray([123]), np.asarray([0, 0])).add(1)
673+ >>> import array_api_extra as xpx
674+ >>> xpx.at(np.asarray([123]), np.asarray([0, 0])).add(1)
671675 array([124])
672- >>> at(jnp.asarray([123]), jnp.asarray([0, 0])).add(1)
676+ >>> xpx. at(jnp.asarray([123]), jnp.asarray([0, 0])).add(1)
673677 Array([125], dtype=int32)
674678
675679 See Also
@@ -686,21 +690,22 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
686690 --------
687691 Given either of these equivalent expressions::
688692
689- x = at(x)[1].add(2, copy=None)
690- x = at(x, 1).add(2, copy=None)
693+ >>> import array_api_extra as xpx
694+ >>> x = xpx.at(x)[1].add(2, copy=None)
695+ >>> x = xpx.at(x, 1).add(2, copy=None)
691696
692697 If x is a JAX array, they are the same as::
693698
694- x = x.at[1].add(2)
699+ >>> x = x.at[1].add(2)
695700
696701 If x is a read-only numpy array, they are the same as::
697702
698- x = x.copy()
699- x[1] += 2
703+ >>> x = x.copy()
704+ >>> x[1] += 2
700705
701706 For other known backends, they are the same as::
702707
703- x[1] += 2
708+ >>> x[1] += 2
704709 """
705710
706711 _x : Array
0 commit comments