|
1 |
| -__all__ = ["argmax", "argmin", "nonzero", "where"] |
| 1 | +__all__ = [ |
| 2 | + "argmax", |
| 3 | + "argmin", |
| 4 | + "nonzero", |
| 5 | + "top_k", |
| 6 | + "top_k_values", |
| 7 | + "top_k_indices", |
| 8 | + "where", |
| 9 | +] |
2 | 10 |
|
3 | 11 |
|
4 |
| -from ._types import Optional, Tuple, array |
| 12 | +from ._types import Optional, Literal, Tuple, array |
5 | 13 |
|
6 | 14 |
|
7 | 15 | def argmax(x: array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> array:
|
@@ -87,6 +95,126 @@ def nonzero(x: array, /) -> Tuple[array, ...]:
|
87 | 95 | """
|
88 | 96 |
|
89 | 97 |
|
| 98 | +def top_k( |
| 99 | + x: array, |
| 100 | + k: int, |
| 101 | + /, |
| 102 | + *, |
| 103 | + axis: Optional[int] = None, |
| 104 | + mode: Literal["largest", "smallest"] = "largest", |
| 105 | +) -> Tuple[array, array]: |
| 106 | + """ |
| 107 | + Returns the ``k`` largest (or smallest) elements of an input array ``x`` along a specified dimension. |
| 108 | +
|
| 109 | + Parameters |
| 110 | + ---------- |
| 111 | + x: array |
| 112 | + input array. Should have a real-valued data type. |
| 113 | + k: int |
| 114 | + number of elements to find. Must be a positive integer value. |
| 115 | + axis: Optional[int] |
| 116 | + axis along which to search. If ``None``, the function must search the flattened array. Default: ``None``. |
| 117 | + mode: Literal['largest', 'smallest'] |
| 118 | + search mode. Must be one of the following modes: |
| 119 | +
|
| 120 | + - ``'largest'``: return the ``k`` largest elements. |
| 121 | + - ``'smallest'``: return the ``k`` smallest elements. |
| 122 | +
|
| 123 | + Returns |
| 124 | + ------- |
| 125 | + out: Tuple[array, array] |
| 126 | + a namedtuple ``(values, indices)`` whose |
| 127 | +
|
| 128 | + - first element must have the field name ``values`` and must be an array containing the ``k`` largest (or smallest) elements of ``x``. The array must have the same data type as ``x``. If ``axis`` is ``None``, the array must be a one-dimensional array having shape ``(k,)``; otherwise, if ``axis`` is an integer value, the array must have the same rank (number of dimensions) and shape as ``x``, except for the axis specified by ``axis`` which must have size ``k``. |
| 129 | + - second element must have the field name ``indices`` and must be an array containing indices of ``x`` that result in ``values``. The array must have the same shape as ``values`` and must have the default array index data type. If ``axis`` is ``None``, ``indices`` must be the indices of a flattened ``x``. |
| 130 | +
|
| 131 | + Notes |
| 132 | + ----- |
| 133 | +
|
| 134 | + - If ``k`` exceeds the number of elements in ``x`` or along the axis specified by ``axis``, behavior is left unspecified and thus implementation-dependent. Conforming implementations may choose, e.g., to raise an exception or return all elements. |
| 135 | + - The order of the returned values and indices is left unspecified and thus implementation-dependent. Conforming implementations may return sorted or unsorted values. |
| 136 | + - Conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). |
| 137 | + """ |
| 138 | + |
| 139 | + |
| 140 | +def top_k_indices( |
| 141 | + x: array, |
| 142 | + k: int, |
| 143 | + /, |
| 144 | + *, |
| 145 | + axis: Optional[int] = None, |
| 146 | + mode: Literal["largest", "smallest"] = "largest", |
| 147 | +) -> array: |
| 148 | + """ |
| 149 | + Returns the indices of the ``k`` largest (or smallest) elements of an input array ``x`` along a specified dimension. |
| 150 | +
|
| 151 | + Parameters |
| 152 | + ---------- |
| 153 | + x: array |
| 154 | + input array. Should have a real-valued data type. |
| 155 | + k: int |
| 156 | + number of elements to find. Must be a positive integer value. |
| 157 | + axis: Optional[int] |
| 158 | + axis along which to search. If ``None``, the function must search the flattened array. Default: ``None``. |
| 159 | + mode: Literal['largest', 'smallest'] |
| 160 | + search mode. Must be one of the following modes: |
| 161 | +
|
| 162 | + - ``'largest'``: return the indices of the ``k`` largest elements. |
| 163 | + - ``'smallest'``: return the indices of the ``k`` smallest elements. |
| 164 | +
|
| 165 | + Returns |
| 166 | + ------- |
| 167 | + out: array |
| 168 | + an array containing indices corresponding to the ``k`` largest (or smallest) elements of ``x``. The array must have the default array index data type. If ``axis`` is ``None``, the array must be a one-dimensional array having shape ``(k,)`` and contain the indices of a flattened ``x``; otherwise, if ``axis`` is an integer value, the array must have the same rank (number of dimensions) and shape as ``x``, except for the axis specified by ``axis`` which must have size ``k``. |
| 169 | +
|
| 170 | + Notes |
| 171 | + ----- |
| 172 | +
|
| 173 | + - If ``k`` exceeds the number of elements in ``x`` or along the axis specified by ``axis``, behavior is left unspecified and thus implementation-dependent. Conforming implementations may choose, e.g., to raise an exception or return all indices. |
| 174 | + - The order of the returned indices is left unspecified and thus implementation-dependent. Conforming implementations may return indices corresponding to sorted or unsorted values. |
| 175 | + - Conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). |
| 176 | + """ |
| 177 | + |
| 178 | + |
| 179 | +def top_k_values( |
| 180 | + x: array, |
| 181 | + k: int, |
| 182 | + /, |
| 183 | + *, |
| 184 | + axis: Optional[int] = None, |
| 185 | + mode: Literal["largest", "smallest"] = "largest", |
| 186 | +) -> array: |
| 187 | + """ |
| 188 | + Returns the ``k`` largest (or smallest) elements of an input array ``x`` along a specified dimension. |
| 189 | +
|
| 190 | + Parameters |
| 191 | + ---------- |
| 192 | + x: array |
| 193 | + input array. Should have a real-valued data type. |
| 194 | + k: int |
| 195 | + number of elements to find. Must be a positive integer value. |
| 196 | + axis: Optional[int] |
| 197 | + axis along which to search. If ``None``, the function must search the flattened array. Default: ``None``. |
| 198 | + mode: Literal['largest', 'smallest'] |
| 199 | + search mode. Must be one of the following modes: |
| 200 | +
|
| 201 | + - ``'largest'``: return the indices of the ``k`` largest elements. |
| 202 | + - ``'smallest'``: return the indices of the ``k`` smallest elements. |
| 203 | +
|
| 204 | + Returns |
| 205 | + ------- |
| 206 | + out: array |
| 207 | + an array containing the ``k`` largest (or smallest) elements of ``x``. The array must have the same data type as ``x``. If ``axis`` is ``None``, the array must be a one-dimensional array having shape ``(k,)``; otherwise, if ``axis`` is an integer value, the array must have the same rank (number of dimensions) and shape as ``x``, except for the axis specified by ``axis`` which must have size ``k``. |
| 208 | +
|
| 209 | + Notes |
| 210 | + ----- |
| 211 | +
|
| 212 | + - If ``k`` exceeds the number of elements in ``x`` or along the axis specified by ``axis``, behavior is left unspecified and thus implementation-dependent. Conforming implementations may choose, e.g., to raise an exception or return all indices. |
| 213 | + - The order of the returned values is left unspecified and thus implementation-dependent. Conforming implementations may return sorted or unsorted values. |
| 214 | + - Conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). |
| 215 | + """ |
| 216 | + |
| 217 | + |
90 | 218 | def where(condition: array, x1: array, x2: array, /) -> array:
|
91 | 219 | """
|
92 | 220 | Returns elements chosen from ``x1`` or ``x2`` depending on ``condition``.
|
|
0 commit comments