Skip to content

Commit b19cc11

Browse files
committed
raise on incompatible cast
1 parent 38562db commit b19cc11

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

src/array_api_extra/_funcs.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,11 @@ def add(
848848
xp: ModuleType | None = None,
849849
) -> Array: # numpydoc ignore=PR01,RT01
850850
"""Apply ``x[idx] += y`` and return the updated array."""
851-
return self._iop("add", operator.add, y, copy=copy, xp=xp)
851+
852+
# Note for this and all other methods based on _iop:
853+
# operator.iadd and operator.add subtly differ in behaviour, as
854+
# only iadd will trigger exceptions when y has an incompatible dtype.
855+
return self._iop("add", operator.iadd, y, copy=copy, xp=xp)
852856

853857
def subtract(
854858
self,
@@ -858,7 +862,7 @@ def subtract(
858862
xp: ModuleType | None = None,
859863
) -> Array: # numpydoc ignore=PR01,RT01
860864
"""Apply ``x[idx] -= y`` and return the updated array."""
861-
return self._iop("subtract", operator.sub, y, copy=copy, xp=xp)
865+
return self._iop("subtract", operator.isub, y, copy=copy, xp=xp)
862866

863867
def multiply(
864868
self,
@@ -868,7 +872,7 @@ def multiply(
868872
xp: ModuleType | None = None,
869873
) -> Array: # numpydoc ignore=PR01,RT01
870874
"""Apply ``x[idx] *= y`` and return the updated array."""
871-
return self._iop("multiply", operator.mul, y, copy=copy, xp=xp)
875+
return self._iop("multiply", operator.imul, y, copy=copy, xp=xp)
872876

873877
def divide(
874878
self,
@@ -878,7 +882,7 @@ def divide(
878882
xp: ModuleType | None = None,
879883
) -> Array: # numpydoc ignore=PR01,RT01
880884
"""Apply ``x[idx] /= y`` and return the updated array."""
881-
return self._iop("divide", operator.truediv, y, copy=copy, xp=xp)
885+
return self._iop("divide", operator.itruediv, y, copy=copy, xp=xp)
882886

883887
def power(
884888
self,
@@ -888,7 +892,7 @@ def power(
888892
xp: ModuleType | None = None,
889893
) -> Array: # numpydoc ignore=PR01,RT01
890894
"""Apply ``x[idx] **= y`` and return the updated array."""
891-
return self._iop("power", operator.pow, y, copy=copy, xp=xp)
895+
return self._iop("power", operator.ipow, y, copy=copy, xp=xp)
892896

893897
def min(
894898
self,

tests/test_at.py

+20
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,23 @@ def test_alternate_index_syntax():
136136
at(a).set(4)
137137
with pytest.raises(ValueError, match="Index"):
138138
at(a, 0)[0].set(4)
139+
140+
141+
@pytest.mark.parametrize("copy", [True, False])
142+
@pytest.mark.parametrize("op", ["add", "subtract", "multiply", "divide", "power"])
143+
def test_iops_incompatible_dtype(op, copy):
144+
"""Test that at() replicates the backend's behaviour for
145+
in-place operations with incompatible dtypes.
146+
147+
Note:
148+
>>> a = np.asarray([1, 2, 3])
149+
>>> a / 1.5
150+
array([0. , 0.66666667, 1.33333333])
151+
>>> a /= 1.5
152+
UFuncTypeError: Cannot cast ufunc 'divide' output from dtype('float64')
153+
to dtype('int64') with casting rule 'same_kind'
154+
"""
155+
a = np.asarray([2, 4])
156+
func = getattr(at(a)[:], op)
157+
with pytest.raises(TypeError, match="Cannot cast ufunc"):
158+
func(1.1, copy=copy)

0 commit comments

Comments
 (0)