@@ -623,8 +623,10 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
623
623
624
624
You may use two alternate syntaxes::
625
625
626
- at(x, idx).set(value) # or add(value), etc.
627
- at(x)[idx].set(value)
626
+ >>> import array_api_extra as xpx
627
+ >>> xpx.at(x, idx).set(value) # or add(value), etc.
628
+ >>> xpx.at(x)[idx].set(value)
629
+
628
630
copy : bool, optional
629
631
True (default)
630
632
Ensure that the inputs are not modified.
@@ -647,14 +649,15 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
647
649
(a) When you use ``copy=None``, you should always immediately overwrite
648
650
the parameter array::
649
651
650
- x = at(x, 0).set(2, copy=None)
652
+ >>> import array_api_extra as xpx
653
+ >>> x = xpx.at(x, 0).set(2, copy=None)
651
654
652
655
The anti-pattern below must be avoided, as it will result in different
653
656
behaviour on read-only versus writeable arrays::
654
657
655
- x = xp.asarray([0, 0, 0])
656
- y = at(x, 0).set(2, copy=None)
657
- z = at(x, 1).set(3, copy=None)
658
+ >>> x = xp.asarray([0, 0, 0])
659
+ >>> y = xpx. at(x, 0).set(2, copy=None)
660
+ >>> z = xpx. at(x, 1).set(3, copy=None)
658
661
659
662
In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
660
663
when ``x`` is read-only, whereas ``x == y == z == [2, 3, 0]`` when ``x`` is
@@ -667,9 +670,10 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
667
670
668
671
>>> import numpy as np
669
672
>>> import jax.numpy as jnp
670
- >>> at(np.asarray([123]), np.asarray([0, 0])).add(1)
673
+ >>> import array_api_extra as xpx
674
+ >>> xpx.at(np.asarray([123]), np.asarray([0, 0])).add(1)
671
675
array([124])
672
- >>> at(jnp.asarray([123]), jnp.asarray([0, 0])).add(1)
676
+ >>> xpx. at(jnp.asarray([123]), jnp.asarray([0, 0])).add(1)
673
677
Array([125], dtype=int32)
674
678
675
679
See Also
@@ -686,21 +690,22 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
686
690
--------
687
691
Given either of these equivalent expressions::
688
692
689
- x = at(x)[1].add(2, copy=None)
690
- x = at(x, 1).add(2, copy=None)
693
+ >>> import array_api_extra as xpx
694
+ >>> x = xpx.at(x)[1].add(2, copy=None)
695
+ >>> x = xpx.at(x, 1).add(2, copy=None)
691
696
692
697
If x is a JAX array, they are the same as::
693
698
694
- x = x.at[1].add(2)
699
+ >>> x = x.at[1].add(2)
695
700
696
701
If x is a read-only numpy array, they are the same as::
697
702
698
- x = x.copy()
699
- x[1] += 2
703
+ >>> x = x.copy()
704
+ >>> x[1] += 2
700
705
701
706
For other known backends, they are the same as::
702
707
703
- x[1] += 2
708
+ >>> x[1] += 2
704
709
"""
705
710
706
711
_x : Array
0 commit comments