Skip to content

Commit 4a0943a

Browse files
committed
WIP at() method
1 parent 343ebf4 commit 4a0943a

File tree

8 files changed

+481
-8
lines changed

8 files changed

+481
-8
lines changed

docs/api-reference.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
:nosignatures:
77
:toctree: generated
88
9+
at
910
atleast_nd
1011
cov
1112
create_diagonal

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ ignore = [
226226
"ISC001", # Conflicts with formatter
227227
"N802", # Function name should be lowercase
228228
"N806", # Variable in function should be lowercase
229+
"PD008", # pandas-use-of-dot-at
229230
]
230231

231232
[tool.ruff.lint.per-file-ignores]

src/array_api_extra/__init__.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
11
from __future__ import annotations
22

3-
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc
3+
from ._funcs import (
4+
at,
5+
atleast_nd,
6+
cov,
7+
create_diagonal,
8+
expand_dims,
9+
kron,
10+
setdiff1d,
11+
sinc,
12+
)
413

514
__version__ = "0.3.3.dev0"
615

716
__all__ = [
817
"__version__",
18+
"at",
919
"atleast_nd",
1020
"cov",
1121
"create_diagonal",

src/array_api_extra/_funcs.py

+286-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
from __future__ import annotations
22

3+
import operator
34
import warnings
5+
from collections.abc import Callable
6+
from typing import Any
47

58
from ._lib import _utils
6-
from ._lib._compat import array_namespace
9+
from ._lib._compat import (
10+
array_namespace,
11+
is_array_api_obj,
12+
is_dask_array,
13+
is_writeable_array,
14+
)
715
from ._lib._typing import Array, ModuleType
816

917
__all__ = [
18+
"at",
1019
"atleast_nd",
1120
"cov",
1221
"create_diagonal",
@@ -545,3 +554,279 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
545554
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device),
546555
)
547556
return xp.sin(y) / y
557+
558+
559+
_undef = object()
560+
561+
562+
class at: # noqa: N801
563+
"""
564+
Update operations for read-only arrays.
565+
566+
This implements ``jax.numpy.ndarray.at`` for all backends.
567+
568+
Parameters
569+
----------
570+
x : array
571+
Input array.
572+
idx : index, optional
573+
You may use two alternate syntaxes::
574+
575+
at(x, idx).set(value) # or get(), add(), etc.
576+
at(x)[idx].set(value)
577+
578+
copy : bool, optional
579+
True (default)
580+
Ensure that the inputs are not modified.
581+
False
582+
Ensure that the update operation writes back to the input.
583+
Raise ValueError if a copy cannot be avoided.
584+
None
585+
The array parameter *may* be modified in place if it is possible and
586+
beneficial for performance.
587+
You should not reuse it after calling this function.
588+
xp : array_namespace, optional
589+
The standard-compatible namespace for `x`. Default: infer
590+
591+
**kwargs:
592+
If the backend supports an `at` method, any additional keyword
593+
arguments are passed to it verbatim; e.g. this allows passing
594+
``indices_are_sorted=True`` to JAX.
595+
596+
Returns
597+
-------
598+
Updated input array.
599+
600+
Examples
601+
--------
602+
Given either of these equivalent expressions::
603+
604+
x = at(x)[1].add(2, copy=None)
605+
x = at(x, 1).add(2, copy=None)
606+
607+
If x is a JAX array, they are the same as::
608+
609+
x = x.at[1].add(2)
610+
611+
If x is a read-only numpy array, they are the same as::
612+
613+
x = x.copy()
614+
x[1] += 2
615+
616+
Otherwise, they are the same as::
617+
618+
x[1] += 2
619+
620+
Warning
621+
-------
622+
When you use copy=None, you should always immediately overwrite
623+
the parameter array::
624+
625+
x = at(x, 0).set(2, copy=None)
626+
627+
The anti-pattern below must be avoided, as it will result in different behaviour
628+
on read-only versus writeable arrays::
629+
630+
x = xp.asarray([0, 0, 0])
631+
y = at(x, 0).set(2, copy=None)
632+
z = at(x, 1).set(3, copy=None)
633+
634+
In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
635+
when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable!
636+
637+
Warning
638+
-------
639+
The array API standard does not support integer array indices.
640+
The behaviour of update methods when the index is an array of integers
641+
is undefined; this is particularly true when the index contains multiple
642+
occurrences of the same index, e.g. ``at(x, [0, 0]).set(2)``.
643+
644+
Note
645+
----
646+
`sparse <https://sparse.pydata.org/>`_ is not supported by update methods yet.
647+
648+
See Also
649+
--------
650+
`jax.numpy.ndarray.at <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html>`_
651+
"""
652+
653+
x: Array
654+
idx: Any
655+
__slots__ = ("idx", "x")
656+
657+
def __init__(self, x: Array, idx: Any = _undef, /):
658+
self.x = x
659+
self.idx = idx
660+
661+
def __getitem__(self, idx: Any) -> Any:
662+
"""Allow for the alternate syntax ``at(x)[start:stop:step]``,
663+
which looks prettier than ``at(x, slice(start, stop, step))``
664+
and feels more intuitive coming from the JAX documentation.
665+
"""
666+
if self.idx is not _undef:
667+
msg = "Index has already been set"
668+
raise ValueError(msg)
669+
self.idx = idx
670+
return self
671+
672+
def _common(
673+
self,
674+
at_op: str,
675+
y: Array = _undef,
676+
/,
677+
copy: bool | None = True,
678+
xp: ModuleType | None = None,
679+
_is_update: bool = True,
680+
**kwargs: Any,
681+
) -> tuple[Any, None] | tuple[None, Array]:
682+
"""Perform common prepocessing.
683+
684+
Returns
685+
-------
686+
If the operation can be resolved by at[], (return value, None)
687+
Otherwise, (None, preprocessed x)
688+
"""
689+
if self.idx is _undef:
690+
msg = (
691+
"Index has not been set.\n"
692+
"Usage: either\n"
693+
" at(x, idx).set(value)\n"
694+
"or\n"
695+
" at(x)[idx].set(value)\n"
696+
"(same for all other methods)."
697+
)
698+
raise TypeError(msg)
699+
700+
x = self.x
701+
702+
if copy is True:
703+
writeable = None
704+
elif copy is False:
705+
writeable = is_writeable_array(x)
706+
if not writeable:
707+
msg = "Cannot modify parameter in place"
708+
raise ValueError(msg)
709+
elif copy is None:
710+
writeable = is_writeable_array(x)
711+
copy = _is_update and not writeable
712+
else:
713+
msg = f"Invalid value for copy: {copy!r}" # type: ignore[unreachable]
714+
raise ValueError(msg)
715+
716+
if copy:
717+
try:
718+
at_ = x.at
719+
except AttributeError:
720+
# Emulate at[] behaviour for non-JAX arrays
721+
# with a copy followed by an update
722+
if xp is None:
723+
xp = array_namespace(x)
724+
# Create writeable copy of read-only numpy array
725+
x = xp.asarray(x, copy=True)
726+
if writeable is False:
727+
# A copy of a read-only numpy array is writeable
728+
writeable = None
729+
else:
730+
# Use JAX's at[] or other library that with the same duck-type API
731+
args = (y,) if y is not _undef else ()
732+
return getattr(at_[self.idx], at_op)(*args, **kwargs), None
733+
734+
if _is_update:
735+
if writeable is None:
736+
writeable = is_writeable_array(x)
737+
if not writeable:
738+
# sparse crashes here
739+
msg = f"Array {x} has no `at` method and is read-only"
740+
raise ValueError(msg)
741+
742+
return None, x
743+
744+
def get(self, **kwargs: Any) -> Any:
745+
"""Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
746+
that the output is either a copy or a view; it also allows passing
747+
keyword arguments to the backend.
748+
"""
749+
if kwargs.get("copy") is False:
750+
if is_array_api_obj(self.idx):
751+
# Boolean index. Note that the array API spec
752+
# https://data-apis.org/array-api/latest/API_specification/indexing.html
753+
# does not allow for list, tuple, and tuples of slices plus one or more
754+
# one-dimensional array indices, although many backends support them.
755+
# So this check will encounter a lot of false negatives in real life,
756+
# which can be caught by testing the user code vs. array-api-strict.
757+
msg = "get() with an array index always returns a copy"
758+
raise ValueError(msg)
759+
if is_dask_array(self.x):
760+
msg = "get() on Dask arrays always returns a copy"
761+
raise ValueError(msg)
762+
763+
res, x = self._common("get", _is_update=False, **kwargs)
764+
if res is not None:
765+
return res
766+
assert x is not None
767+
return x[self.idx]
768+
769+
def set(self, y: Array, /, **kwargs: Any) -> Array:
770+
"""Apply ``x[idx] = y`` and return the update array"""
771+
res, x = self._common("set", y, **kwargs)
772+
if res is not None:
773+
return res
774+
assert x is not None
775+
x[self.idx] = y
776+
return x
777+
778+
def _iop(
779+
self,
780+
at_op: str,
781+
elwise_op: Callable[[Array, Array], Array],
782+
y: Array,
783+
/,
784+
**kwargs: Any,
785+
) -> Array:
786+
"""x[idx] += y or equivalent in-place operation on a subset of x
787+
788+
which is the same as saying
789+
x[idx] = x[idx] + y
790+
Note that this is not the same as
791+
operator.iadd(x[idx], y)
792+
Consider for example when x is a numpy array and idx is a fancy index, which
793+
triggers a deep copy on __getitem__.
794+
"""
795+
res, x = self._common(at_op, y, **kwargs)
796+
if res is not None:
797+
return res
798+
assert x is not None
799+
x[self.idx] = elwise_op(x[self.idx], y)
800+
return x
801+
802+
def add(self, y: Array, /, **kwargs: Any) -> Array:
803+
"""Apply ``x[idx] += y`` and return the updated array"""
804+
return self._iop("add", operator.add, y, **kwargs)
805+
806+
def subtract(self, y: Array, /, **kwargs: Any) -> Array:
807+
"""Apply ``x[idx] -= y`` and return the updated array"""
808+
return self._iop("subtract", operator.sub, y, **kwargs)
809+
810+
def multiply(self, y: Array, /, **kwargs: Any) -> Array:
811+
"""Apply ``x[idx] *= y`` and return the updated array"""
812+
return self._iop("multiply", operator.mul, y, **kwargs)
813+
814+
def divide(self, y: Array, /, **kwargs: Any) -> Array:
815+
"""Apply ``x[idx] /= y`` and return the updated array"""
816+
return self._iop("divide", operator.truediv, y, **kwargs)
817+
818+
def power(self, y: Array, /, **kwargs: Any) -> Array:
819+
"""Apply ``x[idx] **= y`` and return the updated array"""
820+
return self._iop("power", operator.pow, y, **kwargs)
821+
822+
def min(self, y: Array, /, **kwargs: Any) -> Array:
823+
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
824+
xp = array_namespace(self.x)
825+
y = xp.asarray(y)
826+
return self._iop("min", xp.minimum, y, **kwargs)
827+
828+
def max(self, y: Array, /, **kwargs: Any) -> Array:
829+
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
830+
xp = array_namespace(self.x)
831+
y = xp.asarray(y)
832+
return self._iop("max", xp.maximum, y, **kwargs)

src/array_api_extra/_lib/_compat.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,26 @@
33
from __future__ import annotations
44

55
try:
6-
from ..._array_api_compat_vendor import (
7-
array_namespace,
8-
device,
6+
from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
7+
array_namespace, # pyright: ignore[reportUnknownVariableType]
8+
device, # pyright: ignore[reportUnknownVariableType]
9+
is_array_api_obj, # pyright: ignore[reportUnknownVariableType]
10+
is_dask_array, # pyright: ignore[reportUnknownVariableType]
11+
is_writeable_array, # pyright: ignore[reportUnknownVariableType]
912
)
1013
except ImportError:
1114
from array_api_compat import (
1215
array_namespace,
1316
device,
17+
is_array_api_obj, # pyright: ignore[reportUnknownVariableType]
18+
is_dask_array, # pyright: ignore[reportUnknownVariableType]
19+
is_writeable_array, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue]
1420
)
1521

16-
__all__ = [
22+
__all__ = (
1723
"array_namespace",
1824
"device",
19-
]
25+
"is_array_api_obj",
26+
"is_dask_array",
27+
"is_writeable_array",
28+
)

src/array_api_extra/_lib/_compat.pyi

+3
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@ def array_namespace(
1111
use_compat: bool | None = None,
1212
) -> ArrayModule: ...
1313
def device(x: Array, /) -> Device: ...
14+
def is_array_api_obj(x: object, /) -> bool: ...
15+
def is_dask_array(x: object, /) -> bool: ...
16+
def is_writeable_array(x: object, /) -> bool: ...

0 commit comments

Comments
 (0)