Skip to content

Commit 9b9bcba

Browse files
committed
Allow running JAX functions with scalar inputs for RV shapes
1 parent ec793a9 commit 9b9bcba

File tree

2 files changed

+80
-14
lines changed

2 files changed

+80
-14
lines changed

pytensor/link/jax/linker.py

+33-7
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,19 @@
33
from numpy.random import Generator, RandomState
44

55
from pytensor.compile.sharedvalue import SharedVariable, shared
6-
from pytensor.graph.basic import Constant
76
from pytensor.link.basic import JITLinker
87

98

109
class JAXLinker(JITLinker):
1110
"""A `Linker` that JIT-compiles NumPy-based operations using JAX."""
1211

12+
def __init__(self, *args, **kwargs):
13+
self.scalar_shape_inputs: tuple[int] = () # type: ignore[annotation-unchecked]
14+
super().__init__(*args, **kwargs)
15+
1316
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
1417
from pytensor.link.jax.dispatch import jax_funcify
18+
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
1519
from pytensor.tensor.random.type import RandomType
1620

1721
shared_rng_inputs = [
@@ -63,19 +67,41 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
6367
fgraph.inputs.remove(new_inp)
6468
fgraph.inputs.insert(old_inp_fgrap_index, new_inp)
6569

70+
fgraph_inputs = fgraph.inputs
71+
clients = fgraph.clients
72+
# Detect scalar shape inputs that are used only in JAXShapeTuple nodes
73+
scalar_shape_inputs = [
74+
inp
75+
for node in fgraph.apply_nodes
76+
if isinstance(node.op, JAXShapeTuple)
77+
for inp in node.inputs
78+
if inp in fgraph_inputs
79+
and all(isinstance(node.op, JAXShapeTuple) for node, _ in clients[inp])
80+
]
81+
self.scalar_shape_inputs = tuple(
82+
fgraph_inputs.index(inp) for inp in scalar_shape_inputs
83+
)
84+
6685
return jax_funcify(
6786
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
6887
)
6988

7089
def jit_compile(self, fn):
7190
import jax
7291

73-
# I suppose we can consider `Constant`s to be "static" according to
74-
# JAX.
75-
static_argnums = [
76-
n for n, i in enumerate(self.fgraph.inputs) if isinstance(i, Constant)
77-
]
78-
return jax.jit(fn, static_argnums=static_argnums)
92+
jit_fn = jax.jit(fn, static_argnums=self.scalar_shape_inputs)
93+
94+
if not self.scalar_shape_inputs:
95+
return jit_fn
96+
97+
def convert_scalar_shape_inputs(*args):
98+
new_args = [
99+
int(arg) if i in self.scalar_shape_inputs else arg
100+
for i, arg in enumerate(args)
101+
]
102+
return jit_fn(*new_args)
103+
104+
return convert_scalar_shape_inputs
79105

80106
def create_thunk_inputs(self, storage_map):
81107
from pytensor.link.jax.dispatch import jax_typify

tests/link/jax/test_random.py

+47-7
Original file line numberDiff line numberDiff line change
@@ -863,15 +863,55 @@ def test_random_concrete_shape_subtensor_tuple(self):
863863
jax_fn = compile_random_function([x_pt], out)
864864
assert jax_fn(np.ones((2, 3))).shape == (2,)
865865

866+
def test_random_scalar_shape_input(self):
867+
dim0 = pt.scalar("dim0", dtype=int)
868+
dim1 = pt.scalar("dim1", dtype=int)
869+
870+
out = pt.random.normal(0, 1, size=dim0)
871+
jax_fn = compile_random_function([dim0], out)
872+
assert jax_fn(np.array(2)).shape == (2,)
873+
assert jax_fn(np.array(3)).shape == (3,)
874+
875+
out = pt.random.normal(0, 1, size=[dim0, dim1])
876+
jax_fn = compile_random_function([dim0, dim1], out)
877+
assert jax_fn(np.array(2), np.array(3)).shape == (2, 3)
878+
assert jax_fn(np.array(4), np.array(5)).shape == (4, 5)
879+
866880
@pytest.mark.xfail(
867-
reason="`size_pt` should be specified as a static argument", strict=True
881+
raises=TypeError, reason="Cannot convert scalar input to integer"
868882
)
869-
def test_random_concrete_shape_graph_input(self):
870-
rng = shared(np.random.default_rng(123))
871-
size_pt = pt.scalar()
872-
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
873-
jax_fn = compile_random_function([size_pt], out)
874-
assert jax_fn(10).shape == (10,)
883+
def test_random_scalar_shape_input_not_supported(self):
884+
dim = pt.scalar("dim", dtype=int)
885+
out1 = pt.random.normal(0, 1, size=dim)
886+
# An operation that wouldn't work if we replaced 0d array by integer
887+
out2 = dim[...].set(1)
888+
jax_fn = compile_random_function([dim], [out1, out2])
889+
890+
res1, res2 = jax_fn(np.array(2))
891+
assert res1.shape == (2,)
892+
assert res2 == 1
893+
894+
@pytest.mark.xfail(
895+
raises=TypeError, reason="Cannot convert scalar input to integer"
896+
)
897+
def test_random_scalar_shape_input_not_supported2(self):
898+
dim = pt.scalar("dim", dtype=int)
899+
# This could theoretically be supported
900+
# but would require knowing that * 2 is a safe operation for a python integer
901+
out = pt.random.normal(0, 1, size=dim * 2)
902+
jax_fn = compile_random_function([dim], out)
903+
assert jax_fn(np.array(2)).shape == (4,)
904+
905+
@pytest.mark.xfail(
906+
raises=TypeError, reason="Cannot convert tensor input to shape tuple"
907+
)
908+
def test_random_vector_shape_graph_input(self):
909+
shape = pt.vector("shape", shape=(2,), dtype=int)
910+
out = pt.random.normal(0, 1, size=shape)
911+
912+
jax_fn = compile_random_function([shape], out)
913+
assert jax_fn(np.array([2, 3])).shape == (2, 3)
914+
assert jax_fn(np.array([4, 5])).shape == (4, 5)
875915

876916
def test_constant_shape_after_graph_rewriting(self):
877917
size = pt.vector("size", shape=(2,), dtype=int)

0 commit comments

Comments
 (0)