Skip to content

Commit 45ab64c

Browse files
authored
Merge pull request #2010 from IntelPython/support-out-kwarg-take
Add `out` keyword to `dpt.take`
2 parents e7b2b1b + 6f7a653 commit 45ab64c

File tree

3 files changed

+128
-22
lines changed

3 files changed

+128
-22
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
### Added
1010

11+
* Added `out` keyword to `tensor.take` [gh-2010](https://github.com/IntelPython/dpctl/pull/2010)
12+
1113
### Changed
1214

1315
### Fixed

dpctl/tensor/_indexing_functions.py

+45-7
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ def _get_indexing_mode(name):
4040
)
4141

4242

43-
def take(x, indices, /, *, axis=None, mode="wrap"):
44-
"""take(x, indices, axis=None, mode="wrap")
43+
def take(x, indices, /, *, axis=None, out=None, mode="wrap"):
44+
"""take(x, indices, axis=None, out=None, mode="wrap")
4545
4646
Takes elements from an array along a given axis at given indices.
4747
@@ -54,6 +54,9 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
5454
The axis along which the values will be selected.
5555
If ``x`` is one-dimensional, this argument is optional.
5656
Default: ``None``.
57+
out (Optional[usm_ndarray]):
58+
Output array to populate. Array must have the correct
59+
shape and the expected data type.
5760
mode (str, optional):
5861
How out-of-bounds indices will be handled. Possible values
5962
are:
@@ -121,18 +124,53 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
121124
raise ValueError("`axis` must be 0 for an array of dimension 0.")
122125
res_shape = indices.shape
123126

124-
res = dpt.empty(
125-
res_shape, dtype=x.dtype, usm_type=res_usm_type, sycl_queue=exec_q
126-
)
127+
dt = x.dtype
128+
129+
orig_out = out
130+
if out is not None:
131+
if not isinstance(out, dpt.usm_ndarray):
132+
raise TypeError(
133+
f"output array must be of usm_ndarray type, got {type(out)}"
134+
)
135+
if not out.flags.writable:
136+
raise ValueError("provided `out` array is read-only")
137+
138+
if out.shape != res_shape:
139+
raise ValueError(
140+
"The shape of input and output arrays are inconsistent. "
141+
f"Expected output shape is {res_shape}, got {out.shape}"
142+
)
143+
if dt != out.dtype:
144+
raise ValueError(
145+
f"Output array of type {dt} is needed, got {out.dtype}"
146+
)
147+
if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None:
148+
raise dpctl.utils.ExecutionPlacementError(
149+
"Input and output allocation queues are not compatible"
150+
)
151+
if ti._array_overlap(x, out):
152+
out = dpt.empty_like(out)
153+
else:
154+
out = dpt.empty(
155+
res_shape, dtype=dt, usm_type=res_usm_type, sycl_queue=exec_q
156+
)
127157

128158
_manager = dpctl.utils.SequentialOrderManager[exec_q]
129159
deps_ev = _manager.submitted_events
130160
hev, take_ev = ti._take(
131-
x, (indices,), res, axis, mode, sycl_queue=exec_q, depends=deps_ev
161+
x, (indices,), out, axis, mode, sycl_queue=exec_q, depends=deps_ev
132162
)
133163
_manager.add_event_pair(hev, take_ev)
134164

135-
return res
165+
if not (orig_out is None or out is orig_out):
166+
# Copy the out data from temporary buffer to original memory
167+
ht_e_cpy, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
168+
src=out, dst=orig_out, sycl_queue=exec_q, depends=[take_ev]
169+
)
170+
_manager.add_event_pair(ht_e_cpy, cpy_ev)
171+
out = orig_out
172+
173+
return out
136174

137175

138176
def put(x, indices, vals, /, *, axis=None, mode="wrap"):

dpctl/tests/test_usm_ndarray_indexing.py

+81-15
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,7 @@ def test_put_0d_val(data_dt):
625625
skip_if_dtype_not_supported(data_dt, q)
626626

627627
x = dpt.arange(5, dtype=data_dt, sycl_queue=q)
628-
ind = dpt.asarray([0], dtype=np.intp, sycl_queue=q)
628+
ind = dpt.asarray([0], dtype="i8", sycl_queue=q)
629629
val = dpt.asarray(2, dtype=x.dtype, sycl_queue=q)
630630
x[ind] = val
631631
assert_array_equal(np.asarray(2, dtype=data_dt), dpt.asnumpy(x[0]))
@@ -644,7 +644,7 @@ def test_take_0d_data(data_dt):
644644
skip_if_dtype_not_supported(data_dt, q)
645645

646646
x = dpt.asarray(0, dtype=data_dt, sycl_queue=q)
647-
ind = dpt.arange(5, dtype=np.intp, sycl_queue=q)
647+
ind = dpt.arange(5, dtype="i8", sycl_queue=q)
648648

649649
y = dpt.take(x, ind)
650650
assert (
@@ -662,7 +662,7 @@ def test_put_0d_data(data_dt):
662662
skip_if_dtype_not_supported(data_dt, q)
663663

664664
x = dpt.asarray(0, dtype=data_dt, sycl_queue=q)
665-
ind = dpt.arange(5, dtype=np.intp, sycl_queue=q)
665+
ind = dpt.arange(5, dtype="i8", sycl_queue=q)
666666
val = dpt.asarray(2, dtype=data_dt, sycl_queue=q)
667667

668668
dpt.put(x, ind, val, axis=0)
@@ -710,7 +710,7 @@ def test_take_strided_1d_source(data_dt):
710710
skip_if_dtype_not_supported(data_dt, q)
711711

712712
x = dpt.arange(27, dtype=data_dt, sycl_queue=q)
713-
ind = dpt.arange(4, 9, dtype=np.intp, sycl_queue=q)
713+
ind = dpt.arange(4, 9, dtype="i8", sycl_queue=q)
714714

715715
x_np = dpt.asnumpy(x)
716716
ind_np = dpt.asnumpy(ind)
@@ -748,7 +748,7 @@ def test_take_strided(data_dt, order):
748748
skip_if_dtype_not_supported(data_dt, q)
749749

750750
x = dpt.reshape(_make_3d(data_dt, q), (9, 3), order=order)
751-
ind = dpt.arange(2, dtype=np.intp, sycl_queue=q)
751+
ind = dpt.arange(2, dtype="i8", sycl_queue=q)
752752

753753
x_np = dpt.asnumpy(x)
754754
ind_np = dpt.asnumpy(ind)
@@ -781,7 +781,7 @@ def test_take_strided_1d_indices(ind_dt):
781781
ind = dpt.arange(12, 24, dtype=ind_dt, sycl_queue=q)
782782

783783
x_np = dpt.asnumpy(x)
784-
ind_np = dpt.asnumpy(ind).astype(np.intp)
784+
ind_np = dpt.asnumpy(ind).astype("i8")
785785

786786
for s in (
787787
slice(None, None, 2),
@@ -820,7 +820,7 @@ def test_take_strided_indices(ind_dt, order):
820820
)
821821

822822
x_np = dpt.asnumpy(x)
823-
ind_np = dpt.asnumpy(ind).astype(np.intp)
823+
ind_np = dpt.asnumpy(ind).astype("i8")
824824

825825
for s in (
826826
slice(None, None, 2),
@@ -845,7 +845,7 @@ def test_put_strided_1d_destination(data_dt, order):
845845
skip_if_dtype_not_supported(data_dt, q)
846846

847847
x = dpt.arange(27, dtype=data_dt, sycl_queue=q)
848-
ind = dpt.arange(4, 9, dtype=np.intp, sycl_queue=q)
848+
ind = dpt.arange(4, 9, dtype="i8", sycl_queue=q)
849849
val = dpt.asarray(9, dtype=x.dtype, sycl_queue=q)
850850

851851
x_np = dpt.asnumpy(x)
@@ -875,7 +875,7 @@ def test_put_strided_destination(data_dt, order):
875875
skip_if_dtype_not_supported(data_dt, q)
876876

877877
x = dpt.reshape(_make_3d(data_dt, q), (9, 3), order=order)
878-
ind = dpt.arange(2, dtype=np.intp, sycl_queue=q)
878+
ind = dpt.arange(2, dtype="i8", sycl_queue=q)
879879
val = dpt.asarray(9, dtype=x.dtype, sycl_queue=q)
880880

881881
x_np = dpt.asnumpy(x)
@@ -924,7 +924,7 @@ def test_put_strided_1d_indices(ind_dt):
924924
val = dpt.asarray(-1, dtype=x.dtype, sycl_queue=q)
925925

926926
x_np = dpt.asnumpy(x)
927-
ind_np = dpt.asnumpy(ind).astype(np.intp)
927+
ind_np = dpt.asnumpy(ind).astype("i8")
928928
val_np = dpt.asnumpy(val)
929929

930930
for s in (
@@ -955,7 +955,7 @@ def test_put_strided_indices(ind_dt, order):
955955
val = dpt.asarray(-1, sycl_queue=q, dtype=x.dtype)
956956

957957
x_np = dpt.asnumpy(x)
958-
ind_np = dpt.asnumpy(ind).astype(np.intp)
958+
ind_np = dpt.asnumpy(ind).astype("i8")
959959
val_np = dpt.asnumpy(val)
960960

961961
for s in (
@@ -982,15 +982,15 @@ def test_integer_indexing_modes():
982982
x_np = dpt.asnumpy(x)
983983

984984
# wrapping negative indices
985-
ind = dpt.asarray([-4, -3, 0, 2, 4], dtype=np.intp, sycl_queue=q)
985+
ind = dpt.asarray([-4, -3, 0, 2, 4], dtype="i8", sycl_queue=q)
986986

987987
res = dpt.take(x, ind, mode="wrap")
988988
expected_arr = np.take(x_np, dpt.asnumpy(ind), mode="raise")
989989

990990
assert (dpt.asnumpy(res) == expected_arr).all()
991991

992992
# clipping to 0 (disabling negative indices)
993-
ind = dpt.asarray([-6, -3, 0, 2, 6], dtype=np.intp, sycl_queue=q)
993+
ind = dpt.asarray([-6, -3, 0, 2, 6], dtype="i8", sycl_queue=q)
994994

995995
res = dpt.take(x, ind, mode="clip")
996996
expected_arr = np.take(x_np, dpt.asnumpy(ind), mode="clip")
@@ -1002,7 +1002,7 @@ def test_take_arg_validation():
10021002
q = get_queue_or_skip()
10031003

10041004
x = dpt.arange(4, dtype="i4", sycl_queue=q)
1005-
ind0 = dpt.arange(4, dtype=np.intp, sycl_queue=q)
1005+
ind0 = dpt.arange(4, dtype="i8", sycl_queue=q)
10061006
ind1 = dpt.arange(2.0, dtype="f", sycl_queue=q)
10071007

10081008
with pytest.raises(TypeError):
@@ -1034,7 +1034,7 @@ def test_put_arg_validation():
10341034
q = get_queue_or_skip()
10351035

10361036
x = dpt.arange(4, dtype="i4", sycl_queue=q)
1037-
ind0 = dpt.arange(4, dtype=np.intp, sycl_queue=q)
1037+
ind0 = dpt.arange(4, dtype="i8", sycl_queue=q)
10381038
ind1 = dpt.arange(2.0, dtype="f", sycl_queue=q)
10391039
val = dpt.asarray(2, dtype=x.dtype, sycl_queue=q)
10401040

@@ -1890,3 +1890,69 @@ def test_put_along_axis_uint64_indices():
18901890
dpt.put_along_axis(x, inds, dpt.asarray(2, dtype=x.dtype), axis=1)
18911891
expected = dpt.tile(dpt.asarray([0, 2], dtype="i4"), (2, 5))
18921892
assert dpt.all(expected == x)
1893+
1894+
1895+
@pytest.mark.parametrize("data_dt", _all_dtypes)
1896+
@pytest.mark.parametrize("order", ["C", "F"])
1897+
def test_take_out(data_dt, order):
1898+
q = get_queue_or_skip()
1899+
skip_if_dtype_not_supported(data_dt, q)
1900+
1901+
axis = 0
1902+
x = dpt.reshape(_make_3d(data_dt, q), (9, 3), order=order)
1903+
ind = dpt.arange(2, dtype="i8", sycl_queue=q)
1904+
out_sh = x.shape[:axis] + ind.shape + x.shape[axis + 1 :]
1905+
out = dpt.empty(out_sh, dtype=data_dt, sycl_queue=q)
1906+
1907+
expected = dpt.take(x, ind, axis=axis)
1908+
1909+
dpt.take(x, ind, axis=axis, out=out)
1910+
1911+
assert dpt.all(out == expected)
1912+
1913+
1914+
@pytest.mark.parametrize("data_dt", _all_dtypes)
1915+
@pytest.mark.parametrize("order", ["C", "F"])
1916+
def test_take_out_overlap(data_dt, order):
1917+
q = get_queue_or_skip()
1918+
skip_if_dtype_not_supported(data_dt, q)
1919+
1920+
axis = 0
1921+
x = dpt.reshape(_make_3d(data_dt, q), (9, 3), order=order)
1922+
ind = dpt.arange(2, dtype="i8", sycl_queue=q)
1923+
out = x[x.shape[axis] - ind.shape[axis] : x.shape[axis], :]
1924+
1925+
expected = dpt.take(x, ind, axis=axis)
1926+
1927+
dpt.take(x, ind, axis=axis, out=out)
1928+
1929+
assert dpt.all(out == expected)
1930+
assert dpt.all(x[x.shape[0] - ind.shape[0] : x.shape[0], :] == out)
1931+
1932+
1933+
def test_take_out_errors():
1934+
q1 = get_queue_or_skip()
1935+
q2 = get_queue_or_skip()
1936+
1937+
x = dpt.arange(10, dtype="i4", sycl_queue=q1)
1938+
ind = dpt.arange(2, dtype="i4", sycl_queue=q1)
1939+
1940+
with pytest.raises(TypeError):
1941+
dpt.take(x, ind, out=dict())
1942+
1943+
out_read_only = dpt.empty(ind.shape, dtype=x.dtype, sycl_queue=q1)
1944+
out_read_only.flags["W"] = False
1945+
with pytest.raises(ValueError):
1946+
dpt.take(x, ind, out=out_read_only)
1947+
1948+
out_bad_shape = dpt.empty(0, dtype=x.dtype, sycl_queue=q1)
1949+
with pytest.raises(ValueError):
1950+
dpt.take(x, ind, out=out_bad_shape)
1951+
1952+
out_bad_dt = dpt.empty(ind.shape, dtype="i8", sycl_queue=q1)
1953+
with pytest.raises(ValueError):
1954+
dpt.take(x, ind, out=out_bad_dt)
1955+
1956+
out_bad_q = dpt.empty(ind.shape, dtype=x.dtype, sycl_queue=q2)
1957+
with pytest.raises(dpctl.utils.ExecutionPlacementError):
1958+
dpt.take(x, ind, out=out_bad_q)

0 commit comments

Comments
 (0)