7
7
"""
8
8
from __future__ import annotations
9
9
10
+ import operator
10
11
from typing import TYPE_CHECKING
11
12
12
13
if TYPE_CHECKING :
13
- from typing import Optional , Union , Any
14
+ from typing import Callable , Literal , Optional , Union , Any
14
15
from ._typing import Array , Device
15
16
16
17
import sys
@@ -91,7 +92,7 @@ def is_cupy_array(x):
91
92
import cupy as cp
92
93
93
94
# TODO: Should we reject ndarray subclasses?
94
- return isinstance (x , ( cp .ndarray , cp . generic ) )
95
+ return isinstance (x , cp .ndarray )
95
96
96
97
def is_torch_array (x ):
97
98
"""
@@ -787,6 +788,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
787
788
return x
788
789
return x .to_device (device , stream = stream )
789
790
791
+
790
792
def size (x ):
791
793
"""
792
794
Return the total number of elements of x.
@@ -801,6 +803,261 @@ def size(x):
801
803
return None
802
804
return math .prod (x .shape )
803
805
806
+
807
+ def is_writeable_array (x ) -> bool :
808
+ """
809
+ Return False if x.__setitem__ is expected to raise; True otherwise
810
+ """
811
+ if is_numpy_array (x ):
812
+ return x .flags .writeable
813
+ if is_jax_array (x ) or is_pydata_sparse_array (x ):
814
+ return False
815
+ return True
816
+
817
+
818
+ def _is_fancy_index (idx ) -> bool :
819
+ if not isinstance (idx , tuple ):
820
+ idx = (idx ,)
821
+ return any (
822
+ isinstance (i , (list , tuple )) or is_array_api_obj (i )
823
+ for i in idx
824
+ )
825
+
826
+
827
+ _undef = object ()
828
+
829
+
830
+ class at :
831
+ """
832
+ Update operations for read-only arrays.
833
+
834
+ This implements ``jax.numpy.ndarray.at`` for all backends.
835
+ Writeable arrays may be updated in place; you should not rely on it.
836
+
837
+ Keyword arguments (e.g. ``indices_are_sorted``) are passed to JAX and are
838
+ quietly ignored for backends that don't support them.
839
+
840
+ Additionally, this introduces support for the `copy` keyword for all backends:
841
+
842
+ None
843
+ x *may* be modified in place if it is possible and beneficial
844
+ for performance. You should not use x after calling this function.
845
+ True
846
+ Ensure that the inputs are not modified. This is the default.
847
+ False
848
+ Raise ValueError if a copy cannot be avoided.
849
+
850
+ Examples
851
+ --------
852
+ Given either of these equivalent expressions::
853
+
854
+ x = at(x)[1].add(2, copy=None)
855
+ x = at(x, 1).add(2, copy=None)
856
+
857
+ If x is a JAX array, they are the same as::
858
+
859
+ x = x.at[1].add(2)
860
+
861
+ If x is a read-only numpy array, they are the same as::
862
+
863
+ x = x.copy()
864
+ x[1] += 2
865
+
866
+ Otherwise, they are the same as::
867
+
868
+ x[1] += 2
869
+
870
+ Warning
871
+ -------
872
+ When you use copy=None, you should always immediately overwrite
873
+ the parameter array::
874
+
875
+ x = at(x, 0).set(2, copy=None)
876
+
877
+ The anti-pattern below must be avoided, as it will result in different behaviour
878
+ on read-only versus writeable arrays:
879
+
880
+ x = xp.asarray([0, 0, 0])
881
+ y = at(x, 0).set(2, copy=None)
882
+ z = at(x, 1).set(3, copy=None)
883
+
884
+ In the above example, y == [2, 0, 0] and z == [0, 3, 0] when x is read-only,
885
+ whereas y == z == [2, 3, 0] when x is writeable!
886
+
887
+ Caveat
888
+ ------
889
+ The behaviour of methods other than `get()` when the index is an array of
890
+ integers which contains multiple occurrences of the same index is undefined.
891
+
892
+ **Undefined behaviour:** ``at(x, [0, 0]).set(2)``
893
+
894
+ See Also
895
+ --------
896
+ https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
897
+ """
898
+
899
+ __slots__ = ("x" , "idx" )
900
+
901
+ def __init__ (self , x , idx = _undef ):
902
+ self .x = x
903
+ self .idx = idx
904
+
905
+ def __getitem__ (self , idx ):
906
+ """
907
+ Allow for the alternate syntax ``at(x)[start:stop:step]``,
908
+ which looks prettier than ``at(x, slice(start, stop, step))``
909
+ and feels more intuitive coming from the JAX documentation.
910
+ """
911
+ if self .idx is not _undef :
912
+ raise ValueError ("Index has already been set" )
913
+ self .idx = idx
914
+ return self
915
+
916
+ def _common (
917
+ self ,
918
+ at_op : str ,
919
+ y = _undef ,
920
+ copy : bool | None | Literal ["_force_false" ] = True ,
921
+ ** kwargs ,
922
+ ):
923
+ """Perform common prepocessing.
924
+
925
+ Returns
926
+ -------
927
+ If the operation can be resolved by at[], (return value, None)
928
+ Otherwise, (None, preprocessed x)
929
+ """
930
+ if self .idx is _undef :
931
+ raise TypeError (
932
+ "Index has not been set.\n "
933
+ "Usage: either\n "
934
+ " at(x, idx).set(value)\n "
935
+ "or\n "
936
+ " at(x)[idx].set(value)\n "
937
+ "(same for all other methods)."
938
+ )
939
+
940
+ x = self .x
941
+
942
+ if copy is False :
943
+ if not is_writeable_array (x ) or is_dask_array (x ):
944
+ raise ValueError ("Cannot modify parameter in place" )
945
+ elif copy is None :
946
+ copy = not is_writeable_array (x )
947
+ elif copy == "_force_false" :
948
+ copy = False
949
+ elif copy is not True :
950
+ raise ValueError (f"Invalid value for copy: { copy !r} " )
951
+
952
+ if is_jax_array (x ):
953
+ # Use JAX's at[]
954
+ at_ = x .at [self .idx ]
955
+ args = (y ,) if y is not _undef else ()
956
+ return getattr (at_ , at_op )(* args , ** kwargs ), None
957
+
958
+ # Emulate at[] behaviour for non-JAX arrays
959
+ if copy :
960
+ # FIXME We blindly expect the output of x.copy() to be always writeable.
961
+ # This holds true for read-only numpy arrays, but not necessarily for
962
+ # other backends.
963
+ xp = get_namespace (x )
964
+ x = xp .asarray (x , copy = True )
965
+
966
+ return None , x
967
+
968
+ def get (self , copy : bool | None = True , ** kwargs ):
969
+ """
970
+ Return x[idx]. In addition to plain __getitem__, this allows ensuring
971
+ that the output is (not) a copy and kwargs are passed to the backend.
972
+ """
973
+ # __getitem__ with a fancy index always returns a copy.
974
+ # Avoid an unnecessary double copy.
975
+ # If copy is forced to False, raise.
976
+ if _is_fancy_index (self .idx ):
977
+ if copy is False :
978
+ raise ValueError (
979
+ "Indexing a numpy array with a fancy index always "
980
+ "results in a copy"
981
+ )
982
+ # Skip copy inside _common, even if array is not writeable
983
+ copy = "_force_false" # type: ignore
984
+
985
+ res , x = self ._common ("get" , copy = copy , ** kwargs )
986
+ if res is not None :
987
+ return res
988
+ return x [self .idx ]
989
+
990
+ def set (self , y , / , ** kwargs ):
991
+ """x[idx] = y"""
992
+ res , x = self ._common ("set" , y , ** kwargs )
993
+ if res is not None :
994
+ return res
995
+ x [self .idx ] = y
996
+ return x
997
+
998
+ def apply (self , ufunc , / , ** kwargs ):
999
+ """ufunc.at(x, idx)"""
1000
+ if is_cupy_array (self .x ) or is_torch_array (self .x ) or is_dask_array (self .x ):
1001
+ # ufunc.at not implemented
1002
+ return self .set (ufunc (self .x [self .idx ]), ** kwargs )
1003
+
1004
+ res , x = self ._common ("apply" , ufunc , ** kwargs )
1005
+ if res is not None :
1006
+ return res
1007
+ ufunc .at (x , self .idx )
1008
+ return x
1009
+
1010
+ def _iop (
1011
+ self , at_op : str , elwise_op : Callable [[Array , Array ], Array ], y : Array , ** kwargs
1012
+ ):
1013
+ """x[idx] += y or equivalent in-place operation on a subset of x
1014
+
1015
+ which is the same as saying
1016
+ x[idx] = x[idx] + y
1017
+ Note that this is not the same as
1018
+ operator.iadd(x[idx], y)
1019
+ Consider for example when x is a numpy array and idx is a fancy index, which
1020
+ triggers a deep copy on __getitem__.
1021
+ """
1022
+ res , x = self ._common (at_op , y , ** kwargs )
1023
+ if res is not None :
1024
+ return res
1025
+ x [self .idx ] = elwise_op (x [self .idx ], y )
1026
+ return x
1027
+
1028
+ def add (self , y , / , ** kwargs ):
1029
+ """x[idx] += y"""
1030
+ return self ._iop ("add" , operator .add , y , ** kwargs )
1031
+
1032
+ def subtract (self , y , / , ** kwargs ):
1033
+ """x[idx] -= y"""
1034
+ return self ._iop ("subtract" , operator .sub , y , ** kwargs )
1035
+
1036
+ def multiply (self , y , / , ** kwargs ):
1037
+ """x[idx] *= y"""
1038
+ return self ._iop ("multiply" , operator .mul , y , ** kwargs )
1039
+
1040
+ def divide (self , y , / , ** kwargs ):
1041
+ """x[idx] /= y"""
1042
+ return self ._iop ("divide" , operator .truediv , y , ** kwargs )
1043
+
1044
+ def power (self , y , / , ** kwargs ):
1045
+ """x[idx] **= y"""
1046
+ return self ._iop ("power" , operator .pow , y , ** kwargs )
1047
+
1048
+ def min (self , y , / , ** kwargs ):
1049
+ """x[idx] = minimum(x[idx], y)"""
1050
+ import numpy as np
1051
+
1052
+ return self ._iop ("min" , np .minimum , y , ** kwargs )
1053
+
1054
+ def max (self , y , / , ** kwargs ):
1055
+ """x[idx] = maximum(x[idx], y)"""
1056
+ import numpy as np
1057
+
1058
+ return self ._iop ("max" , np .maximum , y , ** kwargs )
1059
+
1060
+
804
1061
__all__ = [
805
1062
"array_namespace" ,
806
1063
"device" ,
@@ -821,8 +1078,10 @@ def size(x):
821
1078
"is_ndonnx_namespace" ,
822
1079
"is_pydata_sparse_array" ,
823
1080
"is_pydata_sparse_namespace" ,
1081
+ "is_writeable_array" ,
824
1082
"size" ,
825
1083
"to_device" ,
1084
+ "at" ,
826
1085
]
827
1086
828
- _all_ignore = ['sys ' , 'math' , 'inspect ' , 'warnings' ]
1087
+ _all_ignore = ['inspect ' , 'math' , 'operator ' , 'warnings' , 'sys ' ]
0 commit comments