Skip to content

Commit 684a929

Browse files
ricardoV94lucianopaz
authored andcommitted
Remove false positive check for supported Subtensors operations in JAX
The check was failing incorrectly for cases that are supported such as constant Boolean arrays. Besides that, user may dispatch without necessarily jitting the graph. There is no reason to fail eagerly.
1 parent 17fa8b1 commit 684a929

File tree

2 files changed

+35
-40
lines changed

2 files changed

+35
-40
lines changed

pytensor/link/jax/dispatch/subtensor.py

+2-18
Original file line numberDiff line numberDiff line change
@@ -31,36 +31,20 @@
3131
"""
3232

3333

34-
def subtensor_assert_indices_jax_compatible(node, idx_list):
35-
from pytensor.graph.basic import Constant
36-
from pytensor.tensor.variable import TensorVariable
37-
38-
ilist = indices_from_subtensor(node.inputs[1:], idx_list)
39-
for idx in ilist:
40-
if isinstance(idx, TensorVariable):
41-
if idx.type.dtype == "bool":
42-
raise NotImplementedError(BOOLEAN_MASK_ERROR)
43-
elif isinstance(idx, slice):
44-
for slice_arg in (idx.start, idx.stop, idx.step):
45-
if slice_arg is not None and not isinstance(slice_arg, Constant):
46-
raise NotImplementedError(DYNAMIC_SLICE_LENGTH_ERROR)
47-
48-
4934
@jax_funcify.register(Subtensor)
5035
@jax_funcify.register(AdvancedSubtensor)
5136
@jax_funcify.register(AdvancedSubtensor1)
5237
def jax_funcify_Subtensor(op, node, **kwargs):
5338
idx_list = getattr(op, "idx_list", None)
54-
subtensor_assert_indices_jax_compatible(node, idx_list)
5539

56-
def subtensor_constant(x, *ilists):
40+
def subtensor(x, *ilists):
5741
indices = indices_from_subtensor(ilists, idx_list)
5842
if len(indices) == 1:
5943
indices = indices[0]
6044

6145
return x.__getitem__(indices)
6246

63-
return subtensor_constant
47+
return subtensor
6448

6549

6650
@jax_funcify.register(IncSubtensor)

tests/link/jax/test_subtensor.py

+33-22
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pytensor.configdefaults import config
66
from pytensor.graph.fg import FunctionGraph
77
from pytensor.tensor import subtensor as pt_subtensor
8+
from pytensor.tensor import tensor
89
from pytensor.tensor.rewriting.jax import (
910
boolean_indexing_set_or_inc,
1011
boolean_indexing_sum,
@@ -13,54 +14,62 @@
1314

1415

1516
def test_jax_Subtensor_constant():
17+
shape = (3, 4, 5)
18+
x_pt = tensor("x", shape=shape, dtype="int")
19+
x_np = np.arange(np.prod(shape)).reshape(shape)
20+
1621
# Basic indices
17-
x_pt = pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
1822
out_pt = x_pt[1, 2, 0]
1923
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
20-
out_fg = FunctionGraph([], [out_pt])
21-
compare_jax_and_py(out_fg, [])
24+
out_fg = FunctionGraph([x_pt], [out_pt])
25+
compare_jax_and_py(out_fg, [x_np])
2226

2327
out_pt = x_pt[1:, 1, :]
2428
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
25-
out_fg = FunctionGraph([], [out_pt])
26-
compare_jax_and_py(out_fg, [])
29+
out_fg = FunctionGraph([x_pt], [out_pt])
30+
compare_jax_and_py(out_fg, [x_np])
2731

2832
out_pt = x_pt[:2, 1, :]
2933
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
30-
out_fg = FunctionGraph([], [out_pt])
31-
compare_jax_and_py(out_fg, [])
34+
out_fg = FunctionGraph([x_pt], [out_pt])
35+
compare_jax_and_py(out_fg, [x_np])
3236

3337
out_pt = x_pt[1:2, 1, :]
3438
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
35-
out_fg = FunctionGraph([], [out_pt])
36-
compare_jax_and_py(out_fg, [])
39+
out_fg = FunctionGraph([x_pt], [out_pt])
40+
compare_jax_and_py(out_fg, [x_np])
3741

3842
# Advanced indexing
3943
out_pt = pt_subtensor.advanced_subtensor1(x_pt, [1, 2])
4044
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1)
41-
out_fg = FunctionGraph([], [out_pt])
42-
compare_jax_and_py(out_fg, [])
45+
out_fg = FunctionGraph([x_pt], [out_pt])
46+
compare_jax_and_py(out_fg, [x_np])
4347

4448
out_pt = x_pt[[1, 2], [2, 3]]
4549
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
46-
out_fg = FunctionGraph([], [out_pt])
47-
compare_jax_and_py(out_fg, [])
50+
out_fg = FunctionGraph([x_pt], [out_pt])
51+
compare_jax_and_py(out_fg, [x_np])
4852

4953
# Advanced and basic indexing
5054
out_pt = x_pt[[1, 2], :]
5155
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
52-
out_fg = FunctionGraph([], [out_pt])
53-
compare_jax_and_py(out_fg, [])
56+
out_fg = FunctionGraph([x_pt], [out_pt])
57+
compare_jax_and_py(out_fg, [x_np])
5458

5559
out_pt = x_pt[[1, 2], :, [3, 4]]
5660
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
57-
out_fg = FunctionGraph([], [out_pt])
58-
compare_jax_and_py(out_fg, [])
61+
out_fg = FunctionGraph([x_pt], [out_pt])
62+
compare_jax_and_py(out_fg, [x_np])
5963

6064
# Flipping
6165
out_pt = x_pt[::-1]
62-
out_fg = FunctionGraph([], [out_pt])
63-
compare_jax_and_py(out_fg, [])
66+
out_fg = FunctionGraph([x_pt], [out_pt])
67+
compare_jax_and_py(out_fg, [x_np])
68+
69+
# Boolean indexing should work if indexes are constant
70+
out_pt = x_pt[np.random.binomial(1, 0.5, size=(3, 4, 5))]
71+
out_fg = FunctionGraph([x_pt], [out_pt])
72+
compare_jax_and_py(out_fg, [x_np])
6473

6574

6675
@pytest.mark.xfail(reason="`a` should be specified as static when JIT-compiling")
@@ -73,16 +82,18 @@ def test_jax_Subtensor_dynamic():
7382
compare_jax_and_py(out_fg, [1])
7483

7584

76-
def test_jax_Subtensor_boolean_mask():
77-
"""JAX does not support resizing arrays with boolean masks."""
85+
def test_jax_Subtensor_dynamic_boolean_mask():
86+
"""JAX does not support resizing arrays with dynamic boolean masks."""
87+
from jax.errors import NonConcreteBooleanIndexError
88+
7889
x_pt = pt.vector("x", dtype="float64")
7990
out_pt = x_pt[x_pt < 0]
8091
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
8192

8293
out_fg = FunctionGraph([x_pt], [out_pt])
8394

8495
x_pt_test = np.arange(-5, 5)
85-
with pytest.raises(NotImplementedError, match="resizing arrays with boolean"):
96+
with pytest.raises(NonConcreteBooleanIndexError):
8697
compare_jax_and_py(out_fg, [x_pt_test])
8798

8899

0 commit comments

Comments
 (0)