Skip to content

Commit 7eb77df

Browse files
twaclawCh0ronomato
authored andcommitted
- Added suport for size None
- Copied generator before sampling from it
1 parent fcd643a commit 7eb77df

File tree

2 files changed

+27
-28
lines changed

2 files changed

+27
-28
lines changed

pytensor/link/pytorch/dispatch/random.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from functools import singledispatch
22

3+
import numpy as np
34
import torch
45
from numpy.random import Generator
56

@@ -12,7 +13,9 @@ def pytorch_typify_Generator(rng, **kwargs):
1213
# XXX: Check if there is a better way.
1314
# Numpy uses PCG64 while Torch uses Mersenne-Twister (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/CPUGeneratorImpl.cpp)
1415
state = rng.__getstate__()
15-
seed = torch.from_numpy(rng.integers([2**32]))
16+
rng_copy = np.random.default_rng()
17+
rng_copy.bit_generator.state = rng.bit_generator.state
18+
seed = torch.from_numpy(rng_copy.integers([2**32]))
1619
state["pytorch_gen"] = torch.manual_seed(seed)
1720
return state
1821

@@ -41,8 +44,12 @@ def pytorch_sample_fn(op, node):
4144
def pytorch_sample_fn_bernoulli(op, node):
4245
def sample_fn(rng, size, dtype, p):
4346
gen = rng["pytorch_gen"]
47+
if not size:
48+
size = (1,)
49+
4450
sample = torch.bernoulli(torch.broadcast_to(p, size), generator=gen)
45-
return (gen, sample)
51+
rng["pytorch_gen"] = gen
52+
return (rng, sample)
4653

4754
return sample_fn
4855

@@ -51,11 +58,15 @@ def sample_fn(rng, size, dtype, p):
5158
def pytorch_sample_fn_binomial(op, node):
5259
def sample_fn(rng, size, dtype, n, p):
5360
gen = rng["pytorch_gen"]
61+
if not size:
62+
size = (1,)
63+
5464
sample = torch.binomial(
5565
torch.broadcast_to(n.to(p.dtype), size),
5666
torch.broadcast_to(p, size),
5767
generator=gen,
5868
)
59-
return (gen, sample)
69+
rng["pytorch_gen"] = gen
70+
return (rng, sample)
6071

6172
return sample_fn

tests/link/pytorch/test_random.py

+13-25
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,8 @@
1414
"size,p",
1515
[
1616
((1000,), 0.5),
17-
(
18-
(
19-
1000,
20-
4,
21-
),
22-
0.5,
23-
),
17+
(None, 0.5),
18+
((1000, 4), 0.5),
2419
((10, 2), np.array([0.5, 0.3])),
2520
((1000, 10, 2), np.array([0.5, 0.3])),
2621
],
@@ -37,29 +32,22 @@ def test_random_bernoulli(size, p):
3732
@pytest.mark.parametrize(
3833
"size,n,p",
3934
[
35+
(None, 10, 0.5),
4036
((1000,), 10, 0.5),
41-
(
42-
(
43-
1000,
44-
4,
45-
),
46-
10,
47-
0.5,
48-
),
49-
(
50-
(
51-
1000,
52-
2,
53-
),
54-
np.array([10, 40]),
55-
np.array([0.5, 0.3]),
56-
),
37+
((1000, 4), 10, 0.5),
38+
((1000, 2), np.array([10, 40]), np.array([0.5, 0.3])),
5739
],
5840
)
5941
def test_binomial(n, p, size):
6042
rng = shared(np.random.default_rng(123))
6143
g = pt.random.binomial(n, p, size=size, rng=rng)
6244
g_fn = function([], g, mode=pytorch_mode)
6345
samples = g_fn()
64-
np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1)
65-
np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1)
46+
if size:
47+
np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1)
48+
np.testing.assert_allclose(
49+
samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.2
50+
)
51+
else:
52+
...
53+
# TODO: define test

0 commit comments

Comments
 (0)