Skip to content

Commit 1f863f5

Browse files
committed
Update test_binomial
1 parent f93a20e commit 1f863f5

File tree

1 file changed

+14
-16
lines changed

1 file changed

+14
-16
lines changed

Diff for: tests/link/pytorch/test_random.py

+14-16
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,22 @@ def test_random_bernoulli(size, p):
5050

5151

5252
@pytest.mark.parametrize(
53-
"size,n,p",
53+
"size,n,p,update",
5454
[
55-
(None, 10, 0.5),
56-
((1000,), 10, 0.5),
57-
((1000, 4), 10, 0.5),
58-
((1000, 2), np.array([10, 40]), np.array([0.5, 0.3])),
55+
((1000,), 10, 0.5, False),
56+
((1000, 4), 10, 0.5, False),
57+
((1000, 2), np.array([10, 40]), np.array([0.5, 0.3]), True),
5958
],
6059
)
61-
def test_binomial(n, p, size):
60+
def test_binomial(size, n, p, update):
6261
rng = shared(np.random.default_rng(123))
63-
g = pt.random.binomial(n, p, size=size, rng=rng)
64-
g_fn = function([], g, mode=pytorch_mode)
62+
rv = pt.random.binomial(n, p, size=size, rng=rng)
63+
next_rng, *_ = rv.owner.inputs
64+
g_fn = function(
65+
[], rv, mode=pytorch_mode, updates={rng: next_rng} if update else None
66+
)
6567
samples = g_fn()
66-
if size:
67-
np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1)
68-
np.testing.assert_allclose(
69-
samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.2
70-
)
71-
else:
72-
...
73-
# TODO: define test
68+
if not update:
69+
np.testing.assert_allclose(samples, g_fn(), rtol=0.1)
70+
np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1)
71+
np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.2)

0 commit comments

Comments
 (0)