Skip to content

Commit 124458f

Browse files
Addresses gh-1360 with documentation
Document race condition in put when indices array contain duplicated. Provide an example of resolving race condition for 1-d ``x``, ``indices``, and ``vals``.
1 parent 7757857 commit 124458f

File tree

1 file changed

+107
-54
lines changed

1 file changed

+107
-54
lines changed

dpctl/tensor/_indexing_functions.py

+107-54
Original file line numberDiff line numberDiff line change
@@ -41,25 +41,29 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
4141
Takes elements from an array along a given axis at given indices.
4242
4343
Args:
44-
x (usm_ndarray):
45-
The array that elements will be taken from.
46-
indices (usm_ndarray):
47-
One-dimensional array of indices.
48-
axis:
49-
The axis along which the values will be selected.
50-
If ``x`` is one-dimensional, this argument is optional.
51-
Default: ``None``.
52-
mode:
53-
How out-of-bounds indices will be handled.
54-
``"wrap"`` - clamps indices to (-n <= i < n), then wraps
55-
negative indices.
56-
``"clip"`` - clips indices to (0 <= i < n)
57-
Default: ``"wrap"``.
44+
x (usm_ndarray):
45+
The array that elements will be taken from.
46+
indices (usm_ndarray):
47+
One-dimensional array of indices.
48+
axis (int, optional):
49+
The axis along which the values will be selected.
50+
If ``x`` is one-dimensional, this argument is optional.
51+
Default: ``None``.
52+
mode (str, optional):
53+
How out-of-bounds indices will be handled. Possible values
54+
are:
55+
56+
- ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
57+
negative indices.
58+
- ``"clip"``: clips indices to (``0 <= i < n``).
59+
60+
Default: ``"wrap"``.
5861
5962
Returns:
6063
usm_ndarray:
61-
Array with shape x.shape[:axis] + indices.shape + x.shape[axis + 1:]
62-
filled with elements from x.
64+
Array with shape
65+
``x.shape[:axis] + indices.shape + x.shape[axis + 1:]``
66+
filled with elements from ``x``.
6367
"""
6468
if not isinstance(x, dpt.usm_ndarray):
6569
raise TypeError(
@@ -128,30 +132,76 @@ def put(x, indices, vals, /, *, axis=None, mode="wrap"):
128132
Puts values into an array along a given axis at given indices.
129133
130134
Args:
131-
x (usm_ndarray):
132-
The array the values will be put into.
133-
indices (usm_ndarray)
134-
One-dimensional array of indices.
135-
136-
Note that if indices are not unique, a race
137-
condition will result, and the value written to
138-
``x`` will not be deterministic.
139-
:py:func:`dpctl.tensor.unique` can be used to
140-
guarantee unique elements in ``indices``.
141-
vals:
142-
Array of values to be put into ``x``.
143-
Must be broadcastable to the result shape
144-
``x.shape[:axis] + indices.shape + x.shape[axis+1:]``.
145-
axis:
146-
The axis along which the values will be placed.
147-
If ``x`` is one-dimensional, this argument is optional.
148-
Default: ``None``.
149-
mode:
150-
How out-of-bounds indices will be handled.
151-
``"wrap"`` - clamps indices to (-n <= i < n), then wraps
152-
negative indices.
153-
``"clip"`` - clips indices to (0 <= i < n)
154-
Default: ``"wrap"``.
135+
x (usm_ndarray):
136+
The array the values will be put into.
137+
indices (usm_ndarray):
138+
One-dimensional array of indices.
139+
140+
Note that if indices are not unique, a race
141+
condition will result, and the value written to
142+
``x`` will not be deterministic.
143+
:func:`dpctl.tensor.unique` can be used to
144+
guarantee unique elements in ``indices``.
145+
vals (usm_ndarray):
146+
Array of values to be put into ``x``.
147+
Must be broadcastable to the result shape
148+
``x.shape[:axis] + indices.shape + x.shape[axis+1:]``.
149+
axis (int, optional):
150+
The axis along which the values will be placed.
151+
If ``x`` is one-dimensional, this argument is optional.
152+
Default: ``None``.
153+
mode (str, optional):
154+
How out-of-bounds indices will be handled. Possible values
155+
are:
156+
157+
- ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
158+
negative indices.
159+
- ``"clip"``: clips indices to (``0 <= i < n``).
160+
161+
Default: ``"wrap"``.
162+
163+
.. note::
164+
165+
If input array ``indices`` contains duplicates, a race condition
166+
occurs, and the value written into corresponding positions in ``x``
167+
may vary from run to run. Preserving sequential semantics in handing
168+
the duplicates requires additional work, e.g.
169+
170+
:Example:
171+
172+
.. code-block:: python
173+
174+
from dpctl import tensor as dpt
175+
176+
def put_vec_duplicates(vec, ind, vals):
177+
"Put values into vec, handling possible duplicates in ind"
178+
assert vec.ndim, ind.ndim, vals.ndim == 1, 1, 1
179+
180+
# find positions of last occurences of each
181+
# unique index
182+
ind_flipped = dpt.flip(ind)
183+
ind_uniq = dpt.unique_all(ind_flipped).indices
184+
has_dups = len(ind) != len(ind_uniq)
185+
186+
if has_dups:
187+
ind_uniq = dpt.subtract(vec.size - 1, ind_uniq)
188+
ind = dpt.take(ind, ind_uniq)
189+
vals = dpt.take(vals, ind_uniq)
190+
191+
dpt.put(vec, ind, vals)
192+
193+
n = 512
194+
ind = dpt.concat((dpt.arange(n), dpt.arange(n, -1, step=-1)))
195+
x = dpt.zeros(ind.size, dtype="int32")
196+
vals = dpt.arange(ind.size, dtype=x.dtype)
197+
198+
# Values corresponding to last positions of
199+
# duplicate indices are written into the vector x
200+
put_vec_duplicates(x, ind, vals)
201+
202+
parts = (vals[-1:-n-2:-1], dpt.zeros(n, dtype=x.dtype))
203+
expected = dpt.concat(parts)
204+
assert dpt.all(x == expected)
155205
"""
156206
if not isinstance(x, dpt.usm_ndarray):
157207
raise TypeError(
@@ -237,22 +287,24 @@ def extract(condition, arr):
237287
238288
Returns the elements of an array that satisfies the condition.
239289
240-
If `condition` is boolean ``dpctl.tensor.extract`` is
290+
If ``condition`` is boolean ``dpctl.tensor.extract`` is
241291
equivalent to ``arr[condition]``.
242292
243293
Note that ``dpctl.tensor.place`` does the opposite of
244294
``dpctl.tensor.extract``.
245295
246296
Args:
247297
conditions (usm_ndarray):
248-
An array whose non-zero or True entries indicate the element
249-
of `arr` to extract.
298+
An array whose non-zero or ``True`` entries indicate the element
299+
of ``arr`` to extract.
300+
250301
arr (usm_ndarray):
251-
Input array of the same size as `condition`.
302+
Input array of the same size as ``condition``.
252303
253304
Returns:
254305
usm_ndarray:
255-
Rank 1 array of values from `arr` where `condition` is True.
306+
Rank 1 array of values from ``arr`` where ``condition`` is
307+
``True``.
256308
"""
257309
if not isinstance(condition, dpt.usm_ndarray):
258310
raise TypeError(
@@ -280,20 +332,20 @@ def place(arr, mask, vals):
280332
281333
Change elements of an array based on conditional and input values.
282334
283-
If `mask` is boolean ``dpctl.tensor.place`` is
335+
If ``mask`` is boolean ``dpctl.tensor.place`` is
284336
equivalent to ``arr[condition] = vals``.
285337
286338
Args:
287339
arr (usm_ndarray):
288340
Array to put data into.
289341
mask (usm_ndarray):
290-
Boolean mask array. Must have the same size as `arr`.
342+
Boolean mask array. Must have the same size as ``arr``.
291343
vals (usm_ndarray, sequence):
292-
Values to put into `arr`. Only the first N elements are
293-
used, where N is the number of True values in `mask`. If
294-
`vals` is smaller than N, it will be repeated, and if
295-
elements of `arr` are to be masked, this sequence must be
296-
non-empty. Array `vals` must be one dimensional.
344+
Values to put into ``arr``. Only the first N elements are
345+
used, where N is the number of True values in ``mask``. If
346+
``vals`` is smaller than N, it will be repeated, and if
347+
elements of ``arr`` are to be masked, this sequence must be
348+
non-empty. Array ``vals`` must be one dimensional.
297349
"""
298350
if not isinstance(arr, dpt.usm_ndarray):
299351
raise TypeError(
@@ -345,13 +397,14 @@ def nonzero(arr):
345397
Return the indices of non-zero elements.
346398
347399
Returns a tuple of usm_ndarrays, one for each dimension
348-
of `arr`, containing the indices of the non-zero elements
349-
in that dimension. The values of `arr` are always tested in
400+
of ``arr``, containing the indices of the non-zero elements
401+
in that dimension. The values of ``arr`` are always tested in
350402
row-major, C-style order.
351403
352404
Args:
353405
arr (usm_ndarray):
354406
Input array, which has non-zero array rank.
407+
355408
Returns:
356409
Tuple[usm_ndarray, ...]:
357410
Indices of non-zero array elements.

0 commit comments

Comments
 (0)