Skip to content

Commit 97d29ef

Browse files
authored
Add argpartition functionality to shortfin.array API (#1063)
This implements `argpartition` on the `shortfin.array` API, using `xtensor`. Argpartition is a sorting algorithm that returns indices, where all indices to the left of k are guaranteed to be the k-smallest elements along an axis, and where all values to the right of k are larger. One can use a positive k-value, for the first k elements along an axis to be the smallest k elements, Or a negative k-value, for the last k elements along an axis to be the largest k elements. Note that those top-k indices are not guaranteed to be in sorted order.
1 parent 4ebbccf commit 97d29ef

File tree

4 files changed

+262
-1
lines changed

4 files changed

+262
-1
lines changed

shortfin/python/array_host_ops.cc

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,33 @@ Implemented for dtypes: float16, float32.
148148
A device_array of dtype=int64, allocated on the host and not visible to the device.
149149
)";
150150

151+
static const char DOCSTRING_ARGPARTITION[] =
152+
R"(Partitions the array `input` along the specified `axis` so that certain
153+
elements occupy the first or last positions depending on `k`.
154+
Similar to `numpy.argpartition`:
155+
156+
- If `k` is positive, the first `k` positions along `axis` are the indices of the
157+
`k` smallest values, while all larger values occupy positions to the right of `k`.
158+
- If `k` is negative, it counts from the end. For example, `k = -3` means the last
159+
3 positions along `axis` are the indices of the 3 largest values, while all smaller
160+
values occupy positions to the left of that boundary.
161+
162+
Implemented for dtypes: float16, float32.
163+
164+
Args:
165+
input: An input array.
166+
k: The number of maximum values to partition.
167+
axis: Axis along which to sort. Defaults to the last axis (note that the
168+
numpy default is into the flattened array, which we do not support).
169+
out: Array to write into. If specified, it must have an expected shape and
170+
int64 dtype.
171+
device_visible: Whether to make the result array visible to devices. Defaults to
172+
False.
173+
174+
Returns:
175+
A device_array of dtype=int64, allocated on the host and not visible to the device.
176+
)";
177+
151178
static const char DOCSTRING_CONVERT[] =
152179
R"(Does an elementwise conversion from one dtype to another.
153180
@@ -795,6 +822,53 @@ void BindArrayHostOps(py::module_ &m) {
795822
py::kw_only(), py::arg("keepdims") = false,
796823
py::arg("device_visible") = false, DOCSTRING_ARGMAX);
797824

825+
m.def(
826+
"argpartition",
827+
[](device_array &input, int k, int axis, std::optional<device_array> out,
828+
bool device_visible) {
829+
SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::argpartition");
830+
if (axis < 0) axis += input.shape().size();
831+
if (axis < 0 || axis >= input.shape().size()) {
832+
throw std::invalid_argument(
833+
fmt::format("Axis out of range: Must be [0, {}) but got {}",
834+
input.shape().size(), axis));
835+
}
836+
// Simulate numpy's negative `k` behavior for max argpartition
837+
if (k < 0) k += input.shape()[axis];
838+
if (k < 0 || k >= input.shape()[axis]) {
839+
throw std::invalid_argument(
840+
fmt::format("K out of range: Must be [-{}, {}) but got {}",
841+
input.shape()[axis], input.shape()[axis], k));
842+
}
843+
if (out && (out->dtype() != DType::int64())) {
844+
throw std::invalid_argument("out array must have dtype=int64");
845+
}
846+
auto compute = [&]<typename EltTy>() {
847+
auto input_t = input.map_xtensor<EltTy>();
848+
auto result = xt::argpartition(*input_t, k, /*axis=*/axis);
849+
if (!out) {
850+
out.emplace(device_array::for_host(input.device(), result.shape(),
851+
DType::int64(), device_visible));
852+
}
853+
auto out_t = out->map_xtensor_w<int64_t>();
854+
*out_t = result;
855+
return *out;
856+
};
857+
858+
switch (input.dtype()) {
859+
SF_UNARY_FUNCTION_CASE(float16, half_float::half);
860+
SF_UNARY_FUNCTION_CASE(bfloat16, bfloat16_t);
861+
SF_UNARY_FUNCTION_CASE(float32, float);
862+
default:
863+
throw std::invalid_argument(
864+
fmt::format("Unsupported dtype({}) for operator argmax",
865+
input.dtype().name()));
866+
}
867+
},
868+
py::arg("input"), py::arg("k"), py::arg("axis") = -1,
869+
py::arg("out") = py::none(), py::arg("device_visible") = false,
870+
DOCSTRING_ARGPARTITION);
871+
798872
// Random number generation.
799873
py::class_<PyRandomGenerator>(m, "RandomGenerator")
800874
.def(py::init<std::optional<PyRandomGenerator::SeedType>>(),

shortfin/python/shortfin/array/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
# Ops.
4949
argmax = _sfl.array.argmax
50+
argpartition = _sfl.array.argpartition
5051
add = _sfl.array.add
5152
ceil = _sfl.array.ceil
5253
convert = _sfl.array.convert
@@ -99,6 +100,7 @@
99100
# Ops.
100101
"add",
101102
"argmax",
103+
"argpartition",
102104
"ceil",
103105
"convert",
104106
"divide",

shortfin/src/shortfin/array/dims.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ class SHORTFIN_API InlinedDims {
108108
return p != other.p;
109109
}
110110
constexpr reference operator*() { return *p; }
111+
constexpr reference operator[](difference_type d) const { return *(p + d); }
111112
constexpr const_iterator operator+(difference_type d) const {
112113
return const_iterator(p + d);
113114
}

shortfin/tests/api/array_ops_test.py

Lines changed: 185 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7-
import array
87
import math
8+
from typing import List
99
import pytest
10+
import random
1011

1112
import shortfin as sf
1213
import shortfin.array as sfnp
@@ -112,6 +113,189 @@ def test_argmax_dtypes(device, dtype):
112113
sfnp.argmax(src)
113114

114115

116+
@pytest.mark.parametrize(
117+
"k,axis",
118+
[
119+
# Min sort, default axis
120+
[3, None],
121+
# Min sort, axis=-1
122+
[20, -1],
123+
# Max sort, default axis
124+
[-3, None],
125+
# Max sort, axis=-1
126+
[-20, -1],
127+
],
128+
)
129+
def test_argpartition(device, k, axis):
130+
src = sfnp.device_array(device, [1, 1, 128], dtype=sfnp.float32)
131+
data = [float(i) for i in range(math.prod([1, 1, 128]))]
132+
randomized_data = data[:]
133+
random.shuffle(randomized_data)
134+
src.items = randomized_data
135+
136+
result = (
137+
sfnp.argpartition(src, k) if axis is None else sfnp.argpartition(src, k, axis)
138+
)
139+
140+
assert result.shape == src.shape
141+
142+
expected_values = data[:k] if k >= 0 else data[k:]
143+
144+
k_slice = slice(0, k) if k >= 0 else slice(k, None)
145+
146+
indices = result.view(0, 0, k_slice).items.tolist()
147+
values = [randomized_data[index] for index in indices]
148+
assert sorted(values) == sorted(expected_values)
149+
150+
151+
def test_argpartition_out_variant(device):
152+
k, axis = -3, -1
153+
src = sfnp.device_array(device, [1, 1, 128], dtype=sfnp.float32)
154+
data = [float(i) for i in range(math.prod(src.shape))]
155+
156+
randomized_data = data[:]
157+
random.shuffle(randomized_data)
158+
src.items = randomized_data
159+
160+
output_array = sfnp.device_array(device, src.shape, dtype=sfnp.int64)
161+
result_out = sfnp.argpartition(src, k, axis, out=output_array)
162+
result_no_out = sfnp.argpartition(src, k, axis)
163+
164+
assert result_out.shape == src.shape
165+
out_items = result_out.items.tolist()
166+
no_out_items = result_no_out.items.tolist()
167+
assert out_items == no_out_items
168+
169+
170+
def test_argpartition_axis0(device):
171+
def _get_top_values_by_col_indices(
172+
indices: List[int], data: List[List[int]], k: int
173+
) -> List[List[int]]:
174+
"""Obtain the top-k values from out matrix, using column indices.
175+
176+
For this test, we partition by column (axis == 0). This is just some
177+
helper logic to obtain the values from the original matrix, given
178+
then column indices.
179+
180+
Args:
181+
indices (List[int]): Flattened indices from `sfnp.argpartition`
182+
data (List[List[int]]): Matrix containing original values.
183+
k (int): Specify top-k values to select.
184+
185+
Returns:
186+
List[List[int]]: Top-k values for each column.
187+
"""
188+
num_cols = len(data[0])
189+
190+
top_values_by_col = []
191+
192+
for c in range(num_cols):
193+
# Collect the row indices for the first k entries in column c.
194+
col_row_idxs = [indices[r * num_cols + c] for r in range(k)]
195+
196+
# Map those row indices into actual values in `data`.
197+
col_values = [data[row_idx][c] for row_idx in col_row_idxs]
198+
199+
top_values_by_col.append(col_values)
200+
201+
return top_values_by_col
202+
203+
def _get_top_values_by_sorting(
204+
data: List[List[float]], k: int
205+
) -> List[List[float]]:
206+
"""Get the top-k value for each col in the matrix, using sorting.
207+
208+
This is just to obtain a comparison for our `argpartition` testing.
209+
210+
Args:
211+
data (List[List[int]]): Matrix of data.
212+
k (int): Specify top-k values to select.
213+
214+
Returns:
215+
List[List[float]]: Top-k values for each column.
216+
"""
217+
num_rows = len(data)
218+
num_cols = len(data[0])
219+
220+
top_values_by_col = []
221+
222+
for c in range(num_cols):
223+
# Extract the entire column 'c' into a list
224+
col = [data[r][c] for r in range(num_rows)]
225+
# Sort the column in ascending order
226+
col_sorted = sorted(col)
227+
# The first k elements are the k smallest
228+
col_k_smallest = col_sorted[:k]
229+
top_values_by_col.append(col_k_smallest)
230+
231+
return top_values_by_col
232+
233+
k, axis = 2, 0
234+
src = sfnp.device_array(device, [3, 4], dtype=sfnp.float32)
235+
# data = [[float(i) for i in range(math.prod(src.shape))]]
236+
data = [[i for i in range(src.shape[-1])] for _ in range(src.shape[0])]
237+
for i in range(len(data)):
238+
random.shuffle(data[i])
239+
240+
for i in range(src.shape[0]):
241+
src.view(i).items = data[i]
242+
243+
result = sfnp.argpartition(src, k, axis)
244+
assert result.shape == src.shape
245+
246+
expected_values = _get_top_values_by_sorting(data, k)
247+
top_values = _get_top_values_by_col_indices(result.items.tolist(), data, k)
248+
for result, expected in zip(top_values, expected_values):
249+
assert sorted(result) == sorted(expected)
250+
251+
252+
def test_argpartition_error_cases(device):
253+
# Invalid `input` dtype
254+
with pytest.raises(
255+
ValueError,
256+
):
257+
src = sfnp.device_array(device, [1, 1, 16], dtype=sfnp.int64)
258+
sfnp.argpartition(src, 0)
259+
260+
src = sfnp.device_array(device, [1, 1, 16], dtype=sfnp.float32)
261+
data = [float(i) for i in range(math.prod(src.shape))]
262+
src.items = data
263+
264+
# Invalid `axis`
265+
with pytest.raises(
266+
ValueError,
267+
):
268+
sfnp.argpartition(src, 1, 3)
269+
sfnp.argpartition(src, 1, -4)
270+
271+
# Invalid `k`
272+
with pytest.raises(
273+
ValueError,
274+
):
275+
sfnp.argpartition(src, 17)
276+
sfnp.argpartition(src, -17)
277+
278+
# Invalid `out` dtype
279+
with pytest.raises(
280+
ValueError,
281+
):
282+
out = sfnp.device_array(device, src.shape, dtype=sfnp.float32)
283+
sfnp.argpartition(src, 2, -1, out)
284+
285+
286+
@pytest.mark.parametrize(
287+
"dtype",
288+
[
289+
sfnp.bfloat16,
290+
sfnp.float16,
291+
sfnp.float32,
292+
],
293+
)
294+
def test_argpartition_dtypes(device, dtype):
295+
src = sfnp.device_array(device, [4, 16, 128], dtype=dtype)
296+
sfnp.argpartition(src, 0)
297+
298+
115299
@pytest.mark.parametrize(
116300
"dtype",
117301
[

0 commit comments

Comments
 (0)