Skip to content

Commit 98141ca

Browse files
committed
Add xpx namespace in documentation
1 parent 64396d0 commit 98141ca

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

src/array_api_extra/_funcs.py

+19-14
Original file line numberDiff line numberDiff line change
@@ -623,8 +623,10 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
623623
624624
You may use two alternate syntaxes::
625625
626-
at(x, idx).set(value) # or add(value), etc.
627-
at(x)[idx].set(value)
626+
>>> import array_api_extra as xpx
627+
>>> xpx.at(x, idx).set(value) # or add(value), etc.
628+
>>> xpx.at(x)[idx].set(value)
629+
628630
copy : bool, optional
629631
True (default)
630632
Ensure that the inputs are not modified.
@@ -647,14 +649,15 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
647649
(a) When you use ``copy=None``, you should always immediately overwrite
648650
the parameter array::
649651
650-
x = at(x, 0).set(2, copy=None)
652+
>>> import array_api_extra as xpx
653+
>>> x = xpx.at(x, 0).set(2, copy=None)
651654
652655
The anti-pattern below must be avoided, as it will result in different
653656
behaviour on read-only versus writeable arrays::
654657
655-
x = xp.asarray([0, 0, 0])
656-
y = at(x, 0).set(2, copy=None)
657-
z = at(x, 1).set(3, copy=None)
658+
>>> x = xp.asarray([0, 0, 0])
659+
>>> y = xpx.at(x, 0).set(2, copy=None)
660+
>>> z = xpx.at(x, 1).set(3, copy=None)
658661
659662
In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
660663
when ``x`` is read-only, whereas ``x == y == z == [2, 3, 0]`` when ``x`` is
@@ -667,9 +670,10 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
667670
668671
>>> import numpy as np
669672
>>> import jax.numpy as jnp
670-
>>> at(np.asarray([123]), np.asarray([0, 0])).add(1)
673+
>>> import array_api_extra as xpx
674+
>>> xpx.at(np.asarray([123]), np.asarray([0, 0])).add(1)
671675
array([124])
672-
>>> at(jnp.asarray([123]), jnp.asarray([0, 0])).add(1)
676+
>>> xpx.at(jnp.asarray([123]), jnp.asarray([0, 0])).add(1)
673677
Array([125], dtype=int32)
674678
675679
See Also
@@ -686,21 +690,22 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
686690
--------
687691
Given either of these equivalent expressions::
688692
689-
x = at(x)[1].add(2, copy=None)
690-
x = at(x, 1).add(2, copy=None)
693+
>>> import array_api_extra as xpx
694+
>>> x = xpx.at(x)[1].add(2, copy=None)
695+
>>> x = xpx.at(x, 1).add(2, copy=None)
691696
692697
If x is a JAX array, they are the same as::
693698
694-
x = x.at[1].add(2)
699+
>>> x = x.at[1].add(2)
695700
696701
If x is a read-only numpy array, they are the same as::
697702
698-
x = x.copy()
699-
x[1] += 2
703+
>>> x = x.copy()
704+
>>> x[1] += 2
700705
701706
For other known backends, they are the same as::
702707
703-
x[1] += 2
708+
>>> x[1] += 2
704709
"""
705710

706711
_x: Array

0 commit comments

Comments
 (0)