|
1 | 1 | from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
|
2 | 2 |
|
3 |
| -import typing |
| 3 | +import operator |
4 | 4 | import warnings
|
| 5 | +from typing import TYPE_CHECKING, Any, Callable, Literal |
5 | 6 |
|
6 |
| -if typing.TYPE_CHECKING: |
| 7 | +if TYPE_CHECKING: |
7 | 8 | from ._lib._typing import Array, ModuleType
|
8 | 9 |
|
9 | 10 | from ._lib import _utils
|
10 |
| -from ._lib._compat import array_namespace |
| 11 | +from ._lib._compat import ( |
| 12 | + array_namespace, |
| 13 | + is_array_api_obj, |
| 14 | + is_writeable_array, |
| 15 | +) |
11 | 16 |
|
12 | 17 | __all__ = [
|
| 18 | + "at", |
13 | 19 | "atleast_nd",
|
14 | 20 | "cov",
|
15 | 21 | "create_diagonal",
|
@@ -546,3 +552,275 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
|
546 | 552 | x, x, xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device)
|
547 | 553 | )
|
548 | 554 | return xp.sin(y) / y
|
| 555 | + |
| 556 | + |
| 557 | +def _is_fancy_index(idx: object | tuple[object, ...]) -> bool: |
| 558 | + if not isinstance(idx, tuple): |
| 559 | + idx = (idx,) |
| 560 | + return any(isinstance(i, (list, tuple)) or is_array_api_obj(i) for i in idx) |
| 561 | + |
| 562 | + |
| 563 | +_undef = object() |
| 564 | + |
| 565 | + |
| 566 | +class at: |
| 567 | + """ |
| 568 | + Update operations for read-only arrays. |
| 569 | +
|
| 570 | + This implements ``jax.numpy.ndarray.at`` for all backends. |
| 571 | +
|
| 572 | + Parameters |
| 573 | + ---------- |
| 574 | + x : array |
| 575 | + Input array. |
| 576 | +
|
| 577 | + copy : bool, optional |
| 578 | + True (default) |
| 579 | + Ensure that the inputs are not modified. |
| 580 | + False |
| 581 | + Ensure that the update operation writes back to the input. |
| 582 | + Raise ValueError if a copy cannot be avoided. |
| 583 | + None |
| 584 | + The array parameter *may* be modified in place if it is possible and |
| 585 | + beneficial for performance. |
| 586 | + You should not reuse it after calling this function. |
| 587 | + xp : array_namespace, optional |
| 588 | + The standard-compatible namespace for `x`. Default: infer |
| 589 | +
|
| 590 | + Additionally, if the backend supports an `at` method, any additional keyword |
| 591 | + arguments are passed to it verbatim; e.g. this allows passing |
| 592 | + ``indices_are_sorted=True`` to JAX. |
| 593 | +
|
| 594 | + Returns |
| 595 | + ------- |
| 596 | + Updated input array. |
| 597 | +
|
| 598 | + Examples |
| 599 | + -------- |
| 600 | + Given either of these equivalent expressions:: |
| 601 | +
|
| 602 | + x = at(x)[1].add(2, copy=None) |
| 603 | + x = at(x, 1).add(2, copy=None) |
| 604 | +
|
| 605 | + If x is a JAX array, they are the same as:: |
| 606 | +
|
| 607 | + x = x.at[1].add(2) |
| 608 | +
|
| 609 | + If x is a read-only numpy array, they are the same as:: |
| 610 | +
|
| 611 | + x = x.copy() |
| 612 | + x[1] += 2 |
| 613 | +
|
| 614 | + Otherwise, they are the same as:: |
| 615 | +
|
| 616 | + x[1] += 2 |
| 617 | +
|
| 618 | + Warning |
| 619 | + ------- |
| 620 | + When you use copy=None, you should always immediately overwrite |
| 621 | + the parameter array:: |
| 622 | +
|
| 623 | + x = at(x, 0).set(2, copy=None) |
| 624 | +
|
| 625 | + The anti-pattern below must be avoided, as it will result in different behaviour |
| 626 | + on read-only versus writeable arrays:: |
| 627 | +
|
| 628 | + x = xp.asarray([0, 0, 0]) |
| 629 | + y = at(x, 0).set(2, copy=None) |
| 630 | + z = at(x, 1).set(3, copy=None) |
| 631 | +
|
| 632 | + In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]`` |
| 633 | + when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable! |
| 634 | +
|
| 635 | + Warning |
| 636 | + ------- |
| 637 | + The behaviour of update methods when the index is an array of integers which |
| 638 | + contains multiple occurrences of the same index is undefined; |
| 639 | + e.g. ``at(x, [0, 0]).set(2)`` |
| 640 | +
|
| 641 | + Note |
| 642 | + ---- |
| 643 | + `sparse <https://sparse.pydata.org/>`_ is not supported by update methods yet. |
| 644 | +
|
| 645 | + See Also |
| 646 | + -------- |
| 647 | + `jax.numpy.ndarray.at <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html>`_ |
| 648 | + """ |
| 649 | + |
| 650 | + x: Array |
| 651 | + idx: Any |
| 652 | + __slots__ = ("x", "idx") |
| 653 | + |
| 654 | + def __init__(self, x: Array, idx: Any = _undef, /): |
| 655 | + self.x = x |
| 656 | + self.idx = idx |
| 657 | + |
| 658 | + def __getitem__(self, idx: Any) -> Any: |
| 659 | + """Allow for the alternate syntax ``at(x)[start:stop:step]``, |
| 660 | + which looks prettier than ``at(x, slice(start, stop, step))`` |
| 661 | + and feels more intuitive coming from the JAX documentation. |
| 662 | + """ |
| 663 | + if self.idx is not _undef: |
| 664 | + raise ValueError("Index has already been set") |
| 665 | + self.idx = idx |
| 666 | + return self |
| 667 | + |
| 668 | + def _common( |
| 669 | + self, |
| 670 | + at_op: str, |
| 671 | + y: Array = _undef, |
| 672 | + /, |
| 673 | + copy: bool | None | Literal["_force_false"] = True, |
| 674 | + xp: ModuleType | None = None, |
| 675 | + _is_update: bool = True, |
| 676 | + **kwargs: Any, |
| 677 | + ) -> tuple[Any, None] | tuple[None, Array]: |
| 678 | + """Perform common prepocessing. |
| 679 | +
|
| 680 | + Returns |
| 681 | + ------- |
| 682 | + If the operation can be resolved by at[], (return value, None) |
| 683 | + Otherwise, (None, preprocessed x) |
| 684 | + """ |
| 685 | + if self.idx is _undef: |
| 686 | + raise TypeError( |
| 687 | + "Index has not been set.\n" |
| 688 | + "Usage: either\n" |
| 689 | + " at(x, idx).set(value)\n" |
| 690 | + "or\n" |
| 691 | + " at(x)[idx].set(value)\n" |
| 692 | + "(same for all other methods)." |
| 693 | + ) |
| 694 | + |
| 695 | + x = self.x |
| 696 | + |
| 697 | + if copy is True: |
| 698 | + writeable = None |
| 699 | + elif copy is False: |
| 700 | + writeable = is_writeable_array(x) |
| 701 | + if not writeable: |
| 702 | + raise ValueError("Cannot modify parameter in place") |
| 703 | + elif copy is None: |
| 704 | + writeable = is_writeable_array(x) |
| 705 | + copy = _is_update and not writeable |
| 706 | + elif copy == "_force_false": # type: ignore[redundant-expr] |
| 707 | + # __getitem__ with fancy index on a numpy array |
| 708 | + writeable = True |
| 709 | + copy = False |
| 710 | + else: |
| 711 | + raise ValueError(f"Invalid value for copy: {copy!r}") |
| 712 | + |
| 713 | + if copy: |
| 714 | + try: |
| 715 | + at_ = x.at |
| 716 | + except AttributeError: |
| 717 | + # Emulate at[] behaviour for non-JAX arrays |
| 718 | + # with a copy followed by an update |
| 719 | + if xp is None: |
| 720 | + xp = array_namespace(x) |
| 721 | + # Create writeable copy of read-only numpy array |
| 722 | + x = xp.asarray(x, copy=True) |
| 723 | + else: |
| 724 | + # Use JAX's at[] or other library that with the same duck-type API |
| 725 | + args = (y,) if y is not _undef else () |
| 726 | + return getattr(at_[self.idx], at_op)(*args, **kwargs), None |
| 727 | + |
| 728 | + # This blindly expects that if x is writeable its copy is also writeable |
| 729 | + if _is_update: |
| 730 | + if writeable is None: |
| 731 | + writeable = is_writeable_array(x) |
| 732 | + if not writeable: |
| 733 | + # sparse crashes here |
| 734 | + raise ValueError(f"Array {x} has no `at` method and is read-only") |
| 735 | + |
| 736 | + return None, x |
| 737 | + |
| 738 | + def get(self, **kwargs: Any) -> Any: |
| 739 | + """Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring |
| 740 | + that the output is either a copy or a view; it also allows passing |
| 741 | + keyword arguments to the backend. |
| 742 | + """ |
| 743 | + # __getitem__ with a fancy index always returns a copy. |
| 744 | + # Avoid an unnecessary double copy. |
| 745 | + # If copy is forced to False, raise. |
| 746 | + # FIXME This is an assumption based on numpy behaviour; it may not hold true |
| 747 | + # for other backends. Namely, a backend could decide to conditionally return a |
| 748 | + # view if the index can be coerced into a slice. |
| 749 | + if _is_fancy_index(self.idx): |
| 750 | + if kwargs.get("copy") is False: |
| 751 | + raise TypeError( |
| 752 | + "Indexing an array with a fancy index always results in a copy" |
| 753 | + ) |
| 754 | + # Skip copy inside _common, even if array is not writeable |
| 755 | + kwargs["copy"] = "_force_false" |
| 756 | + |
| 757 | + res, x = self._common("get", _is_update=False, **kwargs) |
| 758 | + if res is not None: |
| 759 | + return res |
| 760 | + assert x is not None |
| 761 | + return x[self.idx] |
| 762 | + |
| 763 | + def set(self, y: Array, /, **kwargs: Any) -> Array: |
| 764 | + """Apply ``x[idx] = y`` and return the update array""" |
| 765 | + res, x = self._common("set", y, **kwargs) |
| 766 | + if res is not None: |
| 767 | + return res |
| 768 | + assert x is not None |
| 769 | + x[self.idx] = y |
| 770 | + return x |
| 771 | + |
| 772 | + def _iop( |
| 773 | + self, |
| 774 | + at_op: str, |
| 775 | + elwise_op: Callable[[Array, Array], Array], |
| 776 | + y: Array, |
| 777 | + /, |
| 778 | + **kwargs: Any, |
| 779 | + ) -> Array: |
| 780 | + """x[idx] += y or equivalent in-place operation on a subset of x |
| 781 | +
|
| 782 | + which is the same as saying |
| 783 | + x[idx] = x[idx] + y |
| 784 | + Note that this is not the same as |
| 785 | + operator.iadd(x[idx], y) |
| 786 | + Consider for example when x is a numpy array and idx is a fancy index, which |
| 787 | + triggers a deep copy on __getitem__. |
| 788 | + """ |
| 789 | + res, x = self._common(at_op, y, **kwargs) |
| 790 | + if res is not None: |
| 791 | + return res |
| 792 | + assert x is not None |
| 793 | + x[self.idx] = elwise_op(x[self.idx], y) |
| 794 | + return x |
| 795 | + |
| 796 | + def add(self, y: Array, /, **kwargs: Any) -> Array: |
| 797 | + """Apply ``x[idx] += y`` and return the updated array""" |
| 798 | + return self._iop("add", operator.add, y, **kwargs) |
| 799 | + |
| 800 | + def subtract(self, y: Array, /, **kwargs: Any) -> Array: |
| 801 | + """Apply ``x[idx] -= y`` and return the updated array""" |
| 802 | + return self._iop("subtract", operator.sub, y, **kwargs) |
| 803 | + |
| 804 | + def multiply(self, y: Array, /, **kwargs: Any) -> Array: |
| 805 | + """Apply ``x[idx] *= y`` and return the updated array""" |
| 806 | + return self._iop("multiply", operator.mul, y, **kwargs) |
| 807 | + |
| 808 | + def divide(self, y: Array, /, **kwargs: Any) -> Array: |
| 809 | + """Apply ``x[idx] /= y`` and return the updated array""" |
| 810 | + return self._iop("divide", operator.truediv, y, **kwargs) |
| 811 | + |
| 812 | + def power(self, y: Array, /, **kwargs: Any) -> Array: |
| 813 | + """Apply ``x[idx] **= y`` and return the updated array""" |
| 814 | + return self._iop("power", operator.pow, y, **kwargs) |
| 815 | + |
| 816 | + def min(self, y: Array, /, **kwargs: Any) -> Array: |
| 817 | + """Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array""" |
| 818 | + xp = array_namespace(self.x) |
| 819 | + y = xp.asarray(y) |
| 820 | + return self._iop("min", xp.minimum, y, **kwargs) |
| 821 | + |
| 822 | + def max(self, y: Array, /, **kwargs: Any) -> Array: |
| 823 | + """Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array""" |
| 824 | + xp = array_namespace(self.x) |
| 825 | + y = xp.asarray(y) |
| 826 | + return self._iop("max", xp.maximum, y, **kwargs) |
0 commit comments