Skip to content

Commit 38562db

Browse files
committed
Change copy to default to None
1 parent 98141ca commit 38562db

File tree

2 files changed

+29
-28
lines changed

2 files changed

+29
-28
lines changed

src/array_api_extra/_funcs.py

+24-23
Original file line numberDiff line numberDiff line change
@@ -628,15 +628,16 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
628628
>>> xpx.at(x)[idx].set(value)
629629
630630
copy : bool, optional
631-
True (default)
631+
None (default)
632+
The array parameter *may* be modified in place if it is
633+
possible and beneficial for performance.
634+
You should not reuse it after calling this function.
635+
True
632636
Ensure that the inputs are not modified.
633637
False
634638
Ensure that the update operation writes back to the input.
635639
Raise ``ValueError`` if a copy cannot be avoided.
636-
None
637-
The array parameter *may* be modified in place if it is
638-
possible and beneficial for performance.
639-
You should not reuse it after calling this function.
640+
640641
xp : array_namespace, optional
641642
The standard-compatible namespace for `x`. Default: infer.
642643
@@ -646,18 +647,18 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
646647
647648
Warnings
648649
--------
649-
(a) When you use ``copy=None``, you should always immediately overwrite
650+
(a) When you omit the ``copy`` parameter, you should always immediately overwrite
650651
the parameter array::
651652
652653
>>> import array_api_extra as xpx
653-
>>> x = xpx.at(x, 0).set(2, copy=None)
654+
>>> x = xpx.at(x, 0).set(2)
654655
655656
The anti-pattern below must be avoided, as it will result in different
656657
behaviour on read-only versus writeable arrays::
657658
658659
>>> x = xp.asarray([0, 0, 0])
659-
>>> y = xpx.at(x, 0).set(2, copy=None)
660-
>>> z = xpx.at(x, 1).set(3, copy=None)
660+
>>> y = xpx.at(x, 0).set(2)
661+
>>> z = xpx.at(x, 1).set(3)
661662
662663
In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
663664
when ``x`` is read-only, whereas ``x == y == z == [2, 3, 0]`` when ``x`` is
@@ -691,8 +692,8 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
691692
Given either of these equivalent expressions::
692693
693694
>>> import array_api_extra as xpx
694-
>>> x = xpx.at(x)[1].add(2, copy=None)
695-
>>> x = xpx.at(x, 1).add(2, copy=None)
695+
>>> x = xpx.at(x)[1].add(2)
696+
>>> x = xpx.at(x, 1).add(2)
696697
697698
If x is a JAX array, they are the same as::
698699
@@ -735,8 +736,8 @@ def _update_common(
735736
at_op: str,
736737
y: Array,
737738
/,
738-
copy: bool | None = True,
739-
xp: ModuleType | None = None,
739+
copy: bool | None,
740+
xp: ModuleType | None,
740741
) -> tuple[Array, None] | tuple[None, Array]: # numpydoc ignore=PR01
741742
"""
742743
Perform common prepocessing to all update operations.
@@ -800,7 +801,7 @@ def set(
800801
self,
801802
y: Array,
802803
/,
803-
copy: bool | None = True,
804+
copy: bool | None = None,
804805
xp: ModuleType | None = None,
805806
) -> Array: # numpydoc ignore=PR01,RT01
806807
"""Apply ``x[idx] = y`` and return the update array."""
@@ -819,8 +820,8 @@ def _iop(
819820
elwise_op: Callable[[Array, Array], Array],
820821
y: Array,
821822
/,
822-
copy: bool | None = True,
823-
xp: ModuleType | None = None,
823+
copy: bool | None,
824+
xp: ModuleType | None,
824825
) -> Array: # numpydoc ignore=PR01,RT01
825826
"""
826827
``x[idx] += y`` or equivalent in-place operation on a subset of x.
@@ -843,7 +844,7 @@ def add(
843844
self,
844845
y: Array,
845846
/,
846-
copy: bool | None = True,
847+
copy: bool | None = None,
847848
xp: ModuleType | None = None,
848849
) -> Array: # numpydoc ignore=PR01,RT01
849850
"""Apply ``x[idx] += y`` and return the updated array."""
@@ -853,7 +854,7 @@ def subtract(
853854
self,
854855
y: Array,
855856
/,
856-
copy: bool | None = True,
857+
copy: bool | None = None,
857858
xp: ModuleType | None = None,
858859
) -> Array: # numpydoc ignore=PR01,RT01
859860
"""Apply ``x[idx] -= y`` and return the updated array."""
@@ -863,7 +864,7 @@ def multiply(
863864
self,
864865
y: Array,
865866
/,
866-
copy: bool | None = True,
867+
copy: bool | None = None,
867868
xp: ModuleType | None = None,
868869
) -> Array: # numpydoc ignore=PR01,RT01
869870
"""Apply ``x[idx] *= y`` and return the updated array."""
@@ -873,7 +874,7 @@ def divide(
873874
self,
874875
y: Array,
875876
/,
876-
copy: bool | None = True,
877+
copy: bool | None = None,
877878
xp: ModuleType | None = None,
878879
) -> Array: # numpydoc ignore=PR01,RT01
879880
"""Apply ``x[idx] /= y`` and return the updated array."""
@@ -883,7 +884,7 @@ def power(
883884
self,
884885
y: Array,
885886
/,
886-
copy: bool | None = True,
887+
copy: bool | None = None,
887888
xp: ModuleType | None = None,
888889
) -> Array: # numpydoc ignore=PR01,RT01
889890
"""Apply ``x[idx] **= y`` and return the updated array."""
@@ -893,7 +894,7 @@ def min(
893894
self,
894895
y: Array,
895896
/,
896-
copy: bool | None = True,
897+
copy: bool | None = None,
897898
xp: ModuleType | None = None,
898899
) -> Array: # numpydoc ignore=PR01,RT01
899900
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array."""
@@ -906,7 +907,7 @@ def max(
906907
self,
907908
y: Array,
908909
/,
909-
copy: bool | None = True,
910+
copy: bool | None = None,
910911
xp: ModuleType | None = None,
911912
) -> Array: # numpydoc ignore=PR01,RT01
912913
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array."""

tests/test_at.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
7272
({"copy": True}, True),
7373
({"copy": False}, False),
7474
({"copy": None}, None), # Behavior is backend-specific
75-
({}, True), # Test that the copy parameter defaults to True
75+
({}, None), # Test that the copy parameter defaults to None
7676
],
7777
)
7878
@pytest.mark.parametrize(
@@ -125,12 +125,12 @@ def test_xp():
125125

126126
def test_alternate_index_syntax():
127127
a = np.asarray([1, 2, 3])
128-
assert_array_equal(at(a, 0).set(4), [4, 2, 3])
129-
assert_array_equal(at(a)[0].set(4), [4, 2, 3])
128+
assert_array_equal(at(a, 0).set(4, copy=True), [4, 2, 3])
129+
assert_array_equal(at(a)[0].set(4, copy=True), [4, 2, 3])
130130

131131
a_at = at(a)
132-
assert_array_equal(a_at[0].add(1), [2, 2, 3])
133-
assert_array_equal(a_at[1].add(2), [1, 4, 3])
132+
assert_array_equal(a_at[0].add(1, copy=True), [2, 2, 3])
133+
assert_array_equal(a_at[1].add(2, copy=True), [1, 4, 3])
134134

135135
with pytest.raises(ValueError, match="Index"):
136136
at(a).set(4)

0 commit comments

Comments
 (0)