Skip to content

Commit dd068d2

Browse files
committed
WIP at() method
1 parent 00b2ab7 commit dd068d2

File tree

12 files changed

+452
-17
lines changed

12 files changed

+452
-17
lines changed

docs/api-reference.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
:nosignatures:
77
:toctree: generated
88
9+
at
910
atleast_nd
1011
cov
1112
create_diagonal

pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ run.source = ["array_api_extra"]
164164
report.exclude_also = [
165165
'\.\.\.',
166166
'if typing.TYPE_CHECKING:',
167+
'if TYPE_CHECKING:',
167168
]
168169

169170

@@ -235,6 +236,7 @@ ignore = [
235236
"PLR09", # Too many <...>
236237
"PLR2004", # Magic value used in comparison
237238
"ISC001", # Conflicts with formatter
239+
"PD008", # pandas-use-of-dot-at
238240
]
239241
isort.required-imports = ["from __future__ import annotations"]
240242

src/array_api_extra/__init__.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
22

3-
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc
3+
from ._funcs import (
4+
at,
5+
atleast_nd,
6+
cov,
7+
create_diagonal,
8+
expand_dims,
9+
kron,
10+
setdiff1d,
11+
sinc,
12+
)
413

514
__version__ = "0.3.3"
615

716
# pylint: disable=duplicate-code
817
__all__ = [
918
"__version__",
19+
"at",
1020
"atleast_nd",
1121
"cov",
1222
"create_diagonal",

src/array_api_extra/_funcs.py

+279-3
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
22

3-
import typing
3+
import operator
44
import warnings
5+
from typing import TYPE_CHECKING, Any, Callable
56

6-
if typing.TYPE_CHECKING:
7+
if TYPE_CHECKING:
78
from ._lib._typing import Array, ModuleType
89

910
from ._lib import _utils
10-
from ._lib._compat import array_namespace
11+
from ._lib._compat import (
12+
array_namespace,
13+
is_array_api_obj,
14+
is_writeable_array,
15+
)
1116

1217
__all__ = [
18+
"at",
1319
"atleast_nd",
1420
"cov",
1521
"create_diagonal",
@@ -546,3 +552,273 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
546552
x, x, xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device)
547553
)
548554
return xp.sin(y) / y
555+
556+
557+
_undef = object()
558+
559+
560+
class at:
561+
"""
562+
Update operations for read-only arrays.
563+
564+
This implements ``jax.numpy.ndarray.at`` for all backends.
565+
566+
Parameters
567+
----------
568+
x : array
569+
Input array.
570+
idx : index, optional
571+
You may use two alternate syntaxes::
572+
573+
at(x, idx).set(value) # or get(), add(), etc.
574+
at(x)[idx].set(value)
575+
576+
copy : bool, optional
577+
True (default)
578+
Ensure that the inputs are not modified.
579+
False
580+
Ensure that the update operation writes back to the input.
581+
Raise ValueError if a copy cannot be avoided.
582+
None
583+
The array parameter *may* be modified in place if it is possible and
584+
beneficial for performance.
585+
You should not reuse it after calling this function.
586+
xp : array_namespace, optional
587+
The standard-compatible namespace for `x`. Default: infer
588+
589+
Additionally, if the backend supports an `at` method, any additional keyword
590+
arguments are passed to it verbatim; e.g. this allows passing
591+
``indices_are_sorted=True`` to JAX.
592+
593+
Returns
594+
-------
595+
Updated input array.
596+
597+
Examples
598+
--------
599+
Given either of these equivalent expressions::
600+
601+
x = at(x)[1].add(2, copy=None)
602+
x = at(x, 1).add(2, copy=None)
603+
604+
If x is a JAX array, they are the same as::
605+
606+
x = x.at[1].add(2)
607+
608+
If x is a read-only numpy array, they are the same as::
609+
610+
x = x.copy()
611+
x[1] += 2
612+
613+
Otherwise, they are the same as::
614+
615+
x[1] += 2
616+
617+
Warning
618+
-------
619+
When you use copy=None, you should always immediately overwrite
620+
the parameter array::
621+
622+
x = at(x, 0).set(2, copy=None)
623+
624+
The anti-pattern below must be avoided, as it will result in different behaviour
625+
on read-only versus writeable arrays::
626+
627+
x = xp.asarray([0, 0, 0])
628+
y = at(x, 0).set(2, copy=None)
629+
z = at(x, 1).set(3, copy=None)
630+
631+
In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
632+
when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable!
633+
634+
Warning
635+
-------
636+
The behaviour of update methods when the index is an array of integers which
637+
contains multiple occurrences of the same index is undefined;
638+
e.g. ``at(x, [0, 0]).set(2)``
639+
640+
Note
641+
----
642+
`sparse <https://sparse.pydata.org/>`_ is not supported by update methods yet.
643+
644+
See Also
645+
--------
646+
`jax.numpy.ndarray.at <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html>`_
647+
"""
648+
649+
x: Array
650+
idx: Any
651+
__slots__ = ("x", "idx")
652+
653+
def __init__(self, x: Array, idx: Any = _undef, /):
654+
self.x = x
655+
self.idx = idx
656+
657+
def __getitem__(self, idx: Any) -> Any:
658+
"""Allow for the alternate syntax ``at(x)[start:stop:step]``,
659+
which looks prettier than ``at(x, slice(start, stop, step))``
660+
and feels more intuitive coming from the JAX documentation.
661+
"""
662+
if self.idx is not _undef:
663+
msg = "Index has already been set"
664+
raise ValueError(msg)
665+
self.idx = idx
666+
return self
667+
668+
def _common(
669+
self,
670+
at_op: str,
671+
y: Array = _undef,
672+
/,
673+
copy: bool | None = True,
674+
xp: ModuleType | None = None,
675+
_is_update: bool = True,
676+
**kwargs: Any,
677+
) -> tuple[Any, None] | tuple[None, Array]:
678+
"""Perform common prepocessing.
679+
680+
Returns
681+
-------
682+
If the operation can be resolved by at[], (return value, None)
683+
Otherwise, (None, preprocessed x)
684+
"""
685+
if self.idx is _undef:
686+
msg = (
687+
"Index has not been set.\n"
688+
"Usage: either\n"
689+
" at(x, idx).set(value)\n"
690+
"or\n"
691+
" at(x)[idx].set(value)\n"
692+
"(same for all other methods)."
693+
)
694+
raise TypeError(msg)
695+
696+
x = self.x
697+
698+
if copy is True:
699+
writeable = None
700+
elif copy is False:
701+
writeable = is_writeable_array(x)
702+
if not writeable:
703+
msg = "Cannot modify parameter in place"
704+
raise ValueError(msg)
705+
elif copy is None:
706+
writeable = is_writeable_array(x)
707+
copy = _is_update and not writeable
708+
else:
709+
msg = f"Invalid value for copy: {copy!r}"
710+
raise ValueError(msg)
711+
712+
if copy:
713+
try:
714+
at_ = x.at
715+
except AttributeError:
716+
# Emulate at[] behaviour for non-JAX arrays
717+
# with a copy followed by an update
718+
if xp is None:
719+
xp = array_namespace(x)
720+
# Create writeable copy of read-only numpy array
721+
x = xp.asarray(x, copy=True)
722+
if writeable is False:
723+
# A copy of a read-only numpy array is writeable
724+
writeable = None
725+
else:
726+
# Use JAX's at[] or other library that with the same duck-type API
727+
args = (y,) if y is not _undef else ()
728+
return getattr(at_[self.idx], at_op)(*args, **kwargs), None
729+
730+
if _is_update:
731+
if writeable is None:
732+
writeable = is_writeable_array(x)
733+
if not writeable:
734+
# sparse crashes here
735+
msg = f"Array {x} has no `at` method and is read-only"
736+
raise ValueError(msg)
737+
738+
return None, x
739+
740+
def get(self, **kwargs: Any) -> Any:
741+
"""Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
742+
that the output is either a copy or a view; it also allows passing
743+
keyword arguments to the backend.
744+
"""
745+
if kwargs.get("copy") is False and (
746+
is_array_api_obj(self.idx)
747+
or isinstance(self.idx, tuple)
748+
and any(is_array_api_obj(i) for i in self.idx)
749+
):
750+
# Fancy index. Note that the array API spec does not allow for
751+
# list, tuple, or numpy arrays although many backends support them.
752+
msg = "get() with an array index always returns in a copy"
753+
raise ValueError(msg)
754+
755+
res, x = self._common("get", _is_update=False, **kwargs)
756+
if res is not None:
757+
return res
758+
assert x is not None
759+
return x[self.idx]
760+
761+
def set(self, y: Array, /, **kwargs: Any) -> Array:
762+
"""Apply ``x[idx] = y`` and return the update array"""
763+
res, x = self._common("set", y, **kwargs)
764+
if res is not None:
765+
return res
766+
assert x is not None
767+
x[self.idx] = y
768+
return x
769+
770+
def _iop(
771+
self,
772+
at_op: str,
773+
elwise_op: Callable[[Array, Array], Array],
774+
y: Array,
775+
/,
776+
**kwargs: Any,
777+
) -> Array:
778+
"""x[idx] += y or equivalent in-place operation on a subset of x
779+
780+
which is the same as saying
781+
x[idx] = x[idx] + y
782+
Note that this is not the same as
783+
operator.iadd(x[idx], y)
784+
Consider for example when x is a numpy array and idx is a fancy index, which
785+
triggers a deep copy on __getitem__.
786+
"""
787+
res, x = self._common(at_op, y, **kwargs)
788+
if res is not None:
789+
return res
790+
assert x is not None
791+
x[self.idx] = elwise_op(x[self.idx], y)
792+
return x
793+
794+
def add(self, y: Array, /, **kwargs: Any) -> Array:
795+
"""Apply ``x[idx] += y`` and return the updated array"""
796+
return self._iop("add", operator.add, y, **kwargs)
797+
798+
def subtract(self, y: Array, /, **kwargs: Any) -> Array:
799+
"""Apply ``x[idx] -= y`` and return the updated array"""
800+
return self._iop("subtract", operator.sub, y, **kwargs)
801+
802+
def multiply(self, y: Array, /, **kwargs: Any) -> Array:
803+
"""Apply ``x[idx] *= y`` and return the updated array"""
804+
return self._iop("multiply", operator.mul, y, **kwargs)
805+
806+
def divide(self, y: Array, /, **kwargs: Any) -> Array:
807+
"""Apply ``x[idx] /= y`` and return the updated array"""
808+
return self._iop("divide", operator.truediv, y, **kwargs)
809+
810+
def power(self, y: Array, /, **kwargs: Any) -> Array:
811+
"""Apply ``x[idx] **= y`` and return the updated array"""
812+
return self._iop("power", operator.pow, y, **kwargs)
813+
814+
def min(self, y: Array, /, **kwargs: Any) -> Array:
815+
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
816+
xp = array_namespace(self.x)
817+
y = xp.asarray(y)
818+
return self._iop("min", xp.minimum, y, **kwargs)
819+
820+
def max(self, y: Array, /, **kwargs: Any) -> Array:
821+
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
822+
xp = array_namespace(self.x)
823+
y = xp.asarray(y)
824+
return self._iop("max", xp.maximum, y, **kwargs)

src/array_api_extra/_lib/_compat.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,20 @@
66
from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
77
array_namespace, # pyright: ignore[reportUnknownVariableType]
88
device, # pyright: ignore[reportUnknownVariableType]
9+
is_array_api_obj, # pyright: ignore[reportUnknownVariableType]
10+
is_writeable_array, # pyright: ignore[reportUnknownVariableType]
911
)
1012
except ImportError:
1113
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
1214
array_namespace, # pyright: ignore[reportUnknownVariableType]
1315
device,
16+
is_array_api_obj, # pyright: ignore[reportUnknownVariableType]
17+
is_writeable_array, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue]
1418
)
1519

16-
__all__ = [
20+
__all__ = (
1721
"array_namespace",
1822
"device",
19-
]
23+
"is_array_api_obj",
24+
"is_writeable_array",
25+
)

src/array_api_extra/_lib/_compat.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,5 @@ def array_namespace(
1111
use_compat: bool | None = None,
1212
) -> ArrayModule: ...
1313
def device(x: Array, /) -> Device: ...
14+
def is_array_api_obj(x: object, /) -> bool: ...
15+
def is_writeable_array(x: object, /) -> bool: ...

src/array_api_extra/_lib/_typing.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
22

3-
import typing
43
from types import ModuleType
5-
from typing import Any
4+
from typing import TYPE_CHECKING, Any
65

7-
if typing.TYPE_CHECKING:
6+
if TYPE_CHECKING:
87
from typing_extensions import override
98

109
# To be changed to a Protocol later (see data-apis/array-api#589)
@@ -18,5 +17,5 @@ def no_op_decorator(f): # pyright: ignore[reportUnreachable]
1817
override = no_op_decorator
1918

2019
__all__ = ["ModuleType", "override"]
21-
if typing.TYPE_CHECKING:
20+
if TYPE_CHECKING:
2221
__all__ += ["Array", "Device"]

0 commit comments

Comments
 (0)