|
1 | 1 | import numpy as np
|
| 2 | +import pytest |
2 | 3 |
|
3 | 4 | import pytensor.tensor as pt
|
4 | 5 | from pytensor.configdefaults import config
|
5 | 6 | from pytensor.graph.fg import FunctionGraph
|
6 | 7 | from pytensor.tensor import elemwise as pt_elemwise
|
| 8 | +from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax |
7 | 9 | from pytensor.tensor.type import matrix, tensor, vector
|
8 | 10 | from tests.link.pytorch.test_basic import compare_pytorch_and_py
|
9 | 11 |
|
@@ -53,3 +55,50 @@ def test_pytorch_elemwise():
|
53 | 55 |
|
54 | 56 | fg = FunctionGraph([x], [out])
|
55 | 57 | compare_pytorch_and_py(fg, [[0.9, 0.9]])
|
| 58 | + |
| 59 | + |
| 60 | +@pytest.mark.parametrize("dtype", ["float64", "int64"]) |
| 61 | +@pytest.mark.parametrize("axis", [None, 0, 1]) |
| 62 | +def test_softmax(axis, dtype): |
| 63 | + x = matrix("x", dtype=dtype) |
| 64 | + out = softmax(x, axis=axis) |
| 65 | + fgraph = FunctionGraph([x], [out]) |
| 66 | + test_input = np.arange(6, dtype=config.floatX).reshape(2, 3) |
| 67 | + |
| 68 | + if dtype == "int64": |
| 69 | + with pytest.raises( |
| 70 | + NotImplementedError, |
| 71 | + match="Pytorch Softmax is not currently implemented for non-float types.", |
| 72 | + ): |
| 73 | + compare_pytorch_and_py(fgraph, [test_input]) |
| 74 | + else: |
| 75 | + compare_pytorch_and_py(fgraph, [test_input]) |
| 76 | + |
| 77 | + |
| 78 | +@pytest.mark.parametrize("dtype", ["float64", "int64"]) |
| 79 | +@pytest.mark.parametrize("axis", [None, 0, 1]) |
| 80 | +def test_logsoftmax(axis, dtype): |
| 81 | + x = matrix("x", dtype=dtype) |
| 82 | + out = log_softmax(x, axis=axis) |
| 83 | + fgraph = FunctionGraph([x], [out]) |
| 84 | + test_input = np.arange(6, dtype=config.floatX).reshape(2, 3) |
| 85 | + |
| 86 | + if dtype == "int64": |
| 87 | + with pytest.raises( |
| 88 | + NotImplementedError, |
| 89 | + match="Pytorch LogSoftmax is not currently implemented for non-float types.", |
| 90 | + ): |
| 91 | + compare_pytorch_and_py(fgraph, [test_input]) |
| 92 | + else: |
| 93 | + compare_pytorch_and_py(fgraph, [test_input]) |
| 94 | + |
| 95 | + |
| 96 | +@pytest.mark.parametrize("axis", [None, 0, 1]) |
| 97 | +def test_softmax_grad(axis): |
| 98 | + dy = matrix("dy") |
| 99 | + dy_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX) |
| 100 | + sm = matrix("sm") |
| 101 | + sm_value = np.arange(6, dtype=config.floatX).reshape(2, 3) |
| 102 | + out = SoftmaxGrad(axis=axis)(dy, sm) |
| 103 | + fgraph = FunctionGraph([dy, sm], [out]) |
| 104 | + compare_pytorch_and_py(fgraph, [dy_value, sm_value]) |
0 commit comments