Skip to content

Commit f93a20e

Browse files
committed
test rng updates do not overwrite original rng
1 parent 7eb77df commit f93a20e

File tree

3 files changed

+26
-3
lines changed

3 files changed

+26
-3
lines changed

Diff for: pytensor/link/pytorch/dispatch/random.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ def torch_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
2525
rv = node.outputs[1]
2626
out_dtype = rv.type.dtype
2727
shape = rv.type.shape
28+
rv_sample = pytorch_sample_fn(op, node=node)
2829

29-
def sample_fn(rng, size, *parameters):
30-
return pytorch_sample_fn(op, node=node)(rng, shape, out_dtype, *parameters)
30+
def sample_fn(rng, size, *args):
31+
return rv_sample(rng, shape, out_dtype, *args)
3132

3233
return sample_fn
3334

Diff for: pytensor/link/pytorch/linker.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def __call__(self, *inputs, **kwargs):
7474
if getattr(pytensor.link.utils, n[1:], False):
7575
delattr(pytensor.link.utils, n[1:])
7676

77-
return tuple(out.cpu().numpy() for out in outs)
77+
return tuple(
78+
out.cpu().numpy() if torch.is_tensor(out) else out for out in outs
79+
)
7880

7981
def __del__(self):
8082
del self.gen_functors

Diff for: tests/link/pytorch/test_random.py

+20
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,26 @@
1010
torch = pytest.importorskip("torch")
1111

1212

13+
def test_random_updates():
14+
original = np.random.default_rng(seed=123)
15+
rng = shared(original, name="rng", borrow=False)
16+
rv = pt.random.bernoulli(0.5, name="y", rng=rng)
17+
next_rng, x = rv.owner.outputs
18+
x.dprint()
19+
f = function([], [x], updates={rng: next_rng}, mode="PYTORCH")
20+
assert any(f() for _ in range(5))
21+
22+
keys_to_check = set(rng.get_value().keys()) - {"pytorch_gen"}
23+
torch_rand = {k: rng.get_value()[k] for k in keys_to_check}
24+
25+
assert all(
26+
torch_rand[k] == original.bit_generator.state[k]
27+
if not isinstance(original.bit_generator.state, np.ndarray)
28+
else np.array_equal(torch_rand[k], original.bit_generator.state[k])
29+
for k in keys_to_check
30+
)
31+
32+
1333
@pytest.mark.parametrize(
1434
"size,p",
1535
[

0 commit comments

Comments
 (0)