forked from pymc-devs/pytensor
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_basic.py
237 lines (177 loc) · 7.15 KB
/
test_basic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
from collections.abc import Callable, Iterable
from functools import partial
import numpy as np
import pytest
from pytensor.compile.function import function
from pytensor.compile.mode import get_mode
from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty
from pytensor.tensor.type import scalar, vector
torch = pytest.importorskip("torch")
pytorch_mode = get_mode("PYTORCH")
py_mode = get_mode("FAST_COMPILE")
def compare_pytorch_and_py(
fgraph: FunctionGraph,
test_inputs: Iterable,
assert_fn: Callable | None = None,
must_be_device_array: bool = True,
pytorch_mode=pytorch_mode,
py_mode=py_mode,
):
"""Function to compare python graph output and pytorch compiled output for testing equality
Parameters
----------
fgraph: FunctionGraph
PyTensor function Graph object
test_inputs: iter
Numerical inputs for testing the function graph
assert_fn: func, opt
Assert function used to check for equality between python and pytorch. If not
provided uses np.testing.assert_allclose
must_be_device_array: Bool
Checks if torch.device.type is cuda
"""
if assert_fn is None:
assert_fn = partial(np.testing.assert_allclose)
fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)]
pytensor_torch_fn = function(fn_inputs, fgraph.outputs, mode=pytorch_mode)
pytorch_res = pytensor_torch_fn(*test_inputs)
if must_be_device_array:
if isinstance(pytorch_res, list):
assert all(isinstance(res, torch.Tensor) for res in pytorch_res)
else:
assert pytorch_res.device.type == "cuda"
pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode)
py_res = pytensor_py_fn(*test_inputs)
if len(fgraph.outputs) > 1:
for j, p in zip(pytorch_res, py_res):
assert_fn(j.cpu(), p)
else:
assert_fn([pytorch_res[0].cpu()], py_res)
return pytensor_torch_fn, pytorch_res
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_pytorch_FunctionGraph_once(device):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA is not available")
"""Make sure that an output is only computed once when it's referenced multiple times."""
from pytensor.link.pytorch.dispatch import pytorch_funcify
with torch.device(device):
x = vector("x")
y = vector("y")
class TestOp(Op):
def __init__(self):
self.called = 0
def make_node(self, *args):
return Apply(self, list(args), [x.type() for x in args])
def perform(self, inputs, outputs):
for i, inp in enumerate(inputs):
outputs[i][0] = inp[0]
@pytorch_funcify.register(TestOp)
def pytorch_funcify_TestOp(op, **kwargs):
def func(*args, op=op):
op.called += 1
for arg in args:
assert arg.device.type == device
return list(args)
return func
op1 = TestOp()
op2 = TestOp()
q, r = op1(x, y)
outs = op2(q + r, q + r)
out_fg = FunctionGraph([x, y], outs, clone=False)
assert len(out_fg.outputs) == 2
out_torch = pytorch_funcify(out_fg)
x_val = torch.tensor([1, 2]).to(getattr(torch, config.floatX))
y_val = torch.tensor([2, 3]).to(getattr(torch, config.floatX))
res = out_torch(x_val, y_val)
for output in res:
assert torch.equal(
output, torch.tensor([3, 5]).to(getattr(torch, config.floatX))
)
assert len(res) == 2
assert op1.called == 1
assert op2.called == 1
res = out_torch(x_val, y_val)
for output in res:
assert torch.equal(
output, torch.tensor([3, 5]).to(getattr(torch, config.floatX))
)
assert len(res) == 2
assert op1.called == 2
assert op2.called == 2
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_shared(device):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA is not available")
with torch.device(device):
a = shared(np.array([1, 2, 3], dtype=config.floatX))
pytensor_torch_fn = function([], a, mode="PYTORCH")
pytorch_res = pytensor_torch_fn()
assert isinstance(pytorch_res, torch.Tensor)
assert isinstance(a.get_value(), np.ndarray)
np.testing.assert_allclose(pytorch_res.cpu(), a.get_value())
pytensor_torch_fn = function([], a * 2, mode="PYTORCH")
pytorch_res = pytensor_torch_fn()
assert isinstance(pytorch_res, torch.Tensor)
assert isinstance(a.get_value(), np.ndarray)
np.testing.assert_allclose(pytorch_res.cpu(), a.get_value() * 2)
new_a_value = np.array([3, 4, 5], dtype=config.floatX)
a.set_value(new_a_value)
pytorch_res = pytensor_torch_fn()
assert isinstance(pytorch_res, torch.Tensor)
np.testing.assert_allclose(pytorch_res.cpu(), new_a_value * 2)
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_shared_updates(device):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA is not available")
with torch.device(device):
a = shared(0)
pytensor_torch_fn = function([], a, updates={a: a + 1}, mode="PYTORCH")
res1, res2 = pytensor_torch_fn(), pytensor_torch_fn()
assert res1 == 0
assert res2 == 1
assert a.get_value() == 2
assert isinstance(a.get_value(), np.ndarray)
a.set_value(5)
res1, res2 = pytensor_torch_fn(), pytensor_torch_fn()
assert res1 == 5
assert res2 == 6
assert a.get_value() == 7
assert isinstance(a.get_value(), np.ndarray)
def test_checkandraise():
check_and_raise = CheckAndRaise(AssertionError, "testing")
x = scalar("x")
conds = (x > 0, x > 3)
y = check_and_raise(x, *conds)
y_fn = function([x], y, mode="PYTORCH")
with pytest.raises(AssertionError, match="testing"):
y_fn(0.0)
assert y_fn(4).item() == 4
def test_alloc_and_empty():
dim0 = as_tensor(5, dtype="int64")
dim1 = scalar("dim1", dtype="int64")
out = empty((dim0, dim1, 3), dtype="float32")
fn = function([dim1], out, mode=pytorch_mode)
res = fn(7)
assert res.shape == (5, 7, 3)
assert res.dtype == torch.float32
v = vector("v", shape=(3,), dtype="float64")
out = alloc(v, (dim0, dim1, 3))
compare_pytorch_and_py(
FunctionGraph([v, dim1], [out]),
[np.array([1, 2, 3]), np.array(7)],
)
def test_arange():
start = scalar("start", dtype="int64")
stop = scalar("stop", dtype="int64")
step = scalar("step", dtype="int64")
out = arange(start, stop, step, dtype="int16")
compare_pytorch_and_py(
FunctionGraph([start, stop, step], [out]),
[np.array(1), np.array(10), np.array(2)],
)