|
1 | 1 | from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
|
2 | 2 |
|
3 |
| -import typing |
| 3 | +import operator |
4 | 4 | import warnings
|
| 5 | +from typing import TYPE_CHECKING, Any, Callable |
5 | 6 |
|
6 |
| -if typing.TYPE_CHECKING: |
| 7 | +if TYPE_CHECKING: |
7 | 8 | from ._lib._typing import Array, ModuleType
|
8 | 9 |
|
9 | 10 | from ._lib import _utils
|
10 |
| -from ._lib._compat import array_namespace |
| 11 | +from ._lib._compat import ( |
| 12 | + array_namespace, |
| 13 | + is_array_api_obj, |
| 14 | + is_writeable_array, |
| 15 | +) |
11 | 16 |
|
12 | 17 | __all__ = [
|
| 18 | + "at", |
13 | 19 | "atleast_nd",
|
14 | 20 | "cov",
|
15 | 21 | "create_diagonal",
|
@@ -546,3 +552,273 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
|
546 | 552 | x, x, xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device)
|
547 | 553 | )
|
548 | 554 | return xp.sin(y) / y
|
| 555 | + |
| 556 | + |
| 557 | +_undef = object() |
| 558 | + |
| 559 | + |
| 560 | +class at: |
| 561 | + """ |
| 562 | + Update operations for read-only arrays. |
| 563 | +
|
| 564 | + This implements ``jax.numpy.ndarray.at`` for all backends. |
| 565 | +
|
| 566 | + Parameters |
| 567 | + ---------- |
| 568 | + x : array |
| 569 | + Input array. |
| 570 | + idx : index, optional |
| 571 | + You may use two alternate syntaxes:: |
| 572 | +
|
| 573 | + at(x, idx).set(value) # or get(), add(), etc. |
| 574 | + at(x)[idx].set(value) |
| 575 | +
|
| 576 | + copy : bool, optional |
| 577 | + True (default) |
| 578 | + Ensure that the inputs are not modified. |
| 579 | + False |
| 580 | + Ensure that the update operation writes back to the input. |
| 581 | + Raise ValueError if a copy cannot be avoided. |
| 582 | + None |
| 583 | + The array parameter *may* be modified in place if it is possible and |
| 584 | + beneficial for performance. |
| 585 | + You should not reuse it after calling this function. |
| 586 | + xp : array_namespace, optional |
| 587 | + The standard-compatible namespace for `x`. Default: infer |
| 588 | +
|
| 589 | + Additionally, if the backend supports an `at` method, any additional keyword |
| 590 | + arguments are passed to it verbatim; e.g. this allows passing |
| 591 | + ``indices_are_sorted=True`` to JAX. |
| 592 | +
|
| 593 | + Returns |
| 594 | + ------- |
| 595 | + Updated input array. |
| 596 | +
|
| 597 | + Examples |
| 598 | + -------- |
| 599 | + Given either of these equivalent expressions:: |
| 600 | +
|
| 601 | + x = at(x)[1].add(2, copy=None) |
| 602 | + x = at(x, 1).add(2, copy=None) |
| 603 | +
|
| 604 | + If x is a JAX array, they are the same as:: |
| 605 | +
|
| 606 | + x = x.at[1].add(2) |
| 607 | +
|
| 608 | + If x is a read-only numpy array, they are the same as:: |
| 609 | +
|
| 610 | + x = x.copy() |
| 611 | + x[1] += 2 |
| 612 | +
|
| 613 | + Otherwise, they are the same as:: |
| 614 | +
|
| 615 | + x[1] += 2 |
| 616 | +
|
| 617 | + Warning |
| 618 | + ------- |
| 619 | + When you use copy=None, you should always immediately overwrite |
| 620 | + the parameter array:: |
| 621 | +
|
| 622 | + x = at(x, 0).set(2, copy=None) |
| 623 | +
|
| 624 | + The anti-pattern below must be avoided, as it will result in different behaviour |
| 625 | + on read-only versus writeable arrays:: |
| 626 | +
|
| 627 | + x = xp.asarray([0, 0, 0]) |
| 628 | + y = at(x, 0).set(2, copy=None) |
| 629 | + z = at(x, 1).set(3, copy=None) |
| 630 | +
|
| 631 | + In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]`` |
| 632 | + when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable! |
| 633 | +
|
| 634 | + Warning |
| 635 | + ------- |
| 636 | + The behaviour of update methods when the index is an array of integers which |
| 637 | + contains multiple occurrences of the same index is undefined; |
| 638 | + e.g. ``at(x, [0, 0]).set(2)`` |
| 639 | +
|
| 640 | + Note |
| 641 | + ---- |
| 642 | + `sparse <https://sparse.pydata.org/>`_ is not supported by update methods yet. |
| 643 | +
|
| 644 | + See Also |
| 645 | + -------- |
| 646 | + `jax.numpy.ndarray.at <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html>`_ |
| 647 | + """ |
| 648 | + |
| 649 | + x: Array |
| 650 | + idx: Any |
| 651 | + __slots__ = ("x", "idx") |
| 652 | + |
| 653 | + def __init__(self, x: Array, idx: Any = _undef, /): |
| 654 | + self.x = x |
| 655 | + self.idx = idx |
| 656 | + |
| 657 | + def __getitem__(self, idx: Any) -> Any: |
| 658 | + """Allow for the alternate syntax ``at(x)[start:stop:step]``, |
| 659 | + which looks prettier than ``at(x, slice(start, stop, step))`` |
| 660 | + and feels more intuitive coming from the JAX documentation. |
| 661 | + """ |
| 662 | + if self.idx is not _undef: |
| 663 | + msg = "Index has already been set" |
| 664 | + raise ValueError(msg) |
| 665 | + self.idx = idx |
| 666 | + return self |
| 667 | + |
| 668 | + def _common( |
| 669 | + self, |
| 670 | + at_op: str, |
| 671 | + y: Array = _undef, |
| 672 | + /, |
| 673 | + copy: bool | None = True, |
| 674 | + xp: ModuleType | None = None, |
| 675 | + _is_update: bool = True, |
| 676 | + **kwargs: Any, |
| 677 | + ) -> tuple[Any, None] | tuple[None, Array]: |
| 678 | + """Perform common prepocessing. |
| 679 | +
|
| 680 | + Returns |
| 681 | + ------- |
| 682 | + If the operation can be resolved by at[], (return value, None) |
| 683 | + Otherwise, (None, preprocessed x) |
| 684 | + """ |
| 685 | + if self.idx is _undef: |
| 686 | + msg = ( |
| 687 | + "Index has not been set.\n" |
| 688 | + "Usage: either\n" |
| 689 | + " at(x, idx).set(value)\n" |
| 690 | + "or\n" |
| 691 | + " at(x)[idx].set(value)\n" |
| 692 | + "(same for all other methods)." |
| 693 | + ) |
| 694 | + raise TypeError(msg) |
| 695 | + |
| 696 | + x = self.x |
| 697 | + |
| 698 | + if copy is True: |
| 699 | + writeable = None |
| 700 | + elif copy is False: |
| 701 | + writeable = is_writeable_array(x) |
| 702 | + if not writeable: |
| 703 | + msg = "Cannot modify parameter in place" |
| 704 | + raise ValueError(msg) |
| 705 | + elif copy is None: |
| 706 | + writeable = is_writeable_array(x) |
| 707 | + copy = _is_update and not writeable |
| 708 | + else: |
| 709 | + msg = f"Invalid value for copy: {copy!r}" |
| 710 | + raise ValueError(msg) |
| 711 | + |
| 712 | + if copy: |
| 713 | + try: |
| 714 | + at_ = x.at |
| 715 | + except AttributeError: |
| 716 | + # Emulate at[] behaviour for non-JAX arrays |
| 717 | + # with a copy followed by an update |
| 718 | + if xp is None: |
| 719 | + xp = array_namespace(x) |
| 720 | + # Create writeable copy of read-only numpy array |
| 721 | + x = xp.asarray(x, copy=True) |
| 722 | + if writeable is False: |
| 723 | + # A copy of a read-only numpy array is writeable |
| 724 | + writeable = None |
| 725 | + else: |
| 726 | + # Use JAX's at[] or other library that with the same duck-type API |
| 727 | + args = (y,) if y is not _undef else () |
| 728 | + return getattr(at_[self.idx], at_op)(*args, **kwargs), None |
| 729 | + |
| 730 | + if _is_update: |
| 731 | + if writeable is None: |
| 732 | + writeable = is_writeable_array(x) |
| 733 | + if not writeable: |
| 734 | + # sparse crashes here |
| 735 | + msg = f"Array {x} has no `at` method and is read-only" |
| 736 | + raise ValueError(msg) |
| 737 | + |
| 738 | + return None, x |
| 739 | + |
| 740 | + def get(self, **kwargs: Any) -> Any: |
| 741 | + """Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring |
| 742 | + that the output is either a copy or a view; it also allows passing |
| 743 | + keyword arguments to the backend. |
| 744 | + """ |
| 745 | + if kwargs.get("copy") is False and ( |
| 746 | + is_array_api_obj(self.idx) |
| 747 | + or isinstance(self.idx, tuple) |
| 748 | + and any(is_array_api_obj(i) for i in self.idx) |
| 749 | + ): |
| 750 | + # Fancy index. Note that the array API spec does not allow for |
| 751 | + # list, tuple, or numpy arrays although many backends support them. |
| 752 | + msg = "get() with an array index always returns in a copy" |
| 753 | + raise ValueError(msg) |
| 754 | + |
| 755 | + res, x = self._common("get", _is_update=False, **kwargs) |
| 756 | + if res is not None: |
| 757 | + return res |
| 758 | + assert x is not None |
| 759 | + return x[self.idx] |
| 760 | + |
| 761 | + def set(self, y: Array, /, **kwargs: Any) -> Array: |
| 762 | + """Apply ``x[idx] = y`` and return the update array""" |
| 763 | + res, x = self._common("set", y, **kwargs) |
| 764 | + if res is not None: |
| 765 | + return res |
| 766 | + assert x is not None |
| 767 | + x[self.idx] = y |
| 768 | + return x |
| 769 | + |
| 770 | + def _iop( |
| 771 | + self, |
| 772 | + at_op: str, |
| 773 | + elwise_op: Callable[[Array, Array], Array], |
| 774 | + y: Array, |
| 775 | + /, |
| 776 | + **kwargs: Any, |
| 777 | + ) -> Array: |
| 778 | + """x[idx] += y or equivalent in-place operation on a subset of x |
| 779 | +
|
| 780 | + which is the same as saying |
| 781 | + x[idx] = x[idx] + y |
| 782 | + Note that this is not the same as |
| 783 | + operator.iadd(x[idx], y) |
| 784 | + Consider for example when x is a numpy array and idx is a fancy index, which |
| 785 | + triggers a deep copy on __getitem__. |
| 786 | + """ |
| 787 | + res, x = self._common(at_op, y, **kwargs) |
| 788 | + if res is not None: |
| 789 | + return res |
| 790 | + assert x is not None |
| 791 | + x[self.idx] = elwise_op(x[self.idx], y) |
| 792 | + return x |
| 793 | + |
| 794 | + def add(self, y: Array, /, **kwargs: Any) -> Array: |
| 795 | + """Apply ``x[idx] += y`` and return the updated array""" |
| 796 | + return self._iop("add", operator.add, y, **kwargs) |
| 797 | + |
| 798 | + def subtract(self, y: Array, /, **kwargs: Any) -> Array: |
| 799 | + """Apply ``x[idx] -= y`` and return the updated array""" |
| 800 | + return self._iop("subtract", operator.sub, y, **kwargs) |
| 801 | + |
| 802 | + def multiply(self, y: Array, /, **kwargs: Any) -> Array: |
| 803 | + """Apply ``x[idx] *= y`` and return the updated array""" |
| 804 | + return self._iop("multiply", operator.mul, y, **kwargs) |
| 805 | + |
| 806 | + def divide(self, y: Array, /, **kwargs: Any) -> Array: |
| 807 | + """Apply ``x[idx] /= y`` and return the updated array""" |
| 808 | + return self._iop("divide", operator.truediv, y, **kwargs) |
| 809 | + |
| 810 | + def power(self, y: Array, /, **kwargs: Any) -> Array: |
| 811 | + """Apply ``x[idx] **= y`` and return the updated array""" |
| 812 | + return self._iop("power", operator.pow, y, **kwargs) |
| 813 | + |
| 814 | + def min(self, y: Array, /, **kwargs: Any) -> Array: |
| 815 | + """Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array""" |
| 816 | + xp = array_namespace(self.x) |
| 817 | + y = xp.asarray(y) |
| 818 | + return self._iop("min", xp.minimum, y, **kwargs) |
| 819 | + |
| 820 | + def max(self, y: Array, /, **kwargs: Any) -> Array: |
| 821 | + """Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array""" |
| 822 | + xp = array_namespace(self.x) |
| 823 | + y = xp.asarray(y) |
| 824 | + return self._iop("max", xp.maximum, y, **kwargs) |
0 commit comments