@@ -41,25 +41,29 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
41
41
Takes elements from an array along a given axis at given indices.
42
42
43
43
Args:
44
- x (usm_ndarray):
45
- The array that elements will be taken from.
46
- indices (usm_ndarray):
47
- One-dimensional array of indices.
48
- axis:
49
- The axis along which the values will be selected.
50
- If ``x`` is one-dimensional, this argument is optional.
51
- Default: ``None``.
52
- mode:
53
- How out-of-bounds indices will be handled.
54
- ``"wrap"`` - clamps indices to (-n <= i < n), then wraps
55
- negative indices.
56
- ``"clip"`` - clips indices to (0 <= i < n)
57
- Default: ``"wrap"``.
44
+ x (usm_ndarray):
45
+ The array that elements will be taken from.
46
+ indices (usm_ndarray):
47
+ One-dimensional array of indices.
48
+ axis (int, optional):
49
+ The axis along which the values will be selected.
50
+ If ``x`` is one-dimensional, this argument is optional.
51
+ Default: ``None``.
52
+ mode (str, optional):
53
+ How out-of-bounds indices will be handled. Possible values
54
+ are:
55
+
56
+ - ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
57
+ negative indices.
58
+ - ``"clip"``: clips indices to (``0 <= i < n``).
59
+
60
+ Default: ``"wrap"``.
58
61
59
62
Returns:
60
63
usm_ndarray:
61
- Array with shape x.shape[:axis] + indices.shape + x.shape[axis + 1:]
62
- filled with elements from x.
64
+ Array with shape
65
+ ``x.shape[:axis] + indices.shape + x.shape[axis + 1:]``
66
+ filled with elements from ``x``.
63
67
"""
64
68
if not isinstance (x , dpt .usm_ndarray ):
65
69
raise TypeError (
@@ -128,30 +132,76 @@ def put(x, indices, vals, /, *, axis=None, mode="wrap"):
128
132
Puts values into an array along a given axis at given indices.
129
133
130
134
Args:
131
- x (usm_ndarray):
132
- The array the values will be put into.
133
- indices (usm_ndarray)
134
- One-dimensional array of indices.
135
-
136
- Note that if indices are not unique, a race
137
- condition will result, and the value written to
138
- ``x`` will not be deterministic.
139
- :py:func:`dpctl.tensor.unique` can be used to
140
- guarantee unique elements in ``indices``.
141
- vals:
142
- Array of values to be put into ``x``.
143
- Must be broadcastable to the result shape
144
- ``x.shape[:axis] + indices.shape + x.shape[axis+1:]``.
145
- axis:
146
- The axis along which the values will be placed.
147
- If ``x`` is one-dimensional, this argument is optional.
148
- Default: ``None``.
149
- mode:
150
- How out-of-bounds indices will be handled.
151
- ``"wrap"`` - clamps indices to (-n <= i < n), then wraps
152
- negative indices.
153
- ``"clip"`` - clips indices to (0 <= i < n)
154
- Default: ``"wrap"``.
135
+ x (usm_ndarray):
136
+ The array the values will be put into.
137
+ indices (usm_ndarray):
138
+ One-dimensional array of indices.
139
+
140
+ Note that if indices are not unique, a race
141
+ condition will result, and the value written to
142
+ ``x`` will not be deterministic.
143
+ :func:`dpctl.tensor.unique` can be used to
144
+ guarantee unique elements in ``indices``.
145
+ vals (usm_ndarray):
146
+ Array of values to be put into ``x``.
147
+ Must be broadcastable to the result shape
148
+ ``x.shape[:axis] + indices.shape + x.shape[axis+1:]``.
149
+ axis (int, optional):
150
+ The axis along which the values will be placed.
151
+ If ``x`` is one-dimensional, this argument is optional.
152
+ Default: ``None``.
153
+ mode (str, optional):
154
+ How out-of-bounds indices will be handled. Possible values
155
+ are:
156
+
157
+ - ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
158
+ negative indices.
159
+ - ``"clip"``: clips indices to (``0 <= i < n``).
160
+
161
+ Default: ``"wrap"``.
162
+
163
+ .. note::
164
+
165
+ If input array ``indices`` contains duplicates, a race condition
166
+ occurs, and the value written into corresponding positions in ``x``
167
+ may vary from run to run. Preserving sequential semantics in handing
168
+ the duplicates requires additional work, e.g.
169
+
170
+ :Example:
171
+
172
+ .. code-block:: python
173
+
174
+ from dpctl import tensor as dpt
175
+
176
+ def put_vec_duplicates(vec, ind, vals):
177
+ "Put values into vec, handling possible duplicates in ind"
178
+ assert vec.ndim, ind.ndim, vals.ndim == 1, 1, 1
179
+
180
+ # find positions of last occurences of each
181
+ # unique index
182
+ ind_flipped = dpt.flip(ind)
183
+ ind_uniq = dpt.unique_all(ind_flipped).indices
184
+ has_dups = len(ind) != len(ind_uniq)
185
+
186
+ if has_dups:
187
+ ind_uniq = dpt.subtract(vec.size - 1, ind_uniq)
188
+ ind = dpt.take(ind, ind_uniq)
189
+ vals = dpt.take(vals, ind_uniq)
190
+
191
+ dpt.put(vec, ind, vals)
192
+
193
+ n = 512
194
+ ind = dpt.concat((dpt.arange(n), dpt.arange(n, -1, step=-1)))
195
+ x = dpt.zeros(ind.size, dtype="int32")
196
+ vals = dpt.arange(ind.size, dtype=x.dtype)
197
+
198
+ # Values corresponding to last positions of
199
+ # duplicate indices are written into the vector x
200
+ put_vec_duplicates(x, ind, vals)
201
+
202
+ parts = (vals[-1:-n-2:-1], dpt.zeros(n, dtype=x.dtype))
203
+ expected = dpt.concat(parts)
204
+ assert dpt.all(x == expected)
155
205
"""
156
206
if not isinstance (x , dpt .usm_ndarray ):
157
207
raise TypeError (
@@ -237,22 +287,24 @@ def extract(condition, arr):
237
287
238
288
Returns the elements of an array that satisfies the condition.
239
289
240
- If `condition` is boolean ``dpctl.tensor.extract`` is
290
+ If `` condition` ` is boolean ``dpctl.tensor.extract`` is
241
291
equivalent to ``arr[condition]``.
242
292
243
293
Note that ``dpctl.tensor.place`` does the opposite of
244
294
``dpctl.tensor.extract``.
245
295
246
296
Args:
247
297
conditions (usm_ndarray):
248
- An array whose non-zero or True entries indicate the element
249
- of `arr` to extract.
298
+ An array whose non-zero or ``True`` entries indicate the element
299
+ of ``arr`` to extract.
300
+
250
301
arr (usm_ndarray):
251
- Input array of the same size as `condition`.
302
+ Input array of the same size as `` condition` `.
252
303
253
304
Returns:
254
305
usm_ndarray:
255
- Rank 1 array of values from `arr` where `condition` is True.
306
+ Rank 1 array of values from ``arr`` where ``condition`` is
307
+ ``True``.
256
308
"""
257
309
if not isinstance (condition , dpt .usm_ndarray ):
258
310
raise TypeError (
@@ -280,20 +332,20 @@ def place(arr, mask, vals):
280
332
281
333
Change elements of an array based on conditional and input values.
282
334
283
- If `mask` is boolean ``dpctl.tensor.place`` is
335
+ If `` mask` ` is boolean ``dpctl.tensor.place`` is
284
336
equivalent to ``arr[condition] = vals``.
285
337
286
338
Args:
287
339
arr (usm_ndarray):
288
340
Array to put data into.
289
341
mask (usm_ndarray):
290
- Boolean mask array. Must have the same size as `arr`.
342
+ Boolean mask array. Must have the same size as `` arr` `.
291
343
vals (usm_ndarray, sequence):
292
- Values to put into `arr`. Only the first N elements are
293
- used, where N is the number of True values in `mask`. If
294
- `vals` is smaller than N, it will be repeated, and if
295
- elements of `arr` are to be masked, this sequence must be
296
- non-empty. Array `vals` must be one dimensional.
344
+ Values to put into `` arr` `. Only the first N elements are
345
+ used, where N is the number of True values in `` mask` `. If
346
+ `` vals` ` is smaller than N, it will be repeated, and if
347
+ elements of `` arr` ` are to be masked, this sequence must be
348
+ non-empty. Array `` vals` ` must be one dimensional.
297
349
"""
298
350
if not isinstance (arr , dpt .usm_ndarray ):
299
351
raise TypeError (
@@ -345,13 +397,14 @@ def nonzero(arr):
345
397
Return the indices of non-zero elements.
346
398
347
399
Returns a tuple of usm_ndarrays, one for each dimension
348
- of `arr`, containing the indices of the non-zero elements
349
- in that dimension. The values of `arr` are always tested in
400
+ of `` arr` `, containing the indices of the non-zero elements
401
+ in that dimension. The values of `` arr` ` are always tested in
350
402
row-major, C-style order.
351
403
352
404
Args:
353
405
arr (usm_ndarray):
354
406
Input array, which has non-zero array rank.
407
+
355
408
Returns:
356
409
Tuple[usm_ndarray, ...]:
357
410
Indices of non-zero array elements.
0 commit comments