|
61 | 61 | matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
|
62 | 62 | tensordot = get_xp(np)(_aliases.tensordot)
|
63 | 63 |
|
| 64 | + |
| 65 | +def top_k(a, k, /, *, axis=-1, largest=True): |
| 66 | + """ |
| 67 | + Returns the ``k`` largest/smallest elements and corresponding |
| 68 | + indices along the given ``axis``. |
| 69 | +
|
| 70 | + When ``axis`` is None, a flattened array is used. |
| 71 | +
|
| 72 | + If ``largest`` is false, then the ``k`` smallest elements are returned. |
| 73 | +
|
| 74 | + A tuple of ``(values, indices)`` is returned, where ``values`` and |
| 75 | + ``indices`` of the largest/smallest elements of each row of the input |
| 76 | + array in the given ``axis``. |
| 77 | +
|
| 78 | + Parameters |
| 79 | + ---------- |
| 80 | + a: array_like |
| 81 | + The source array |
| 82 | + k: int |
| 83 | + The number of largest/smallest elements to return. ``k`` must |
| 84 | + be a positive integer and within indexable range specified by |
| 85 | + ``axis``. |
| 86 | + axis: int, optional |
| 87 | + Axis along which to find the largest/smallest elements. |
| 88 | + The default is -1 (the last axis). |
| 89 | + If None, a flattened array is used. |
| 90 | + largest: bool, optional |
| 91 | + If True, largest elements are returned. Otherwise the smallest |
| 92 | + are returned. |
| 93 | +
|
| 94 | + Returns |
| 95 | + ------- |
| 96 | + tuple_of_array: tuple |
| 97 | + The output tuple of ``(topk_values, topk_indices)``, where |
| 98 | + ``topk_values`` are returned elements from the source array |
| 99 | + (not necessarily in sorted order), and ``topk_indices`` are |
| 100 | + the corresponding indices. |
| 101 | +
|
| 102 | + See Also |
| 103 | + -------- |
| 104 | + argpartition : Indirect partition. |
| 105 | + sort : Full sorting. |
| 106 | +
|
| 107 | + Notes |
| 108 | + ----- |
| 109 | + The returned indices are not guaranteed to be sorted according to |
| 110 | + the values. Furthermore, the returned indices are not guaranteed |
| 111 | + to be the earliest/latest occurrence of the element. E.g., |
| 112 | + ``np.top_k([3,3,3], 1)`` can return ``(array([3]), array([1]))`` |
| 113 | + rather than ``(array([3]), array([0]))`` or |
| 114 | + ``(array([3]), array([2]))``. |
| 115 | +
|
| 116 | + Warning: The treatment of ``np.nan`` in the input array is undefined. |
| 117 | +
|
| 118 | + Examples |
| 119 | + -------- |
| 120 | + >>> a = np.array([[1,2,3,4,5], [5,4,3,2,1], [3,4,5,1,2]]) |
| 121 | + >>> np.top_k(a, 2) |
| 122 | + (array([[4, 5], |
| 123 | + [4, 5], |
| 124 | + [4, 5]]), |
| 125 | + array([[3, 4], |
| 126 | + [1, 0], |
| 127 | + [1, 2]])) |
| 128 | + >>> np.top_k(a, 2, axis=0) |
| 129 | + (array([[3, 4, 3, 2, 2], |
| 130 | + [5, 4, 5, 4, 5]]), |
| 131 | + array([[2, 1, 1, 1, 2], |
| 132 | + [1, 2, 2, 0, 0]])) |
| 133 | + >>> a.flatten() |
| 134 | + array([1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 3, 4, 5, 1, 2]) |
| 135 | + >>> np.top_k(a, 2, axis=None) |
| 136 | + (array([5, 5]), array([ 5, 12])) |
| 137 | + """ |
| 138 | + if k <= 0: |
| 139 | + raise ValueError(f'k(={k}) provided must be positive.') |
| 140 | + |
| 141 | + positive_axis: int |
| 142 | + _arr = np.asanyarray(a) |
| 143 | + if axis is None: |
| 144 | + arr = _arr.ravel() |
| 145 | + positive_axis = 0 |
| 146 | + else: |
| 147 | + arr = _arr |
| 148 | + positive_axis = axis if axis > 0 else axis % arr.ndim |
| 149 | + |
| 150 | + slice_start = (np.s_[:],) * positive_axis |
| 151 | + if largest: |
| 152 | + indices_array = np.argpartition(arr, -k, axis=axis) |
| 153 | + slice = slice_start + (np.s_[-k:],) |
| 154 | + topk_indices = indices_array[slice] |
| 155 | + else: |
| 156 | + indices_array = np.argpartition(arr, k-1, axis=axis) |
| 157 | + slice = slice_start + (np.s_[:k],) |
| 158 | + topk_indices = indices_array[slice] |
| 159 | + |
| 160 | + topk_values = np.take_along_axis(arr, topk_indices, axis=axis) |
| 161 | + |
| 162 | + return (topk_values, topk_indices) |
| 163 | + |
| 164 | + |
64 | 165 | def _supports_buffer_protocol(obj):
|
65 | 166 | try:
|
66 | 167 | memoryview(obj)
|
@@ -126,6 +227,6 @@ def asarray(
|
126 | 227 | __all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
|
127 | 228 | 'acosh', 'asin', 'asinh', 'atan', 'atan2',
|
128 | 229 | 'atanh', 'bitwise_left_shift', 'bitwise_invert',
|
129 |
| - 'bitwise_right_shift', 'concat', 'pow'] |
| 230 | + 'bitwise_right_shift', 'concat', 'pow', 'top_k'] |
130 | 231 |
|
131 | 232 | _all_ignore = ['np', 'get_xp']
|
0 commit comments