27
27
from pytensor .link .jax .dispatch .random import numpyro_available # noqa: E402
28
28
29
29
30
- def compile_random_function (* args , mode = "JAX" , ** kwargs ):
30
+ def compile_random_function (* args , mode = jax_mode , ** kwargs ):
31
31
with pytest .warns (
32
32
UserWarning , match = r"The RandomType SharedVariables \[.+\] will not be used"
33
33
):
@@ -42,7 +42,7 @@ def test_random_RandomStream():
42
42
srng = RandomStream (seed = 123 )
43
43
out = srng .normal () - srng .normal ()
44
44
45
- fn = compile_random_function ([], out , mode = jax_mode )
45
+ fn = compile_random_function ([], out )
46
46
jax_res_1 = fn ()
47
47
jax_res_2 = fn ()
48
48
@@ -55,7 +55,7 @@ def test_random_updates(rng_ctor):
55
55
rng = shared (original_value , name = "original_rng" , borrow = False )
56
56
next_rng , x = pt .random .normal (name = "x" , rng = rng ).owner .outputs
57
57
58
- f = compile_random_function ([], [x ], updates = {rng : next_rng }, mode = jax_mode )
58
+ f = compile_random_function ([], [x ], updates = {rng : next_rng })
59
59
assert f () != f ()
60
60
61
61
# Check that original rng variable content was not overwritten when calling jax_typify
@@ -475,7 +475,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
475
475
"""
476
476
rng = shared (np .random .default_rng (29403 ))
477
477
g = rv_op (* dist_params , size = (10000 , * base_size ), rng = rng )
478
- g_fn = compile_random_function (dist_params , g , mode = jax_mode )
478
+ g_fn = compile_random_function (dist_params , g )
479
479
samples = g_fn (
480
480
* [
481
481
i .tag .test_value
@@ -517,7 +517,7 @@ def test_size_implied_by_broadcasted_parameters(rv_fn):
517
517
param_that_implies_size = pt .matrix ("param_that_implies_size" , shape = (None , None ))
518
518
519
519
rv = rv_fn (param_that_implies_size )
520
- draws = rv .eval ({param_that_implies_size : np .zeros ((2 , 2 ))}, mode = jax_mode )
520
+ draws = rv .eval ({param_that_implies_size : np .zeros ((2 , 2 ))})
521
521
522
522
assert draws .shape == (2 , 2 )
523
523
assert np .unique (draws ).size == 4
@@ -527,7 +527,7 @@ def test_size_implied_by_broadcasted_parameters(rv_fn):
527
527
def test_random_bernoulli (size ):
528
528
rng = shared (np .random .default_rng (123 ))
529
529
g = pt .random .bernoulli (0.5 , size = (1000 , * size ), rng = rng )
530
- g_fn = compile_random_function ([], g , mode = jax_mode )
530
+ g_fn = compile_random_function ([], g )
531
531
samples = g_fn ()
532
532
np .testing .assert_allclose (samples .mean (axis = 0 ), 0.5 , 1 )
533
533
@@ -538,7 +538,7 @@ def test_random_mvnormal():
538
538
mu = np .ones (4 )
539
539
cov = np .eye (4 )
540
540
g = pt .random .multivariate_normal (mu , cov , size = (10000 ,), rng = rng )
541
- g_fn = compile_random_function ([], g , mode = jax_mode )
541
+ g_fn = compile_random_function ([], g )
542
542
samples = g_fn ()
543
543
np .testing .assert_allclose (samples .mean (axis = 0 ), mu , atol = 0.1 )
544
544
@@ -553,7 +553,7 @@ def test_random_mvnormal():
553
553
def test_random_dirichlet (parameter , size ):
554
554
rng = shared (np .random .default_rng (123 ))
555
555
g = pt .random .dirichlet (parameter , size = (1000 , * size ), rng = rng )
556
- g_fn = compile_random_function ([], g , mode = jax_mode )
556
+ g_fn = compile_random_function ([], g )
557
557
samples = g_fn ()
558
558
np .testing .assert_allclose (samples .mean (axis = 0 ), 0.5 , 1 )
559
559
@@ -562,7 +562,7 @@ def test_random_choice():
562
562
# `replace=True` and `p is None`
563
563
rng = shared (np .random .default_rng (123 ))
564
564
g = pt .random .choice (np .arange (4 ), size = 10_000 , rng = rng )
565
- g_fn = compile_random_function ([], g , mode = jax_mode )
565
+ g_fn = compile_random_function ([], g )
566
566
samples = g_fn ()
567
567
assert samples .shape == (10_000 ,)
568
568
# Elements are picked at equal frequency
@@ -571,7 +571,7 @@ def test_random_choice():
571
571
# `replace=True` and `p is not None`
572
572
rng = shared (np .random .default_rng (123 ))
573
573
g = pt .random .choice (4 , p = np .array ([0.0 , 0.5 , 0.0 , 0.5 ]), size = (5 , 2 ), rng = rng )
574
- g_fn = compile_random_function ([], g , mode = jax_mode )
574
+ g_fn = compile_random_function ([], g )
575
575
samples = g_fn ()
576
576
assert samples .shape == (5 , 2 )
577
577
# Only odd numbers are picked
@@ -580,7 +580,7 @@ def test_random_choice():
580
580
# `replace=False` and `p is None`
581
581
rng = shared (np .random .default_rng (123 ))
582
582
g = pt .random .choice (np .arange (100 ), replace = False , size = (2 , 49 ), rng = rng )
583
- g_fn = compile_random_function ([], g , mode = jax_mode )
583
+ g_fn = compile_random_function ([], g )
584
584
samples = g_fn ()
585
585
assert samples .shape == (2 , 49 )
586
586
# Elements are unique
@@ -595,7 +595,7 @@ def test_random_choice():
595
595
rng = rng ,
596
596
replace = False ,
597
597
)
598
- g_fn = compile_random_function ([], g , mode = jax_mode )
598
+ g_fn = compile_random_function ([], g )
599
599
samples = g_fn ()
600
600
assert samples .shape == (3 ,)
601
601
# Elements are unique
@@ -607,14 +607,14 @@ def test_random_choice():
607
607
def test_random_categorical ():
608
608
rng = shared (np .random .default_rng (123 ))
609
609
g = pt .random .categorical (0.25 * np .ones (4 ), size = (10000 , 4 ), rng = rng )
610
- g_fn = compile_random_function ([], g , mode = jax_mode )
610
+ g_fn = compile_random_function ([], g )
611
611
samples = g_fn ()
612
612
assert samples .shape == (10000 , 4 )
613
613
np .testing .assert_allclose (samples .mean (axis = 0 ), 6 / 4 , 1 )
614
614
615
615
# Test zero probabilities
616
616
g = pt .random .categorical ([0 , 0.5 , 0 , 0.5 ], size = (1000 ,), rng = rng )
617
- g_fn = compile_random_function ([], g , mode = jax_mode )
617
+ g_fn = compile_random_function ([], g )
618
618
samples = g_fn ()
619
619
assert samples .shape == (1000 ,)
620
620
assert np .all (samples % 2 == 1 )
@@ -624,7 +624,7 @@ def test_random_permutation():
624
624
array = np .arange (4 )
625
625
rng = shared (np .random .default_rng (123 ))
626
626
g = pt .random .permutation (array , rng = rng )
627
- g_fn = compile_random_function ([], g , mode = jax_mode )
627
+ g_fn = compile_random_function ([], g )
628
628
permuted = g_fn ()
629
629
with pytest .raises (AssertionError ):
630
630
np .testing .assert_allclose (array , permuted )
@@ -647,7 +647,7 @@ def test_random_geometric():
647
647
rng = shared (np .random .default_rng (123 ))
648
648
p = np .array ([0.3 , 0.7 ])
649
649
g = pt .random .geometric (p , size = (10_000 , 2 ), rng = rng )
650
- g_fn = compile_random_function ([], g , mode = jax_mode )
650
+ g_fn = compile_random_function ([], g )
651
651
samples = g_fn ()
652
652
np .testing .assert_allclose (samples .mean (axis = 0 ), 1 / p , rtol = 0.1 )
653
653
np .testing .assert_allclose (samples .std (axis = 0 ), np .sqrt ((1 - p ) / p ** 2 ), rtol = 0.1 )
@@ -658,7 +658,7 @@ def test_negative_binomial():
658
658
n = np .array ([10 , 40 ])
659
659
p = np .array ([0.3 , 0.7 ])
660
660
g = pt .random .negative_binomial (n , p , size = (10_000 , 2 ), rng = rng )
661
- g_fn = compile_random_function ([], g , mode = jax_mode )
661
+ g_fn = compile_random_function ([], g )
662
662
samples = g_fn ()
663
663
np .testing .assert_allclose (samples .mean (axis = 0 ), n * (1 - p ) / p , rtol = 0.1 )
664
664
np .testing .assert_allclose (
@@ -672,7 +672,7 @@ def test_binomial():
672
672
n = np .array ([10 , 40 ])
673
673
p = np .array ([0.3 , 0.7 ])
674
674
g = pt .random .binomial (n , p , size = (10_000 , 2 ), rng = rng )
675
- g_fn = compile_random_function ([], g , mode = jax_mode )
675
+ g_fn = compile_random_function ([], g )
676
676
samples = g_fn ()
677
677
np .testing .assert_allclose (samples .mean (axis = 0 ), n * p , rtol = 0.1 )
678
678
np .testing .assert_allclose (samples .std (axis = 0 ), np .sqrt (n * p * (1 - p )), rtol = 0.1 )
@@ -687,7 +687,7 @@ def test_beta_binomial():
687
687
a = np .array ([1.5 , 13 ])
688
688
b = np .array ([0.5 , 9 ])
689
689
g = pt .random .betabinom (n , a , b , size = (10_000 , 2 ), rng = rng )
690
- g_fn = compile_random_function ([], g , mode = jax_mode )
690
+ g_fn = compile_random_function ([], g )
691
691
samples = g_fn ()
692
692
np .testing .assert_allclose (samples .mean (axis = 0 ), n * a / (a + b ), rtol = 0.1 )
693
693
np .testing .assert_allclose (
@@ -721,7 +721,7 @@ def test_vonmises_mu_outside_circle():
721
721
mu = np .array ([- 30 , 40 ])
722
722
kappa = np .array ([100 , 10 ])
723
723
g = pt .random .vonmises (mu , kappa , size = (10_000 , 2 ), rng = rng )
724
- g_fn = compile_random_function ([], g , mode = jax_mode )
724
+ g_fn = compile_random_function ([], g )
725
725
samples = g_fn ()
726
726
np .testing .assert_allclose (
727
727
samples .mean (axis = 0 ), (mu + np .pi ) % (2.0 * np .pi ) - np .pi , rtol = 0.1
@@ -819,15 +819,15 @@ def test_random_concrete_shape():
819
819
rng = shared (np .random .default_rng (123 ))
820
820
x_pt = pt .dmatrix ()
821
821
out = pt .random .normal (0 , 1 , size = x_pt .shape , rng = rng )
822
- jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
822
+ jax_fn = compile_random_function ([x_pt ], out )
823
823
assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
824
824
825
825
826
826
def test_random_concrete_shape_from_param ():
827
827
rng = shared (np .random .default_rng (123 ))
828
828
x_pt = pt .dmatrix ()
829
829
out = pt .random .normal (x_pt , 1 , rng = rng )
830
- jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
830
+ jax_fn = compile_random_function ([x_pt ], out )
831
831
assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
832
832
833
833
@@ -846,7 +846,7 @@ def test_random_concrete_shape_subtensor():
846
846
rng = shared (np .random .default_rng (123 ))
847
847
x_pt = pt .dmatrix ()
848
848
out = pt .random .normal (0 , 1 , size = x_pt .shape [1 ], rng = rng )
849
- jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
849
+ jax_fn = compile_random_function ([x_pt ], out )
850
850
assert jax_fn (np .ones ((2 , 3 ))).shape == (3 ,)
851
851
852
852
@@ -862,7 +862,7 @@ def test_random_concrete_shape_subtensor_tuple():
862
862
rng = shared (np .random .default_rng (123 ))
863
863
x_pt = pt .dmatrix ()
864
864
out = pt .random .normal (0 , 1 , size = (x_pt .shape [0 ],), rng = rng )
865
- jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
865
+ jax_fn = compile_random_function ([x_pt ], out )
866
866
assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
867
867
868
868
@@ -873,7 +873,7 @@ def test_random_concrete_shape_graph_input():
873
873
rng = shared (np .random .default_rng (123 ))
874
874
size_pt = pt .scalar ()
875
875
out = pt .random .normal (0 , 1 , size = size_pt , rng = rng )
876
- jax_fn = compile_random_function ([size_pt ], out , mode = jax_mode )
876
+ jax_fn = compile_random_function ([size_pt ], out )
877
877
assert jax_fn (10 ).shape == (10 ,)
878
878
879
879
0 commit comments