Skip to content

Commit a920c09

Browse files
committed
Update RNG in numba Dirichlet test
1 parent 75789de commit a920c09

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

tests/link/numba/test_random.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -652,15 +652,11 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
652652
def test_DirichletRV(a, size, cm):
653653
a, a_val = a
654654
rng = shared(np.random.default_rng(29402))
655-
g = ptr.dirichlet(a, size=size, rng=rng)
656-
g_fn = function([a], g, mode=numba_mode)
655+
next_rng, g = ptr.dirichlet(a, size=size, rng=rng).owner.outputs
656+
g_fn = function([a], g, mode=numba_mode, updates={rng: next_rng})
657657

658658
with cm:
659-
all_samples = []
660-
for i in range(1000):
661-
samples = g_fn(a_val)
662-
all_samples.append(samples)
663-
659+
all_samples = [g_fn(a_val) for _ in range(1000)]
664660
exp_res = a_val / a_val.sum(-1)
665661
res = np.mean(all_samples, axis=tuple(range(0, a_val.ndim - 1)))
666662
assert np.allclose(res, exp_res, atol=1e-4)

0 commit comments

Comments
 (0)