Skip to content

Commit 619c679

Browse files
committed
Abstractions for read-only arrays
1 parent ee25aae commit 619c679

File tree

3 files changed

+176
-3
lines changed

3 files changed

+176
-3
lines changed

array_api_compat/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""
22
NumPy Array API compatibility library
33
4-
This is a small wrapper around NumPy and CuPy that is compatible with the
5-
Array API standard https://data-apis.org/array-api/latest/. See also NEP 47
6-
https://numpy.org/neps/nep-0047-array-api-standard.html.
4+
This is a small wrapper around NumPy, CuPy, JAX and others that is compatible
5+
with the Array API standard https://data-apis.org/array-api/latest/.
6+
See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html.
77
88
Unlike array_api_strict, this is not a strict minimal implementation of the
99
Array API, but rather just an extension of the main NumPy namespace with

array_api_compat/common/_helpers.py

+170
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,173 @@ def size(x):
801801
return None
802802
return math.prod(x.shape)
803803

804+
def is_writeable_array(x):
805+
"""
806+
Return False if x.__setitem__ is expected to raise; True otherwise
807+
"""
808+
if is_numpy_array(x):
809+
return x.flags.writeable
810+
if is_jax_array(x):
811+
return False
812+
return True
813+
814+
_undef = object()
815+
816+
def at(x, idx=_undef, /):
817+
"""
818+
Update operations for read-only arrays.
819+
820+
This implements ``jax.numpy.ndarray.at`` for all backends.
821+
Writeable arrays may be updated in place; you should not rely on it.
822+
823+
Keyword arguments (e.g. ``indices_are_sorted``) are passed to JAX and are
824+
quietly ignored for backends that don't support them.
825+
826+
Examples
827+
--------
828+
Given either of these equivalent expressions::
829+
830+
x = at(x)[1].add(2)
831+
x = at(x, 1).add(2)
832+
833+
If x is a JAX array, they are the same as::
834+
835+
x = x.at[1].add(x)
836+
837+
If x is a read-only numpy array, they are the same as::
838+
839+
x = x.copy()
840+
x[1] += 2
841+
842+
Otherwise, they are the same as::
843+
844+
x[1] += 2
845+
846+
Warning
847+
-------
848+
You should always immediately overwrite the parameter array::
849+
850+
x = at(x, 0).set(2)
851+
852+
The anti-pattern below must be avoided, as it will result in different behaviour
853+
on read-only versus writeable arrays:
854+
855+
x = xp.asarray([0, 0, 0])
856+
y = at(x, 0).set(2)
857+
z = at(x, 1).set(3)
858+
859+
In the above example, y == [2, 0, 0] and z == [0, 3, 0] when x is read-only,
860+
whereas y == z == [2, 3, 0] when x is writeable!
861+
862+
See Also
863+
--------
864+
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
865+
"""
866+
if is_jax_array(x):
867+
return x.at
868+
if is_numpy_array(x) and not x.flags.writeable:
869+
x = x.copy()
870+
return _DummyAt(x, idx)
871+
872+
class _DummyAt:
873+
"""Helper of at().
874+
875+
Trivially implement jax.numpy.ndarray.at for other backends.
876+
x is updated in place.
877+
"""
878+
__slots__ = ("x", "idx")
879+
880+
def __init__(self, x, idx=_undef):
881+
self.x = x
882+
self.idx = idx
883+
884+
def __getitem__(self, idx):
885+
"""
886+
Allow for the alternate syntax ``at(x)[start:stop:step]``,
887+
which looks prettier than ``at(x, slice(start, stop, step))``
888+
and feels more intuitive coming from the JAX documentation.
889+
"""
890+
self.idx = idx
891+
return self
892+
893+
def _check_args(self, mode="promise_in_bounds", **kwargs):
894+
if self.idx is _undef:
895+
raise TypeError(
896+
"Index has not been set.\n"
897+
"Usage: either\n"
898+
" at(x, idx).set(value)\n"
899+
"or\n"
900+
" at(x)[idx].set(value)\n"
901+
"(same for all other methods)."
902+
)
903+
if mode != "promise_in_bounds":
904+
xp = array_namespace(self.x)
905+
raise NotImplementedError(
906+
f"mode='{mode}' is not supported for backend {xp.__name__}"
907+
)
908+
909+
def set(self, y, /, **kwargs):
910+
self._check_args(**kwargs)
911+
self.x[self.idx] = y
912+
return self.x
913+
914+
def add(self, y, /, **kwargs):
915+
self._check_args(**kwargs)
916+
self.x[self.idx] += y
917+
return self.x
918+
919+
def subtract(self, y, /, **kwargs):
920+
self._check_args(**kwargs)
921+
self.x[self.idx] -= y
922+
return self.x
923+
924+
def multiply(self, y, /, **kwargs):
925+
self._check_args(**kwargs)
926+
self.x[self.idx] *= y
927+
return self.x
928+
929+
def divide(self, y, /, **kwargs):
930+
self._check_args(**kwargs)
931+
self.x[self.idx] /= y
932+
return self.x
933+
934+
def power(self, y, /, **kwargs):
935+
self._check_args(**kwargs)
936+
self.x[self.idx] **= y
937+
return self.x
938+
939+
def min(self, y, /, **kwargs):
940+
self._check_args(**kwargs)
941+
xp = array_namespace(self.x)
942+
self.x[self.idx] = xp.minimum(self.x[self.idx], y)
943+
return self.x
944+
945+
def max(self, y, /, **kwargs):
946+
self._check_args(**kwargs)
947+
xp = array_namespace(self.x)
948+
self.x[self.idx] = xp.maximum(self.x[self.idx], y)
949+
return self.x
950+
951+
def apply(self, ufunc, /, **kwargs):
952+
self._check_args(**kwargs)
953+
ufunc.at(self.x, self.idx)
954+
return self.x
955+
956+
def get(self, **kwargs):
957+
self._check_args(**kwargs)
958+
return self.x[self.idx]
959+
960+
def iwhere(condition, x, y, /):
961+
"""Variant of xp.where(condition, x, y) which may or may not update
962+
x in place, if it's possible and beneficial for performance.
963+
"""
964+
if is_writeable_array(x):
965+
x[condition] = y
966+
return x
967+
else:
968+
xp = array_namespace(x)
969+
return xp.where(condition, x, y)
970+
804971
__all__ = [
805972
"array_namespace",
806973
"device",
@@ -821,8 +988,11 @@ def size(x):
821988
"is_ndonnx_namespace",
822989
"is_pydata_sparse_array",
823990
"is_pydata_sparse_namespace",
991+
"is_writeable_array",
824992
"size",
825993
"to_device",
994+
"at",
995+
"iwhere",
826996
]
827997

828998
_all_ignore = ['sys', 'math', 'inspect', 'warnings']

docs/helper-functions.rst

+3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ instead, which would be wrapped.
3636
.. autofunction:: device
3737
.. autofunction:: to_device
3838
.. autofunction:: size
39+
.. autofunction:: at
40+
.. autofunction:: iwhere
3941

4042
Inspection Helpers
4143
------------------
@@ -51,6 +53,7 @@ yet.
5153
.. autofunction:: is_jax_array
5254
.. autofunction:: is_pydata_sparse_array
5355
.. autofunction:: is_ndonnx_array
56+
.. autofunction:: is_writeable_array
5457
.. autofunction:: is_numpy_namespace
5558
.. autofunction:: is_cupy_namespace
5659
.. autofunction:: is_torch_namespace

0 commit comments

Comments
 (0)