Skip to content

Commit 0f9c857

Browse files
committed
Added tests for in-place multiplication, subtraction
- Adjusted tests for in-place addition to improve coverage
1 parent cb87b68 commit 0f9c857

File tree

3 files changed

+109
-1
lines changed

3 files changed

+109
-1
lines changed

dpctl/tests/elementwise/test_add.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ def test_add_inplace_errors():
423423
ar1 = np.ones(2, dtype="float32")
424424
ar2 = dpt.ones(2, dtype="float32")
425425
with pytest.raises(TypeError):
426-
ar1 += ar2
426+
dpt.add.inplace(ar1, ar2)
427427

428428
ar1 = dpt.ones(2, dtype="float32")
429429
ar2 = dict()

dpctl/tests/elementwise/test_multiply.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,59 @@ def test_multiply_python_scalar_gh1219(arr_dt, sc):
172172
R = dpt.multiply(sc, X)
173173
Rnp = np.multiply(sc, Xnp)
174174
assert _compare_dtypes(R.dtype, Rnp.dtype, sycl_queue=q)
175+
176+
177+
@pytest.mark.parametrize("dtype", _all_dtypes)
178+
def test_multiply_inplace_python_scalar(dtype):
179+
q = get_queue_or_skip()
180+
skip_if_dtype_not_supported(dtype, q)
181+
X = dpt.ones((10, 10), dtype=dtype, sycl_queue=q)
182+
dt_kind = X.dtype.kind
183+
if dt_kind in "ui":
184+
X *= int(1)
185+
elif dt_kind == "f":
186+
X *= float(1)
187+
elif dt_kind == "c":
188+
X *= complex(1)
189+
elif dt_kind == "b":
190+
X *= bool(1)
191+
192+
193+
@pytest.mark.parametrize("op1_dtype", _all_dtypes)
194+
@pytest.mark.parametrize("op2_dtype", _all_dtypes)
195+
def test_multiply_inplace_dtype_matrix(op1_dtype, op2_dtype):
196+
q = get_queue_or_skip()
197+
skip_if_dtype_not_supported(op1_dtype, q)
198+
skip_if_dtype_not_supported(op2_dtype, q)
199+
200+
sz = 127
201+
ar1 = dpt.ones(sz, dtype=op1_dtype)
202+
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
203+
204+
if dpt.can_cast(op2_dtype, op1_dtype, casting="safe"):
205+
ar1 *= ar2
206+
assert (
207+
dpt.asnumpy(ar1) == np.full(ar1.shape, 1, dtype=ar1.dtype)
208+
).all()
209+
210+
ar3 = dpt.ones(sz, dtype=op1_dtype)
211+
ar4 = dpt.ones(2 * sz, dtype=op2_dtype)
212+
213+
ar3[::-1] *= ar4[::2]
214+
assert (
215+
dpt.asnumpy(ar3) == np.full(ar3.shape, 1, dtype=ar3.dtype)
216+
).all()
217+
218+
else:
219+
with pytest.raises(TypeError):
220+
ar1 *= ar2
221+
222+
223+
def test_multiply_inplace_broadcasting():
224+
get_queue_or_skip()
225+
226+
m = dpt.ones((100, 5), dtype="i4")
227+
v = dpt.arange(5, dtype="i4")
228+
229+
m *= v
230+
assert (dpt.asnumpy(m) == np.arange(0, 5, dtype="i4")[np.newaxis, :]).all()

dpctl/tests/elementwise/test_subtract.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,55 @@ def test_subtract_python_scalar(arr_dt):
169169
assert isinstance(R, dpt.usm_ndarray)
170170
R = dpt.subtract(sc, X)
171171
assert isinstance(R, dpt.usm_ndarray)
172+
173+
174+
@pytest.mark.parametrize("dtype", _all_dtypes[1:])
175+
def test_subtract_inplace_python_scalar(dtype):
176+
q = get_queue_or_skip()
177+
skip_if_dtype_not_supported(dtype, q)
178+
X = dpt.zeros((10, 10), dtype=dtype, sycl_queue=q)
179+
dt_kind = X.dtype.kind
180+
if dt_kind in "ui":
181+
X -= int(0)
182+
elif dt_kind == "f":
183+
X -= float(0)
184+
elif dt_kind == "c":
185+
X -= complex(0)
186+
187+
188+
@pytest.mark.parametrize("op1_dtype", _all_dtypes[1:])
189+
@pytest.mark.parametrize("op2_dtype", _all_dtypes[1:])
190+
def test_subtract_inplace_dtype_matrix(op1_dtype, op2_dtype):
191+
q = get_queue_or_skip()
192+
skip_if_dtype_not_supported(op1_dtype, q)
193+
skip_if_dtype_not_supported(op2_dtype, q)
194+
195+
sz = 127
196+
ar1 = dpt.ones(sz, dtype=op1_dtype)
197+
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
198+
199+
if dpt.can_cast(op2_dtype, op1_dtype, casting="safe"):
200+
ar1 -= ar2
201+
assert (dpt.asnumpy(ar1) == np.zeros(ar1.shape, dtype=ar1.dtype)).all()
202+
203+
ar3 = dpt.ones(sz, dtype=op1_dtype)
204+
ar4 = dpt.ones(2 * sz, dtype=op2_dtype)
205+
206+
ar3[::-1] -= ar4[::2]
207+
assert (dpt.asnumpy(ar3) == np.zeros(ar3.shape, dtype=ar3.dtype)).all()
208+
209+
else:
210+
with pytest.raises(TypeError):
211+
ar1 -= ar2
212+
213+
214+
def test_subtract_inplace_broadcasting():
215+
get_queue_or_skip()
216+
217+
m = dpt.ones((100, 5), dtype="i4")
218+
v = dpt.arange(5, dtype="i4")
219+
220+
m -= v
221+
assert (
222+
dpt.asnumpy(m) == np.arange(1, -4, step=-1, dtype="i4")[np.newaxis, :]
223+
).all()

0 commit comments

Comments
 (0)