Skip to content

Commit fcd643a

Browse files
twaclawCh0ronomato
authored andcommitted
Proposal to infer Torch's generator state from the Numpy one
1 parent 69d62c7 commit fcd643a

File tree

2 files changed

+71
-27
lines changed

2 files changed

+71
-27
lines changed

pytensor/link/pytorch/dispatch/random.py

+23-24
Original file line numberDiff line numberDiff line change
@@ -4,40 +4,27 @@
44
from numpy.random import Generator
55

66
import pytensor.tensor.random.basic as ptr
7-
from pytensor.graph import Constant
87
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify
9-
from pytensor.tensor.type_other import NoneTypeT
108

119

1210
@pytorch_typify.register(Generator)
1311
def pytorch_typify_Generator(rng, **kwargs):
12+
# XXX: Check if there is a better way.
13+
# Numpy uses PCG64 while Torch uses Mersenne-Twister (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/CPUGeneratorImpl.cpp)
1414
state = rng.__getstate__()
15-
state["pytorch_state"] = torch.manual_seed(123).get_state() # XXX: replace
15+
seed = torch.from_numpy(rng.integers([2**32]))
16+
state["pytorch_gen"] = torch.manual_seed(seed)
1617
return state
1718

1819

1920
@pytorch_funcify.register(ptr.RandomVariable)
2021
def torch_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
2122
rv = node.outputs[1]
2223
out_dtype = rv.type.dtype
23-
static_shape = rv.type.shape
24-
batch_ndim = op.batch_ndim(node)
25-
26-
# Try to pass static size directly to JAX
27-
static_size = static_shape[:batch_ndim]
28-
if None in static_size:
29-
# Sometimes size can be constant folded during rewrites,
30-
# without the RandomVariable node being updated with new static types
31-
size_param = op.size_param(node)
32-
if isinstance(size_param, Constant) and not isinstance(
33-
size_param.type, NoneTypeT
34-
):
35-
static_size = tuple(size_param.data)
24+
shape = rv.type.shape
3625

3726
def sample_fn(rng, size, *parameters):
38-
return pytorch_sample_fn(op, node=node)(
39-
rng, static_size, out_dtype, *parameters
40-
)
27+
return pytorch_sample_fn(op, node=node)(rng, shape, out_dtype, *parameters)
4128

4229
return sample_fn
4330

@@ -53,10 +40,22 @@ def pytorch_sample_fn(op, node):
5340
@pytorch_sample_fn.register(ptr.BernoulliRV)
5441
def pytorch_sample_fn_bernoulli(op, node):
5542
def sample_fn(rng, size, dtype, p):
56-
# XXX replace
57-
state_ = rng["pytorch_state"]
58-
gen = torch.Generator().set_state(state_)
59-
sample = torch.bernoulli(torch.expand_copy(p, size), generator=gen)
60-
return (rng, sample)
43+
gen = rng["pytorch_gen"]
44+
sample = torch.bernoulli(torch.broadcast_to(p, size), generator=gen)
45+
return (gen, sample)
46+
47+
return sample_fn
48+
49+
50+
@pytorch_sample_fn.register(ptr.BinomialRV)
51+
def pytorch_sample_fn_binomial(op, node):
52+
def sample_fn(rng, size, dtype, n, p):
53+
gen = rng["pytorch_gen"]
54+
sample = torch.binomial(
55+
torch.broadcast_to(n.to(p.dtype), size),
56+
torch.broadcast_to(p, size),
57+
generator=gen,
58+
)
59+
return (gen, sample)
6160

6261
return sample_fn

tests/link/pytorch/test_random.py

+48-3
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,56 @@
1010
torch = pytest.importorskip("torch")
1111

1212

13-
@pytest.mark.parametrize("size", [(), (4,)])
14-
def test_random_bernoulli(size):
13+
@pytest.mark.parametrize(
14+
"size,p",
15+
[
16+
((1000,), 0.5),
17+
(
18+
(
19+
1000,
20+
4,
21+
),
22+
0.5,
23+
),
24+
((10, 2), np.array([0.5, 0.3])),
25+
((1000, 10, 2), np.array([0.5, 0.3])),
26+
],
27+
)
28+
def test_random_bernoulli(size, p):
1529
rng = shared(np.random.default_rng(123))
1630

17-
g = pt.random.bernoulli(0.5, size=(1000, *size), rng=rng)
31+
g = pt.random.bernoulli(p, size=size, rng=rng)
1832
g_fn = function([], g, mode=pytorch_mode)
1933
samples = g_fn()
2034
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)
35+
36+
37+
@pytest.mark.parametrize(
38+
"size,n,p",
39+
[
40+
((1000,), 10, 0.5),
41+
(
42+
(
43+
1000,
44+
4,
45+
),
46+
10,
47+
0.5,
48+
),
49+
(
50+
(
51+
1000,
52+
2,
53+
),
54+
np.array([10, 40]),
55+
np.array([0.5, 0.3]),
56+
),
57+
],
58+
)
59+
def test_binomial(n, p, size):
60+
rng = shared(np.random.default_rng(123))
61+
g = pt.random.binomial(n, p, size=size, rng=rng)
62+
g_fn = function([], g, mode=pytorch_mode)
63+
samples = g_fn()
64+
np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1)
65+
np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1)

0 commit comments

Comments
 (0)