forked from pymc-devs/pytensor
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_extra_ops.py
40 lines (32 loc) · 1.24 KB
/
test_extra_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor.configdefaults import config
from pytensor.graph import FunctionGraph
from tests.link.pytorch.test_basic import compare_pytorch_and_py
@pytest.mark.parametrize(
"axis",
[None, 1, (0,)],
)
def test_pytorch_CumOp(axis):
"""Test PyTorch conversion of the `CumOp` `Op`."""
# Create a symbolic input for the first input of `CumOp`
a = pt.matrix("a")
# Create test value
test_value = np.arange(9, dtype=config.floatX).reshape((3, 3))
# Create the output variable
if isinstance(axis, tuple):
with pytest.raises(TypeError, match="axis must be an integer or None."):
out = pt.cumsum(a, axis=axis)
with pytest.raises(TypeError, match="axis must be an integer or None."):
out = pt.cumprod(a, axis=axis)
else:
out = pt.cumsum(a, axis=axis)
# Create a PyTensor `FunctionGraph`
fgraph = FunctionGraph([a], [out])
# Pass the graph and inputs to the testing function
compare_pytorch_and_py(fgraph, [test_value])
# For the second mode of CumOp
out = pt.cumprod(a, axis=axis)
fgraph = FunctionGraph([a], [out])
compare_pytorch_and_py(fgraph, [test_value])