@@ -50,24 +50,22 @@ def test_random_bernoulli(size, p):
50
50
51
51
52
52
@pytest .mark .parametrize (
53
- "size,n,p" ,
53
+ "size,n,p,update " ,
54
54
[
55
- (None , 10 , 0.5 ),
56
- ((1000 ,), 10 , 0.5 ),
57
- ((1000 , 4 ), 10 , 0.5 ),
58
- ((1000 , 2 ), np .array ([10 , 40 ]), np .array ([0.5 , 0.3 ])),
55
+ ((1000 ,), 10 , 0.5 , False ),
56
+ ((1000 , 4 ), 10 , 0.5 , False ),
57
+ ((1000 , 2 ), np .array ([10 , 40 ]), np .array ([0.5 , 0.3 ]), True ),
59
58
],
60
59
)
61
- def test_binomial (n , p , size ):
60
+ def test_binomial (size , n , p , update ):
62
61
rng = shared (np .random .default_rng (123 ))
63
- g = pt .random .binomial (n , p , size = size , rng = rng )
64
- g_fn = function ([], g , mode = pytorch_mode )
62
+ rv = pt .random .binomial (n , p , size = size , rng = rng )
63
+ next_rng , * _ = rv .owner .inputs
64
+ g_fn = function (
65
+ [], rv , mode = pytorch_mode , updates = {rng : next_rng } if update else None
66
+ )
65
67
samples = g_fn ()
66
- if size :
67
- np .testing .assert_allclose (samples .mean (axis = 0 ), n * p , rtol = 0.1 )
68
- np .testing .assert_allclose (
69
- samples .std (axis = 0 ), np .sqrt (n * p * (1 - p )), rtol = 0.2
70
- )
71
- else :
72
- ...
73
- # TODO: define test
68
+ if not update :
69
+ np .testing .assert_allclose (samples , g_fn (), rtol = 0.1 )
70
+ np .testing .assert_allclose (samples .mean (axis = 0 ), n * p , rtol = 0.1 )
71
+ np .testing .assert_allclose (samples .std (axis = 0 ), np .sqrt (n * p * (1 - p )), rtol = 0.2 )
0 commit comments