Skip to content

Commit ccf66de

Browse files
committed
In-place operators now enabled for an array with itself
- This change fixes some failing tests - Added additional tests for in-place addition
1 parent 9f942e0 commit ccf66de

File tree

2 files changed

+47
-5
lines changed

2 files changed

+47
-5
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,10 @@ def inplace(self, lhs, val):
650650
raise TypeError(
651651
f"Expected dpctl.tensor.usm_ndarray, got {type(lhs)}"
652652
)
653+
if isinstance(val, dpt.usm_ndarray):
654+
if ti._array_overlap(lhs, val):
655+
# call standard operator in this case
656+
return self(lhs, val)
653657
q1, lhs_usm_type = _get_queue_usm_type(lhs)
654658
q2, val_usm_type = _get_queue_usm_type(val)
655659
if q2 is None:

dpctl/tests/elementwise/test_add.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -366,11 +366,11 @@ def test_add_inplace_dtype_matrix(op1_dtype, op2_dtype):
366366
skip_if_dtype_not_supported(op1_dtype, q)
367367
skip_if_dtype_not_supported(op2_dtype, q)
368368

369-
if dpt.can_cast(op2_dtype, op1_dtype, casting="safe"):
370-
sz = 127
371-
ar1 = dpt.ones(sz, dtype=op1_dtype)
372-
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
369+
sz = 127
370+
ar1 = dpt.ones(sz, dtype=op1_dtype)
371+
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
373372

373+
if dpt.can_cast(op2_dtype, op1_dtype, casting="safe"):
374374
ar1 += ar2
375375
assert (
376376
dpt.asnumpy(ar1) == np.full(ar1.shape, 2, dtype=ar1.dtype)
@@ -385,7 +385,8 @@ def test_add_inplace_dtype_matrix(op1_dtype, op2_dtype):
385385
).all()
386386

387387
else:
388-
assert pytest.raises(TypeError)
388+
with pytest.raises(TypeError):
389+
ar1 += ar2
389390

390391

391392
def test_add_inplace_broadcasting():
@@ -396,3 +397,40 @@ def test_add_inplace_broadcasting():
396397

397398
m += v
398399
assert (dpt.asnumpy(m) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all()
400+
401+
402+
def test_add_inplace_errors():
403+
get_queue_or_skip()
404+
try:
405+
gpu_queue = dpctl.SyclQueue("gpu")
406+
except dpctl.SyclQueueCreationError:
407+
pytest.skip("SyclQueue('gpu') failed, skipping")
408+
try:
409+
cpu_queue = dpctl.SyclQueue("cpu")
410+
except dpctl.SyclQueueCreationError:
411+
pytest.skip("SyclQueue('cpu') failed, skipping")
412+
413+
ar1 = dpt.ones(2, dtype="float32", sycl_queue=gpu_queue)
414+
ar2 = dpt.ones_like(ar1, sycl_queue=cpu_queue)
415+
with pytest.raises(ExecutionPlacementError):
416+
ar1 += ar2
417+
418+
ar1 = dpt.ones(2, dtype="float32")
419+
ar2 = dpt.ones(3, dtype="float32")
420+
with pytest.raises(ValueError):
421+
ar1 += ar2
422+
423+
ar1 = np.ones(2, dtype="float32")
424+
ar2 = dpt.ones(2, dtype="float32")
425+
with pytest.raises(TypeError):
426+
ar1 += ar2
427+
428+
ar1 = dpt.ones(2, dtype="float32")
429+
ar2 = dict()
430+
with pytest.raises(ValueError):
431+
ar1 += ar2
432+
433+
ar1 = dpt.ones((2, 1), dtype="float32")
434+
ar2 = dpt.ones((1, 2), dtype="float32")
435+
with pytest.raises(ValueError):
436+
ar1 += ar2

0 commit comments

Comments
 (0)