Skip to content

Commit 17fa8b1

Browse files
PyTorch Softmax Ops (#846)
Co-authored-by: HarshvirSandhu <[email protected]>
1 parent f3d2ede commit 17fa8b1

File tree

3 files changed

+100
-1
lines changed

3 files changed

+100
-1
lines changed

environment.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ channels:
99
dependencies:
1010
- python>=3.10
1111
- compilers
12-
- numpy>=1.17.0
12+
- numpy>=1.17.0,<2
1313
- scipy>=0.14,<1.14.0
1414
- filelock
1515
- etuples

pytensor/link/pytorch/dispatch/elemwise.py

+50
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
44
from pytensor.tensor.elemwise import DimShuffle, Elemwise
5+
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
56

67

78
@pytorch_funcify.register(Elemwise)
@@ -34,3 +35,52 @@ def dimshuffle(x):
3435
return res
3536

3637
return dimshuffle
38+
39+
40+
@pytorch_funcify.register(Softmax)
41+
def pytorch_funcify_Softmax(op, **kwargs):
42+
axis = op.axis
43+
dtype = kwargs["node"].inputs[0].dtype
44+
45+
if not dtype.startswith("float"):
46+
raise NotImplementedError(
47+
"Pytorch Softmax is not currently implemented for non-float types."
48+
)
49+
50+
def softmax(x):
51+
if axis is not None:
52+
return torch.softmax(x, dim=axis)
53+
else:
54+
return torch.softmax(x.ravel(), dim=0).reshape(x.shape)
55+
56+
return softmax
57+
58+
59+
@pytorch_funcify.register(LogSoftmax)
60+
def pytorch_funcify_LogSoftmax(op, **kwargs):
61+
axis = op.axis
62+
dtype = kwargs["node"].inputs[0].dtype
63+
64+
if not dtype.startswith("float"):
65+
raise NotImplementedError(
66+
"Pytorch LogSoftmax is not currently implemented for non-float types."
67+
)
68+
69+
def log_softmax(x):
70+
if axis is not None:
71+
return torch.log_softmax(x, dim=axis)
72+
else:
73+
return torch.log_softmax(x.ravel(), dim=0).reshape(x.shape)
74+
75+
return log_softmax
76+
77+
78+
@pytorch_funcify.register(SoftmaxGrad)
79+
def jax_funcify_SoftmaxGrad(op, **kwargs):
80+
axis = op.axis
81+
82+
def softmax_grad(dy, sm):
83+
dy_times_sm = dy * sm
84+
return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdim=True) * sm
85+
86+
return softmax_grad

tests/link/pytorch/test_elemwise.py

+49
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import numpy as np
2+
import pytest
23

34
import pytensor.tensor as pt
45
from pytensor.configdefaults import config
56
from pytensor.graph.fg import FunctionGraph
67
from pytensor.tensor import elemwise as pt_elemwise
8+
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
79
from pytensor.tensor.type import matrix, tensor, vector
810
from tests.link.pytorch.test_basic import compare_pytorch_and_py
911

@@ -53,3 +55,50 @@ def test_pytorch_elemwise():
5355

5456
fg = FunctionGraph([x], [out])
5557
compare_pytorch_and_py(fg, [[0.9, 0.9]])
58+
59+
60+
@pytest.mark.parametrize("dtype", ["float64", "int64"])
61+
@pytest.mark.parametrize("axis", [None, 0, 1])
62+
def test_softmax(axis, dtype):
63+
x = matrix("x", dtype=dtype)
64+
out = softmax(x, axis=axis)
65+
fgraph = FunctionGraph([x], [out])
66+
test_input = np.arange(6, dtype=config.floatX).reshape(2, 3)
67+
68+
if dtype == "int64":
69+
with pytest.raises(
70+
NotImplementedError,
71+
match="Pytorch Softmax is not currently implemented for non-float types.",
72+
):
73+
compare_pytorch_and_py(fgraph, [test_input])
74+
else:
75+
compare_pytorch_and_py(fgraph, [test_input])
76+
77+
78+
@pytest.mark.parametrize("dtype", ["float64", "int64"])
79+
@pytest.mark.parametrize("axis", [None, 0, 1])
80+
def test_logsoftmax(axis, dtype):
81+
x = matrix("x", dtype=dtype)
82+
out = log_softmax(x, axis=axis)
83+
fgraph = FunctionGraph([x], [out])
84+
test_input = np.arange(6, dtype=config.floatX).reshape(2, 3)
85+
86+
if dtype == "int64":
87+
with pytest.raises(
88+
NotImplementedError,
89+
match="Pytorch LogSoftmax is not currently implemented for non-float types.",
90+
):
91+
compare_pytorch_and_py(fgraph, [test_input])
92+
else:
93+
compare_pytorch_and_py(fgraph, [test_input])
94+
95+
96+
@pytest.mark.parametrize("axis", [None, 0, 1])
97+
def test_softmax_grad(axis):
98+
dy = matrix("dy")
99+
dy_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
100+
sm = matrix("sm")
101+
sm_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
102+
out = SoftmaxGrad(axis=axis)(dy, sm)
103+
fgraph = FunctionGraph([dy, sm], [out])
104+
compare_pytorch_and_py(fgraph, [dy_value, sm_value])

0 commit comments

Comments
 (0)