diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 19e91e5f8e..92bc44739f 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -37,6 +37,7 @@ from pytensor.tensor.math import Dot from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from pytensor.tensor.slinalg import Solve +from pytensor.tensor.sort import ArgSortOp, SortOp from pytensor.tensor.type import TensorType from pytensor.tensor.type_other import MakeSlice, NoneConst @@ -432,6 +433,68 @@ def shape_i(x): return shape_i +@numba_funcify.register(SortOp) +def numba_funcify_SortOp(op, node, **kwargs): + @numba_njit + def sort_f(a, axis): + axis = axis.item() + + a_swapped = np.swapaxes(a, axis, -1) + a_sorted = np.sort(a_swapped) + a_sorted_swapped = np.swapaxes(a_sorted, -1, axis) + + return a_sorted_swapped + + if op.kind != "quicksort": + warnings.warn( + ( + f'Numba function sort doesn\'t support kind="{op.kind}"' + " switching to `quicksort`." + ), + UserWarning, + ) + + return sort_f + + +@numba_funcify.register(ArgSortOp) +def numba_funcify_ArgSortOp(op, node, **kwargs): + def argsort_f_kind(kind): + @numba_njit + def argort_vec(X, axis): + axis = axis.item() + + Y = np.swapaxes(X, axis, 0) + result = np.empty_like(Y) + + indices = list(np.ndindex(Y.shape[1:])) + + for idx in indices: + result[(slice(None), *idx)] = np.argsort( + Y[(slice(None), *idx)], kind=kind + ) + + result = np.swapaxes(result, 0, axis) + + return result + + return argort_vec + + kind = op.kind + + if kind not in ["quicksort", "mergesort"]: + kind = "quicksort" + warnings.warn( + ( + f'Numba function argsort doesn\'t support kind="{op.kind}"' + " switching to `quicksort`." + ), + UserWarning, + ) + + return argsort_f_kind(kind) + + @numba.extending.intrinsic def direct_cast(typingctx, val, typ): if isinstance(typ, numba.types.TypeRef): diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 654cbe7bd4..101dd393d3 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -33,6 +33,7 @@ from pytensor.tensor import blas from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape +from pytensor.tensor.sort import ArgSortOp, SortOp if TYPE_CHECKING: @@ -378,6 +379,70 @@ def test_Shape(x, i): compare_numba_and_py([], [g], []) +@pytest.mark.parametrize( + "x", + [ + [], # Empty list + [3, 2, 1], # Simple list + np.random.randint(0, 10, (3, 2, 3, 4, 4)), # Multi-dimensional array + ], +) +@pytest.mark.parametrize("axis", [0, -1, None]) +@pytest.mark.parametrize( + ("kind", "exc"), + [ + ["quicksort", None], + ["mergesort", UserWarning], + ["heapsort", UserWarning], + ["stable", UserWarning], + ], +) +def test_Sort(x, axis, kind, exc): + if axis: + g = SortOp(kind)(pt.as_tensor_variable(x), axis) + else: + g = SortOp(kind)(pt.as_tensor_variable(x)) + + cm = contextlib.suppress() if not exc else pytest.warns(exc) + + with cm: + compare_numba_and_py([], [g], []) + + +@pytest.mark.parametrize( + "x", + [ + [], # Empty list + [3, 2, 1], # Simple list + None, # Multi-dimensional array (see below) + ], +) +@pytest.mark.parametrize("axis", [0, -1, None]) +@pytest.mark.parametrize( + ("kind", "exc"), + [ + ["quicksort", None], + ["heapsort", None], + ["stable", UserWarning], + ], +) +def test_ArgSort(x, axis, kind, exc): + if x is None: + x = np.arange(5 * 5 * 5 * 5) + np.random.shuffle(x) + x = np.reshape(x, (5, 5, 5, 5)) + + if axis: + g = ArgSortOp(kind)(pt.as_tensor_variable(x), axis) + else: + g = ArgSortOp(kind)(pt.as_tensor_variable(x)) + + cm = contextlib.suppress() if not exc else pytest.warns(exc) + + with cm: + compare_numba_and_py([], [g], []) + + @pytest.mark.parametrize( "v, shape, ndim", [