|
9 | 9 |
|
10 | 10 | import numpy as np
|
11 | 11 | from xarray import DataArray
|
| 12 | +from xarray import broadcast as xr_broadcast |
12 | 13 | from xarray import concat as xr_concat
|
13 | 14 |
|
14 | 15 | from pytensor.tensor import scalar
|
15 | 16 | from pytensor.xtensor.shape import (
|
| 17 | + broadcast, |
16 | 18 | concat,
|
17 | 19 | stack,
|
18 | 20 | unstack,
|
@@ -466,3 +468,169 @@ def test_expand_dims_errors():
|
466 | 468 | # Test with a numpy array as dim (not supported)
|
467 | 469 | with pytest.raises(TypeError, match="unhashable type"):
|
468 | 470 | y.expand_dims(np.array([1, 2]))
|
| 471 | + |
| 472 | + |
| 473 | +class TestBroadcast: |
| 474 | + @pytest.mark.parametrize( |
| 475 | + "exclude", |
| 476 | + [ |
| 477 | + None, |
| 478 | + [], |
| 479 | + ["b"], |
| 480 | + ["b", "d"], |
| 481 | + ["a", "d"], |
| 482 | + ["b", "c", "d"], |
| 483 | + ["a", "b", "c", "d"], |
| 484 | + ], |
| 485 | + ) |
| 486 | + def test_compatible_excluded_shapes(self, exclude): |
| 487 | + # Create test data |
| 488 | + x = xtensor("x", dims=("a", "b"), shape=(3, 4)) |
| 489 | + y = xtensor("y", dims=("c", "d"), shape=(5, 6)) |
| 490 | + z = xtensor("z", dims=("b", "d"), shape=(4, 6)) |
| 491 | + |
| 492 | + x_test = xr_arange_like(x) |
| 493 | + y_test = xr_arange_like(y) |
| 494 | + z_test = xr_arange_like(z) |
| 495 | + |
| 496 | + # Test with excluded dims |
| 497 | + x2_expected, y2_expected, z2_expected = xr_broadcast( |
| 498 | + x_test, y_test, z_test, exclude=exclude |
| 499 | + ) |
| 500 | + x2, y2, z2 = broadcast(x, y, z, exclude=exclude) |
| 501 | + fn = xr_function([x, y, z], [x2, y2, z2]) |
| 502 | + x2_result, y2_result, z2_result = fn(x_test, y_test, z_test) |
| 503 | + |
| 504 | + xr_assert_allclose(x2_result, x2_expected) |
| 505 | + xr_assert_allclose(y2_result, y2_expected) |
| 506 | + xr_assert_allclose(z2_result, z2_expected) |
| 507 | + |
| 508 | + def test_incompatible_excluded_shapes(self): |
| 509 | + # Test that excluded dims are allowed to be different sizes |
| 510 | + x = xtensor("x", dims=("a", "b"), shape=(3, 4)) |
| 511 | + y = xtensor("y", dims=("c", "d"), shape=(5, 6)) |
| 512 | + z = xtensor("z", dims=("b", "d"), shape=(4, 7)) |
| 513 | + out = broadcast(x, y, z, exclude=["d"]) |
| 514 | + |
| 515 | + x_test = xr_arange_like(x) |
| 516 | + y_test = xr_arange_like(y) |
| 517 | + z_test = xr_arange_like(z) |
| 518 | + fn = xr_function([x, y, z], out) |
| 519 | + results = fn(x_test, y_test, z_test) |
| 520 | + expected_results = xr_broadcast(x_test, y_test, z_test, exclude=["d"]) |
| 521 | + for res, expected_res in zip(results, expected_results, strict=True): |
| 522 | + xr_assert_allclose(res, expected_res) |
| 523 | + |
| 524 | + @pytest.mark.parametrize("exclude", [[], ["b"], ["b", "c"], ["a", "b", "d"]]) |
| 525 | + def test_runtime_shapes(self, exclude): |
| 526 | + # Test with symbolic shapes but no excluded dims |
| 527 | + x = xtensor("x", dims=("a", "b"), shape=(None, 4)) |
| 528 | + y = xtensor("y", dims=("c", "d"), shape=(5, None)) |
| 529 | + z = xtensor("z", dims=("b", "d"), shape=(None, None)) |
| 530 | + out = broadcast(x, y, z, exclude=exclude) |
| 531 | + |
| 532 | + x_test = xr_arange_like(xtensor(dims=x.dims, shape=(3, 4))) |
| 533 | + y_test = xr_arange_like(xtensor(dims=y.dims, shape=(5, 6))) |
| 534 | + z_test = xr_arange_like(xtensor(dims=z.dims, shape=(4, 6))) |
| 535 | + fn = xr_function([x, y, z], out) |
| 536 | + results = fn(x_test, y_test, z_test) |
| 537 | + expected_results = xr_broadcast(x_test, y_test, z_test, exclude=exclude) |
| 538 | + for res, expected_res in zip(results, expected_results, strict=True): |
| 539 | + xr_assert_allclose(res, expected_res) |
| 540 | + |
| 541 | + # Test invalid shape raises an error |
| 542 | + # Note: We might decide not to raise an error in the lowered graphs for performance reasons |
| 543 | + if "d" not in exclude: |
| 544 | + z_test_bad = xr_arange_like(xtensor(dims=z.dims, shape=(4, 7))) |
| 545 | + with pytest.raises(Exception): |
| 546 | + fn(x_test, y_test, z_test_bad) |
| 547 | + |
| 548 | + def test_broadcast_excluded_dims_in_different_order(self): |
| 549 | + """Test broadcasting excluded dims are aligned with user input.""" |
| 550 | + x = xtensor("x", dims=("a", "c", "b"), shape=(3, 4, 5)) |
| 551 | + y = xtensor("y", dims=("a", "b", "c"), shape=(3, 5, 4)) |
| 552 | + out = (out_x, out_y) = broadcast(x, y, exclude=["c", "b"]) |
| 553 | + assert out_x.type.dims == ("a", "c", "b") |
| 554 | + assert out_y.type.dims == ("a", "c", "b") |
| 555 | + |
| 556 | + x_test = xr_arange_like(x) |
| 557 | + y_test = xr_arange_like(y) |
| 558 | + fn = xr_function([x, y], out) |
| 559 | + results = fn(x_test, y_test) |
| 560 | + expected_results = xr_broadcast(x_test, y_test, exclude=["c", "b"]) |
| 561 | + for res, expected_res in zip(results, expected_results, strict=True): |
| 562 | + xr_assert_allclose(res, expected_res) |
| 563 | + |
| 564 | + def test_broadcast_errors(self): |
| 565 | + """Test error handling in broadcast.""" |
| 566 | + x = xtensor("x", dims=("a", "b"), shape=(3, 4)) |
| 567 | + y = xtensor("y", dims=("c", "d"), shape=(5, 6)) |
| 568 | + z = xtensor("z", dims=("b", "d"), shape=(4, 6)) |
| 569 | + |
| 570 | + with pytest.raises(TypeError, match="exclude must be None, str, or Sequence"): |
| 571 | + broadcast(x, y, z, exclude=1) |
| 572 | + |
| 573 | + # Test with conflicting shapes |
| 574 | + x = xtensor("x", dims=("a", "b"), shape=(3, 4)) |
| 575 | + y = xtensor("y", dims=("c", "d"), shape=(5, 6)) |
| 576 | + z = xtensor("z", dims=("b", "d"), shape=(4, 7)) |
| 577 | + |
| 578 | + with pytest.raises(ValueError, match="Dimension .* has conflicting shapes"): |
| 579 | + broadcast(x, y, z) |
| 580 | + |
| 581 | + def test_broadcast_no_input(self): |
| 582 | + assert broadcast() == xr_broadcast() |
| 583 | + assert broadcast(exclude=("a",)) == xr_broadcast(exclude=("a",)) |
| 584 | + |
| 585 | + def test_broadcast_single_input(self): |
| 586 | + """Test broadcasting a single input.""" |
| 587 | + x = xtensor("x", dims=("a", "b"), shape=(3, 4)) |
| 588 | + # Broadcast with a single input can still imply a transpose via the exclude parameter |
| 589 | + outs = [ |
| 590 | + *broadcast(x), |
| 591 | + *broadcast(x, exclude=("a", "b")), |
| 592 | + *broadcast(x, exclude=("b", "a")), |
| 593 | + *broadcast(x, exclude=("b",)), |
| 594 | + ] |
| 595 | + |
| 596 | + fn = xr_function([x], outs) |
| 597 | + x_test = xr_arange_like(x) |
| 598 | + results = fn(x_test) |
| 599 | + expected_results = [ |
| 600 | + *xr_broadcast(x_test), |
| 601 | + *xr_broadcast(x_test, exclude=("a", "b")), |
| 602 | + *xr_broadcast(x_test, exclude=("b", "a")), |
| 603 | + *xr_broadcast(x_test, exclude=("b",)), |
| 604 | + ] |
| 605 | + for res, expected_res in zip(results, expected_results, strict=True): |
| 606 | + xr_assert_allclose(res, expected_res) |
| 607 | + |
| 608 | + @pytest.mark.parametrize("exclude", [None, ["b"], ["b", "c"]]) |
| 609 | + def test_broadcast_like(self, exclude): |
| 610 | + """Test broadcast_like method""" |
| 611 | + # Create test data |
| 612 | + x = xtensor("x", dims=("a", "b"), shape=(3, 4)) |
| 613 | + y = xtensor("y", dims=("c", "d"), shape=(5, 6)) |
| 614 | + z = xtensor("z", dims=("b", "d"), shape=(4, 6)) |
| 615 | + |
| 616 | + # Order matters so we test both orders |
| 617 | + outs = [ |
| 618 | + x.broadcast_like(y, exclude=exclude), |
| 619 | + y.broadcast_like(x, exclude=exclude), |
| 620 | + y.broadcast_like(z, exclude=exclude), |
| 621 | + z.broadcast_like(y, exclude=exclude), |
| 622 | + ] |
| 623 | + |
| 624 | + x_test = xr_arange_like(x) |
| 625 | + y_test = xr_arange_like(y) |
| 626 | + z_test = xr_arange_like(z) |
| 627 | + fn = xr_function([x, y, z], outs) |
| 628 | + results = fn(x_test, y_test, z_test) |
| 629 | + expected_results = [ |
| 630 | + x_test.broadcast_like(y_test, exclude=exclude), |
| 631 | + y_test.broadcast_like(x_test, exclude=exclude), |
| 632 | + y_test.broadcast_like(z_test, exclude=exclude), |
| 633 | + z_test.broadcast_like(y_test, exclude=exclude), |
| 634 | + ] |
| 635 | + for res, expected_res in zip(results, expected_results, strict=True): |
| 636 | + xr_assert_allclose(res, expected_res) |
0 commit comments