@@ -801,6 +801,173 @@ def size(x):
801
801
return None
802
802
return math .prod (x .shape )
803
803
804
+ def is_writeable_array (x ):
805
+ """
806
+ Return False if x.__setitem__ is expected to raise; True otherwise
807
+ """
808
+ if is_numpy_array (x ):
809
+ return x .flags .writeable
810
+ if is_jax_array (x ):
811
+ return False
812
+ return True
813
+
814
+ _undef = object ()
815
+
816
+ def at (x , idx = _undef , / ):
817
+ """
818
+ Update operations for read-only arrays.
819
+
820
+ This implements ``jax.numpy.ndarray.at`` for all backends.
821
+ Writeable arrays may be updated in place; you should not rely on it.
822
+
823
+ Keyword arguments (e.g. ``indices_are_sorted``) are passed to JAX and are
824
+ quietly ignored for backends that don't support them.
825
+
826
+ Examples
827
+ --------
828
+ Given either of these equivalent expressions::
829
+
830
+ x = at(x)[1].add(2)
831
+ x = at(x, 1).add(2)
832
+
833
+ If x is a JAX array, they are the same as::
834
+
835
+ x = x.at[1].add(x)
836
+
837
+ If x is a read-only numpy array, they are the same as::
838
+
839
+ x = x.copy()
840
+ x[1] += 2
841
+
842
+ Otherwise, they are the same as::
843
+
844
+ x[1] += 2
845
+
846
+ Warning
847
+ -------
848
+ You should always immediately overwrite the parameter array::
849
+
850
+ x = at(x, 0).set(2)
851
+
852
+ The anti-pattern below must be avoided, as it will result in different behaviour
853
+ on read-only versus writeable arrays:
854
+
855
+ x = xp.asarray([0, 0, 0])
856
+ y = at(x, 0).set(2)
857
+ z = at(x, 1).set(3)
858
+
859
+ In the above example, y == [2, 0, 0] and z == [0, 3, 0] when x is read-only,
860
+ whereas y == z == [2, 3, 0] when x is writeable!
861
+
862
+ See Also
863
+ --------
864
+ https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
865
+ """
866
+ if is_jax_array (x ):
867
+ return x .at
868
+ if is_numpy_array (x ) and not x .flags .writeable :
869
+ x = x .copy ()
870
+ return _DummyAt (x , idx )
871
+
872
+ class _DummyAt :
873
+ """Helper of at().
874
+
875
+ Trivially implement jax.numpy.ndarray.at for other backends.
876
+ x is updated in place.
877
+ """
878
+ __slots__ = ("x" , "idx" )
879
+
880
+ def __init__ (self , x , idx = _undef ):
881
+ self .x = x
882
+ self .idx = idx
883
+
884
+ def __getitem__ (self , idx ):
885
+ """
886
+ Allow for the alternate syntax ``at(x)[start:stop:step]``,
887
+ which looks prettier than ``at(x, slice(start, stop, step))``
888
+ and feels more intuitive coming from the JAX documentation.
889
+ """
890
+ self .idx = idx
891
+ return self
892
+
893
+ def _check_args (self , mode = "promise_in_bounds" , ** kwargs ):
894
+ if self .idx is _undef :
895
+ raise TypeError (
896
+ "Index has not been set.\n "
897
+ "Usage: either\n "
898
+ " at(x, idx).set(value)\n "
899
+ "or\n "
900
+ " at(x)[idx].set(value)\n "
901
+ "(same for all other methods)."
902
+ )
903
+ if mode != "promise_in_bounds" :
904
+ xp = array_namespace (self .x )
905
+ raise NotImplementedError (
906
+ f"mode='{ mode } ' is not supported for backend { xp .__name__ } "
907
+ )
908
+
909
+ def set (self , y , / , ** kwargs ):
910
+ self ._check_args (** kwargs )
911
+ self .x [self .idx ] = y
912
+ return self .x
913
+
914
+ def add (self , y , / , ** kwargs ):
915
+ self ._check_args (** kwargs )
916
+ self .x [self .idx ] += y
917
+ return self .x
918
+
919
+ def subtract (self , y , / , ** kwargs ):
920
+ self ._check_args (** kwargs )
921
+ self .x [self .idx ] -= y
922
+ return self .x
923
+
924
+ def multiply (self , y , / , ** kwargs ):
925
+ self ._check_args (** kwargs )
926
+ self .x [self .idx ] *= y
927
+ return self .x
928
+
929
+ def divide (self , y , / , ** kwargs ):
930
+ self ._check_args (** kwargs )
931
+ self .x [self .idx ] /= y
932
+ return self .x
933
+
934
+ def power (self , y , / , ** kwargs ):
935
+ self ._check_args (** kwargs )
936
+ self .x [self .idx ] **= y
937
+ return self .x
938
+
939
+ def min (self , y , / , ** kwargs ):
940
+ self ._check_args (** kwargs )
941
+ xp = array_namespace (self .x )
942
+ self .x [self .idx ] = xp .minimum (self .x [self .idx ], y )
943
+ return self .x
944
+
945
+ def max (self , y , / , ** kwargs ):
946
+ self ._check_args (** kwargs )
947
+ xp = array_namespace (self .x )
948
+ self .x [self .idx ] = xp .maximum (self .x [self .idx ], y )
949
+ return self .x
950
+
951
+ def apply (self , ufunc , / , ** kwargs ):
952
+ self ._check_args (** kwargs )
953
+ ufunc .at (self .x , self .idx )
954
+ return self .x
955
+
956
+ def get (self , ** kwargs ):
957
+ self ._check_args (** kwargs )
958
+ return self .x [self .idx ]
959
+
960
+ def iwhere (condition , x , y , / ):
961
+ """Variant of xp.where(condition, x, y) which may or may not update
962
+ x in place, if it's possible and beneficial for performance.
963
+ """
964
+ if is_writeable_array (x ):
965
+ x [condition ] = y
966
+ return x
967
+ else :
968
+ xp = array_namespace (x )
969
+ return xp .where (condition , x , y )
970
+
804
971
__all__ = [
805
972
"array_namespace" ,
806
973
"device" ,
@@ -821,8 +988,11 @@ def size(x):
821
988
"is_ndonnx_namespace" ,
822
989
"is_pydata_sparse_array" ,
823
990
"is_pydata_sparse_namespace" ,
991
+ "is_writeable_array" ,
824
992
"size" ,
825
993
"to_device" ,
994
+ "at" ,
995
+ "iwhere" ,
826
996
]
827
997
828
998
_all_ignore = ['sys' , 'math' , 'inspect' , 'warnings' ]
0 commit comments