5
5
from pytensor .configdefaults import config
6
6
from pytensor .graph .fg import FunctionGraph
7
7
from pytensor .tensor import subtensor as pt_subtensor
8
+ from pytensor .tensor import tensor
8
9
from pytensor .tensor .rewriting .jax import (
9
10
boolean_indexing_set_or_inc ,
10
11
boolean_indexing_sum ,
13
14
14
15
15
16
def test_jax_Subtensor_constant ():
17
+ shape = (3 , 4 , 5 )
18
+ x_pt = tensor ("x" , shape = shape , dtype = "int" )
19
+ x_np = np .arange (np .prod (shape )).reshape (shape )
20
+
16
21
# Basic indices
17
- x_pt = pt .as_tensor (np .arange (3 * 4 * 5 ).reshape ((3 , 4 , 5 )))
18
22
out_pt = x_pt [1 , 2 , 0 ]
19
23
assert isinstance (out_pt .owner .op , pt_subtensor .Subtensor )
20
- out_fg = FunctionGraph ([], [out_pt ])
21
- compare_jax_and_py (out_fg , [])
24
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
25
+ compare_jax_and_py (out_fg , [x_np ])
22
26
23
27
out_pt = x_pt [1 :, 1 , :]
24
28
assert isinstance (out_pt .owner .op , pt_subtensor .Subtensor )
25
- out_fg = FunctionGraph ([], [out_pt ])
26
- compare_jax_and_py (out_fg , [])
29
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
30
+ compare_jax_and_py (out_fg , [x_np ])
27
31
28
32
out_pt = x_pt [:2 , 1 , :]
29
33
assert isinstance (out_pt .owner .op , pt_subtensor .Subtensor )
30
- out_fg = FunctionGraph ([], [out_pt ])
31
- compare_jax_and_py (out_fg , [])
34
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
35
+ compare_jax_and_py (out_fg , [x_np ])
32
36
33
37
out_pt = x_pt [1 :2 , 1 , :]
34
38
assert isinstance (out_pt .owner .op , pt_subtensor .Subtensor )
35
- out_fg = FunctionGraph ([], [out_pt ])
36
- compare_jax_and_py (out_fg , [])
39
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
40
+ compare_jax_and_py (out_fg , [x_np ])
37
41
38
42
# Advanced indexing
39
43
out_pt = pt_subtensor .advanced_subtensor1 (x_pt , [1 , 2 ])
40
44
assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedSubtensor1 )
41
- out_fg = FunctionGraph ([], [out_pt ])
42
- compare_jax_and_py (out_fg , [])
45
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
46
+ compare_jax_and_py (out_fg , [x_np ])
43
47
44
48
out_pt = x_pt [[1 , 2 ], [2 , 3 ]]
45
49
assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedSubtensor )
46
- out_fg = FunctionGraph ([], [out_pt ])
47
- compare_jax_and_py (out_fg , [])
50
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
51
+ compare_jax_and_py (out_fg , [x_np ])
48
52
49
53
# Advanced and basic indexing
50
54
out_pt = x_pt [[1 , 2 ], :]
51
55
assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedSubtensor )
52
- out_fg = FunctionGraph ([], [out_pt ])
53
- compare_jax_and_py (out_fg , [])
56
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
57
+ compare_jax_and_py (out_fg , [x_np ])
54
58
55
59
out_pt = x_pt [[1 , 2 ], :, [3 , 4 ]]
56
60
assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedSubtensor )
57
- out_fg = FunctionGraph ([], [out_pt ])
58
- compare_jax_and_py (out_fg , [])
61
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
62
+ compare_jax_and_py (out_fg , [x_np ])
59
63
60
64
# Flipping
61
65
out_pt = x_pt [::- 1 ]
62
- out_fg = FunctionGraph ([], [out_pt ])
63
- compare_jax_and_py (out_fg , [])
66
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
67
+ compare_jax_and_py (out_fg , [x_np ])
68
+
69
+ # Boolean indexing should work if indexes are constant
70
+ out_pt = x_pt [np .random .binomial (1 , 0.5 , size = (3 , 4 , 5 ))]
71
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
72
+ compare_jax_and_py (out_fg , [x_np ])
64
73
65
74
66
75
@pytest .mark .xfail (reason = "`a` should be specified as static when JIT-compiling" )
@@ -73,16 +82,18 @@ def test_jax_Subtensor_dynamic():
73
82
compare_jax_and_py (out_fg , [1 ])
74
83
75
84
76
- def test_jax_Subtensor_boolean_mask ():
77
- """JAX does not support resizing arrays with boolean masks."""
85
+ def test_jax_Subtensor_dynamic_boolean_mask ():
86
+ """JAX does not support resizing arrays with dynamic boolean masks."""
87
+ from jax .errors import NonConcreteBooleanIndexError
88
+
78
89
x_pt = pt .vector ("x" , dtype = "float64" )
79
90
out_pt = x_pt [x_pt < 0 ]
80
91
assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedSubtensor )
81
92
82
93
out_fg = FunctionGraph ([x_pt ], [out_pt ])
83
94
84
95
x_pt_test = np .arange (- 5 , 5 )
85
- with pytest .raises (NotImplementedError , match = "resizing arrays with boolean" ):
96
+ with pytest .raises (NonConcreteBooleanIndexError ):
86
97
compare_jax_and_py (out_fg , [x_pt_test ])
87
98
88
99
0 commit comments