Skip to content

Commit 162264b

Browse files
committed
WIP at() method
1 parent 00b2ab7 commit 162264b

File tree

7 files changed

+477
-4
lines changed

7 files changed

+477
-4
lines changed

docs/api-reference.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
:nosignatures:
77
:toctree: generated
88
9+
at
910
atleast_nd
1011
cov
1112
create_diagonal

pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,9 @@ ignore = [
235235
"PLR09", # Too many <...>
236236
"PLR2004", # Magic value used in comparison
237237
"ISC001", # Conflicts with formatter
238+
"EM101", # raw-string-in-exception
239+
"EM102", # f-string-in-exception
240+
"PD008", # pandas-use-of-dot-at
238241
]
239242
isort.required-imports = ["from __future__ import annotations"]
240243

src/array_api_extra/__init__.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
22

3-
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc
3+
from ._funcs import (
4+
at,
5+
atleast_nd,
6+
cov,
7+
create_diagonal,
8+
expand_dims,
9+
kron,
10+
setdiff1d,
11+
sinc,
12+
)
413

514
__version__ = "0.3.3"
615

716
# pylint: disable=duplicate-code
817
__all__ = [
918
"__version__",
19+
"at",
1020
"atleast_nd",
1121
"cov",
1222
"create_diagonal",

src/array_api_extra/_funcs.py

+281-3
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
22

3-
import typing
3+
import operator
44
import warnings
5+
from typing import TYPE_CHECKING, Any, Callable, Literal
56

6-
if typing.TYPE_CHECKING:
7+
if TYPE_CHECKING:
78
from ._lib._typing import Array, ModuleType
89

910
from ._lib import _utils
10-
from ._lib._compat import array_namespace
11+
from ._lib._compat import (
12+
array_namespace,
13+
is_array_api_obj,
14+
is_writeable_array,
15+
)
1116

1217
__all__ = [
18+
"at",
1319
"atleast_nd",
1420
"cov",
1521
"create_diagonal",
@@ -546,3 +552,275 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
546552
x, x, xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device)
547553
)
548554
return xp.sin(y) / y
555+
556+
557+
def _is_fancy_index(idx: object | tuple[object, ...]) -> bool:
558+
if not isinstance(idx, tuple):
559+
idx = (idx,)
560+
return any(isinstance(i, (list, tuple)) or is_array_api_obj(i) for i in idx)
561+
562+
563+
_undef = object()
564+
565+
566+
class at:
567+
"""
568+
Update operations for read-only arrays.
569+
570+
This implements ``jax.numpy.ndarray.at`` for all backends.
571+
572+
Parameters
573+
----------
574+
x : array
575+
Input array.
576+
577+
copy : bool, optional
578+
True (default)
579+
Ensure that the inputs are not modified.
580+
False
581+
Ensure that the update operation writes back to the input.
582+
Raise ValueError if a copy cannot be avoided.
583+
None
584+
The array parameter *may* be modified in place if it is possible and
585+
beneficial for performance.
586+
You should not reuse it after calling this function.
587+
xp : array_namespace, optional
588+
The standard-compatible namespace for `x`. Default: infer
589+
590+
Additionally, if the backend supports an `at` method, any additional keyword
591+
arguments are passed to it verbatim; e.g. this allows passing
592+
``indices_are_sorted=True`` to JAX.
593+
594+
Returns
595+
-------
596+
Updated input array.
597+
598+
Examples
599+
--------
600+
Given either of these equivalent expressions::
601+
602+
x = at(x)[1].add(2, copy=None)
603+
x = at(x, 1).add(2, copy=None)
604+
605+
If x is a JAX array, they are the same as::
606+
607+
x = x.at[1].add(2)
608+
609+
If x is a read-only numpy array, they are the same as::
610+
611+
x = x.copy()
612+
x[1] += 2
613+
614+
Otherwise, they are the same as::
615+
616+
x[1] += 2
617+
618+
Warning
619+
-------
620+
When you use copy=None, you should always immediately overwrite
621+
the parameter array::
622+
623+
x = at(x, 0).set(2, copy=None)
624+
625+
The anti-pattern below must be avoided, as it will result in different behaviour
626+
on read-only versus writeable arrays::
627+
628+
x = xp.asarray([0, 0, 0])
629+
y = at(x, 0).set(2, copy=None)
630+
z = at(x, 1).set(3, copy=None)
631+
632+
In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
633+
when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable!
634+
635+
Warning
636+
-------
637+
The behaviour of update methods when the index is an array of integers which
638+
contains multiple occurrences of the same index is undefined;
639+
e.g. ``at(x, [0, 0]).set(2)``
640+
641+
Note
642+
----
643+
`sparse <https://sparse.pydata.org/>`_ is not supported by update methods yet.
644+
645+
See Also
646+
--------
647+
`jax.numpy.ndarray.at <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html>`_
648+
"""
649+
650+
x: Array
651+
idx: Any
652+
__slots__ = ("x", "idx")
653+
654+
def __init__(self, x: Array, idx: Any = _undef, /):
655+
self.x = x
656+
self.idx = idx
657+
658+
def __getitem__(self, idx: Any) -> Any:
659+
"""Allow for the alternate syntax ``at(x)[start:stop:step]``,
660+
which looks prettier than ``at(x, slice(start, stop, step))``
661+
and feels more intuitive coming from the JAX documentation.
662+
"""
663+
if self.idx is not _undef:
664+
raise ValueError("Index has already been set")
665+
self.idx = idx
666+
return self
667+
668+
def _common(
669+
self,
670+
at_op: str,
671+
y: Array = _undef,
672+
/,
673+
copy: bool | None | Literal["_force_false"] = True,
674+
xp: ModuleType | None = None,
675+
_is_update: bool = True,
676+
**kwargs: Any,
677+
) -> tuple[Any, None] | tuple[None, Array]:
678+
"""Perform common prepocessing.
679+
680+
Returns
681+
-------
682+
If the operation can be resolved by at[], (return value, None)
683+
Otherwise, (None, preprocessed x)
684+
"""
685+
if self.idx is _undef:
686+
raise TypeError(
687+
"Index has not been set.\n"
688+
"Usage: either\n"
689+
" at(x, idx).set(value)\n"
690+
"or\n"
691+
" at(x)[idx].set(value)\n"
692+
"(same for all other methods)."
693+
)
694+
695+
x = self.x
696+
697+
if copy is True:
698+
writeable = None
699+
elif copy is False:
700+
writeable = is_writeable_array(x)
701+
if not writeable:
702+
raise ValueError("Cannot modify parameter in place")
703+
elif copy is None:
704+
writeable = is_writeable_array(x)
705+
copy = _is_update and not writeable
706+
elif copy == "_force_false": # type: ignore[redundant-expr]
707+
# __getitem__ with fancy index on a numpy array
708+
writeable = True
709+
copy = False
710+
else:
711+
raise ValueError(f"Invalid value for copy: {copy!r}")
712+
713+
if copy:
714+
try:
715+
at_ = x.at
716+
except AttributeError:
717+
# Emulate at[] behaviour for non-JAX arrays
718+
# with a copy followed by an update
719+
if xp is None:
720+
xp = array_namespace(x)
721+
# Create writeable copy of read-only numpy array
722+
x = xp.asarray(x, copy=True)
723+
else:
724+
# Use JAX's at[] or other library that with the same duck-type API
725+
args = (y,) if y is not _undef else ()
726+
return getattr(at_[self.idx], at_op)(*args, **kwargs), None
727+
728+
# This blindly expects that if x is writeable its copy is also writeable
729+
if _is_update:
730+
if writeable is None:
731+
writeable = is_writeable_array(x)
732+
if not writeable:
733+
# sparse crashes here
734+
raise ValueError(f"Array {x} has no `at` method and is read-only")
735+
736+
return None, x
737+
738+
def get(self, **kwargs: Any) -> Any:
739+
"""Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
740+
that the output is either a copy or a view; it also allows passing
741+
keyword arguments to the backend.
742+
"""
743+
# __getitem__ with a fancy index always returns a copy.
744+
# Avoid an unnecessary double copy.
745+
# If copy is forced to False, raise.
746+
# FIXME This is an assumption based on numpy behaviour; it may not hold true
747+
# for other backends. Namely, a backend could decide to conditionally return a
748+
# view if the index can be coerced into a slice.
749+
if _is_fancy_index(self.idx):
750+
if kwargs.get("copy") is False:
751+
raise TypeError(
752+
"Indexing an array with a fancy index always results in a copy"
753+
)
754+
# Skip copy inside _common, even if array is not writeable
755+
kwargs["copy"] = "_force_false"
756+
757+
res, x = self._common("get", _is_update=False, **kwargs)
758+
if res is not None:
759+
return res
760+
assert x is not None
761+
return x[self.idx]
762+
763+
def set(self, y: Array, /, **kwargs: Any) -> Array:
764+
"""Apply ``x[idx] = y`` and return the update array"""
765+
res, x = self._common("set", y, **kwargs)
766+
if res is not None:
767+
return res
768+
assert x is not None
769+
x[self.idx] = y
770+
return x
771+
772+
def _iop(
773+
self,
774+
at_op: str,
775+
elwise_op: Callable[[Array, Array], Array],
776+
y: Array,
777+
/,
778+
**kwargs: Any,
779+
) -> Array:
780+
"""x[idx] += y or equivalent in-place operation on a subset of x
781+
782+
which is the same as saying
783+
x[idx] = x[idx] + y
784+
Note that this is not the same as
785+
operator.iadd(x[idx], y)
786+
Consider for example when x is a numpy array and idx is a fancy index, which
787+
triggers a deep copy on __getitem__.
788+
"""
789+
res, x = self._common(at_op, y, **kwargs)
790+
if res is not None:
791+
return res
792+
assert x is not None
793+
x[self.idx] = elwise_op(x[self.idx], y)
794+
return x
795+
796+
def add(self, y: Array, /, **kwargs: Any) -> Array:
797+
"""Apply ``x[idx] += y`` and return the updated array"""
798+
return self._iop("add", operator.add, y, **kwargs)
799+
800+
def subtract(self, y: Array, /, **kwargs: Any) -> Array:
801+
"""Apply ``x[idx] -= y`` and return the updated array"""
802+
return self._iop("subtract", operator.sub, y, **kwargs)
803+
804+
def multiply(self, y: Array, /, **kwargs: Any) -> Array:
805+
"""Apply ``x[idx] *= y`` and return the updated array"""
806+
return self._iop("multiply", operator.mul, y, **kwargs)
807+
808+
def divide(self, y: Array, /, **kwargs: Any) -> Array:
809+
"""Apply ``x[idx] /= y`` and return the updated array"""
810+
return self._iop("divide", operator.truediv, y, **kwargs)
811+
812+
def power(self, y: Array, /, **kwargs: Any) -> Array:
813+
"""Apply ``x[idx] **= y`` and return the updated array"""
814+
return self._iop("power", operator.pow, y, **kwargs)
815+
816+
def min(self, y: Array, /, **kwargs: Any) -> Array:
817+
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
818+
xp = array_namespace(self.x)
819+
y = xp.asarray(y)
820+
return self._iop("min", xp.minimum, y, **kwargs)
821+
822+
def max(self, y: Array, /, **kwargs: Any) -> Array:
823+
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
824+
xp = array_namespace(self.x)
825+
y = xp.asarray(y)
826+
return self._iop("max", xp.maximum, y, **kwargs)

src/array_api_extra/_lib/_compat.py

+5
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,19 @@
66
from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
77
array_namespace, # pyright: ignore[reportUnknownVariableType]
88
device, # pyright: ignore[reportUnknownVariableType]
9+
is_array_api_obj, # pyright: ignore[reportUnknownVariableType]
10+
is_writeable_array, # pyright: ignore[reportUnknownVariableType]
911
)
1012
except ImportError:
1113
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
1214
array_namespace, # pyright: ignore[reportUnknownVariableType]
1315
device,
16+
is_writeable_array, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue]
1417
)
1518

1619
__all__ = [
1720
"array_namespace",
1821
"device",
22+
"is_array_api_obj",
23+
"is_writeable_array",
1924
]

src/array_api_extra/_lib/_compat.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,5 @@ def array_namespace(
1111
use_compat: bool | None = None,
1212
) -> ArrayModule: ...
1313
def device(x: Array, /) -> Device: ...
14+
def is_array_api_obj(x: object, /) -> bool: ...
15+
def is_writeable_array(x: object, /) -> bool: ...

0 commit comments

Comments
 (0)