13
13
array_namespace ,
14
14
is_array_api_obj ,
15
15
is_dask_array ,
16
+ is_jax_array ,
17
+ is_pydata_sparse_array ,
16
18
is_writeable_array ,
17
19
)
18
20
19
21
if typing .TYPE_CHECKING :
20
- from ._lib ._typing import Array , Index , ModuleType , Untyped
22
+ from ._lib ._typing import Array , Index , ModuleType
21
23
22
24
__all__ = [
23
25
"at" ,
@@ -593,11 +595,6 @@ class at: # pylint: disable=invalid-name
593
595
xp : array_namespace, optional
594
596
The standard-compatible namespace for `x`. Default: infer
595
597
596
- **kwargs:
597
- If the backend supports an `at` method, any additional keyword
598
- arguments are passed to it verbatim; e.g. this allows passing
599
- ``indices_are_sorted=True`` to JAX.
600
-
601
598
Returns
602
599
-------
603
600
Updated input array.
@@ -674,23 +671,7 @@ def __getitem__(self, idx: Index, /) -> at:
674
671
self .idx = idx
675
672
return self
676
673
677
- def _common (
678
- self ,
679
- at_op : str ,
680
- y : Array = _undef ,
681
- / ,
682
- copy : bool | None = True ,
683
- xp : ModuleType | None = None ,
684
- _is_update : bool = True ,
685
- ** kwargs : Untyped ,
686
- ) -> tuple [Array , None ] | tuple [None , Array ]:
687
- """Perform common prepocessing.
688
-
689
- Returns
690
- -------
691
- If the operation can be resolved by at[], (return value, None)
692
- Otherwise, (None, preprocessed x)
693
- """
674
+ def _check_args (self , / , copy : bool | None ) -> None :
694
675
if self .idx is _undef :
695
676
msg = (
696
677
"Index has not been set.\n "
@@ -702,64 +683,23 @@ def _common(
702
683
)
703
684
raise TypeError (msg )
704
685
705
- x = self .x
706
-
707
686
if copy not in (True , False , None ):
708
687
msg = f"copy must be True, False, or None; got { copy !r} " # pyright: ignore[reportUnreachable]
709
688
raise ValueError (msg )
710
689
711
- if copy is None :
712
- writeable = is_writeable_array (x )
713
- copy = _is_update and not writeable
714
- elif copy :
715
- writeable = None
716
- elif _is_update :
717
- writeable = is_writeable_array (x )
718
- if not writeable :
719
- msg = "Cannot modify parameter in place"
720
- raise ValueError (msg )
721
- else :
722
- writeable = None
723
-
724
- if copy :
725
- try :
726
- at_ = x .at
727
- except AttributeError :
728
- # Emulate at[] behaviour for non-JAX arrays
729
- # with a copy followed by an update
730
- if xp is None :
731
- xp = array_namespace (x )
732
- x = xp .asarray (x , copy = True )
733
- if writeable is False :
734
- # A copy of a read-only numpy array is writeable
735
- # Note: this assumes that a copy of a writeable array is writeable
736
- writeable = None
737
- else :
738
- # Use JAX's at[] or other library that with the same duck-type API
739
- args = (y ,) if y is not _undef else ()
740
- return getattr (at_ [self .idx ], at_op )(* args , ** kwargs ), None
741
-
742
- if _is_update :
743
- if writeable is None :
744
- writeable = is_writeable_array (x )
745
- if not writeable :
746
- # sparse crashes here
747
- msg = f"Array { x } has no `at` method and is read-only"
748
- raise ValueError (msg )
749
-
750
- return None , x
751
-
752
690
def get (
753
691
self ,
754
692
/ ,
755
693
copy : bool | None = True ,
756
694
xp : ModuleType | None = None ,
757
- ** kwargs : Untyped ,
758
- ) -> Untyped :
759
- """Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
760
- that the output is either a copy or a view; it also allows passing
695
+ ) -> Array :
696
+ """Return ``xp.asarray(x[idx])``. In addition to plain ``__getitem__``, this allows
697
+ ensuring that the output is either a copy or a view; it also allows passing
761
698
keyword arguments to the backend.
762
699
"""
700
+ self ._check_args (copy = copy )
701
+ x = self .x
702
+
763
703
if copy is False :
764
704
if is_array_api_obj (self .idx ):
765
705
# Boolean index. Note that the array API spec
@@ -782,26 +722,81 @@ def get(
782
722
msg = "get() with a scalar index typically returns a copy"
783
723
raise ValueError (msg )
784
724
785
- if is_dask_array (self .x ):
786
- msg = "get() on Dask arrays always returns a copy"
725
+ # Note: this is not the same list of backends as is_writeable_array()
726
+ if is_dask_array (x ) or is_jax_array (x ) or is_pydata_sparse_array (x ):
727
+ msg = f"get() on { array_namespace (x )} arrays always returns a copy"
787
728
raise ValueError (msg )
788
729
789
- res , x = self ._common ("get" , copy = copy , xp = xp , _is_update = False , ** kwargs )
790
- if res is not None :
791
- return res
792
- assert x is not None
793
- return x [self .idx ]
730
+ if is_jax_array (x ):
731
+ # Use JAX's at[] or other library that with the same duck-type API
732
+ return x .at [self .idx ].get ()
733
+
734
+ if xp is None :
735
+ xp = array_namespace (x )
736
+ # Note: when self.idx is a boolean mask, numpy always returns a deep copy.
737
+ # However, some backends may legitimately return a view when the mask can
738
+ # be downgraded to a slice, e.g. a[[True, True, False]] -> a[:2].
739
+ # Err on the side of caution and perform a double-copy in numpy.
740
+ return xp .asarray (x [self .idx ], copy = copy )
741
+
742
+ def _update_common (
743
+ self ,
744
+ at_op : str ,
745
+ y : Array = _undef ,
746
+ / ,
747
+ copy : bool | None = True ,
748
+ xp : ModuleType | None = None ,
749
+ ) -> tuple [Array , None ] | tuple [None , Array ]:
750
+ """Perform common prepocessing to all update operations.
751
+
752
+ Returns
753
+ -------
754
+ If the operation can be resolved by at[], (return value, None)
755
+ Otherwise, (None, preprocessed x)
756
+ """
757
+ x = self .x
758
+ if copy is None :
759
+ writeable = is_writeable_array (x )
760
+ copy = not writeable
761
+ elif copy :
762
+ writeable = None
763
+ else :
764
+ writeable = is_writeable_array (x )
765
+
766
+ if copy :
767
+ if is_jax_array (x ):
768
+ # Use JAX's at[] or other library that with the same duck-type API
769
+ func = getattr (x .at [self .idx ], at_op )
770
+ return func (y ) if y is not _undef else func (), None
771
+ # Emulate at[] behaviour for non-JAX arrays
772
+ # with a copy followed by an update
773
+ if xp is None :
774
+ xp = array_namespace (x )
775
+ x = xp .asarray (x , copy = True )
776
+ if writeable is False :
777
+ # A copy of a read-only numpy array is writeable
778
+ # Note: this assumes that a copy of a writeable array is writeable
779
+ writeable = None
780
+
781
+ if writeable is None :
782
+ writeable = is_writeable_array (x )
783
+ if not writeable :
784
+ # sparse crashes here
785
+ msg = f"Array { x } has no `at` method and is read-only"
786
+ raise ValueError (msg )
787
+
788
+ return None , x
794
789
795
790
def set (
796
791
self ,
797
792
y : Array ,
798
793
/ ,
799
794
copy : bool | None = True ,
800
795
xp : ModuleType | None = None ,
801
- ** kwargs : Untyped ,
802
796
) -> Array :
803
797
"""Apply ``x[idx] = y`` and return the update array"""
804
- res , x = self ._common ("set" , y , copy = copy , xp = xp , ** kwargs )
798
+ self ._check_args (copy = copy )
799
+ res , x = self ._update_common ("set" , y , copy = copy , xp = xp )
805
800
if res is not None :
806
801
return res
807
802
assert x is not None
@@ -818,7 +813,6 @@ def _iop(
818
813
/ ,
819
814
copy : bool | None = True ,
820
815
xp : ModuleType | None = None ,
821
- ** kwargs : Untyped ,
822
816
) -> Array :
823
817
"""x[idx] += y or equivalent in-place operation on a subset of x
824
818
@@ -829,7 +823,8 @@ def _iop(
829
823
Consider for example when x is a numpy array and idx is a fancy index, which
830
824
triggers a deep copy on __getitem__.
831
825
"""
832
- res , x = self ._common (at_op , y , copy = copy , xp = xp , ** kwargs )
826
+ self ._check_args (copy = copy )
827
+ res , x = self ._update_common (at_op , y , copy = copy , xp = xp )
833
828
if res is not None :
834
829
return res
835
830
assert x is not None
@@ -842,79 +837,72 @@ def add(
842
837
/ ,
843
838
copy : bool | None = True ,
844
839
xp : ModuleType | None = None ,
845
- ** kwargs : Untyped ,
846
840
) -> Array :
847
841
"""Apply ``x[idx] += y`` and return the updated array"""
848
- return self ._iop ("add" , operator .add , y , copy = copy , xp = xp , ** kwargs )
842
+ return self ._iop ("add" , operator .add , y , copy = copy , xp = xp )
849
843
850
844
def subtract (
851
845
self ,
852
846
y : Array ,
853
847
/ ,
854
848
copy : bool | None = True ,
855
849
xp : ModuleType | None = None ,
856
- ** kwargs : Untyped ,
857
850
) -> Array :
858
851
"""Apply ``x[idx] -= y`` and return the updated array"""
859
- return self ._iop ("subtract" , operator .sub , y , copy = copy , xp = xp , ** kwargs )
852
+ return self ._iop ("subtract" , operator .sub , y , copy = copy , xp = xp )
860
853
861
854
def multiply (
862
855
self ,
863
856
y : Array ,
864
857
/ ,
865
858
copy : bool | None = True ,
866
859
xp : ModuleType | None = None ,
867
- ** kwargs : Untyped ,
868
860
) -> Array :
869
861
"""Apply ``x[idx] *= y`` and return the updated array"""
870
- return self ._iop ("multiply" , operator .mul , y , copy = copy , xp = xp , ** kwargs )
862
+ return self ._iop ("multiply" , operator .mul , y , copy = copy , xp = xp )
871
863
872
864
def divide (
873
865
self ,
874
866
y : Array ,
875
867
/ ,
876
868
copy : bool | None = True ,
877
869
xp : ModuleType | None = None ,
878
- ** kwargs : Untyped ,
879
870
) -> Array :
880
871
"""Apply ``x[idx] /= y`` and return the updated array"""
881
- return self ._iop ("divide" , operator .truediv , y , copy = copy , xp = xp , ** kwargs )
872
+ return self ._iop ("divide" , operator .truediv , y , copy = copy , xp = xp )
882
873
883
874
def power (
884
875
self ,
885
876
y : Array ,
886
877
/ ,
887
878
copy : bool | None = True ,
888
879
xp : ModuleType | None = None ,
889
- ** kwargs : Untyped ,
890
880
) -> Array :
891
881
"""Apply ``x[idx] **= y`` and return the updated array"""
892
- return self ._iop ("power" , operator .pow , y , copy = copy , xp = xp , ** kwargs )
882
+ return self ._iop ("power" , operator .pow , y , copy = copy , xp = xp )
893
883
894
884
def min (
895
885
self ,
896
886
y : Array ,
897
887
/ ,
898
888
copy : bool | None = True ,
899
889
xp : ModuleType | None = None ,
900
- ** kwargs : Untyped ,
901
890
) -> Array :
902
891
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
903
892
if xp is None :
904
893
xp = array_namespace (self .x )
905
894
y = xp .asarray (y )
906
- return self ._iop ("min" , xp .minimum , y , copy = copy , xp = xp , ** kwargs )
895
+ return self ._iop ("min" , xp .minimum , y , copy = copy , xp = xp )
907
896
908
897
def max (
909
898
self ,
910
899
y : Array ,
911
900
/ ,
912
901
copy : bool | None = True ,
913
902
xp : ModuleType | None = None ,
914
- ** kwargs : Untyped ,
915
903
) -> Array :
916
904
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
917
905
if xp is None :
918
906
xp = array_namespace (self .x )
919
907
y = xp .asarray (y )
920
- return self ._iop ("max" , xp .maximum , y , copy = copy , xp = xp , ** kwargs )
908
+ return self ._iop ("max" , xp .maximum , y , copy = copy , xp = xp )
0 commit comments