Skip to content

Commit 47b2921

Browse files
committed
Adjusted dtype matrix test for in-place operators
- By using _type_utils._can_cast, prevents failures on platforms where the maximal precision type may not be 64-bit
1 parent 2e7cbe0 commit 47b2921

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

dpctl/tests/elementwise/test_add.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import dpctl
2424
import dpctl.tensor as dpt
25+
from dpctl.tensor._type_utils import _can_cast
2526
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2627
from dpctl.utils import ExecutionPlacementError
2728

@@ -371,7 +372,10 @@ def test_add_inplace_dtype_matrix(op1_dtype, op2_dtype):
371372
ar1 = dpt.ones(sz, dtype=op1_dtype)
372373
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
373374

374-
if dpt.can_cast(op2_dtype, op1_dtype, casting="safe"):
375+
dev = q.sycl_device
376+
_fp16 = dev.has_aspect_fp16
377+
_fp64 = dev.has_aspect_fp64
378+
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
375379
ar1 += ar2
376380
assert (
377381
dpt.asnumpy(ar1) == np.full(ar1.shape, 2, dtype=ar1.dtype)

dpctl/tests/elementwise/test_multiply.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import dpctl
2323
import dpctl.tensor as dpt
24+
from dpctl.tensor._type_utils import _can_cast
2425
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2526

2627
from .utils import _all_dtypes, _compare_dtypes, _usm_types
@@ -201,7 +202,10 @@ def test_multiply_inplace_dtype_matrix(op1_dtype, op2_dtype):
201202
ar1 = dpt.ones(sz, dtype=op1_dtype)
202203
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
203204

204-
if dpt.can_cast(op2_dtype, op1_dtype, casting="safe"):
205+
dev = q.sycl_device
206+
_fp16 = dev.has_aspect_fp16
207+
_fp64 = dev.has_aspect_fp64
208+
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
205209
ar1 *= ar2
206210
assert (
207211
dpt.asnumpy(ar1) == np.full(ar1.shape, 1, dtype=ar1.dtype)

dpctl/tests/elementwise/test_subtract.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import dpctl
2323
import dpctl.tensor as dpt
24+
from dpctl.tensor._type_utils import _can_cast
2425
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2526

2627
from .utils import _all_dtypes, _compare_dtypes, _usm_types
@@ -196,7 +197,10 @@ def test_subtract_inplace_dtype_matrix(op1_dtype, op2_dtype):
196197
ar1 = dpt.ones(sz, dtype=op1_dtype)
197198
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
198199

199-
if dpt.can_cast(op2_dtype, op1_dtype, casting="safe"):
200+
dev = q.sycl_device
201+
_fp16 = dev.has_aspect_fp16
202+
_fp64 = dev.has_aspect_fp64
203+
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
200204
ar1 -= ar2
201205
assert (dpt.asnumpy(ar1) == np.zeros(ar1.shape, dtype=ar1.dtype)).all()
202206

0 commit comments

Comments
 (0)