Skip to content

Commit 376ad49

Browse files
committed
WIP at() method
1 parent 343ebf4 commit 376ad49

File tree

8 files changed

+467
-8
lines changed

8 files changed

+467
-8
lines changed

docs/api-reference.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
:nosignatures:
77
:toctree: generated
88
9+
at
910
atleast_nd
1011
cov
1112
create_diagonal

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ ignore = [
226226
"ISC001", # Conflicts with formatter
227227
"N802", # Function name should be lowercase
228228
"N806", # Variable in function should be lowercase
229+
"PD008", # pandas-use-of-dot-at
229230
]
230231

231232
[tool.ruff.lint.per-file-ignores]

src/array_api_extra/__init__.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
11
from __future__ import annotations
22

3-
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc
3+
from ._funcs import (
4+
at,
5+
atleast_nd,
6+
cov,
7+
create_diagonal,
8+
expand_dims,
9+
kron,
10+
setdiff1d,
11+
sinc,
12+
)
413

514
__version__ = "0.3.3.dev0"
615

716
__all__ = [
817
"__version__",
18+
"at",
919
"atleast_nd",
1020
"cov",
1121
"create_diagonal",

src/array_api_extra/_funcs.py

+284-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
from __future__ import annotations
22

3+
import operator
34
import warnings
5+
from collections.abc import Callable
6+
from typing import Any
47

58
from ._lib import _utils
6-
from ._lib._compat import array_namespace
9+
from ._lib._compat import (
10+
array_namespace,
11+
is_array_api_obj,
12+
is_dask_array,
13+
is_writeable_array,
14+
)
715
from ._lib._typing import Array, ModuleType
816

917
__all__ = [
18+
"at",
1019
"atleast_nd",
1120
"cov",
1221
"create_diagonal",
@@ -545,3 +554,277 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
545554
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device),
546555
)
547556
return xp.sin(y) / y
557+
558+
559+
_undef = object()
560+
561+
562+
class at: # noqa: N801
563+
"""
564+
Update operations for read-only arrays.
565+
566+
This implements ``jax.numpy.ndarray.at`` for all backends.
567+
568+
Parameters
569+
----------
570+
x : array
571+
Input array.
572+
idx : index, optional
573+
You may use two alternate syntaxes::
574+
575+
at(x, idx).set(value) # or get(), add(), etc.
576+
at(x)[idx].set(value)
577+
578+
copy : bool, optional
579+
True (default)
580+
Ensure that the inputs are not modified.
581+
False
582+
Ensure that the update operation writes back to the input.
583+
Raise ValueError if a copy cannot be avoided.
584+
None
585+
The array parameter *may* be modified in place if it is possible and
586+
beneficial for performance.
587+
You should not reuse it after calling this function.
588+
xp : array_namespace, optional
589+
The standard-compatible namespace for `x`. Default: infer
590+
591+
Additionally, if the backend supports an `at` method, any additional keyword
592+
arguments are passed to it verbatim; e.g. this allows passing
593+
``indices_are_sorted=True`` to JAX.
594+
595+
Returns
596+
-------
597+
Updated input array.
598+
599+
Examples
600+
--------
601+
Given either of these equivalent expressions::
602+
603+
x = at(x)[1].add(2, copy=None)
604+
x = at(x, 1).add(2, copy=None)
605+
606+
If x is a JAX array, they are the same as::
607+
608+
x = x.at[1].add(2)
609+
610+
If x is a read-only numpy array, they are the same as::
611+
612+
x = x.copy()
613+
x[1] += 2
614+
615+
Otherwise, they are the same as::
616+
617+
x[1] += 2
618+
619+
Warning
620+
-------
621+
When you use copy=None, you should always immediately overwrite
622+
the parameter array::
623+
624+
x = at(x, 0).set(2, copy=None)
625+
626+
The anti-pattern below must be avoided, as it will result in different behaviour
627+
on read-only versus writeable arrays::
628+
629+
x = xp.asarray([0, 0, 0])
630+
y = at(x, 0).set(2, copy=None)
631+
z = at(x, 1).set(3, copy=None)
632+
633+
In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
634+
when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable!
635+
636+
Warning
637+
-------
638+
The behaviour of update methods when the index is an array of integers which
639+
contains multiple occurrences of the same index is undefined;
640+
e.g. ``at(x, [0, 0]).set(2)``
641+
642+
Note
643+
----
644+
`sparse <https://sparse.pydata.org/>`_ is not supported by update methods yet.
645+
646+
See Also
647+
--------
648+
`jax.numpy.ndarray.at <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html>`_
649+
"""
650+
651+
x: Array
652+
idx: Any
653+
__slots__ = ("idx", "x")
654+
655+
def __init__(self, x: Array, idx: Any = _undef, /):
656+
self.x = x
657+
self.idx = idx
658+
659+
def __getitem__(self, idx: Any) -> Any:
660+
"""Allow for the alternate syntax ``at(x)[start:stop:step]``,
661+
which looks prettier than ``at(x, slice(start, stop, step))``
662+
and feels more intuitive coming from the JAX documentation.
663+
"""
664+
if self.idx is not _undef:
665+
msg = "Index has already been set"
666+
raise ValueError(msg)
667+
self.idx = idx
668+
return self
669+
670+
def _common(
671+
self,
672+
at_op: str,
673+
y: Array = _undef,
674+
/,
675+
copy: bool | None = True,
676+
xp: ModuleType | None = None,
677+
_is_update: bool = True,
678+
**kwargs: Any,
679+
) -> tuple[Any, None] | tuple[None, Array]:
680+
"""Perform common prepocessing.
681+
682+
Returns
683+
-------
684+
If the operation can be resolved by at[], (return value, None)
685+
Otherwise, (None, preprocessed x)
686+
"""
687+
if self.idx is _undef:
688+
msg = (
689+
"Index has not been set.\n"
690+
"Usage: either\n"
691+
" at(x, idx).set(value)\n"
692+
"or\n"
693+
" at(x)[idx].set(value)\n"
694+
"(same for all other methods)."
695+
)
696+
raise TypeError(msg)
697+
698+
x = self.x
699+
700+
if copy is True:
701+
writeable = None
702+
elif copy is False:
703+
writeable = is_writeable_array(x)
704+
if not writeable:
705+
msg = "Cannot modify parameter in place"
706+
raise ValueError(msg)
707+
elif copy is None:
708+
writeable = is_writeable_array(x)
709+
copy = _is_update and not writeable
710+
else:
711+
msg = f"Invalid value for copy: {copy!r}" # type: ignore[unreachable]
712+
raise ValueError(msg)
713+
714+
if copy:
715+
try:
716+
at_ = x.at
717+
except AttributeError:
718+
# Emulate at[] behaviour for non-JAX arrays
719+
# with a copy followed by an update
720+
if xp is None:
721+
xp = array_namespace(x)
722+
# Create writeable copy of read-only numpy array
723+
x = xp.asarray(x, copy=True)
724+
if writeable is False:
725+
# A copy of a read-only numpy array is writeable
726+
writeable = None
727+
else:
728+
# Use JAX's at[] or other library that with the same duck-type API
729+
args = (y,) if y is not _undef else ()
730+
return getattr(at_[self.idx], at_op)(*args, **kwargs), None
731+
732+
if _is_update:
733+
if writeable is None:
734+
writeable = is_writeable_array(x)
735+
if not writeable:
736+
# sparse crashes here
737+
msg = f"Array {x} has no `at` method and is read-only"
738+
raise ValueError(msg)
739+
740+
return None, x
741+
742+
def get(self, **kwargs: Any) -> Any:
743+
"""Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
744+
that the output is either a copy or a view; it also allows passing
745+
keyword arguments to the backend.
746+
"""
747+
if kwargs.get("copy") is False:
748+
if is_array_api_obj(self.idx):
749+
# Boolean index. Note that the array API spec
750+
# https://data-apis.org/array-api/latest/API_specification/indexing.html
751+
# does not allow for list, tuple, and tuples of slices plus one or more
752+
# one-dimensional array indices, although many backends support them.
753+
# So this check will encounter a lot of false negatives in real life,
754+
# which can be caught by testing the user code vs. array-api-strict.
755+
msg = "get() with an array index always returns a copy"
756+
raise ValueError(msg)
757+
if is_dask_array(self.x):
758+
msg = "get() on Dask arrays always returns a copy"
759+
raise ValueError(msg)
760+
761+
res, x = self._common("get", _is_update=False, **kwargs)
762+
if res is not None:
763+
return res
764+
assert x is not None
765+
return x[self.idx]
766+
767+
def set(self, y: Array, /, **kwargs: Any) -> Array:
768+
"""Apply ``x[idx] = y`` and return the update array"""
769+
res, x = self._common("set", y, **kwargs)
770+
if res is not None:
771+
return res
772+
assert x is not None
773+
x[self.idx] = y
774+
return x
775+
776+
def _iop(
777+
self,
778+
at_op: str,
779+
elwise_op: Callable[[Array, Array], Array],
780+
y: Array,
781+
/,
782+
**kwargs: Any,
783+
) -> Array:
784+
"""x[idx] += y or equivalent in-place operation on a subset of x
785+
786+
which is the same as saying
787+
x[idx] = x[idx] + y
788+
Note that this is not the same as
789+
operator.iadd(x[idx], y)
790+
Consider for example when x is a numpy array and idx is a fancy index, which
791+
triggers a deep copy on __getitem__.
792+
"""
793+
res, x = self._common(at_op, y, **kwargs)
794+
if res is not None:
795+
return res
796+
assert x is not None
797+
x[self.idx] = elwise_op(x[self.idx], y)
798+
return x
799+
800+
def add(self, y: Array, /, **kwargs: Any) -> Array:
801+
"""Apply ``x[idx] += y`` and return the updated array"""
802+
return self._iop("add", operator.add, y, **kwargs)
803+
804+
def subtract(self, y: Array, /, **kwargs: Any) -> Array:
805+
"""Apply ``x[idx] -= y`` and return the updated array"""
806+
return self._iop("subtract", operator.sub, y, **kwargs)
807+
808+
def multiply(self, y: Array, /, **kwargs: Any) -> Array:
809+
"""Apply ``x[idx] *= y`` and return the updated array"""
810+
return self._iop("multiply", operator.mul, y, **kwargs)
811+
812+
def divide(self, y: Array, /, **kwargs: Any) -> Array:
813+
"""Apply ``x[idx] /= y`` and return the updated array"""
814+
return self._iop("divide", operator.truediv, y, **kwargs)
815+
816+
def power(self, y: Array, /, **kwargs: Any) -> Array:
817+
"""Apply ``x[idx] **= y`` and return the updated array"""
818+
return self._iop("power", operator.pow, y, **kwargs)
819+
820+
def min(self, y: Array, /, **kwargs: Any) -> Array:
821+
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
822+
xp = array_namespace(self.x)
823+
y = xp.asarray(y)
824+
return self._iop("min", xp.minimum, y, **kwargs)
825+
826+
def max(self, y: Array, /, **kwargs: Any) -> Array:
827+
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
828+
xp = array_namespace(self.x)
829+
y = xp.asarray(y)
830+
return self._iop("max", xp.maximum, y, **kwargs)

src/array_api_extra/_lib/_compat.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,26 @@
33
from __future__ import annotations
44

55
try:
6-
from ..._array_api_compat_vendor import (
7-
array_namespace,
8-
device,
6+
from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
7+
array_namespace, # pyright: ignore[reportUnknownVariableType]
8+
device, # pyright: ignore[reportUnknownVariableType]
9+
is_array_api_obj, # pyright: ignore[reportUnknownVariableType]
10+
is_dask_array, # pyright: ignore[reportUnknownVariableType]
11+
is_writeable_array, # pyright: ignore[reportUnknownVariableType]
912
)
1013
except ImportError:
1114
from array_api_compat import (
1215
array_namespace,
1316
device,
17+
is_array_api_obj, # pyright: ignore[reportUnknownVariableType]
18+
is_dask_array, # pyright: ignore[reportUnknownVariableType]
19+
is_writeable_array, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue]
1420
)
1521

16-
__all__ = [
22+
__all__ = (
1723
"array_namespace",
1824
"device",
19-
]
25+
"is_array_api_obj",
26+
"is_dask_array",
27+
"is_writeable_array",
28+
)

src/array_api_extra/_lib/_compat.pyi

+3
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@ def array_namespace(
1111
use_compat: bool | None = None,
1212
) -> ArrayModule: ...
1313
def device(x: Array, /) -> Device: ...
14+
def is_array_api_obj(x: object, /) -> bool: ...
15+
def is_dask_array(x: object, /) -> bool: ...
16+
def is_writeable_array(x: object, /) -> bool: ...

0 commit comments

Comments
 (0)