@@ -863,15 +863,55 @@ def test_random_concrete_shape_subtensor_tuple(self):
863
863
jax_fn = compile_random_function ([x_pt ], out )
864
864
assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
865
865
866
+ def test_random_scalar_shape_input (self ):
867
+ dim0 = pt .scalar ("dim0" , dtype = int )
868
+ dim1 = pt .scalar ("dim1" , dtype = int )
869
+
870
+ out = pt .random .normal (0 , 1 , size = dim0 )
871
+ jax_fn = compile_random_function ([dim0 ], out )
872
+ assert jax_fn (np .array (2 )).shape == (2 ,)
873
+ assert jax_fn (np .array (3 )).shape == (3 ,)
874
+
875
+ out = pt .random .normal (0 , 1 , size = [dim0 , dim1 ])
876
+ jax_fn = compile_random_function ([dim0 , dim1 ], out )
877
+ assert jax_fn (np .array (2 ), np .array (3 )).shape == (2 , 3 )
878
+ assert jax_fn (np .array (4 ), np .array (5 )).shape == (4 , 5 )
879
+
866
880
@pytest .mark .xfail (
867
- reason = "`size_pt` should be specified as a static argument" , strict = True
881
+ raises = TypeError , reason = "Cannot convert scalar input to integer"
868
882
)
869
- def test_random_concrete_shape_graph_input (self ):
870
- rng = shared (np .random .default_rng (123 ))
871
- size_pt = pt .scalar ()
872
- out = pt .random .normal (0 , 1 , size = size_pt , rng = rng )
873
- jax_fn = compile_random_function ([size_pt ], out )
874
- assert jax_fn (10 ).shape == (10 ,)
883
+ def test_random_scalar_shape_input_not_supported (self ):
884
+ dim = pt .scalar ("dim" , dtype = int )
885
+ out1 = pt .random .normal (0 , 1 , size = dim )
886
+ # An operation that wouldn't work if we replaced 0d array by integer
887
+ out2 = dim [...].set (1 )
888
+ jax_fn = compile_random_function ([dim ], [out1 , out2 ])
889
+
890
+ res1 , res2 = jax_fn (np .array (2 ))
891
+ assert res1 .shape == (2 ,)
892
+ assert res2 == 1
893
+
894
+ @pytest .mark .xfail (
895
+ raises = TypeError , reason = "Cannot convert scalar input to integer"
896
+ )
897
+ def test_random_scalar_shape_input_not_supported2 (self ):
898
+ dim = pt .scalar ("dim" , dtype = int )
899
+ # This could theoretically be supported
900
+ # but would require knowing that * 2 is a safe operation for a python integer
901
+ out = pt .random .normal (0 , 1 , size = dim * 2 )
902
+ jax_fn = compile_random_function ([dim ], out )
903
+ assert jax_fn (np .array (2 )).shape == (4 ,)
904
+
905
+ @pytest .mark .xfail (
906
+ raises = TypeError , reason = "Cannot convert tensor input to shape tuple"
907
+ )
908
+ def test_random_vector_shape_graph_input (self ):
909
+ shape = pt .vector ("shape" , shape = (2 ,), dtype = int )
910
+ out = pt .random .normal (0 , 1 , size = shape )
911
+
912
+ jax_fn = compile_random_function ([shape ], out )
913
+ assert jax_fn (np .array ([2 , 3 ])).shape == (2 , 3 )
914
+ assert jax_fn (np .array ([4 , 5 ])).shape == (4 , 5 )
875
915
876
916
def test_constant_shape_after_graph_rewriting (self ):
877
917
size = pt .vector ("size" , shape = (2 ,), dtype = int )
0 commit comments