Skip to content

Commit ca33a66

Browse files
committed
Abstractions for read-only arrays
1 parent ee25aae commit ca33a66

File tree

5 files changed

+449
-7
lines changed

5 files changed

+449
-7
lines changed

array_api_compat/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""
22
NumPy Array API compatibility library
33
4-
This is a small wrapper around NumPy and CuPy that is compatible with the
5-
Array API standard https://data-apis.org/array-api/latest/. See also NEP 47
6-
https://numpy.org/neps/nep-0047-array-api-standard.html.
4+
This is a small wrapper around NumPy, CuPy, JAX and others that is compatible
5+
with the Array API standard https://data-apis.org/array-api/latest/.
6+
See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html.
77
88
Unlike array_api_strict, this is not a strict minimal implementation of the
99
Array API, but rather just an extension of the main NumPy namespace with

array_api_compat/common/_helpers.py

+262-3
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
"""
88
from __future__ import annotations
99

10+
import operator
1011
from typing import TYPE_CHECKING
1112

1213
if TYPE_CHECKING:
13-
from typing import Optional, Union, Any
14+
from typing import Callable, Literal, Optional, Union, Any
1415
from ._typing import Array, Device
1516

1617
import sys
@@ -91,7 +92,7 @@ def is_cupy_array(x):
9192
import cupy as cp
9293

9394
# TODO: Should we reject ndarray subclasses?
94-
return isinstance(x, (cp.ndarray, cp.generic))
95+
return isinstance(x, cp.ndarray)
9596

9697
def is_torch_array(x):
9798
"""
@@ -787,6 +788,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
787788
return x
788789
return x.to_device(device, stream=stream)
789790

791+
790792
def size(x):
791793
"""
792794
Return the total number of elements of x.
@@ -801,6 +803,261 @@ def size(x):
801803
return None
802804
return math.prod(x.shape)
803805

806+
807+
def is_writeable_array(x) -> bool:
808+
"""
809+
Return False if x.__setitem__ is expected to raise; True otherwise
810+
"""
811+
if is_numpy_array(x):
812+
return x.flags.writeable
813+
if is_jax_array(x) or is_pydata_sparse_array(x):
814+
return False
815+
return True
816+
817+
818+
def _is_fancy_index(idx) -> bool:
819+
if not isinstance(idx, tuple):
820+
idx = (idx,)
821+
return any(
822+
isinstance(i, (list, tuple)) or is_array_api_obj(i)
823+
for i in idx
824+
)
825+
826+
827+
_undef = object()
828+
829+
830+
class at:
831+
"""
832+
Update operations for read-only arrays.
833+
834+
This implements ``jax.numpy.ndarray.at`` for all backends.
835+
Writeable arrays may be updated in place; you should not rely on it.
836+
837+
Keyword arguments (e.g. ``indices_are_sorted``) are passed to JAX and are
838+
quietly ignored for backends that don't support them.
839+
840+
Additionally, this introduces support for the `copy` keyword for all backends:
841+
842+
None
843+
x *may* be modified in place if it is possible and beneficial
844+
for performance. You should not use x after calling this function.
845+
True
846+
Ensure that the inputs are not modified. This is the default.
847+
False
848+
Raise ValueError if a copy cannot be avoided.
849+
850+
Examples
851+
--------
852+
Given either of these equivalent expressions::
853+
854+
x = at(x)[1].add(2, copy=None)
855+
x = at(x, 1).add(2, copy=None)
856+
857+
If x is a JAX array, they are the same as::
858+
859+
x = x.at[1].add(2)
860+
861+
If x is a read-only numpy array, they are the same as::
862+
863+
x = x.copy()
864+
x[1] += 2
865+
866+
Otherwise, they are the same as::
867+
868+
x[1] += 2
869+
870+
Warning
871+
-------
872+
When you use copy=None, you should always immediately overwrite
873+
the parameter array::
874+
875+
x = at(x, 0).set(2, copy=None)
876+
877+
The anti-pattern below must be avoided, as it will result in different behaviour
878+
on read-only versus writeable arrays:
879+
880+
x = xp.asarray([0, 0, 0])
881+
y = at(x, 0).set(2, copy=None)
882+
z = at(x, 1).set(3, copy=None)
883+
884+
In the above example, y == [2, 0, 0] and z == [0, 3, 0] when x is read-only,
885+
whereas y == z == [2, 3, 0] when x is writeable!
886+
887+
Caveat
888+
------
889+
The behaviour of methods other than `get()` when the index is an array of
890+
integers which contains multiple occurrences of the same index is undefined.
891+
892+
**Undefined behaviour:** ``at(x, [0, 0]).set(2)``
893+
894+
See Also
895+
--------
896+
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
897+
"""
898+
899+
__slots__ = ("x", "idx")
900+
901+
def __init__(self, x, idx=_undef):
902+
self.x = x
903+
self.idx = idx
904+
905+
def __getitem__(self, idx):
906+
"""
907+
Allow for the alternate syntax ``at(x)[start:stop:step]``,
908+
which looks prettier than ``at(x, slice(start, stop, step))``
909+
and feels more intuitive coming from the JAX documentation.
910+
"""
911+
if self.idx is not _undef:
912+
raise ValueError("Index has already been set")
913+
self.idx = idx
914+
return self
915+
916+
def _common(
917+
self,
918+
at_op: str,
919+
y=_undef,
920+
copy: bool | None | Literal["_force_false"] = True,
921+
**kwargs,
922+
):
923+
"""Perform common prepocessing.
924+
925+
Returns
926+
-------
927+
If the operation can be resolved by at[], (return value, None)
928+
Otherwise, (None, preprocessed x)
929+
"""
930+
if self.idx is _undef:
931+
raise TypeError(
932+
"Index has not been set.\n"
933+
"Usage: either\n"
934+
" at(x, idx).set(value)\n"
935+
"or\n"
936+
" at(x)[idx].set(value)\n"
937+
"(same for all other methods)."
938+
)
939+
940+
x = self.x
941+
942+
if copy is False:
943+
if not is_writeable_array(x) or is_dask_array(x):
944+
raise ValueError("Cannot modify parameter in place")
945+
elif copy is None:
946+
copy = not is_writeable_array(x)
947+
elif copy == "_force_false":
948+
copy = False
949+
elif copy is not True:
950+
raise ValueError(f"Invalid value for copy: {copy!r}")
951+
952+
if is_jax_array(x):
953+
# Use JAX's at[]
954+
at_ = x.at[self.idx]
955+
args = (y,) if y is not _undef else ()
956+
return getattr(at_, at_op)(*args, **kwargs), None
957+
958+
# Emulate at[] behaviour for non-JAX arrays
959+
if copy:
960+
# FIXME We blindly expect the output of x.copy() to be always writeable.
961+
# This holds true for read-only numpy arrays, but not necessarily for
962+
# other backends.
963+
xp = get_namespace(x)
964+
x = xp.asarray(x, copy=True)
965+
966+
return None, x
967+
968+
def get(self, copy: bool | None = True, **kwargs):
969+
"""
970+
Return x[idx]. In addition to plain __getitem__, this allows ensuring
971+
that the output is (not) a copy and kwargs are passed to the backend.
972+
"""
973+
# __getitem__ with a fancy index always returns a copy.
974+
# Avoid an unnecessary double copy.
975+
# If copy is forced to False, raise.
976+
if _is_fancy_index(self.idx):
977+
if copy is False:
978+
raise ValueError(
979+
"Indexing a numpy array with a fancy index always "
980+
"results in a copy"
981+
)
982+
# Skip copy inside _common, even if array is not writeable
983+
copy = "_force_false" # type: ignore
984+
985+
res, x = self._common("get", copy=copy, **kwargs)
986+
if res is not None:
987+
return res
988+
return x[self.idx]
989+
990+
def set(self, y, /, **kwargs):
991+
"""x[idx] = y"""
992+
res, x = self._common("set", y, **kwargs)
993+
if res is not None:
994+
return res
995+
x[self.idx] = y
996+
return x
997+
998+
def apply(self, ufunc, /, **kwargs):
999+
"""ufunc.at(x, idx)"""
1000+
if is_cupy_array(self.x) or is_torch_array(self.x) or is_dask_array(self.x):
1001+
# ufunc.at not implemented
1002+
return self.set(ufunc(self.x[self.idx]), **kwargs)
1003+
1004+
res, x = self._common("apply", ufunc, **kwargs)
1005+
if res is not None:
1006+
return res
1007+
ufunc.at(x, self.idx)
1008+
return x
1009+
1010+
def _iop(
1011+
self, at_op: str, elwise_op: Callable[[Array, Array], Array], y: Array, **kwargs
1012+
):
1013+
"""x[idx] += y or equivalent in-place operation on a subset of x
1014+
1015+
which is the same as saying
1016+
x[idx] = x[idx] + y
1017+
Note that this is not the same as
1018+
operator.iadd(x[idx], y)
1019+
Consider for example when x is a numpy array and idx is a fancy index, which
1020+
triggers a deep copy on __getitem__.
1021+
"""
1022+
res, x = self._common(at_op, y, **kwargs)
1023+
if res is not None:
1024+
return res
1025+
x[self.idx] = elwise_op(x[self.idx], y)
1026+
return x
1027+
1028+
def add(self, y, /, **kwargs):
1029+
"""x[idx] += y"""
1030+
return self._iop("add", operator.add, y, **kwargs)
1031+
1032+
def subtract(self, y, /, **kwargs):
1033+
"""x[idx] -= y"""
1034+
return self._iop("subtract", operator.sub, y, **kwargs)
1035+
1036+
def multiply(self, y, /, **kwargs):
1037+
"""x[idx] *= y"""
1038+
return self._iop("multiply", operator.mul, y, **kwargs)
1039+
1040+
def divide(self, y, /, **kwargs):
1041+
"""x[idx] /= y"""
1042+
return self._iop("divide", operator.truediv, y, **kwargs)
1043+
1044+
def power(self, y, /, **kwargs):
1045+
"""x[idx] **= y"""
1046+
return self._iop("power", operator.pow, y, **kwargs)
1047+
1048+
def min(self, y, /, **kwargs):
1049+
"""x[idx] = minimum(x[idx], y)"""
1050+
import numpy as np
1051+
1052+
return self._iop("min", np.minimum, y, **kwargs)
1053+
1054+
def max(self, y, /, **kwargs):
1055+
"""x[idx] = maximum(x[idx], y)"""
1056+
import numpy as np
1057+
1058+
return self._iop("max", np.maximum, y, **kwargs)
1059+
1060+
8041061
__all__ = [
8051062
"array_namespace",
8061063
"device",
@@ -821,8 +1078,10 @@ def size(x):
8211078
"is_ndonnx_namespace",
8221079
"is_pydata_sparse_array",
8231080
"is_pydata_sparse_namespace",
1081+
"is_writeable_array",
8241082
"size",
8251083
"to_device",
1084+
"at",
8261085
]
8271086

828-
_all_ignore = ['sys', 'math', 'inspect', 'warnings']
1087+
_all_ignore = ['inspect', 'math', 'operator', 'warnings', 'sys']

docs/helper-functions.rst

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ instead, which would be wrapped.
3636
.. autofunction:: device
3737
.. autofunction:: to_device
3838
.. autofunction:: size
39+
.. autofunction:: at
3940

4041
Inspection Helpers
4142
------------------
@@ -51,6 +52,7 @@ yet.
5152
.. autofunction:: is_jax_array
5253
.. autofunction:: is_pydata_sparse_array
5354
.. autofunction:: is_ndonnx_array
55+
.. autofunction:: is_writeable_array
5456
.. autofunction:: is_numpy_namespace
5557
.. autofunction:: is_cupy_namespace
5658
.. autofunction:: is_torch_namespace

0 commit comments

Comments
 (0)