Skip to content

Commit 4b4f8d0

Browse files
Ch0ronomatoischweer
authored andcommitted
Ensure semantics of rng immutability
1 parent 288d2c4 commit 4b4f8d0

File tree

2 files changed

+49
-35
lines changed

2 files changed

+49
-35
lines changed
Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
from functools import singledispatch
22

3-
import numpy as np
3+
import numpy.random
44
import torch
5-
from numpy.random import Generator
65

76
import pytensor.tensor.random.basic as ptr
87
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify
98

109

11-
@pytorch_typify.register(Generator)
10+
@pytorch_typify.register(numpy.random.Generator)
1211
def pytorch_typify_Generator(rng, **kwargs):
1312
# XXX: Check if there is a better way.
1413
# Numpy uses PCG64 while Torch uses Mersenne-Twister (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/CPUGeneratorImpl.cpp)
15-
state = rng.__getstate__()
16-
rng_copy = np.random.default_rng()
17-
rng_copy.bit_generator.state = rng.bit_generator.state
18-
seed = torch.from_numpy(rng_copy.integers([2**32]))
19-
state["pytorch_gen"] = torch.manual_seed(seed)
20-
return state
14+
seed = torch.from_numpy(rng.integers([2**32]))
15+
return torch.manual_seed(seed)
16+
17+
18+
@pytorch_typify.register(torch._C.Generator)
19+
def pytorch_typify_pass_generator(rng, **kwargs):
20+
return rng
2121

2222

2323
@pytorch_funcify.register(ptr.RandomVariable)
@@ -28,7 +28,9 @@ def torch_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
2828
rv_sample = pytorch_sample_fn(op, node=node)
2929

3030
def sample_fn(rng, size, *args):
31-
return rv_sample(rng, shape, out_dtype, *args)
31+
new_rng = torch.Generator(device="cpu")
32+
new_rng.set_state(rng.get_state().clone())
33+
return rv_sample(new_rng, shape, out_dtype, *args)
3234

3335
return sample_fn
3436

@@ -43,25 +45,31 @@ def pytorch_sample_fn(op, node):
4345

4446
@pytorch_sample_fn.register(ptr.BernoulliRV)
4547
def pytorch_sample_fn_bernoulli(op, node):
46-
def sample_fn(rng, size, dtype, p):
47-
gen = rng["pytorch_gen"]
48+
def sample_fn(gen, size, dtype, p):
4849
sample = torch.bernoulli(torch.broadcast_to(p, size), generator=gen)
49-
rng["pytorch_gen"] = gen
50-
return (rng, sample)
50+
return (gen, sample)
5151

5252
return sample_fn
5353

5454

5555
@pytorch_sample_fn.register(ptr.BinomialRV)
5656
def pytorch_sample_fn_binomial(op, node):
57-
def sample_fn(rng, size, dtype, n, p):
58-
gen = rng["pytorch_gen"]
57+
def sample_fn(gen, size, dtype, n, p):
5958
sample = torch.binomial(
60-
torch.broadcast_to(n, size),
61-
torch.broadcast_to(p, size),
59+
torch.broadcast_to(n, size).to(torch.float32),
60+
torch.broadcast_to(p, size).to(torch.float32),
6261
generator=gen,
6362
)
64-
rng["pytorch_gen"] = gen
65-
return (rng, sample)
63+
return (gen, sample)
64+
65+
return sample_fn
66+
67+
68+
@pytorch_sample_fn.register(ptr.UniformRV)
69+
def pytorch_sample_fn_uniform(op, node):
70+
def sample_fn(gen, size, dtype, low, high):
71+
sample = torch.FloatTensor(size)
72+
sample.uniform_(low.item(), high.item(), generator=gen)
73+
return (gen, sample)
6674

6775
return sample_fn

tests/link/pytorch/test_random.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,31 @@
44
import pytensor.tensor as pt
55
from pytensor.compile.function import function
66
from pytensor.compile.sharedvalue import shared
7+
from pytensor.link.pytorch.dispatch.basic import pytorch_typify
78
from tests.link.pytorch.test_basic import pytorch_mode
89

910

1011
torch = pytest.importorskip("torch")
1112

1213

13-
def test_random_updates():
14+
@pytest.mark.parametrize("update", [(True), (False)])
15+
def test_random_updates(update):
1416
original = np.random.default_rng(seed=123)
17+
original_torch = pytorch_typify(original)
1518
rng = shared(original, name="rng", borrow=False)
1619
rv = pt.random.bernoulli(0.5, name="y", rng=rng)
1720
next_rng, x = rv.owner.outputs
1821
x.dprint()
19-
f = function([], [x], updates={rng: next_rng}, mode="PYTORCH")
20-
assert any(f() for _ in range(5))
21-
22-
keys_to_check = set(rng.get_value().keys()) - {"pytorch_gen"}
23-
torch_rand = {k: rng.get_value()[k] for k in keys_to_check}
24-
25-
assert all(
26-
torch_rand[k] == original.bit_generator.state[k]
27-
if not isinstance(original.bit_generator.state, np.ndarray)
28-
else np.array_equal(torch_rand[k], original.bit_generator.state[k])
29-
for k in keys_to_check
30-
)
22+
f = function([], x, updates={rng: next_rng} if update else None, mode="PYTORCH")
23+
draws = np.stack([f() for _ in range(5)])
24+
# assert we are getting different values
25+
if update:
26+
assert draws.sum() < 5 and draws.sum() >= 1
27+
# assert we have a new rng
28+
rng_value = rng.get_value(borrow=True) # we can't copy torch generator
29+
assert torch.eq(rng_value.get_state(), original_torch.get_state())
30+
else:
31+
pass
3132

3233

3334
@pytest.mark.parametrize(
@@ -68,5 +69,10 @@ def test_binomial(size, n, p, update):
6869
samples = g_fn()
6970
if not update:
7071
np.testing.assert_allclose(samples, g_fn(), rtol=0.1)
71-
np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1)
72-
np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.2)
72+
np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1)
73+
np.testing.assert_allclose(
74+
samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.2
75+
)
76+
else:
77+
second_samples = g_fn()
78+
np.testing.assert_array_equal(second_samples, samples)

0 commit comments

Comments
 (0)