Skip to content

Commit 17ca447

Browse files
committed
only treat integers as arrays when arrays have been consecutive
this fixes cases like `x[..., arr1, int1, arr2, int2, :, int3]`, the final int will be treated as an integer instead of an array, and `arr1` through `int2` will be treated as arrays
1 parent accb135 commit 17ca447

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

dpctl/tensor/_slicing.pxi

+13-3
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
176176
array_streak_interrupted = True
177177
elif _is_integral(i):
178178
axes_referenced += 1
179-
if array_streak_started:
179+
if array_streak_started and not array_streak_interrupted:
180180
# integers converted to arrays in this case
181181
array_count += 1
182182
else:
@@ -227,6 +227,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
227227
advanced_start_pos_set = False
228228
new_offset = offset
229229
is_empty = False
230+
array_streak = False
230231
for i in range(len(ind)):
231232
ind_i = ind[i]
232233
if (ind_i is Ellipsis):
@@ -237,9 +238,13 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
237238
is_empty = True
238239
new_offset = offset
239240
k = k_new
241+
if array_streak:
242+
array_streak = False
240243
elif ind_i is None:
241244
new_shape.append(1)
242245
new_strides.append(0)
246+
if array_streak:
247+
array_streak = False
243248
elif isinstance(ind_i, slice):
244249
k_new = k + 1
245250
sl_start, sl_stop, sl_step = ind_i.indices(shape[k])
@@ -253,13 +258,16 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
253258
is_empty = True
254259
new_offset = offset
255260
k = k_new
261+
if array_streak:
262+
array_streak = False
256263
elif _is_boolean(ind_i):
257264
new_shape.append(1 if ind_i else 0)
258265
new_strides.append(0)
266+
if array_streak:
267+
array_streak = False
259268
elif _is_integral(ind_i):
260269
ind_i = ind_i.__index__()
261-
if advanced_start_pos_set:
262-
# integers converted to arrays in this case
270+
if array_streak:
263271
new_advanced_ind.append(ind_i)
264272
k_new = k + 1
265273
new_shape.extend(shape[k:k_new])
@@ -281,6 +289,8 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
281289
("Index {0} is out of range for "
282290
"axes {1} with size {2}").format(ind_i, k, shape[k]))
283291
elif isinstance(ind_i, usm_ndarray):
292+
if not array_streak:
293+
array_streak = True
284294
if not advanced_start_pos_set:
285295
new_advanced_start_pos = len(new_shape)
286296
advanced_start_pos_set = True

0 commit comments

Comments
 (0)