Skip to content

Commit 7960513

Browse files
Merge pull request #2229 from IntelPython/fix_repeat_scalar_conv
Update `dpt.repeat()` to explicitly extract the single element from size-1 1D usm_ndarray repeats before converting it to a Python scalar and extends tests to cover both 0D and 1D repeats cases
2 parents 0797836 + cba240e commit 7960513

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,12 @@ def repeat(x, repeats, /, *, axis=None):
829829
if repeats.size == 1:
830830
scalar = True
831831
# bring the single element to the host
832-
repeats = int(repeats)
832+
if repeats.ndim == 0:
833+
repeats = int(repeats)
834+
else:
835+
# Get the single element explicitly
836+
# since non-0D arrays can not be converted to scalars
837+
repeats = int(repeats[0])
833838
if repeats < 0:
834839
raise ValueError("`repeats` elements must be positive")
835840
else:

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,6 +1342,21 @@ def test_repeat_strided_repeats():
13421342
assert dpt.all(res == x)
13431343

13441344

1345+
def test_repeat_size1_repeats():
1346+
get_queue_or_skip()
1347+
1348+
x = dpt.arange(5, dtype="i4")
1349+
expected_res = dpt.repeat(x, 2)
1350+
# 0D repeats
1351+
reps_0d = dpt.asarray(2, dtype="i8")
1352+
res = dpt.repeat(x, reps_0d)
1353+
assert dpt.all(res == expected_res)
1354+
# 1D repeats
1355+
reps_1d = dpt.asarray([2], dtype="i8")
1356+
res = dpt.repeat(x, reps_1d)
1357+
assert dpt.all(res == expected_res)
1358+
1359+
13451360
def test_repeat_arg_validation():
13461361
get_queue_or_skip()
13471362

0 commit comments

Comments
 (0)