@@ -801,6 +801,173 @@ def size(x):
801801 return None
802802 return math .prod (x .shape )
803803
804+ def is_writeable_array (x ):
805+ """
806+ Return False if x.__setitem__ is expected to raise; True otherwise
807+ """
808+ if is_numpy_array (x ):
809+ return x .flags .writeable
810+ if is_jax_array (x ):
811+ return False
812+ return True
813+
814+ _undef = object ()
815+
816+ def at (x , idx = _undef , / ):
817+ """
818+ Update operations for read-only arrays.
819+
820+ This implements ``jax.numpy.ndarray.at`` for all backends.
821+ Writeable arrays may be updated in place; you should not rely on it.
822+
823+ Keyword arguments (e.g. ``indices_are_sorted``) are passed to JAX and are
824+ quietly ignored for backends that don't support them.
825+
826+ Examples
827+ --------
828+ Given either of these equivalent expressions::
829+
830+ x = at(x)[1].add(2)
831+ x = at(x, 1).add(2)
832+
833+ If x is a JAX array, they are the same as::
834+
835+ x = x.at[1].add(x)
836+
837+ If x is a read-only numpy array, they are the same as::
838+
839+ x = x.copy()
840+ x[1] += 2
841+
842+ Otherwise, they are the same as::
843+
844+ x[1] += 2
845+
846+ Warning
847+ -------
848+ You should always immediately overwrite the parameter array::
849+
850+ x = at(x, 0).set(2)
851+
852+ The anti-pattern below must be avoided, as it will result in different behaviour
853+ on read-only versus writeable arrays:
854+
855+ x = xp.asarray([0, 0, 0])
856+ y = at(x, 0).set(2)
857+ z = at(x, 1).set(3)
858+
859+ In the above example, y == [2, 0, 0] and z == [0, 3, 0] when x is read-only,
860+ whereas y == z == [2, 3, 0] when x is writeable!
861+
862+ See Also
863+ --------
864+ https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
865+ """
866+ if is_jax_array (x ):
867+ return x .at
868+ if is_numpy_array (x ) and not x .flags .writeable :
869+ x = x .copy ()
870+ return _DummyAt (x , idx )
871+
872+ class _DummyAt :
873+ """Helper of at().
874+
875+ Trivially implement jax.numpy.ndarray.at for other backends.
876+ x is updated in place.
877+ """
878+ __slots__ = ("x" , "idx" )
879+
880+ def __init__ (self , x , idx = _undef ):
881+ self .x = x
882+ self .idx = idx
883+
884+ def __getitem__ (self , idx ):
885+ """
886+ Allow for the alternate syntax ``at(x)[start:stop:step]``,
887+ which looks prettier than ``at(x, slice(start, stop, step))``
888+ and feels more intuitive coming from the JAX documentation.
889+ """
890+ self .idx = idx
891+ return self
892+
893+ def _check_args (self , mode = "promise_in_bounds" , ** kwargs ):
894+ if self .idx is _undef :
895+ raise TypeError (
896+ "Index has not been set.\n "
897+ "Usage: either\n "
898+ " at(x, idx).set(value)\n "
899+ "or\n "
900+ " at(x)[idx].set(value)\n "
901+ "(same for all other methods)."
902+ )
903+ if mode != "promise_in_bounds" :
904+ xp = array_namespace (self .x )
905+ raise NotImplementedError (
906+ f"mode='{ mode } ' is not supported for backend { xp .__name__ } "
907+ )
908+
909+ def set (self , y , / , ** kwargs ):
910+ self ._check_args (** kwargs )
911+ self .x [self .idx ] = y
912+ return self .x
913+
914+ def add (self , y , / , ** kwargs ):
915+ self ._check_args (** kwargs )
916+ self .x [self .idx ] += y
917+ return self .x
918+
919+ def subtract (self , y , / , ** kwargs ):
920+ self ._check_args (** kwargs )
921+ self .x [self .idx ] -= y
922+ return self .x
923+
924+ def multiply (self , y , / , ** kwargs ):
925+ self ._check_args (** kwargs )
926+ self .x [self .idx ] *= y
927+ return self .x
928+
929+ def divide (self , y , / , ** kwargs ):
930+ self ._check_args (** kwargs )
931+ self .x [self .idx ] /= y
932+ return self .x
933+
934+ def power (self , y , / , ** kwargs ):
935+ self ._check_args (** kwargs )
936+ self .x [self .idx ] **= y
937+ return self .x
938+
939+ def min (self , y , / , ** kwargs ):
940+ self ._check_args (** kwargs )
941+ xp = array_namespace (self .x )
942+ self .x [self .idx ] = xp .minimum (self .x [self .idx ], y )
943+ return self .x
944+
945+ def max (self , y , / , ** kwargs ):
946+ self ._check_args (** kwargs )
947+ xp = array_namespace (self .x )
948+ self .x [self .idx ] = xp .maximum (self .x [self .idx ], y )
949+ return self .x
950+
951+ def apply (self , ufunc , / , ** kwargs ):
952+ self ._check_args (** kwargs )
953+ ufunc .at (self .x , self .idx )
954+ return self .x
955+
956+ def get (self , ** kwargs ):
957+ self ._check_args (** kwargs )
958+ return self .x [self .idx ]
959+
960+ def iwhere (condition , x , y , / ):
961+ """Variant of xp.where(condition, x, y) which may or may not update
962+ x in place, if it's possible and beneficial for performance.
963+ """
964+ if is_writeable_array (x ):
965+ x [condition ] = y
966+ return x
967+ else :
968+ xp = array_namespace (x )
969+ return xp .where (condition , x , y )
970+
804971__all__ = [
805972 "array_namespace" ,
806973 "device" ,
@@ -821,8 +988,11 @@ def size(x):
821988 "is_ndonnx_namespace" ,
822989 "is_pydata_sparse_array" ,
823990 "is_pydata_sparse_namespace" ,
991+ "is_writeable_array" ,
824992 "size" ,
825993 "to_device" ,
994+ "at" ,
995+ "iwhere" ,
826996]
827997
828998_all_ignore = ['sys' , 'math' , 'inspect' , 'warnings' ]
0 commit comments