|
4 | 4 | # See https://llvm.org/LICENSE.txt for license information.
|
5 | 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
6 | 6 |
|
7 |
| -import array |
8 | 7 | import math
|
| 8 | +from typing import List |
9 | 9 | import pytest
|
| 10 | +import random |
10 | 11 |
|
11 | 12 | import shortfin as sf
|
12 | 13 | import shortfin.array as sfnp
|
@@ -112,6 +113,189 @@ def test_argmax_dtypes(device, dtype):
|
112 | 113 | sfnp.argmax(src)
|
113 | 114 |
|
114 | 115 |
|
| 116 | +@pytest.mark.parametrize( |
| 117 | + "k,axis", |
| 118 | + [ |
| 119 | + # Min sort, default axis |
| 120 | + [3, None], |
| 121 | + # Min sort, axis=-1 |
| 122 | + [20, -1], |
| 123 | + # Max sort, default axis |
| 124 | + [-3, None], |
| 125 | + # Max sort, axis=-1 |
| 126 | + [-20, -1], |
| 127 | + ], |
| 128 | +) |
| 129 | +def test_argpartition(device, k, axis): |
| 130 | + src = sfnp.device_array(device, [1, 1, 128], dtype=sfnp.float32) |
| 131 | + data = [float(i) for i in range(math.prod([1, 1, 128]))] |
| 132 | + randomized_data = data[:] |
| 133 | + random.shuffle(randomized_data) |
| 134 | + src.items = randomized_data |
| 135 | + |
| 136 | + result = ( |
| 137 | + sfnp.argpartition(src, k) if axis is None else sfnp.argpartition(src, k, axis) |
| 138 | + ) |
| 139 | + |
| 140 | + assert result.shape == src.shape |
| 141 | + |
| 142 | + expected_values = data[:k] if k >= 0 else data[k:] |
| 143 | + |
| 144 | + k_slice = slice(0, k) if k >= 0 else slice(k, None) |
| 145 | + |
| 146 | + indices = result.view(0, 0, k_slice).items.tolist() |
| 147 | + values = [randomized_data[index] for index in indices] |
| 148 | + assert sorted(values) == sorted(expected_values) |
| 149 | + |
| 150 | + |
| 151 | +def test_argpartition_out_variant(device): |
| 152 | + k, axis = -3, -1 |
| 153 | + src = sfnp.device_array(device, [1, 1, 128], dtype=sfnp.float32) |
| 154 | + data = [float(i) for i in range(math.prod(src.shape))] |
| 155 | + |
| 156 | + randomized_data = data[:] |
| 157 | + random.shuffle(randomized_data) |
| 158 | + src.items = randomized_data |
| 159 | + |
| 160 | + output_array = sfnp.device_array(device, src.shape, dtype=sfnp.int64) |
| 161 | + result_out = sfnp.argpartition(src, k, axis, out=output_array) |
| 162 | + result_no_out = sfnp.argpartition(src, k, axis) |
| 163 | + |
| 164 | + assert result_out.shape == src.shape |
| 165 | + out_items = result_out.items.tolist() |
| 166 | + no_out_items = result_no_out.items.tolist() |
| 167 | + assert out_items == no_out_items |
| 168 | + |
| 169 | + |
| 170 | +def test_argpartition_axis0(device): |
| 171 | + def _get_top_values_by_col_indices( |
| 172 | + indices: List[int], data: List[List[int]], k: int |
| 173 | + ) -> List[List[int]]: |
| 174 | + """Obtain the top-k values from out matrix, using column indices. |
| 175 | +
|
| 176 | + For this test, we partition by column (axis == 0). This is just some |
| 177 | + helper logic to obtain the values from the original matrix, given |
| 178 | + then column indices. |
| 179 | +
|
| 180 | + Args: |
| 181 | + indices (List[int]): Flattened indices from `sfnp.argpartition` |
| 182 | + data (List[List[int]]): Matrix containing original values. |
| 183 | + k (int): Specify top-k values to select. |
| 184 | +
|
| 185 | + Returns: |
| 186 | + List[List[int]]: Top-k values for each column. |
| 187 | + """ |
| 188 | + num_cols = len(data[0]) |
| 189 | + |
| 190 | + top_values_by_col = [] |
| 191 | + |
| 192 | + for c in range(num_cols): |
| 193 | + # Collect the row indices for the first k entries in column c. |
| 194 | + col_row_idxs = [indices[r * num_cols + c] for r in range(k)] |
| 195 | + |
| 196 | + # Map those row indices into actual values in `data`. |
| 197 | + col_values = [data[row_idx][c] for row_idx in col_row_idxs] |
| 198 | + |
| 199 | + top_values_by_col.append(col_values) |
| 200 | + |
| 201 | + return top_values_by_col |
| 202 | + |
| 203 | + def _get_top_values_by_sorting( |
| 204 | + data: List[List[float]], k: int |
| 205 | + ) -> List[List[float]]: |
| 206 | + """Get the top-k value for each col in the matrix, using sorting. |
| 207 | +
|
| 208 | + This is just to obtain a comparison for our `argpartition` testing. |
| 209 | +
|
| 210 | + Args: |
| 211 | + data (List[List[int]]): Matrix of data. |
| 212 | + k (int): Specify top-k values to select. |
| 213 | +
|
| 214 | + Returns: |
| 215 | + List[List[float]]: Top-k values for each column. |
| 216 | + """ |
| 217 | + num_rows = len(data) |
| 218 | + num_cols = len(data[0]) |
| 219 | + |
| 220 | + top_values_by_col = [] |
| 221 | + |
| 222 | + for c in range(num_cols): |
| 223 | + # Extract the entire column 'c' into a list |
| 224 | + col = [data[r][c] for r in range(num_rows)] |
| 225 | + # Sort the column in ascending order |
| 226 | + col_sorted = sorted(col) |
| 227 | + # The first k elements are the k smallest |
| 228 | + col_k_smallest = col_sorted[:k] |
| 229 | + top_values_by_col.append(col_k_smallest) |
| 230 | + |
| 231 | + return top_values_by_col |
| 232 | + |
| 233 | + k, axis = 2, 0 |
| 234 | + src = sfnp.device_array(device, [3, 4], dtype=sfnp.float32) |
| 235 | + # data = [[float(i) for i in range(math.prod(src.shape))]] |
| 236 | + data = [[i for i in range(src.shape[-1])] for _ in range(src.shape[0])] |
| 237 | + for i in range(len(data)): |
| 238 | + random.shuffle(data[i]) |
| 239 | + |
| 240 | + for i in range(src.shape[0]): |
| 241 | + src.view(i).items = data[i] |
| 242 | + |
| 243 | + result = sfnp.argpartition(src, k, axis) |
| 244 | + assert result.shape == src.shape |
| 245 | + |
| 246 | + expected_values = _get_top_values_by_sorting(data, k) |
| 247 | + top_values = _get_top_values_by_col_indices(result.items.tolist(), data, k) |
| 248 | + for result, expected in zip(top_values, expected_values): |
| 249 | + assert sorted(result) == sorted(expected) |
| 250 | + |
| 251 | + |
| 252 | +def test_argpartition_error_cases(device): |
| 253 | + # Invalid `input` dtype |
| 254 | + with pytest.raises( |
| 255 | + ValueError, |
| 256 | + ): |
| 257 | + src = sfnp.device_array(device, [1, 1, 16], dtype=sfnp.int64) |
| 258 | + sfnp.argpartition(src, 0) |
| 259 | + |
| 260 | + src = sfnp.device_array(device, [1, 1, 16], dtype=sfnp.float32) |
| 261 | + data = [float(i) for i in range(math.prod(src.shape))] |
| 262 | + src.items = data |
| 263 | + |
| 264 | + # Invalid `axis` |
| 265 | + with pytest.raises( |
| 266 | + ValueError, |
| 267 | + ): |
| 268 | + sfnp.argpartition(src, 1, 3) |
| 269 | + sfnp.argpartition(src, 1, -4) |
| 270 | + |
| 271 | + # Invalid `k` |
| 272 | + with pytest.raises( |
| 273 | + ValueError, |
| 274 | + ): |
| 275 | + sfnp.argpartition(src, 17) |
| 276 | + sfnp.argpartition(src, -17) |
| 277 | + |
| 278 | + # Invalid `out` dtype |
| 279 | + with pytest.raises( |
| 280 | + ValueError, |
| 281 | + ): |
| 282 | + out = sfnp.device_array(device, src.shape, dtype=sfnp.float32) |
| 283 | + sfnp.argpartition(src, 2, -1, out) |
| 284 | + |
| 285 | + |
| 286 | +@pytest.mark.parametrize( |
| 287 | + "dtype", |
| 288 | + [ |
| 289 | + sfnp.bfloat16, |
| 290 | + sfnp.float16, |
| 291 | + sfnp.float32, |
| 292 | + ], |
| 293 | +) |
| 294 | +def test_argpartition_dtypes(device, dtype): |
| 295 | + src = sfnp.device_array(device, [4, 16, 128], dtype=dtype) |
| 296 | + sfnp.argpartition(src, 0) |
| 297 | + |
| 298 | + |
115 | 299 | @pytest.mark.parametrize(
|
116 | 300 | "dtype",
|
117 | 301 | [
|
|
0 commit comments