Skip to content

Commit 4b1761b

Browse files
Support numba compiled sort and argsort functions (#1309)
* feat: support numba compiled sort and argsort functions Signed-off-by: Victor Garcia Reolid <[email protected]> * default to supported kind and add warning Signed-off-by: Victor Garcia Reolid <[email protected]> * feat: support axis Signed-off-by: Victor Garcia Reolid <[email protected]> * use syntax compatible with python 3.10 Signed-off-by: Victor Garcia Reolid <[email protected]> * remove checks Signed-off-by: Victor Garcia Reolid <[email protected]> * use range instead of prange Signed-off-by: Victor Garcia Reolid <[email protected]> * add extra case to check Axis error is raised Signed-off-by: Victor Garcia Reolid <[email protected]> * simplify tests Signed-off-by: Victor Garcia Reolid <[email protected]> --------- Signed-off-by: Victor Garcia Reolid <[email protected]>
1 parent 0f5da80 commit 4b1761b

File tree

2 files changed

+128
-0
lines changed

2 files changed

+128
-0
lines changed

pytensor/link/numba/dispatch/basic.py

+63
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from pytensor.tensor.math import Dot
3939
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
4040
from pytensor.tensor.slinalg import Solve
41+
from pytensor.tensor.sort import ArgSortOp, SortOp
4142
from pytensor.tensor.type import TensorType
4243
from pytensor.tensor.type_other import MakeSlice, NoneConst
4344

@@ -433,6 +434,68 @@ def shape_i(x):
433434
return shape_i
434435

435436

437+
@numba_funcify.register(SortOp)
438+
def numba_funcify_SortOp(op, node, **kwargs):
439+
@numba_njit
440+
def sort_f(a, axis):
441+
axis = axis.item()
442+
443+
a_swapped = np.swapaxes(a, axis, -1)
444+
a_sorted = np.sort(a_swapped)
445+
a_sorted_swapped = np.swapaxes(a_sorted, -1, axis)
446+
447+
return a_sorted_swapped
448+
449+
if op.kind != "quicksort":
450+
warnings.warn(
451+
(
452+
f'Numba function sort doesn\'t support kind="{op.kind}"'
453+
" switching to `quicksort`."
454+
),
455+
UserWarning,
456+
)
457+
458+
return sort_f
459+
460+
461+
@numba_funcify.register(ArgSortOp)
462+
def numba_funcify_ArgSortOp(op, node, **kwargs):
463+
def argsort_f_kind(kind):
464+
@numba_njit
465+
def argort_vec(X, axis):
466+
axis = axis.item()
467+
468+
Y = np.swapaxes(X, axis, 0)
469+
result = np.empty_like(Y)
470+
471+
indices = list(np.ndindex(Y.shape[1:]))
472+
473+
for idx in indices:
474+
result[(slice(None), *idx)] = np.argsort(
475+
Y[(slice(None), *idx)], kind=kind
476+
)
477+
478+
result = np.swapaxes(result, 0, axis)
479+
480+
return result
481+
482+
return argort_vec
483+
484+
kind = op.kind
485+
486+
if kind not in ["quicksort", "mergesort"]:
487+
kind = "quicksort"
488+
warnings.warn(
489+
(
490+
f'Numba function argsort doesn\'t support kind="{op.kind}"'
491+
" switching to `quicksort`."
492+
),
493+
UserWarning,
494+
)
495+
496+
return argsort_f_kind(kind)
497+
498+
436499
@numba.extending.intrinsic
437500
def direct_cast(typingctx, val, typ):
438501
if isinstance(typ, numba.types.TypeRef):

tests/link/numba/test_basic.py

+65
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from pytensor.tensor import blas
3535
from pytensor.tensor.elemwise import Elemwise
3636
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
37+
from pytensor.tensor.sort import ArgSortOp, SortOp
3738

3839

3940
if TYPE_CHECKING:
@@ -383,6 +384,70 @@ def test_Shape(x, i):
383384
compare_numba_and_py([], [g], [])
384385

385386

387+
@pytest.mark.parametrize(
388+
"x",
389+
[
390+
[], # Empty list
391+
[3, 2, 1], # Simple list
392+
np.random.randint(0, 10, (3, 2, 3, 4, 4)), # Multi-dimensional array
393+
],
394+
)
395+
@pytest.mark.parametrize("axis", [0, -1, None])
396+
@pytest.mark.parametrize(
397+
("kind", "exc"),
398+
[
399+
["quicksort", None],
400+
["mergesort", UserWarning],
401+
["heapsort", UserWarning],
402+
["stable", UserWarning],
403+
],
404+
)
405+
def test_Sort(x, axis, kind, exc):
406+
if axis:
407+
g = SortOp(kind)(pt.as_tensor_variable(x), axis)
408+
else:
409+
g = SortOp(kind)(pt.as_tensor_variable(x))
410+
411+
cm = contextlib.suppress() if not exc else pytest.warns(exc)
412+
413+
with cm:
414+
compare_numba_and_py([], [g], [])
415+
416+
417+
@pytest.mark.parametrize(
418+
"x",
419+
[
420+
[], # Empty list
421+
[3, 2, 1], # Simple list
422+
None, # Multi-dimensional array (see below)
423+
],
424+
)
425+
@pytest.mark.parametrize("axis", [0, -1, None])
426+
@pytest.mark.parametrize(
427+
("kind", "exc"),
428+
[
429+
["quicksort", None],
430+
["heapsort", None],
431+
["stable", UserWarning],
432+
],
433+
)
434+
def test_ArgSort(x, axis, kind, exc):
435+
if x is None:
436+
x = np.arange(5 * 5 * 5 * 5)
437+
np.random.shuffle(x)
438+
x = np.reshape(x, (5, 5, 5, 5))
439+
440+
if axis:
441+
g = ArgSortOp(kind)(pt.as_tensor_variable(x), axis)
442+
else:
443+
g = ArgSortOp(kind)(pt.as_tensor_variable(x))
444+
445+
cm = contextlib.suppress() if not exc else pytest.warns(exc)
446+
447+
with cm:
448+
compare_numba_and_py([], [g], [])
449+
450+
386451
@pytest.mark.parametrize(
387452
"v, shape, ndim",
388453
[

0 commit comments

Comments
 (0)