Skip to content

Commit 5ebd6c3

Browse files
Set default dtype of usm_ndarray depending on capabilities of device (#1265)
* Change default dtype value in usm_ndarray * Use select_default_device() instead of SyclDevice()
1 parent bc59de9 commit 5ebd6c3

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

dpctl/tensor/_usmarray.pyx

+13-3
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ cimport dpctl.memory as c_dpmem
3737
cimport dpctl.tensor._dlpack as c_dlpack
3838

3939
import dpctl.tensor._flags as _flags
40+
from dpctl.tensor._tensor_impl import default_device_fp_type
4041

4142
include "_stride_utils.pxi"
4243
include "_types.pxi"
@@ -104,7 +105,7 @@ cdef class InternalUSMArrayError(Exception):
104105

105106

106107
cdef class usm_ndarray:
107-
""" usm_ndarray(shape, dtype="|f8", strides=None, buffer="device", \
108+
""" usm_ndarray(shape, dtype=None, strides=None, buffer="device", \
108109
offset=0, order="C", buffer_ctor_kwargs=dict(), \
109110
array_namespace=None)
110111
@@ -116,6 +117,8 @@ cdef class usm_ndarray:
116117
Shape of the array to be created.
117118
dtype (str, dtype):
118119
Array data type, i.e. the type of array elements.
120+
If ``dtype`` has the value ``None``, it is determined by default
121+
floating point type supported by target device.
119122
The supported types are
120123
* ``bool``
121124
boolean type
@@ -134,7 +137,7 @@ cdef class usm_ndarray:
134137
double-precision real and complex floating
135138
types, supported if target device's property
136139
``has_aspect_fp64`` is ``True``.
137-
Default: ``"|f8"``.
140+
Default: ``None``.
138141
strides (tuple, optional):
139142
Strides of the array to be created in elements.
140143
If ``strides`` has the value ``None``, it is determined by the
@@ -219,7 +222,7 @@ cdef class usm_ndarray:
219222
"Data pointers of cloned and original objects are different.")
220223
return res
221224

222-
def __cinit__(self, shape, dtype="|f8", strides=None, buffer='device',
225+
def __cinit__(self, shape, dtype=None, strides=None, buffer='device',
223226
Py_ssize_t offset=0, order='C',
224227
buffer_ctor_kwargs=dict(),
225228
array_namespace=None):
@@ -252,6 +255,13 @@ cdef class usm_ndarray:
252255
except Exception:
253256
raise TypeError("Argument shape must be a list or a tuple.")
254257
nd = len(shape)
258+
if dtype is None:
259+
q = buffer_ctor_kwargs.get("queue")
260+
if q is not None:
261+
dtype = default_device_fp_type(q)
262+
else:
263+
dev = dpctl.select_default_device()
264+
dtype = "f8" if dev.has_aspect_fp64 else "f4"
255265
typenum = dtype_to_typenum(dtype)
256266
if (typenum < 0):
257267
if typenum == -2:

dpctl/tests/test_usm_ndarray_ctor.py

+20
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,26 @@ def test_dtypes(dtype):
140140
assert expected_fmt == actual_fmt
141141

142142

143+
@pytest.mark.parametrize("usm_type", ["device", "shared", "host"])
144+
@pytest.mark.parametrize("buffer_ctor_kwargs", [dict(), {"queue": None}])
145+
def test_default_dtype(usm_type, buffer_ctor_kwargs):
146+
q = get_queue_or_skip()
147+
dev = q.get_sycl_device()
148+
if buffer_ctor_kwargs:
149+
buffer_ctor_kwargs["queue"] = q
150+
Xusm = dpt.usm_ndarray(
151+
(1,), buffer=usm_type, buffer_ctor_kwargs=buffer_ctor_kwargs
152+
)
153+
if dev.has_aspect_fp64:
154+
expected_dtype = "f8"
155+
else:
156+
expected_dtype = "f4"
157+
assert Xusm.itemsize == dpt.dtype(expected_dtype).itemsize
158+
expected_fmt = (dpt.dtype(expected_dtype).str)[1:]
159+
actual_fmt = Xusm.__sycl_usm_array_interface__["typestr"][1:]
160+
assert expected_fmt == actual_fmt
161+
162+
143163
@pytest.mark.parametrize(
144164
"dtype",
145165
[

0 commit comments

Comments
 (0)