@@ -37,6 +37,7 @@ cimport dpctl.memory as c_dpmem
37
37
cimport dpctl.tensor._dlpack as c_dlpack
38
38
39
39
import dpctl.tensor._flags as _flags
40
+ from dpctl.tensor._tensor_impl import default_device_fp_type
40
41
41
42
include " _stride_utils.pxi"
42
43
include " _types.pxi"
@@ -104,7 +105,7 @@ cdef class InternalUSMArrayError(Exception):
104
105
105
106
106
107
cdef class usm_ndarray:
107
- """ usm_ndarray(shape, dtype="|f8" , strides=None, buffer="device", \
108
+ """ usm_ndarray(shape, dtype=None , strides=None, buffer="device", \
108
109
offset=0, order="C", buffer_ctor_kwargs=dict(), \
109
110
array_namespace=None)
110
111
@@ -116,6 +117,8 @@ cdef class usm_ndarray:
116
117
Shape of the array to be created.
117
118
dtype (str, dtype):
118
119
Array data type, i.e. the type of array elements.
120
+ If ``dtype`` has the value ``None``, it is determined by default
121
+ floating point type supported by target device.
119
122
The supported types are
120
123
* ``bool``
121
124
boolean type
@@ -134,7 +137,7 @@ cdef class usm_ndarray:
134
137
double-precision real and complex floating
135
138
types, supported if target device's property
136
139
``has_aspect_fp64`` is ``True``.
137
- Default: ``"|f8" ``.
140
+ Default: ``None ``.
138
141
strides (tuple, optional):
139
142
Strides of the array to be created in elements.
140
143
If ``strides`` has the value ``None``, it is determined by the
@@ -219,7 +222,7 @@ cdef class usm_ndarray:
219
222
" Data pointers of cloned and original objects are different." )
220
223
return res
221
224
222
- def __cinit__ (self , shape , dtype = " |f8 " , strides = None , buffer = ' device' ,
225
+ def __cinit__ (self , shape , dtype = None , strides = None , buffer = ' device' ,
223
226
Py_ssize_t offset = 0 , order = ' C' ,
224
227
buffer_ctor_kwargs = dict (),
225
228
array_namespace = None ):
@@ -252,6 +255,13 @@ cdef class usm_ndarray:
252
255
except Exception :
253
256
raise TypeError (" Argument shape must be a list or a tuple." )
254
257
nd = len (shape)
258
+ if dtype is None :
259
+ q = buffer_ctor_kwargs.get(" queue" )
260
+ if q is not None :
261
+ dtype = default_device_fp_type(q)
262
+ else :
263
+ dev = dpctl.select_default_device()
264
+ dtype = " f8" if dev.has_aspect_fp64 else " f4"
255
265
typenum = dtype_to_typenum(dtype)
256
266
if (typenum < 0 ):
257
267
if typenum == - 2 :
0 commit comments