Skip to content

Commit e57e25b

Browse files
Pytorch support for Join and Careduce Ops (#869)
1 parent df769f6 commit e57e25b

File tree

4 files changed

+159
-3
lines changed

4 files changed

+159
-3
lines changed

pytensor/link/pytorch/dispatch/basic.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pytensor.graph.fg import FunctionGraph
77
from pytensor.link.utils import fgraph_to_python
88
from pytensor.raise_op import CheckAndRaise
9-
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange
9+
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Join
1010

1111

1212
@singledispatch
@@ -89,3 +89,14 @@ def arange(start, stop, step):
8989
return torch.arange(start, stop, step, dtype=dtype)
9090

9191
return arange
92+
93+
94+
@pytorch_funcify.register(Join)
95+
def pytorch_funcify_Join(op, **kwargs):
96+
def join(axis, *tensors):
97+
# tensors could also be tuples, and in this case they don't have a ndim
98+
tensors = [torch.tensor(tensor) for tensor in tensors]
99+
100+
return torch.cat(tensors, dim=axis)
101+
102+
return join

pytensor/link/pytorch/dispatch/elemwise.py

+64
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
44
from pytensor.tensor.elemwise import DimShuffle, Elemwise
5+
from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum
56
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
67

78

@@ -37,6 +38,69 @@ def dimshuffle(x):
3738
return dimshuffle
3839

3940

41+
@pytorch_funcify.register(Sum)
42+
def pytorch_funcify_sum(op, **kwargs):
43+
def torch_sum(x):
44+
return torch.sum(x, dim=op.axis)
45+
46+
return torch_sum
47+
48+
49+
@pytorch_funcify.register(All)
50+
def pytorch_funcify_all(op, **kwargs):
51+
def torch_all(x):
52+
return torch.all(x, dim=op.axis)
53+
54+
return torch_all
55+
56+
57+
@pytorch_funcify.register(Prod)
58+
def pytorch_funcify_prod(op, **kwargs):
59+
def torch_prod(x):
60+
if isinstance(op.axis, tuple):
61+
for d in sorted(op.axis, reverse=True):
62+
x = torch.prod(x, dim=d)
63+
return x
64+
else:
65+
return torch.prod(x.flatten(), dim=0)
66+
67+
return torch_prod
68+
69+
70+
@pytorch_funcify.register(Any)
71+
def pytorch_funcify_any(op, **kwargs):
72+
def torch_any(x):
73+
return torch.any(x, dim=op.axis)
74+
75+
return torch_any
76+
77+
78+
@pytorch_funcify.register(Max)
79+
def pytorch_funcify_max(op, **kwargs):
80+
def torch_max(x):
81+
if isinstance(op.axis, tuple):
82+
for d in sorted(op.axis, reverse=True):
83+
x = torch.max(x, dim=d).values
84+
return x
85+
else:
86+
return torch.max(x.flatten(), dim=0).values
87+
88+
return torch_max
89+
90+
91+
@pytorch_funcify.register(Min)
92+
def pytorch_funcify_min(op, **kwargs):
93+
def torch_min(x):
94+
if isinstance(op.axis, tuple):
95+
for d in sorted(op.axis, reverse=True):
96+
x = torch.min(x, dim=d).values
97+
return x
98+
else:
99+
return torch.min(x.flatten(), dim=0).values
100+
101+
return torch_min
102+
103+
40104
@pytorch_funcify.register(Softmax)
41105
def pytorch_funcify_Softmax(op, **kwargs):
42106
axis = op.axis

tests/link/pytorch/test_basic.py

+41-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import pytest
66

7+
import pytensor.tensor.basic as ptb
78
from pytensor.compile.function import function
89
from pytensor.compile.mode import get_mode
910
from pytensor.compile.sharedvalue import SharedVariable, shared
@@ -13,7 +14,7 @@
1314
from pytensor.graph.op import Op
1415
from pytensor.raise_op import CheckAndRaise
1516
from pytensor.tensor import alloc, arange, as_tensor, empty
16-
from pytensor.tensor.type import scalar, vector
17+
from pytensor.tensor.type import matrix, scalar, vector
1718

1819

1920
torch = pytest.importorskip("torch")
@@ -235,3 +236,42 @@ def test_arange():
235236
FunctionGraph([start, stop, step], [out]),
236237
[np.array(1), np.array(10), np.array(2)],
237238
)
239+
240+
241+
def test_pytorch_Join():
242+
a = matrix("a")
243+
b = matrix("b")
244+
245+
x = ptb.join(0, a, b)
246+
x_fg = FunctionGraph([a, b], [x])
247+
compare_pytorch_and_py(
248+
x_fg,
249+
[
250+
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
251+
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
252+
],
253+
)
254+
compare_pytorch_and_py(
255+
x_fg,
256+
[
257+
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
258+
np.c_[[4.0, 5.0]].astype(config.floatX),
259+
],
260+
)
261+
262+
x = ptb.join(1, a, b)
263+
x_fg = FunctionGraph([a, b], [x])
264+
compare_pytorch_and_py(
265+
x_fg,
266+
[
267+
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
268+
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
269+
],
270+
)
271+
compare_pytorch_and_py(
272+
x_fg,
273+
[
274+
np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX),
275+
np.c_[[5.0, 6.0]].astype(config.floatX),
276+
],
277+
)

tests/link/pytorch/test_elemwise.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
import pytest
33

44
import pytensor.tensor as pt
5+
import pytensor.tensor.math as ptm
56
from pytensor.configdefaults import config
67
from pytensor.graph.fg import FunctionGraph
78
from pytensor.tensor import elemwise as pt_elemwise
89
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
9-
from pytensor.tensor.type import matrix, tensor, vector
10+
from pytensor.tensor.type import matrix, tensor, tensor3, vector
1011
from tests.link.pytorch.test_basic import compare_pytorch_and_py
1112

1213

@@ -57,6 +58,46 @@ def test_pytorch_elemwise():
5758
compare_pytorch_and_py(fg, [[0.9, 0.9]])
5859

5960

61+
@pytest.mark.parametrize("fn", [ptm.sum, ptm.prod, ptm.max, ptm.min])
62+
@pytest.mark.parametrize("axis", [None, 0, 1, (0, -1)])
63+
def test_pytorch_careduce(fn, axis):
64+
a_pt = tensor3("a")
65+
test_value = np.array(
66+
[
67+
[
68+
[1, 1, 1, 1],
69+
[2, 2, 2, 2],
70+
],
71+
[
72+
[3, 3, 3, 3],
73+
[
74+
4,
75+
4,
76+
4,
77+
4,
78+
],
79+
],
80+
]
81+
).astype(config.floatX)
82+
83+
x = fn(a_pt, axis=axis)
84+
x_fg = FunctionGraph([a_pt], [x])
85+
86+
compare_pytorch_and_py(x_fg, [test_value])
87+
88+
89+
@pytest.mark.parametrize("fn", [ptm.any, ptm.all])
90+
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)])
91+
def test_pytorch_any_all(fn, axis):
92+
a_pt = matrix("a")
93+
test_value = np.array([[True, False, True], [False, True, True]])
94+
95+
x = fn(a_pt, axis=axis)
96+
x_fg = FunctionGraph([a_pt], [x])
97+
98+
compare_pytorch_and_py(x_fg, [test_value])
99+
100+
60101
@pytest.mark.parametrize("dtype", ["float64", "int64"])
61102
@pytest.mark.parametrize("axis", [None, 0, 1])
62103
def test_softmax(axis, dtype):

0 commit comments

Comments
 (0)