Skip to content

Commit ee4d4f7

Browse files
authored
Implemented Sort/Argsort Ops in PyTorch (#897)
1 parent a99d067 commit ee4d4f7

File tree

3 files changed

+52
-0
lines changed

3 files changed

+52
-0
lines changed

pytensor/link/pytorch/dispatch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
import pytensor.link.pytorch.dispatch.scalar
66
import pytensor.link.pytorch.dispatch.elemwise
77
import pytensor.link.pytorch.dispatch.extra_ops
8+
import pytensor.link.pytorch.dispatch.sort
89
# isort: on
+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import torch
2+
3+
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
4+
from pytensor.tensor.sort import ArgSortOp, SortOp
5+
6+
7+
@pytorch_funcify.register(SortOp)
8+
def pytorch_funcify_Sort(op, **kwargs):
9+
stable = op.kind == "stable"
10+
11+
def sort(arr, axis):
12+
sorted, _ = torch.sort(arr, dim=axis, stable=stable)
13+
return sorted
14+
15+
return sort
16+
17+
18+
@pytorch_funcify.register(ArgSortOp)
19+
def pytorch_funcify_ArgSort(op, **kwargs):
20+
stable = op.kind == "stable"
21+
22+
def argsort(arr, axis):
23+
return torch.argsort(arr, dim=axis, stable=stable)
24+
25+
return argsort

tests/link/pytorch/test_sort.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pytensor.graph import FunctionGraph
5+
from pytensor.tensor import matrix
6+
from pytensor.tensor.sort import argsort, sort
7+
from tests.link.pytorch.test_basic import compare_pytorch_and_py
8+
9+
10+
@pytest.mark.parametrize("func", (sort, argsort))
11+
@pytest.mark.parametrize(
12+
"axis",
13+
[
14+
pytest.param(0),
15+
pytest.param(1),
16+
pytest.param(
17+
None, marks=pytest.mark.xfail(reason="Reshape Op not implemented")
18+
),
19+
],
20+
)
21+
def test_sort(func, axis):
22+
x = matrix("x", shape=(2, 2), dtype="float64")
23+
out = func(x, axis=axis)
24+
fgraph = FunctionGraph([x], [out])
25+
arr = np.array([[1.0, 4.0], [5.0, 2.0]])
26+
compare_pytorch_and_py(fgraph, [arr])

0 commit comments

Comments
 (0)