Skip to content

Commit ea5fe1c

Browse files
committed
Vectorize make_vector
1 parent 31bf682 commit ea5fe1c

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

pytensor/tensor/basic.py

+23
Original file line numberDiff line numberDiff line change
@@ -1888,6 +1888,25 @@ def _get_vector_length_MakeVector(op, var):
18881888
return len(var.owner.inputs)
18891889

18901890

1891+
@_vectorize_node.register
1892+
def vectorize_make_vector(op: MakeVector, node, *batch_inputs):
1893+
# We vectorize make_vector as a join along the last axis of the broadcasted inputs
1894+
from pytensor.tensor.extra_ops import broadcast_arrays
1895+
1896+
# Check if we need to broadcast at all
1897+
bcast_pattern = batch_inputs[0].type.broadcastable
1898+
if not all(
1899+
batch_input.type.broadcastable == bcast_pattern for batch_input in batch_inputs
1900+
):
1901+
batch_inputs = broadcast_arrays(*batch_inputs)
1902+
1903+
# Join along the last axis
1904+
new_out = join(
1905+
-1, *[expand_dims(batch_inputs, axis=-1) for batch_inputs in batch_inputs]
1906+
)
1907+
return new_out.owner
1908+
1909+
18911910
def transfer(var, target):
18921911
"""
18931912
Return a version of `var` transferred to `target`.
@@ -2687,6 +2706,10 @@ def vectorize_join(op: Join, node, batch_axis, *batch_inputs):
26872706
# We can vectorize join as a shifted axis on the batch inputs if:
26882707
# 1. The batch axis is a constant and has not changed
26892708
# 2. All inputs are batched with the same broadcastable pattern
2709+
2710+
# TODO: We can relax the second condition by broadcasting the batch dimensions
2711+
# This can be done with `broadcast_arrays` if the tensors shape match at the axis or reduction
2712+
# Or otherwise by calling `broadcast_to` for each tensor that needs it
26902713
if (
26912714
original_axis.type.ndim == 0
26922715
and isinstance(original_axis, Constant)

tests/tensor/test_basic.py

+30
Original file line numberDiff line numberDiff line change
@@ -4577,6 +4577,36 @@ def core_np(x):
45774577
)
45784578

45794579

4580+
@pytest.mark.parametrize("requires_broadcasting", [False, True])
4581+
def test_vectorize_make_vector(requires_broadcasting):
4582+
signature = "(),(),()->(4)"
4583+
4584+
def core_pt(a, b, c):
4585+
return ptb.stack([a, b, c])
4586+
4587+
def core_np(a, b, c):
4588+
return np.stack([a, b, c])
4589+
4590+
a, b, c = (vector(shape=(3,)) for _ in range(3))
4591+
if requires_broadcasting:
4592+
b = matrix(shape=(5, 3))
4593+
4594+
vectorize_pt = function([a, b, c], vectorize(core_pt, signature=signature)(a, b, c))
4595+
assert not any(
4596+
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
4597+
)
4598+
4599+
a_test = np.random.normal(size=a.type.shape).astype(a.type.dtype)
4600+
b_test = np.random.normal(size=b.type.shape).astype(b.type.dtype)
4601+
c_test = np.random.normal(size=c.type.shape).astype(c.type.dtype)
4602+
4603+
vectorize_np = np.vectorize(core_np, signature=signature)
4604+
np.testing.assert_allclose(
4605+
vectorize_pt(a_test, b_test, c_test),
4606+
vectorize_np(a_test, b_test, c_test),
4607+
)
4608+
4609+
45804610
@pytest.mark.parametrize("axis", [constant(1), constant(-2), shared(1)])
45814611
@pytest.mark.parametrize("broadcasting_y", ["none", "implicit", "explicit"])
45824612
@config.change_flags(cxx="") # C code not needed

0 commit comments

Comments
 (0)