Skip to content

Commit 7757857

Browse files
authored
Type promotion for indices arrays and casting vals in integer indexing (#1647)
* Tweaks to advanced integer indexing Setting items in an array now casts the right-hand side to the array data type when the data types differ Setting and getting from an empty axis with non-empty indices now throws `IndexError` * Integer advanced indexing now promotes indices arrays * `put` now casts `vals` when the data type differs from `x` Fixes `take` and `put` being used on non-empty axes with non-empty indices Also adds a note to `put` about race conditions for non-unique indices * Adds tests for indexing array casting for indices and values * Fixes range when checking for empty axes in _take/_put_multi_index Also corrects error raised in _put_multi_index when attempting to put into indices along an empty axis * Changes per PR review
1 parent f5c6610 commit 7757857

File tree

4 files changed

+237
-60
lines changed

4 files changed

+237
-60
lines changed

dpctl/tensor/_copy_utils.py

+85-39
Original file line numberDiff line numberDiff line change
@@ -763,44 +763,66 @@ def _nonzero_impl(ary):
763763

764764
def _take_multi_index(ary, inds, p):
765765
if not isinstance(ary, dpt.usm_ndarray):
766-
raise TypeError
766+
raise TypeError(
767+
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
768+
)
769+
ary_nd = ary.ndim
770+
p = normalize_axis_index(operator.index(p), ary_nd)
767771
queues_ = [
768772
ary.sycl_queue,
769773
]
770774
usm_types_ = [
771775
ary.usm_type,
772776
]
773-
if not isinstance(inds, list) and not isinstance(inds, tuple):
777+
if not isinstance(inds, (list, tuple)):
774778
inds = (inds,)
775-
all_integers = True
776779
for ind in inds:
780+
if not isinstance(ind, dpt.usm_ndarray):
781+
raise TypeError("all elements of `ind` expected to be usm_ndarrays")
777782
queues_.append(ind.sycl_queue)
778783
usm_types_.append(ind.usm_type)
779-
if all_integers:
780-
all_integers = ind.dtype.kind in "ui"
784+
if ind.dtype.kind not in "ui":
785+
raise IndexError(
786+
"arrays used as indices must be of integer (or boolean) type"
787+
)
788+
res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
781789
exec_q = dpctl.utils.get_execution_queue(queues_)
782790
if exec_q is None:
783-
raise dpctl.utils.ExecutionPlacementError("")
784-
if not all_integers:
785-
raise IndexError(
786-
"arrays used as indices must be of integer (or boolean) type"
791+
raise dpctl.utils.ExecutionPlacementError(
792+
"Can not automatically determine where to allocate the "
793+
"result or performance execution. "
794+
"Use `usm_ndarray.to_device` method to migrate data to "
795+
"be associated with the same queue."
787796
)
788797
if len(inds) > 1:
798+
ind_dt = dpt.result_type(*inds)
799+
# ind arrays have been checked to be of integer dtype
800+
if ind_dt.kind not in "ui":
801+
raise ValueError(
802+
"cannot safely promote indices to an integer data type"
803+
)
804+
inds = tuple(
805+
map(
806+
lambda ind: ind
807+
if ind.dtype == ind_dt
808+
else dpt.astype(ind, ind_dt),
809+
inds,
810+
)
811+
)
789812
inds = dpt.broadcast_arrays(*inds)
790-
ary_ndim = ary.ndim
791-
p = normalize_axis_index(operator.index(p), ary_ndim)
792-
793-
res_shape = ary.shape[:p] + inds[0].shape + ary.shape[p + len(inds) :]
794-
res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
813+
ind0 = inds[0]
814+
ary_sh = ary.shape
815+
p_end = p + len(inds)
816+
if 0 in ary_sh[p:p_end] and ind0.size != 0:
817+
raise IndexError("cannot take non-empty indices from an empty axis")
818+
res_shape = ary_sh[:p] + ind0.shape + ary_sh[p_end:]
795819
res = dpt.empty(
796820
res_shape, dtype=ary.dtype, usm_type=res_usm_type, sycl_queue=exec_q
797821
)
798-
799822
hev, _ = ti._take(
800823
src=ary, ind=inds, dst=res, axis_start=p, mode=0, sycl_queue=exec_q
801824
)
802825
hev.wait()
803-
804826
return res
805827

806828

@@ -864,6 +886,12 @@ def _place_impl(ary, ary_mask, vals, axis=0):
864886

865887

866888
def _put_multi_index(ary, inds, p, vals):
889+
if not isinstance(ary, dpt.usm_ndarray):
890+
raise TypeError(
891+
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
892+
)
893+
ary_nd = ary.ndim
894+
p = normalize_axis_index(operator.index(p), ary_nd)
867895
if isinstance(vals, dpt.usm_ndarray):
868896
queues_ = [ary.sycl_queue, vals.sycl_queue]
869897
usm_types_ = [ary.usm_type, vals.usm_type]
@@ -874,46 +902,64 @@ def _put_multi_index(ary, inds, p, vals):
874902
usm_types_ = [
875903
ary.usm_type,
876904
]
877-
if not isinstance(inds, list) and not isinstance(inds, tuple):
905+
if not isinstance(inds, (list, tuple)):
878906
inds = (inds,)
879-
all_integers = True
880907
for ind in inds:
881908
if not isinstance(ind, dpt.usm_ndarray):
882-
raise TypeError
909+
raise TypeError("all elements of `ind` expected to be usm_ndarrays")
883910
queues_.append(ind.sycl_queue)
884911
usm_types_.append(ind.usm_type)
885-
if all_integers:
886-
all_integers = ind.dtype.kind in "ui"
912+
if ind.dtype.kind not in "ui":
913+
raise IndexError(
914+
"arrays used as indices must be of integer (or boolean) type"
915+
)
916+
vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
887917
exec_q = dpctl.utils.get_execution_queue(queues_)
918+
if exec_q is not None:
919+
if not isinstance(vals, dpt.usm_ndarray):
920+
vals = dpt.asarray(
921+
vals, dtype=ary.dtype, usm_type=vals_usm_type, sycl_queue=exec_q
922+
)
923+
else:
924+
exec_q = dpctl.utils.get_execution_queue((exec_q, vals.sycl_queue))
888925
if exec_q is None:
889926
raise dpctl.utils.ExecutionPlacementError(
890927
"Can not automatically determine where to allocate the "
891928
"result or performance execution. "
892929
"Use `usm_ndarray.to_device` method to migrate data to "
893930
"be associated with the same queue."
894931
)
895-
if not all_integers:
896-
raise IndexError(
897-
"arrays used as indices must be of integer (or boolean) type"
898-
)
899932
if len(inds) > 1:
933+
ind_dt = dpt.result_type(*inds)
934+
# ind arrays have been checked to be of integer dtype
935+
if ind_dt.kind not in "ui":
936+
raise ValueError(
937+
"cannot safely promote indices to an integer data type"
938+
)
939+
inds = tuple(
940+
map(
941+
lambda ind: ind
942+
if ind.dtype == ind_dt
943+
else dpt.astype(ind, ind_dt),
944+
inds,
945+
)
946+
)
900947
inds = dpt.broadcast_arrays(*inds)
901-
ary_ndim = ary.ndim
902-
903-
p = normalize_axis_index(operator.index(p), ary_ndim)
904-
vals_shape = ary.shape[:p] + inds[0].shape + ary.shape[p + len(inds) :]
905-
906-
vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
907-
if not isinstance(vals, dpt.usm_ndarray):
908-
vals = dpt.asarray(
909-
vals, dtype=ary.dtype, usm_type=vals_usm_type, sycl_queue=exec_q
948+
ind0 = inds[0]
949+
ary_sh = ary.shape
950+
p_end = p + len(inds)
951+
if 0 in ary_sh[p:p_end] and ind0.size != 0:
952+
raise IndexError(
953+
"cannot put into non-empty indices along an empty axis"
910954
)
911-
912-
vals = dpt.broadcast_to(vals, vals_shape)
913-
955+
expected_vals_shape = ary_sh[:p] + ind0.shape + ary_sh[p_end:]
956+
if vals.dtype == ary.dtype:
957+
rhs = vals
958+
else:
959+
rhs = dpt.astype(vals, ary.dtype)
960+
rhs = dpt.broadcast_to(rhs, expected_vals_shape)
914961
hev, _ = ti._put(
915-
dst=ary, ind=inds, val=vals, axis_start=p, mode=0, sycl_queue=exec_q
962+
dst=ary, ind=inds, val=rhs, axis_start=p, mode=0, sycl_queue=exec_q
916963
)
917964
hev.wait()
918-
919965
return

dpctl/tensor/_indexing_functions.py

+40-21
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import operator
1818

19-
import numpy as np
2019
from numpy.core.numeric import normalize_axis_index
2120

2221
import dpctl
@@ -47,15 +46,15 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
4746
indices (usm_ndarray):
4847
One-dimensional array of indices.
4948
axis:
50-
The axis over which the values will be selected.
51-
If x is one-dimensional, this argument is optional.
52-
Default: `None`.
49+
The axis along which the values will be selected.
50+
If ``x`` is one-dimensional, this argument is optional.
51+
Default: ``None``.
5352
mode:
5453
How out-of-bounds indices will be handled.
55-
"wrap" - clamps indices to (-n <= i < n), then wraps
54+
``"wrap"`` - clamps indices to (-n <= i < n), then wraps
5655
negative indices.
57-
"clip" - clips indices to (0 <= i < n)
58-
Default: `"wrap"`.
56+
``"clip"`` - clips indices to (0 <= i < n)
57+
Default: ``"wrap"``.
5958
6059
Returns:
6160
usm_ndarray:
@@ -73,7 +72,7 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
7372
type(indices)
7473
)
7574
)
76-
if not np.issubdtype(indices.dtype, np.integer):
75+
if indices.dtype.kind not in "ui":
7776
raise IndexError(
7877
"`indices` expected integer data type, got `{}`".format(
7978
indices.dtype
@@ -104,6 +103,9 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
104103

105104
if x_ndim > 0:
106105
axis = normalize_axis_index(operator.index(axis), x_ndim)
106+
x_sh = x.shape
107+
if x_sh[axis] == 0 and indices.size != 0:
108+
raise IndexError("cannot take non-empty indices from an empty axis")
107109
res_shape = x.shape[:axis] + indices.shape + x.shape[axis + 1 :]
108110
else:
109111
if axis != 0:
@@ -130,19 +132,26 @@ def put(x, indices, vals, /, *, axis=None, mode="wrap"):
130132
The array the values will be put into.
131133
indices (usm_ndarray)
132134
One-dimensional array of indices.
135+
136+
Note that if indices are not unique, a race
137+
condition will result, and the value written to
138+
``x`` will not be deterministic.
139+
:py:func:`dpctl.tensor.unique` can be used to
140+
guarantee unique elements in ``indices``.
133141
vals:
134-
Array of values to be put into `x`.
135-
Must be broadcastable to the shape of `indices`.
142+
Array of values to be put into ``x``.
143+
Must be broadcastable to the result shape
144+
``x.shape[:axis] + indices.shape + x.shape[axis+1:]``.
136145
axis:
137-
The axis over which the values will be placed.
138-
If x is one-dimensional, this argument is optional.
139-
Default: `None`.
146+
The axis along which the values will be placed.
147+
If ``x`` is one-dimensional, this argument is optional.
148+
Default: ``None``.
140149
mode:
141150
How out-of-bounds indices will be handled.
142-
"wrap" - clamps indices to (-n <= i < n), then wraps
151+
``"wrap"`` - clamps indices to (-n <= i < n), then wraps
143152
negative indices.
144-
"clip" - clips indices to (0 <= i < n)
145-
Default: `"wrap"`.
153+
``"clip"`` - clips indices to (0 <= i < n)
154+
Default: ``"wrap"``.
146155
"""
147156
if not isinstance(x, dpt.usm_ndarray):
148157
raise TypeError(
@@ -168,7 +177,7 @@ def put(x, indices, vals, /, *, axis=None, mode="wrap"):
168177
raise ValueError(
169178
"`indices` expected a 1D array, got `{}`".format(indices.ndim)
170179
)
171-
if not np.issubdtype(indices.dtype, np.integer):
180+
if indices.dtype.kind not in "ui":
172181
raise IndexError(
173182
"`indices` expected integer data type, got `{}`".format(
174183
indices.dtype
@@ -195,7 +204,9 @@ def put(x, indices, vals, /, *, axis=None, mode="wrap"):
195204

196205
if x_ndim > 0:
197206
axis = normalize_axis_index(operator.index(axis), x_ndim)
198-
207+
x_sh = x.shape
208+
if x_sh[axis] == 0 and indices.size != 0:
209+
raise IndexError("cannot take non-empty indices from an empty axis")
199210
val_shape = x.shape[:axis] + indices.shape + x.shape[axis + 1 :]
200211
else:
201212
if axis != 0:
@@ -206,10 +217,18 @@ def put(x, indices, vals, /, *, axis=None, mode="wrap"):
206217
vals = dpt.asarray(
207218
vals, dtype=x.dtype, usm_type=vals_usm_type, sycl_queue=exec_q
208219
)
220+
# choose to throw here for consistency with `place`
221+
if vals.size == 0:
222+
raise ValueError(
223+
"cannot put into non-empty indices along an empty axis"
224+
)
225+
if vals.dtype == x.dtype:
226+
rhs = vals
227+
else:
228+
rhs = dpt.astype(vals, x.dtype)
229+
rhs = dpt.broadcast_to(rhs, val_shape)
209230

210-
vals = dpt.broadcast_to(vals, val_shape)
211-
212-
hev, _ = ti._put(x, (indices,), vals, axis, mode, sycl_queue=exec_q)
231+
hev, _ = ti._put(x, (indices,), rhs, axis, mode, sycl_queue=exec_q)
213232
hev.wait()
214233

215234

dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,10 @@ usm_ndarray_take(const dpctl::tensor::usm_ndarray &src,
414414
ind_offsets.push_back(py::ssize_t(0));
415415
}
416416

417+
if (ind_nelems == 0) {
418+
return std::make_pair(sycl::event{}, sycl::event{});
419+
}
420+
417421
char **packed_ind_ptrs = sycl::malloc_device<char *>(k, exec_q);
418422

419423
if (packed_ind_ptrs == nullptr) {
@@ -717,6 +721,10 @@ usm_ndarray_put(const dpctl::tensor::usm_ndarray &dst,
717721
ind_offsets.push_back(py::ssize_t(0));
718722
}
719723

724+
if (ind_nelems == 0) {
725+
return std::make_pair(sycl::event{}, sycl::event{});
726+
}
727+
720728
char **packed_ind_ptrs = sycl::malloc_device<char *>(k, exec_q);
721729

722730
if (packed_ind_ptrs == nullptr) {

0 commit comments

Comments
 (0)