|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +import operator |
3 | 4 | import warnings
|
| 5 | +from collections.abc import Callable |
| 6 | +from typing import Any |
4 | 7 |
|
5 | 8 | from ._lib import _utils
|
6 |
| -from ._lib._compat import array_namespace |
| 9 | +from ._lib._compat import ( |
| 10 | + array_namespace, |
| 11 | + is_array_api_obj, |
| 12 | + is_dask_array, |
| 13 | + is_writeable_array, |
| 14 | +) |
7 | 15 | from ._lib._typing import Array, ModuleType
|
8 | 16 |
|
9 | 17 | __all__ = [
|
| 18 | + "at", |
10 | 19 | "atleast_nd",
|
11 | 20 | "cov",
|
12 | 21 | "create_diagonal",
|
@@ -545,3 +554,279 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
|
545 | 554 | xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device),
|
546 | 555 | )
|
547 | 556 | return xp.sin(y) / y
|
| 557 | + |
| 558 | + |
| 559 | +_undef = object() |
| 560 | + |
| 561 | + |
| 562 | +class at: # noqa: N801 |
| 563 | + """ |
| 564 | + Update operations for read-only arrays. |
| 565 | +
|
| 566 | + This implements ``jax.numpy.ndarray.at`` for all backends. |
| 567 | +
|
| 568 | + Parameters |
| 569 | + ---------- |
| 570 | + x : array |
| 571 | + Input array. |
| 572 | + idx : index, optional |
| 573 | + You may use two alternate syntaxes:: |
| 574 | +
|
| 575 | + at(x, idx).set(value) # or get(), add(), etc. |
| 576 | + at(x)[idx].set(value) |
| 577 | +
|
| 578 | + copy : bool, optional |
| 579 | + True (default) |
| 580 | + Ensure that the inputs are not modified. |
| 581 | + False |
| 582 | + Ensure that the update operation writes back to the input. |
| 583 | + Raise ValueError if a copy cannot be avoided. |
| 584 | + None |
| 585 | + The array parameter *may* be modified in place if it is possible and |
| 586 | + beneficial for performance. |
| 587 | + You should not reuse it after calling this function. |
| 588 | + xp : array_namespace, optional |
| 589 | + The standard-compatible namespace for `x`. Default: infer |
| 590 | +
|
| 591 | + **kwargs: |
| 592 | + If the backend supports an `at` method, any additional keyword |
| 593 | + arguments are passed to it verbatim; e.g. this allows passing |
| 594 | + ``indices_are_sorted=True`` to JAX. |
| 595 | +
|
| 596 | + Returns |
| 597 | + ------- |
| 598 | + Updated input array. |
| 599 | +
|
| 600 | + Examples |
| 601 | + -------- |
| 602 | + Given either of these equivalent expressions:: |
| 603 | +
|
| 604 | + x = at(x)[1].add(2, copy=None) |
| 605 | + x = at(x, 1).add(2, copy=None) |
| 606 | +
|
| 607 | + If x is a JAX array, they are the same as:: |
| 608 | +
|
| 609 | + x = x.at[1].add(2) |
| 610 | +
|
| 611 | + If x is a read-only numpy array, they are the same as:: |
| 612 | +
|
| 613 | + x = x.copy() |
| 614 | + x[1] += 2 |
| 615 | +
|
| 616 | + Otherwise, they are the same as:: |
| 617 | +
|
| 618 | + x[1] += 2 |
| 619 | +
|
| 620 | + Warning |
| 621 | + ------- |
| 622 | + When you use copy=None, you should always immediately overwrite |
| 623 | + the parameter array:: |
| 624 | +
|
| 625 | + x = at(x, 0).set(2, copy=None) |
| 626 | +
|
| 627 | + The anti-pattern below must be avoided, as it will result in different behaviour |
| 628 | + on read-only versus writeable arrays:: |
| 629 | +
|
| 630 | + x = xp.asarray([0, 0, 0]) |
| 631 | + y = at(x, 0).set(2, copy=None) |
| 632 | + z = at(x, 1).set(3, copy=None) |
| 633 | +
|
| 634 | + In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]`` |
| 635 | + when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable! |
| 636 | +
|
| 637 | + Warning |
| 638 | + ------- |
| 639 | + The array API standard does not support integer array indices. |
| 640 | + The behaviour of update methods when the index is an array of integers |
| 641 | + is undefined; this is particularly true when the index contains multiple |
| 642 | + occurrences of the same index, e.g. ``at(x, [0, 0]).set(2)``. |
| 643 | +
|
| 644 | + Note |
| 645 | + ---- |
| 646 | + `sparse <https://sparse.pydata.org/>`_ is not supported by update methods yet. |
| 647 | +
|
| 648 | + See Also |
| 649 | + -------- |
| 650 | + `jax.numpy.ndarray.at <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html>`_ |
| 651 | + """ |
| 652 | + |
| 653 | + x: Array |
| 654 | + idx: Any |
| 655 | + __slots__ = ("idx", "x") |
| 656 | + |
| 657 | + def __init__(self, x: Array, idx: Any = _undef, /): |
| 658 | + self.x = x |
| 659 | + self.idx = idx |
| 660 | + |
| 661 | + def __getitem__(self, idx: Any) -> Any: |
| 662 | + """Allow for the alternate syntax ``at(x)[start:stop:step]``, |
| 663 | + which looks prettier than ``at(x, slice(start, stop, step))`` |
| 664 | + and feels more intuitive coming from the JAX documentation. |
| 665 | + """ |
| 666 | + if self.idx is not _undef: |
| 667 | + msg = "Index has already been set" |
| 668 | + raise ValueError(msg) |
| 669 | + self.idx = idx |
| 670 | + return self |
| 671 | + |
| 672 | + def _common( |
| 673 | + self, |
| 674 | + at_op: str, |
| 675 | + y: Array = _undef, |
| 676 | + /, |
| 677 | + copy: bool | None = True, |
| 678 | + xp: ModuleType | None = None, |
| 679 | + _is_update: bool = True, |
| 680 | + **kwargs: Any, |
| 681 | + ) -> tuple[Any, None] | tuple[None, Array]: |
| 682 | + """Perform common prepocessing. |
| 683 | +
|
| 684 | + Returns |
| 685 | + ------- |
| 686 | + If the operation can be resolved by at[], (return value, None) |
| 687 | + Otherwise, (None, preprocessed x) |
| 688 | + """ |
| 689 | + if self.idx is _undef: |
| 690 | + msg = ( |
| 691 | + "Index has not been set.\n" |
| 692 | + "Usage: either\n" |
| 693 | + " at(x, idx).set(value)\n" |
| 694 | + "or\n" |
| 695 | + " at(x)[idx].set(value)\n" |
| 696 | + "(same for all other methods)." |
| 697 | + ) |
| 698 | + raise TypeError(msg) |
| 699 | + |
| 700 | + x = self.x |
| 701 | + |
| 702 | + if copy is True: |
| 703 | + writeable = None |
| 704 | + elif copy is False: |
| 705 | + writeable = is_writeable_array(x) |
| 706 | + if not writeable: |
| 707 | + msg = "Cannot modify parameter in place" |
| 708 | + raise ValueError(msg) |
| 709 | + elif copy is None: |
| 710 | + writeable = is_writeable_array(x) |
| 711 | + copy = _is_update and not writeable |
| 712 | + else: |
| 713 | + msg = f"Invalid value for copy: {copy!r}" # type: ignore[unreachable] |
| 714 | + raise ValueError(msg) |
| 715 | + |
| 716 | + if copy: |
| 717 | + try: |
| 718 | + at_ = x.at |
| 719 | + except AttributeError: |
| 720 | + # Emulate at[] behaviour for non-JAX arrays |
| 721 | + # with a copy followed by an update |
| 722 | + if xp is None: |
| 723 | + xp = array_namespace(x) |
| 724 | + # Create writeable copy of read-only numpy array |
| 725 | + x = xp.asarray(x, copy=True) |
| 726 | + if writeable is False: |
| 727 | + # A copy of a read-only numpy array is writeable |
| 728 | + writeable = None |
| 729 | + else: |
| 730 | + # Use JAX's at[] or other library that with the same duck-type API |
| 731 | + args = (y,) if y is not _undef else () |
| 732 | + return getattr(at_[self.idx], at_op)(*args, **kwargs), None |
| 733 | + |
| 734 | + if _is_update: |
| 735 | + if writeable is None: |
| 736 | + writeable = is_writeable_array(x) |
| 737 | + if not writeable: |
| 738 | + # sparse crashes here |
| 739 | + msg = f"Array {x} has no `at` method and is read-only" |
| 740 | + raise ValueError(msg) |
| 741 | + |
| 742 | + return None, x |
| 743 | + |
| 744 | + def get(self, **kwargs: Any) -> Any: |
| 745 | + """Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring |
| 746 | + that the output is either a copy or a view; it also allows passing |
| 747 | + keyword arguments to the backend. |
| 748 | + """ |
| 749 | + if kwargs.get("copy") is False: |
| 750 | + if is_array_api_obj(self.idx): |
| 751 | + # Boolean index. Note that the array API spec |
| 752 | + # https://data-apis.org/array-api/latest/API_specification/indexing.html |
| 753 | + # does not allow for list, tuple, and tuples of slices plus one or more |
| 754 | + # one-dimensional array indices, although many backends support them. |
| 755 | + # So this check will encounter a lot of false negatives in real life, |
| 756 | + # which can be caught by testing the user code vs. array-api-strict. |
| 757 | + msg = "get() with an array index always returns a copy" |
| 758 | + raise ValueError(msg) |
| 759 | + if is_dask_array(self.x): |
| 760 | + msg = "get() on Dask arrays always returns a copy" |
| 761 | + raise ValueError(msg) |
| 762 | + |
| 763 | + res, x = self._common("get", _is_update=False, **kwargs) |
| 764 | + if res is not None: |
| 765 | + return res |
| 766 | + assert x is not None |
| 767 | + return x[self.idx] |
| 768 | + |
| 769 | + def set(self, y: Array, /, **kwargs: Any) -> Array: |
| 770 | + """Apply ``x[idx] = y`` and return the update array""" |
| 771 | + res, x = self._common("set", y, **kwargs) |
| 772 | + if res is not None: |
| 773 | + return res |
| 774 | + assert x is not None |
| 775 | + x[self.idx] = y |
| 776 | + return x |
| 777 | + |
| 778 | + def _iop( |
| 779 | + self, |
| 780 | + at_op: str, |
| 781 | + elwise_op: Callable[[Array, Array], Array], |
| 782 | + y: Array, |
| 783 | + /, |
| 784 | + **kwargs: Any, |
| 785 | + ) -> Array: |
| 786 | + """x[idx] += y or equivalent in-place operation on a subset of x |
| 787 | +
|
| 788 | + which is the same as saying |
| 789 | + x[idx] = x[idx] + y |
| 790 | + Note that this is not the same as |
| 791 | + operator.iadd(x[idx], y) |
| 792 | + Consider for example when x is a numpy array and idx is a fancy index, which |
| 793 | + triggers a deep copy on __getitem__. |
| 794 | + """ |
| 795 | + res, x = self._common(at_op, y, **kwargs) |
| 796 | + if res is not None: |
| 797 | + return res |
| 798 | + assert x is not None |
| 799 | + x[self.idx] = elwise_op(x[self.idx], y) |
| 800 | + return x |
| 801 | + |
| 802 | + def add(self, y: Array, /, **kwargs: Any) -> Array: |
| 803 | + """Apply ``x[idx] += y`` and return the updated array""" |
| 804 | + return self._iop("add", operator.add, y, **kwargs) |
| 805 | + |
| 806 | + def subtract(self, y: Array, /, **kwargs: Any) -> Array: |
| 807 | + """Apply ``x[idx] -= y`` and return the updated array""" |
| 808 | + return self._iop("subtract", operator.sub, y, **kwargs) |
| 809 | + |
| 810 | + def multiply(self, y: Array, /, **kwargs: Any) -> Array: |
| 811 | + """Apply ``x[idx] *= y`` and return the updated array""" |
| 812 | + return self._iop("multiply", operator.mul, y, **kwargs) |
| 813 | + |
| 814 | + def divide(self, y: Array, /, **kwargs: Any) -> Array: |
| 815 | + """Apply ``x[idx] /= y`` and return the updated array""" |
| 816 | + return self._iop("divide", operator.truediv, y, **kwargs) |
| 817 | + |
| 818 | + def power(self, y: Array, /, **kwargs: Any) -> Array: |
| 819 | + """Apply ``x[idx] **= y`` and return the updated array""" |
| 820 | + return self._iop("power", operator.pow, y, **kwargs) |
| 821 | + |
| 822 | + def min(self, y: Array, /, **kwargs: Any) -> Array: |
| 823 | + """Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array""" |
| 824 | + xp = array_namespace(self.x) |
| 825 | + y = xp.asarray(y) |
| 826 | + return self._iop("min", xp.minimum, y, **kwargs) |
| 827 | + |
| 828 | + def max(self, y: Array, /, **kwargs: Any) -> Array: |
| 829 | + """Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array""" |
| 830 | + xp = array_namespace(self.x) |
| 831 | + y = xp.asarray(y) |
| 832 | + return self._iop("max", xp.maximum, y, **kwargs) |
0 commit comments