Skip to content

Commit 2b18097

Browse files
committed
Self-review
1 parent 18096a5 commit 2b18097

File tree

8 files changed

+6379
-1008
lines changed

8 files changed

+6379
-1008
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ jobs:
4848
strategy:
4949
fail-fast: false
5050
matrix:
51-
environment: [ci-py310, ci-py313]
51+
environment: [ci-py310, ci-py313, ci-backends]
5252
runs-on: [ubuntu-latest]
5353

5454
steps:

pixi.lock

+6,259-911
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+23
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,27 @@ python = "~=3.10.0"
127127
[tool.pixi.feature.py313.dependencies]
128128
python = "~=3.13.0"
129129

130+
[tool.pixi.feature.backends.target.linux-64.dependencies]
131+
cupy = "*"
132+
pytorch = "*"
133+
dask = "*"
134+
sparse = ">=0.15"
135+
jax = "*"
136+
137+
[tool.pixi.feature.backends.target.osx-arm64.dependencies]
138+
# cupy = "*"
139+
pytorch = "*"
140+
dask = "*"
141+
sparse = ">=0.15"
142+
jax = "*"
143+
144+
[tool.pixi.feature.backends.target.win-64.dependencies]
145+
cupy = "*"
146+
# pytorch = "*"
147+
dask = "*"
148+
sparse = ">=0.15"
149+
# jax = "*"
150+
130151
[tool.pixi.environments]
131152
default = { solve-group = "default" }
132153
lint = { features = ["lint"], solve-group = "default" }
@@ -135,6 +156,7 @@ docs = { features = ["docs"], solve-group = "default" }
135156
dev = { features = ["lint", "tests", "docs", "dev"], solve-group = "default" }
136157
ci-py310 = ["py310", "tests"]
137158
ci-py313 = ["py313", "tests"]
159+
ci-backends = ["py310", "tests", "backends"]
138160

139161

140162
# pytest
@@ -232,6 +254,7 @@ ignore = [
232254
"PLR09", # Too many <...>
233255
"PLR2004", # Magic value used in comparison
234256
"ISC001", # Conflicts with formatter
257+
"PD008", # Use `.loc` instead of `.at`
235258
]
236259

237260
[tool.ruff.lint.per-file-ignores]

src/array_api_extra/_funcs.py

+83-95
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
array_namespace,
1414
is_array_api_obj,
1515
is_dask_array,
16+
is_jax_array,
17+
is_pydata_sparse_array,
1618
is_writeable_array,
1719
)
1820

1921
if typing.TYPE_CHECKING:
20-
from ._lib._typing import Array, Index, ModuleType, Untyped
22+
from ._lib._typing import Array, Index, ModuleType
2123

2224
__all__ = [
2325
"at",
@@ -593,11 +595,6 @@ class at: # pylint: disable=invalid-name
593595
xp : array_namespace, optional
594596
The standard-compatible namespace for `x`. Default: infer
595597
596-
**kwargs:
597-
If the backend supports an `at` method, any additional keyword
598-
arguments are passed to it verbatim; e.g. this allows passing
599-
``indices_are_sorted=True`` to JAX.
600-
601598
Returns
602599
-------
603600
Updated input array.
@@ -674,23 +671,7 @@ def __getitem__(self, idx: Index, /) -> at:
674671
self.idx = idx
675672
return self
676673

677-
def _common(
678-
self,
679-
at_op: str,
680-
y: Array = _undef,
681-
/,
682-
copy: bool | None = True,
683-
xp: ModuleType | None = None,
684-
_is_update: bool = True,
685-
**kwargs: Untyped,
686-
) -> tuple[Array, None] | tuple[None, Array]:
687-
"""Perform common prepocessing.
688-
689-
Returns
690-
-------
691-
If the operation can be resolved by at[], (return value, None)
692-
Otherwise, (None, preprocessed x)
693-
"""
674+
def _check_args(self, /, copy: bool | None) -> None:
694675
if self.idx is _undef:
695676
msg = (
696677
"Index has not been set.\n"
@@ -702,64 +683,23 @@ def _common(
702683
)
703684
raise TypeError(msg)
704685

705-
x = self.x
706-
707686
if copy not in (True, False, None):
708687
msg = f"copy must be True, False, or None; got {copy!r}" # pyright: ignore[reportUnreachable]
709688
raise ValueError(msg)
710689

711-
if copy is None:
712-
writeable = is_writeable_array(x)
713-
copy = _is_update and not writeable
714-
elif copy:
715-
writeable = None
716-
elif _is_update:
717-
writeable = is_writeable_array(x)
718-
if not writeable:
719-
msg = "Cannot modify parameter in place"
720-
raise ValueError(msg)
721-
else:
722-
writeable = None
723-
724-
if copy:
725-
try:
726-
at_ = x.at
727-
except AttributeError:
728-
# Emulate at[] behaviour for non-JAX arrays
729-
# with a copy followed by an update
730-
if xp is None:
731-
xp = array_namespace(x)
732-
x = xp.asarray(x, copy=True)
733-
if writeable is False:
734-
# A copy of a read-only numpy array is writeable
735-
# Note: this assumes that a copy of a writeable array is writeable
736-
writeable = None
737-
else:
738-
# Use JAX's at[] or other library that with the same duck-type API
739-
args = (y,) if y is not _undef else ()
740-
return getattr(at_[self.idx], at_op)(*args, **kwargs), None
741-
742-
if _is_update:
743-
if writeable is None:
744-
writeable = is_writeable_array(x)
745-
if not writeable:
746-
# sparse crashes here
747-
msg = f"Array {x} has no `at` method and is read-only"
748-
raise ValueError(msg)
749-
750-
return None, x
751-
752690
def get(
753691
self,
754692
/,
755693
copy: bool | None = True,
756694
xp: ModuleType | None = None,
757-
**kwargs: Untyped,
758-
) -> Untyped:
759-
"""Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
760-
that the output is either a copy or a view; it also allows passing
695+
) -> Array:
696+
"""Return ``xp.asarray(x[idx])``. In addition to plain ``__getitem__``, this allows
697+
ensuring that the output is either a copy or a view; it also allows passing
761698
keyword arguments to the backend.
762699
"""
700+
self._check_args(copy=copy)
701+
x = self.x
702+
763703
if copy is False:
764704
if is_array_api_obj(self.idx):
765705
# Boolean index. Note that the array API spec
@@ -782,26 +722,81 @@ def get(
782722
msg = "get() with a scalar index typically returns a copy"
783723
raise ValueError(msg)
784724

785-
if is_dask_array(self.x):
786-
msg = "get() on Dask arrays always returns a copy"
725+
# Note: this is not the same list of backends as is_writeable_array()
726+
if is_dask_array(x) or is_jax_array(x) or is_pydata_sparse_array(x):
727+
msg = f"get() on {array_namespace(x)} arrays always returns a copy"
787728
raise ValueError(msg)
788729

789-
res, x = self._common("get", copy=copy, xp=xp, _is_update=False, **kwargs)
790-
if res is not None:
791-
return res
792-
assert x is not None
793-
return x[self.idx]
730+
if is_jax_array(x):
731+
# Use JAX's at[] or other library that with the same duck-type API
732+
return x.at[self.idx].get()
733+
734+
if xp is None:
735+
xp = array_namespace(x)
736+
# Note: when self.idx is a boolean mask, numpy always returns a deep copy.
737+
# However, some backends may legitimately return a view when the mask can
738+
# be downgraded to a slice, e.g. a[[True, True, False]] -> a[:2].
739+
# Err on the side of caution and perform a double-copy in numpy.
740+
return xp.asarray(x[self.idx], copy=copy)
741+
742+
def _update_common(
743+
self,
744+
at_op: str,
745+
y: Array = _undef,
746+
/,
747+
copy: bool | None = True,
748+
xp: ModuleType | None = None,
749+
) -> tuple[Array, None] | tuple[None, Array]:
750+
"""Perform common prepocessing to all update operations.
751+
752+
Returns
753+
-------
754+
If the operation can be resolved by at[], (return value, None)
755+
Otherwise, (None, preprocessed x)
756+
"""
757+
x = self.x
758+
if copy is None:
759+
writeable = is_writeable_array(x)
760+
copy = not writeable
761+
elif copy:
762+
writeable = None
763+
else:
764+
writeable = is_writeable_array(x)
765+
766+
if copy:
767+
if is_jax_array(x):
768+
# Use JAX's at[] or other library that with the same duck-type API
769+
func = getattr(x.at[self.idx], at_op)
770+
return func(y) if y is not _undef else func(), None
771+
# Emulate at[] behaviour for non-JAX arrays
772+
# with a copy followed by an update
773+
if xp is None:
774+
xp = array_namespace(x)
775+
x = xp.asarray(x, copy=True)
776+
if writeable is False:
777+
# A copy of a read-only numpy array is writeable
778+
# Note: this assumes that a copy of a writeable array is writeable
779+
writeable = None
780+
781+
if writeable is None:
782+
writeable = is_writeable_array(x)
783+
if not writeable:
784+
# sparse crashes here
785+
msg = f"Array {x} has no `at` method and is read-only"
786+
raise ValueError(msg)
787+
788+
return None, x
794789

795790
def set(
796791
self,
797792
y: Array,
798793
/,
799794
copy: bool | None = True,
800795
xp: ModuleType | None = None,
801-
**kwargs: Untyped,
802796
) -> Array:
803797
"""Apply ``x[idx] = y`` and return the update array"""
804-
res, x = self._common("set", y, copy=copy, xp=xp, **kwargs)
798+
self._check_args(copy=copy)
799+
res, x = self._update_common("set", y, copy=copy, xp=xp)
805800
if res is not None:
806801
return res
807802
assert x is not None
@@ -818,7 +813,6 @@ def _iop(
818813
/,
819814
copy: bool | None = True,
820815
xp: ModuleType | None = None,
821-
**kwargs: Untyped,
822816
) -> Array:
823817
"""x[idx] += y or equivalent in-place operation on a subset of x
824818
@@ -829,7 +823,8 @@ def _iop(
829823
Consider for example when x is a numpy array and idx is a fancy index, which
830824
triggers a deep copy on __getitem__.
831825
"""
832-
res, x = self._common(at_op, y, copy=copy, xp=xp, **kwargs)
826+
self._check_args(copy=copy)
827+
res, x = self._update_common(at_op, y, copy=copy, xp=xp)
833828
if res is not None:
834829
return res
835830
assert x is not None
@@ -842,79 +837,72 @@ def add(
842837
/,
843838
copy: bool | None = True,
844839
xp: ModuleType | None = None,
845-
**kwargs: Untyped,
846840
) -> Array:
847841
"""Apply ``x[idx] += y`` and return the updated array"""
848-
return self._iop("add", operator.add, y, copy=copy, xp=xp, **kwargs)
842+
return self._iop("add", operator.add, y, copy=copy, xp=xp)
849843

850844
def subtract(
851845
self,
852846
y: Array,
853847
/,
854848
copy: bool | None = True,
855849
xp: ModuleType | None = None,
856-
**kwargs: Untyped,
857850
) -> Array:
858851
"""Apply ``x[idx] -= y`` and return the updated array"""
859-
return self._iop("subtract", operator.sub, y, copy=copy, xp=xp, **kwargs)
852+
return self._iop("subtract", operator.sub, y, copy=copy, xp=xp)
860853

861854
def multiply(
862855
self,
863856
y: Array,
864857
/,
865858
copy: bool | None = True,
866859
xp: ModuleType | None = None,
867-
**kwargs: Untyped,
868860
) -> Array:
869861
"""Apply ``x[idx] *= y`` and return the updated array"""
870-
return self._iop("multiply", operator.mul, y, copy=copy, xp=xp, **kwargs)
862+
return self._iop("multiply", operator.mul, y, copy=copy, xp=xp)
871863

872864
def divide(
873865
self,
874866
y: Array,
875867
/,
876868
copy: bool | None = True,
877869
xp: ModuleType | None = None,
878-
**kwargs: Untyped,
879870
) -> Array:
880871
"""Apply ``x[idx] /= y`` and return the updated array"""
881-
return self._iop("divide", operator.truediv, y, copy=copy, xp=xp, **kwargs)
872+
return self._iop("divide", operator.truediv, y, copy=copy, xp=xp)
882873

883874
def power(
884875
self,
885876
y: Array,
886877
/,
887878
copy: bool | None = True,
888879
xp: ModuleType | None = None,
889-
**kwargs: Untyped,
890880
) -> Array:
891881
"""Apply ``x[idx] **= y`` and return the updated array"""
892-
return self._iop("power", operator.pow, y, copy=copy, xp=xp, **kwargs)
882+
return self._iop("power", operator.pow, y, copy=copy, xp=xp)
893883

894884
def min(
895885
self,
896886
y: Array,
897887
/,
898888
copy: bool | None = True,
899889
xp: ModuleType | None = None,
900-
**kwargs: Untyped,
901890
) -> Array:
902891
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
903892
if xp is None:
904893
xp = array_namespace(self.x)
905894
y = xp.asarray(y)
906-
return self._iop("min", xp.minimum, y, copy=copy, xp=xp, **kwargs)
895+
return self._iop("min", xp.minimum, y, copy=copy, xp=xp)
907896

908897
def max(
909898
self,
910899
y: Array,
911900
/,
912901
copy: bool | None = True,
913902
xp: ModuleType | None = None,
914-
**kwargs: Untyped,
915903
) -> Array:
916904
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
917905
if xp is None:
918906
xp = array_namespace(self.x)
919907
y = xp.asarray(y)
920-
return self._iop("max", xp.maximum, y, copy=copy, xp=xp, **kwargs)
908+
return self._iop("max", xp.maximum, y, copy=copy, xp=xp)

src/array_api_extra/_lib/_compat.py

+6
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
device,
99
is_array_api_obj,
1010
is_dask_array,
11+
is_jax_array,
12+
is_pydata_sparse_array,
1113
is_writeable_array,
1214
)
1315
except ImportError:
@@ -16,6 +18,8 @@
1618
device,
1719
is_array_api_obj,
1820
is_dask_array,
21+
is_jax_array,
22+
is_pydata_sparse_array,
1923
is_writeable_array,
2024
)
2125

@@ -24,5 +28,7 @@
2428
"device",
2529
"is_array_api_obj",
2630
"is_dask_array",
31+
"is_jax_array",
32+
"is_pydata_sparse_array",
2733
"is_writeable_array",
2834
)

0 commit comments

Comments
 (0)