Skip to content

Commit 9bc1c3a

Browse files
Implement broadcast for XTensorVariables
Co-authored-by: Ricardo <[email protected]>
1 parent a36d55a commit 9bc1c3a

File tree

6 files changed

+312
-4
lines changed

6 files changed

+312
-4
lines changed

pytensor/xtensor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytensor.xtensor.rewriting
44
from pytensor.xtensor import linalg, random
55
from pytensor.xtensor.math import dot
6-
from pytensor.xtensor.shape import concat
6+
from pytensor.xtensor.shape import broadcast, concat
77
from pytensor.xtensor.type import (
88
as_xtensor,
99
xtensor,

pytensor/xtensor/rewriting/shape.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytensor.tensor as pt
12
from pytensor.graph import node_rewriter
23
from pytensor.tensor import (
34
broadcast_to,
@@ -17,6 +18,7 @@
1718
Stack,
1819
Transpose,
1920
UnStack,
21+
XBroadcast,
2022
)
2123

2224

@@ -157,3 +159,61 @@ def lower_expand_dims(fgraph, node):
157159
# Convert result back to xtensor
158160
result = xtensor_from_tensor(result_tensor, dims=out.type.dims)
159161
return [result]
162+
163+
164+
@register_lower_xtensor
165+
@node_rewriter(tracks=[XBroadcast])
166+
def lower_broadcast(fgraph, node):
167+
"""Rewrite XBroadcast using tensor operations."""
168+
169+
excluded_dims = node.op.exclude
170+
171+
tensor_inputs = [
172+
lower_aligned(inp, out.type.dims)
173+
for inp, out in zip(node.inputs, node.outputs, strict=True)
174+
]
175+
176+
if not excluded_dims:
177+
# Simple case: All dimensions are broadcasted
178+
tensor_outputs = pt.broadcast_arrays(*tensor_inputs)
179+
180+
else:
181+
# Complex case: Some dimensions are excluded from broadcasting
182+
# Pick the first dimension_length for each dim
183+
broadcast_dims = {
184+
d: None for d in node.outputs[0].type.dims if d not in excluded_dims
185+
}
186+
for xtensor_inp in node.inputs:
187+
for dim, dim_length in xtensor_inp.sizes.items():
188+
if dim in broadcast_dims and broadcast_dims[dim] is None:
189+
# If the dimension is not excluded, set its shape
190+
broadcast_dims[dim] = dim_length
191+
assert not any(
192+
value is None for value in broadcast_dims.values()
193+
), "All dimensions must have a length"
194+
195+
# Create zeros with the broadcast dimensions, to then broadcast each input against
196+
# PyTensor will rewrite into using only the shapes of the zeros tensor
197+
broadcast_dims = pt.zeros(
198+
tuple(broadcast_dims.values()),
199+
dtype=node.outputs[0].type.dtype,
200+
)
201+
n_broadcast_dims = broadcast_dims.ndim
202+
203+
tensor_outputs = []
204+
for tensor_inp, xtensor_out in zip(tensor_inputs, node.outputs, strict=True):
205+
n_excluded_dims = tensor_inp.type.ndim - n_broadcast_dims
206+
# Excluded dimensions are on the right side of the output tensor so we padright the broadcast_dims
207+
# second is equivalent to `np.broadcast_arrays(x, y)[1]` in PyTensor
208+
tensor_outputs.append(
209+
pt.second(
210+
pt.shape_padright(broadcast_dims, n_excluded_dims),
211+
tensor_inp,
212+
)
213+
)
214+
215+
new_outs = [
216+
xtensor_from_tensor(out_tensor, dims=out.type.dims)
217+
for out_tensor, out in zip(tensor_outputs, node.outputs)
218+
]
219+
return new_outs

pytensor/xtensor/shape.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from pytensor.tensor.type import integer_dtypes
1414
from pytensor.tensor.utils import get_static_shape_from_size_variables
1515
from pytensor.xtensor.basic import XOp
16-
from pytensor.xtensor.type import as_xtensor, xtensor
16+
from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor
17+
from pytensor.xtensor.vectorization import combine_dims_and_shape
1718

1819

1920
class Stack(XOp):
@@ -504,3 +505,63 @@ def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwa
504505
x = Transpose(dims=tuple(target_dims))(x)
505506

506507
return x
508+
509+
510+
class XBroadcast(XOp):
511+
"""Broadcast multiple XTensorVariables against each other."""
512+
513+
__props__ = ("exclude",)
514+
515+
def __init__(self, exclude: Sequence[str] = ()):
516+
self.exclude = tuple(exclude)
517+
518+
def make_node(self, *inputs):
519+
inputs = [as_xtensor(x) for x in inputs]
520+
521+
exclude = self.exclude
522+
dims_and_shape = combine_dims_and_shape(inputs, exclude=exclude)
523+
524+
broadcast_dims = tuple(d for d in dims_and_shape if d not in exclude)
525+
broadcast_shape = tuple(dims_and_shape[d] for d in broadcast_dims)
526+
dtype = upcast(*[x.type.dtype for x in inputs])
527+
528+
outputs = []
529+
for x in inputs:
530+
x_dims = x.type.dims
531+
x_shape = x.type.shape
532+
# The output has excluded dimensions in the order they appear in the op argument
533+
excluded_dims = tuple(d for d in exclude if d in x_dims)
534+
excluded_shape = tuple(x_shape[x_dims.index(d)] for d in excluded_dims)
535+
536+
output = xtensor(
537+
dtype=dtype,
538+
shape=broadcast_shape + excluded_shape,
539+
dims=broadcast_dims + excluded_dims,
540+
)
541+
outputs.append(output)
542+
543+
return Apply(self, inputs, outputs)
544+
545+
546+
def broadcast(
547+
*args, exclude: str | Sequence[str] | None = None
548+
) -> tuple[XTensorVariable, ...]:
549+
"""Broadcast any number of XTensorVariables against each other.
550+
551+
Parameters
552+
----------
553+
*args : XTensorVariable
554+
The tensors to broadcast against each other.
555+
exclude : str or Sequence[str] or None, optional
556+
"""
557+
if not args:
558+
return ()
559+
560+
if exclude is None:
561+
exclude = ()
562+
elif isinstance(exclude, str):
563+
exclude = (exclude,)
564+
elif not isinstance(exclude, Sequence):
565+
raise TypeError(f"exclude must be None, str, or Sequence, got {type(exclude)}")
566+
# xarray broadcast always returns a tuple, even if there's only one tensor
567+
return tuple(XBroadcast(exclude=exclude)(*args, return_list=True))

pytensor/xtensor/type.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,15 @@ def dot(self, other, dim=None):
736736
"""Matrix multiplication with another XTensorVariable, contracting over matching or specified dims."""
737737
return px.math.dot(self, other, dim=dim)
738738

739+
def broadcast(self, *others, exclude=None):
740+
"""Broadcast this tensor against other XTensorVariables."""
741+
return px.shape.broadcast(self, *others, exclude=exclude)
742+
743+
def broadcast_like(self, other, exclude=None):
744+
"""Broadcast this tensor against another XTensorVariable."""
745+
_, self_bcast = px.shape.broadcast(other, self, exclude=exclude)
746+
return self_bcast
747+
739748

740749
class XTensorConstantSignature(TensorConstantSignature):
741750
pass

pytensor/xtensor/vectorization.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Sequence
12
from itertools import chain
23

34
import numpy as np
@@ -13,13 +14,22 @@
1314
get_static_shape_from_size_variables,
1415
)
1516
from pytensor.xtensor.basic import XOp
16-
from pytensor.xtensor.type import as_xtensor, xtensor
17+
from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor
1718

1819

19-
def combine_dims_and_shape(inputs):
20+
def combine_dims_and_shape(
21+
inputs: Sequence[XTensorVariable], exclude: Sequence[str] | None = None
22+
) -> dict[str, int | None]:
23+
"""Combine information of static dimensions and shapes from multiple xtensor inputs.
24+
25+
Exclude
26+
"""
27+
exclude = set() if exclude is None else set(exclude)
2028
dims_and_shape: dict[str, int | None] = {}
2129
for inp in inputs:
2230
for dim, dim_length in zip(inp.type.dims, inp.type.shape):
31+
if dim in exclude:
32+
continue
2333
if dim not in dims_and_shape:
2434
dims_and_shape[dim] = dim_length
2535
elif dim_length is not None:

tests/xtensor/test_shape.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99

1010
import numpy as np
1111
from xarray import DataArray
12+
from xarray import broadcast as xr_broadcast
1213
from xarray import concat as xr_concat
1314

1415
from pytensor.tensor import scalar
1516
from pytensor.xtensor.shape import (
17+
broadcast,
1618
concat,
1719
stack,
1820
unstack,
@@ -466,3 +468,169 @@ def test_expand_dims_errors():
466468
# Test with a numpy array as dim (not supported)
467469
with pytest.raises(TypeError, match="unhashable type"):
468470
y.expand_dims(np.array([1, 2]))
471+
472+
473+
class TestBroadcast:
474+
@pytest.mark.parametrize(
475+
"exclude",
476+
[
477+
None,
478+
[],
479+
["b"],
480+
["b", "d"],
481+
["a", "d"],
482+
["b", "c", "d"],
483+
["a", "b", "c", "d"],
484+
],
485+
)
486+
def test_compatible_excluded_shapes(self, exclude):
487+
# Create test data
488+
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
489+
y = xtensor("y", dims=("c", "d"), shape=(5, 6))
490+
z = xtensor("z", dims=("b", "d"), shape=(4, 6))
491+
492+
x_test = xr_arange_like(x)
493+
y_test = xr_arange_like(y)
494+
z_test = xr_arange_like(z)
495+
496+
# Test with excluded dims
497+
x2_expected, y2_expected, z2_expected = xr_broadcast(
498+
x_test, y_test, z_test, exclude=exclude
499+
)
500+
x2, y2, z2 = broadcast(x, y, z, exclude=exclude)
501+
fn = xr_function([x, y, z], [x2, y2, z2])
502+
x2_result, y2_result, z2_result = fn(x_test, y_test, z_test)
503+
504+
xr_assert_allclose(x2_result, x2_expected)
505+
xr_assert_allclose(y2_result, y2_expected)
506+
xr_assert_allclose(z2_result, z2_expected)
507+
508+
def test_incompatible_excluded_shapes(self):
509+
# Test that excluded dims are allowed to be different sizes
510+
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
511+
y = xtensor("y", dims=("c", "d"), shape=(5, 6))
512+
z = xtensor("z", dims=("b", "d"), shape=(4, 7))
513+
out = broadcast(x, y, z, exclude=["d"])
514+
515+
x_test = xr_arange_like(x)
516+
y_test = xr_arange_like(y)
517+
z_test = xr_arange_like(z)
518+
fn = xr_function([x, y, z], out)
519+
results = fn(x_test, y_test, z_test)
520+
expected_results = xr_broadcast(x_test, y_test, z_test, exclude=["d"])
521+
for res, expected_res in zip(results, expected_results, strict=True):
522+
xr_assert_allclose(res, expected_res)
523+
524+
@pytest.mark.parametrize("exclude", [[], ["b"], ["b", "c"], ["a", "b", "d"]])
525+
def test_runtime_shapes(self, exclude):
526+
# Test with symbolic shapes but no excluded dims
527+
x = xtensor("x", dims=("a", "b"), shape=(None, 4))
528+
y = xtensor("y", dims=("c", "d"), shape=(5, None))
529+
z = xtensor("z", dims=("b", "d"), shape=(None, None))
530+
out = broadcast(x, y, z, exclude=exclude)
531+
532+
x_test = xr_arange_like(xtensor(dims=x.dims, shape=(3, 4)))
533+
y_test = xr_arange_like(xtensor(dims=y.dims, shape=(5, 6)))
534+
z_test = xr_arange_like(xtensor(dims=z.dims, shape=(4, 6)))
535+
fn = xr_function([x, y, z], out)
536+
results = fn(x_test, y_test, z_test)
537+
expected_results = xr_broadcast(x_test, y_test, z_test, exclude=exclude)
538+
for res, expected_res in zip(results, expected_results, strict=True):
539+
xr_assert_allclose(res, expected_res)
540+
541+
# Test invalid shape raises an error
542+
# Note: We might decide not to raise an error in the lowered graphs for performance reasons
543+
if "d" not in exclude:
544+
z_test_bad = xr_arange_like(xtensor(dims=z.dims, shape=(4, 7)))
545+
with pytest.raises(Exception):
546+
fn(x_test, y_test, z_test_bad)
547+
548+
def test_broadcast_excluded_dims_in_different_order(self):
549+
"""Test broadcasting excluded dims are aligned with user input."""
550+
x = xtensor("x", dims=("a", "c", "b"), shape=(3, 4, 5))
551+
y = xtensor("y", dims=("a", "b", "c"), shape=(3, 5, 4))
552+
out = (out_x, out_y) = broadcast(x, y, exclude=["c", "b"])
553+
assert out_x.type.dims == ("a", "c", "b")
554+
assert out_y.type.dims == ("a", "c", "b")
555+
556+
x_test = xr_arange_like(x)
557+
y_test = xr_arange_like(y)
558+
fn = xr_function([x, y], out)
559+
results = fn(x_test, y_test)
560+
expected_results = xr_broadcast(x_test, y_test, exclude=["c", "b"])
561+
for res, expected_res in zip(results, expected_results, strict=True):
562+
xr_assert_allclose(res, expected_res)
563+
564+
def test_broadcast_errors(self):
565+
"""Test error handling in broadcast."""
566+
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
567+
y = xtensor("y", dims=("c", "d"), shape=(5, 6))
568+
z = xtensor("z", dims=("b", "d"), shape=(4, 6))
569+
570+
with pytest.raises(TypeError, match="exclude must be None, str, or Sequence"):
571+
broadcast(x, y, z, exclude=1)
572+
573+
# Test with conflicting shapes
574+
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
575+
y = xtensor("y", dims=("c", "d"), shape=(5, 6))
576+
z = xtensor("z", dims=("b", "d"), shape=(4, 7))
577+
578+
with pytest.raises(ValueError, match="Dimension .* has conflicting shapes"):
579+
broadcast(x, y, z)
580+
581+
def test_broadcast_no_input(self):
582+
assert broadcast() == xr_broadcast()
583+
assert broadcast(exclude=("a",)) == xr_broadcast(exclude=("a",))
584+
585+
def test_broadcast_single_input(self):
586+
"""Test broadcasting a single input."""
587+
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
588+
# Broadcast with a single input can still imply a transpose via the exclude parameter
589+
outs = [
590+
*broadcast(x),
591+
*broadcast(x, exclude=("a", "b")),
592+
*broadcast(x, exclude=("b", "a")),
593+
*broadcast(x, exclude=("b",)),
594+
]
595+
596+
fn = xr_function([x], outs)
597+
x_test = xr_arange_like(x)
598+
results = fn(x_test)
599+
expected_results = [
600+
*xr_broadcast(x_test),
601+
*xr_broadcast(x_test, exclude=("a", "b")),
602+
*xr_broadcast(x_test, exclude=("b", "a")),
603+
*xr_broadcast(x_test, exclude=("b",)),
604+
]
605+
for res, expected_res in zip(results, expected_results, strict=True):
606+
xr_assert_allclose(res, expected_res)
607+
608+
@pytest.mark.parametrize("exclude", [None, ["b"], ["b", "c"]])
609+
def test_broadcast_like(self, exclude):
610+
"""Test broadcast_like method"""
611+
# Create test data
612+
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
613+
y = xtensor("y", dims=("c", "d"), shape=(5, 6))
614+
z = xtensor("z", dims=("b", "d"), shape=(4, 6))
615+
616+
# Order matters so we test both orders
617+
outs = [
618+
x.broadcast_like(y, exclude=exclude),
619+
y.broadcast_like(x, exclude=exclude),
620+
y.broadcast_like(z, exclude=exclude),
621+
z.broadcast_like(y, exclude=exclude),
622+
]
623+
624+
x_test = xr_arange_like(x)
625+
y_test = xr_arange_like(y)
626+
z_test = xr_arange_like(z)
627+
fn = xr_function([x, y, z], outs)
628+
results = fn(x_test, y_test, z_test)
629+
expected_results = [
630+
x_test.broadcast_like(y_test, exclude=exclude),
631+
y_test.broadcast_like(x_test, exclude=exclude),
632+
y_test.broadcast_like(z_test, exclude=exclude),
633+
z_test.broadcast_like(y_test, exclude=exclude),
634+
]
635+
for res, expected_res in zip(results, expected_results, strict=True):
636+
xr_assert_allclose(res, expected_res)

0 commit comments

Comments
 (0)