Skip to content

Commit 7f623fe

Browse files
committed
Vectorize make_vector
1 parent 5fd729d commit 7f623fe

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

Diff for: pytensor/tensor/basic.py

+21
Original file line numberDiff line numberDiff line change
@@ -1890,6 +1890,23 @@ def _get_vector_length_MakeVector(op, var):
18901890
return len(var.owner.inputs)
18911891

18921892

1893+
@_vectorize_node.register
1894+
def vectorize_make_vector(op: MakeVector, node, *batch_inputs):
1895+
# We vectorize make_vector as a join along the last axis of the broadcasted inputs
1896+
from pytensor.tensor.extra_ops import broadcast_arrays
1897+
1898+
# Check if we need to broadcast at all
1899+
bcast_pattern = batch_inputs[0].type.broadcastable
1900+
if not all(
1901+
batch_input.type.broadcastable == bcast_pattern for batch_input in batch_inputs
1902+
):
1903+
batch_inputs = broadcast_arrays(*batch_inputs)
1904+
1905+
# Join along the last axis
1906+
new_out = stack(batch_inputs, axis=-1)
1907+
return new_out.owner
1908+
1909+
18931910
def transfer(var, target):
18941911
"""
18951912
Return a version of `var` transferred to `target`.
@@ -2690,6 +2707,10 @@ def vectorize_join(op: Join, node, batch_axis, *batch_inputs):
26902707
# We can vectorize join as a shifted axis on the batch inputs if:
26912708
# 1. The batch axis is a constant and has not changed
26922709
# 2. All inputs are batched with the same broadcastable pattern
2710+
2711+
# TODO: We can relax the second condition by broadcasting the batch dimensions
2712+
# This can be done with `broadcast_arrays` if the tensors shape match at the axis or reduction
2713+
# Or otherwise by calling `broadcast_to` for each tensor that needs it
26932714
if (
26942715
original_axis.type.ndim == 0
26952716
and isinstance(original_axis, Constant)

Diff for: tests/tensor/test_basic.py

+40
Original file line numberDiff line numberDiff line change
@@ -4577,6 +4577,46 @@ def core_np(x):
45774577
)
45784578

45794579

4580+
@pytest.mark.parametrize(
4581+
"batch_shapes",
4582+
[
4583+
((3,),), # edge case of make_vector with a single input
4584+
((), (), ()), # Useless
4585+
((3,), (3,), (3,)), # No broadcasting needed
4586+
((3,), (5, 3), ()), # Broadcasting needed
4587+
],
4588+
)
4589+
def test_vectorize_make_vector(batch_shapes):
4590+
n_inputs = len(batch_shapes)
4591+
input_sig = ",".join(["()"] * n_inputs)
4592+
signature = f"{input_sig}->({n_inputs})" # Something like "(),(),()->(3)"
4593+
4594+
def core_pt(*scalars):
4595+
out = stack(scalars)
4596+
out.dprint()
4597+
return out
4598+
4599+
def core_np(*scalars):
4600+
return np.stack(scalars)
4601+
4602+
tensors = [tensor(shape=shape) for shape in batch_shapes]
4603+
4604+
vectorize_pt = function(tensors, vectorize(core_pt, signature=signature)(*tensors))
4605+
assert not any(
4606+
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
4607+
)
4608+
4609+
test_values = [
4610+
np.random.normal(size=tensor.type.shape).astype(tensor.type.dtype)
4611+
for tensor in tensors
4612+
]
4613+
4614+
np.testing.assert_allclose(
4615+
vectorize_pt(*test_values),
4616+
np.vectorize(core_np, signature=signature)(*test_values),
4617+
)
4618+
4619+
45804620
@pytest.mark.parametrize("axis", [constant(1), constant(-2), shared(1)])
45814621
@pytest.mark.parametrize("broadcasting_y", ["none", "implicit", "explicit"])
45824622
@config.change_flags(cxx="") # C code not needed

0 commit comments

Comments
 (0)