Skip to content

Commit 6ed5349

Browse files
committed
Default to JAX test mode in random tests
1 parent b248eba commit 6ed5349

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

tests/link/jax/test_random.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402
2828

2929

30-
def compile_random_function(*args, mode="JAX", **kwargs):
30+
def compile_random_function(*args, mode=jax_mode, **kwargs):
3131
with pytest.warns(
3232
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
3333
):
@@ -42,7 +42,7 @@ def test_random_RandomStream():
4242
srng = RandomStream(seed=123)
4343
out = srng.normal() - srng.normal()
4444

45-
fn = compile_random_function([], out, mode=jax_mode)
45+
fn = compile_random_function([], out)
4646
jax_res_1 = fn()
4747
jax_res_2 = fn()
4848

@@ -55,7 +55,7 @@ def test_random_updates(rng_ctor):
5555
rng = shared(original_value, name="original_rng", borrow=False)
5656
next_rng, x = pt.random.normal(name="x", rng=rng).owner.outputs
5757

58-
f = compile_random_function([], [x], updates={rng: next_rng}, mode=jax_mode)
58+
f = compile_random_function([], [x], updates={rng: next_rng})
5959
assert f() != f()
6060

6161
# Check that original rng variable content was not overwritten when calling jax_typify
@@ -475,7 +475,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
475475
"""
476476
rng = shared(np.random.default_rng(29403))
477477
g = rv_op(*dist_params, size=(10000, *base_size), rng=rng)
478-
g_fn = compile_random_function(dist_params, g, mode=jax_mode)
478+
g_fn = compile_random_function(dist_params, g)
479479
samples = g_fn(
480480
*[
481481
i.tag.test_value
@@ -517,7 +517,7 @@ def test_size_implied_by_broadcasted_parameters(rv_fn):
517517
param_that_implies_size = pt.matrix("param_that_implies_size", shape=(None, None))
518518

519519
rv = rv_fn(param_that_implies_size)
520-
draws = rv.eval({param_that_implies_size: np.zeros((2, 2))}, mode=jax_mode)
520+
draws = rv.eval({param_that_implies_size: np.zeros((2, 2))})
521521

522522
assert draws.shape == (2, 2)
523523
assert np.unique(draws).size == 4
@@ -527,7 +527,7 @@ def test_size_implied_by_broadcasted_parameters(rv_fn):
527527
def test_random_bernoulli(size):
528528
rng = shared(np.random.default_rng(123))
529529
g = pt.random.bernoulli(0.5, size=(1000, *size), rng=rng)
530-
g_fn = compile_random_function([], g, mode=jax_mode)
530+
g_fn = compile_random_function([], g)
531531
samples = g_fn()
532532
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)
533533

@@ -538,7 +538,7 @@ def test_random_mvnormal():
538538
mu = np.ones(4)
539539
cov = np.eye(4)
540540
g = pt.random.multivariate_normal(mu, cov, size=(10000,), rng=rng)
541-
g_fn = compile_random_function([], g, mode=jax_mode)
541+
g_fn = compile_random_function([], g)
542542
samples = g_fn()
543543
np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1)
544544

@@ -553,7 +553,7 @@ def test_random_mvnormal():
553553
def test_random_dirichlet(parameter, size):
554554
rng = shared(np.random.default_rng(123))
555555
g = pt.random.dirichlet(parameter, size=(1000, *size), rng=rng)
556-
g_fn = compile_random_function([], g, mode=jax_mode)
556+
g_fn = compile_random_function([], g)
557557
samples = g_fn()
558558
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)
559559

@@ -562,7 +562,7 @@ def test_random_choice():
562562
# `replace=True` and `p is None`
563563
rng = shared(np.random.default_rng(123))
564564
g = pt.random.choice(np.arange(4), size=10_000, rng=rng)
565-
g_fn = compile_random_function([], g, mode=jax_mode)
565+
g_fn = compile_random_function([], g)
566566
samples = g_fn()
567567
assert samples.shape == (10_000,)
568568
# Elements are picked at equal frequency
@@ -571,7 +571,7 @@ def test_random_choice():
571571
# `replace=True` and `p is not None`
572572
rng = shared(np.random.default_rng(123))
573573
g = pt.random.choice(4, p=np.array([0.0, 0.5, 0.0, 0.5]), size=(5, 2), rng=rng)
574-
g_fn = compile_random_function([], g, mode=jax_mode)
574+
g_fn = compile_random_function([], g)
575575
samples = g_fn()
576576
assert samples.shape == (5, 2)
577577
# Only odd numbers are picked
@@ -580,7 +580,7 @@ def test_random_choice():
580580
# `replace=False` and `p is None`
581581
rng = shared(np.random.default_rng(123))
582582
g = pt.random.choice(np.arange(100), replace=False, size=(2, 49), rng=rng)
583-
g_fn = compile_random_function([], g, mode=jax_mode)
583+
g_fn = compile_random_function([], g)
584584
samples = g_fn()
585585
assert samples.shape == (2, 49)
586586
# Elements are unique
@@ -595,7 +595,7 @@ def test_random_choice():
595595
rng=rng,
596596
replace=False,
597597
)
598-
g_fn = compile_random_function([], g, mode=jax_mode)
598+
g_fn = compile_random_function([], g)
599599
samples = g_fn()
600600
assert samples.shape == (3,)
601601
# Elements are unique
@@ -607,14 +607,14 @@ def test_random_choice():
607607
def test_random_categorical():
608608
rng = shared(np.random.default_rng(123))
609609
g = pt.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng)
610-
g_fn = compile_random_function([], g, mode=jax_mode)
610+
g_fn = compile_random_function([], g)
611611
samples = g_fn()
612612
assert samples.shape == (10000, 4)
613613
np.testing.assert_allclose(samples.mean(axis=0), 6 / 4, 1)
614614

615615
# Test zero probabilities
616616
g = pt.random.categorical([0, 0.5, 0, 0.5], size=(1000,), rng=rng)
617-
g_fn = compile_random_function([], g, mode=jax_mode)
617+
g_fn = compile_random_function([], g)
618618
samples = g_fn()
619619
assert samples.shape == (1000,)
620620
assert np.all(samples % 2 == 1)
@@ -624,7 +624,7 @@ def test_random_permutation():
624624
array = np.arange(4)
625625
rng = shared(np.random.default_rng(123))
626626
g = pt.random.permutation(array, rng=rng)
627-
g_fn = compile_random_function([], g, mode=jax_mode)
627+
g_fn = compile_random_function([], g)
628628
permuted = g_fn()
629629
with pytest.raises(AssertionError):
630630
np.testing.assert_allclose(array, permuted)
@@ -647,7 +647,7 @@ def test_random_geometric():
647647
rng = shared(np.random.default_rng(123))
648648
p = np.array([0.3, 0.7])
649649
g = pt.random.geometric(p, size=(10_000, 2), rng=rng)
650-
g_fn = compile_random_function([], g, mode=jax_mode)
650+
g_fn = compile_random_function([], g)
651651
samples = g_fn()
652652
np.testing.assert_allclose(samples.mean(axis=0), 1 / p, rtol=0.1)
653653
np.testing.assert_allclose(samples.std(axis=0), np.sqrt((1 - p) / p**2), rtol=0.1)
@@ -658,7 +658,7 @@ def test_negative_binomial():
658658
n = np.array([10, 40])
659659
p = np.array([0.3, 0.7])
660660
g = pt.random.negative_binomial(n, p, size=(10_000, 2), rng=rng)
661-
g_fn = compile_random_function([], g, mode=jax_mode)
661+
g_fn = compile_random_function([], g)
662662
samples = g_fn()
663663
np.testing.assert_allclose(samples.mean(axis=0), n * (1 - p) / p, rtol=0.1)
664664
np.testing.assert_allclose(
@@ -672,7 +672,7 @@ def test_binomial():
672672
n = np.array([10, 40])
673673
p = np.array([0.3, 0.7])
674674
g = pt.random.binomial(n, p, size=(10_000, 2), rng=rng)
675-
g_fn = compile_random_function([], g, mode=jax_mode)
675+
g_fn = compile_random_function([], g)
676676
samples = g_fn()
677677
np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1)
678678
np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1)
@@ -687,7 +687,7 @@ def test_beta_binomial():
687687
a = np.array([1.5, 13])
688688
b = np.array([0.5, 9])
689689
g = pt.random.betabinom(n, a, b, size=(10_000, 2), rng=rng)
690-
g_fn = compile_random_function([], g, mode=jax_mode)
690+
g_fn = compile_random_function([], g)
691691
samples = g_fn()
692692
np.testing.assert_allclose(samples.mean(axis=0), n * a / (a + b), rtol=0.1)
693693
np.testing.assert_allclose(
@@ -721,7 +721,7 @@ def test_vonmises_mu_outside_circle():
721721
mu = np.array([-30, 40])
722722
kappa = np.array([100, 10])
723723
g = pt.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng)
724-
g_fn = compile_random_function([], g, mode=jax_mode)
724+
g_fn = compile_random_function([], g)
725725
samples = g_fn()
726726
np.testing.assert_allclose(
727727
samples.mean(axis=0), (mu + np.pi) % (2.0 * np.pi) - np.pi, rtol=0.1
@@ -819,15 +819,15 @@ def test_random_concrete_shape():
819819
rng = shared(np.random.default_rng(123))
820820
x_pt = pt.dmatrix()
821821
out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng)
822-
jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
822+
jax_fn = compile_random_function([x_pt], out)
823823
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
824824

825825

826826
def test_random_concrete_shape_from_param():
827827
rng = shared(np.random.default_rng(123))
828828
x_pt = pt.dmatrix()
829829
out = pt.random.normal(x_pt, 1, rng=rng)
830-
jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
830+
jax_fn = compile_random_function([x_pt], out)
831831
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
832832

833833

@@ -846,7 +846,7 @@ def test_random_concrete_shape_subtensor():
846846
rng = shared(np.random.default_rng(123))
847847
x_pt = pt.dmatrix()
848848
out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng)
849-
jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
849+
jax_fn = compile_random_function([x_pt], out)
850850
assert jax_fn(np.ones((2, 3))).shape == (3,)
851851

852852

@@ -862,7 +862,7 @@ def test_random_concrete_shape_subtensor_tuple():
862862
rng = shared(np.random.default_rng(123))
863863
x_pt = pt.dmatrix()
864864
out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng)
865-
jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
865+
jax_fn = compile_random_function([x_pt], out)
866866
assert jax_fn(np.ones((2, 3))).shape == (2,)
867867

868868

@@ -873,7 +873,7 @@ def test_random_concrete_shape_graph_input():
873873
rng = shared(np.random.default_rng(123))
874874
size_pt = pt.scalar()
875875
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
876-
jax_fn = compile_random_function([size_pt], out, mode=jax_mode)
876+
jax_fn = compile_random_function([size_pt], out)
877877
assert jax_fn(10).shape == (10,)
878878

879879

0 commit comments

Comments
 (0)