Skip to content

Commit 9f1416a

Browse files
committed
Ensure semantics of rng immutability
1 parent 288d2c4 commit 9f1416a

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

pytensor/link/pytorch/dispatch/random.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from copy import deepcopy
12
from functools import singledispatch
23

34
import numpy as np
@@ -28,7 +29,8 @@ def torch_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
2829
rv_sample = pytorch_sample_fn(op, node=node)
2930

3031
def sample_fn(rng, size, *args):
31-
return rv_sample(rng, shape, out_dtype, *args)
32+
_rng = deepcopy(rng)
33+
return rv_sample(_rng, shape, out_dtype, *args)
3234

3335
return sample_fn
3436

@@ -57,11 +59,23 @@ def pytorch_sample_fn_binomial(op, node):
5759
def sample_fn(rng, size, dtype, n, p):
5860
gen = rng["pytorch_gen"]
5961
sample = torch.binomial(
60-
torch.broadcast_to(n, size),
61-
torch.broadcast_to(p, size),
62+
torch.broadcast_to(n, size).to(torch.float32),
63+
torch.broadcast_to(p, size).to(torch.float32),
6264
generator=gen,
6365
)
6466
rng["pytorch_gen"] = gen
6567
return (rng, sample)
6668

6769
return sample_fn
70+
71+
72+
@pytorch_sample_fn.register(ptr.UniformRV)
73+
def pytorch_sample_fn_uniform(op, node):
74+
def sample_fn(rng, size, dtype, low, high):
75+
gen = rng["pytorch_gen"]
76+
sample = torch.FloatTensor(size)
77+
sample.uniform_(low.item(), high.item(), generator=gen)
78+
rng["pytorch_gen"] = gen
79+
return (rng, sample)
80+
81+
return sample_fn

tests/link/pytorch/test_random.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,18 @@ def test_random_updates():
1616
rv = pt.random.bernoulli(0.5, name="y", rng=rng)
1717
next_rng, x = rv.owner.outputs
1818
x.dprint()
19-
f = function([], [x], updates={rng: next_rng}, mode="PYTORCH")
20-
assert any(f() for _ in range(5))
19+
f = function([], x, updates={rng: next_rng}, mode="PYTORCH")
20+
draws = np.stack([f() for _ in range(5)])
21+
# assert we are getting different values
22+
assert draws.sum() < 5 and draws.sum() >= 1
2123

24+
# assert we didn't mutate the original rng
2225
keys_to_check = set(rng.get_value().keys()) - {"pytorch_gen"}
2326
torch_rand = {k: rng.get_value()[k] for k in keys_to_check}
2427

2528
assert all(
2629
torch_rand[k] == original.bit_generator.state[k]
27-
if not isinstance(original.bit_generator.state, np.ndarray)
30+
if not isinstance(original.bit_generator.state[k], np.ndarray)
2831
else np.array_equal(torch_rand[k], original.bit_generator.state[k])
2932
for k in keys_to_check
3033
)

0 commit comments

Comments
 (0)