From 1f254d3c1dbe821d94ff5606a6bc23396d6d880b Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Thu, 21 Sep 2023 15:15:32 -0700 Subject: [PATCH 1/2] Fix Broadcast.broadcast_shape inference --- src/blockbroadcast.jl | 7 ++++--- test/test_blockbroadcast.jl | 13 +++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/blockbroadcast.jl b/src/blockbroadcast.jl index f8e7e89d..9fd1fee9 100644 --- a/src/blockbroadcast.jl +++ b/src/blockbroadcast.jl @@ -30,13 +30,14 @@ BroadcastStyle(::PseudoBlockStyle{M}, ::BlockStyle{N}) where {M,N} = BlockStyle( # sortedunion can assume inputs are already sorted so this could be improved sortedunion(a,b) = sort!(union(a,b)) +sortedunion(a::Tuple, b::Tuple) = (a..., b...) sortedunion(a::Base.OneTo, b::Base.OneTo) = Base.OneTo(max(last(a),last(b))) sortedunion(a::AbstractUnitRange, b::AbstractUnitRange) = min(first(a),first(b)):max(last(a),last(b)) combine_blockaxes(a, b) = _BlockedUnitRange(sortedunion(blocklasts(a), blocklasts(b))) -Base.Broadcast.axistype(a::BlockedUnitRange, b::BlockedUnitRange) = length(b) == 1 ? a : combine_blockaxes(a, b) -Base.Broadcast.axistype(a::BlockedUnitRange, b) = length(b) == 1 ? a : combine_blockaxes(a, b) -Base.Broadcast.axistype(a, b::BlockedUnitRange) = length(b) == 1 ? a : combine_blockaxes(a, b) +Base.Broadcast.axistype(a::BlockedUnitRange, b::BlockedUnitRange) = combine_blockaxes(a, b) +Base.Broadcast.axistype(a::BlockedUnitRange, b) = combine_blockaxes(a, b) +Base.Broadcast.axistype(a, b::BlockedUnitRange) = combine_blockaxes(a, b) similar(bc::Broadcasted{<:AbstractBlockStyle{N}}, ::Type{T}) where {T,N} = diff --git a/test/test_blockbroadcast.jl b/test/test_blockbroadcast.jl index 0d58aa3d..9591354c 100644 --- a/test/test_blockbroadcast.jl +++ b/test/test_blockbroadcast.jl @@ -182,6 +182,19 @@ import BlockArrays: SubBlockIterator, BlockIndexRange, Diagonal u = BlockArray(randn(5), [2,3]); @inferred(copyto!(similar(u), Base.broadcasted(exp, u))) @test exp.(u) == exp.(Vector(u)) + + function test_allocation!(shape1, shape2) + x = Base.Broadcast.broadcast_shape(shape1, shape2) + return nothing + end + shape1 = (BlockArrays._BlockedUnitRange((2,)),); + shape2 = (BlockArrays._BlockedUnitRange((2,)),); + @inferred Base.Broadcast.axistype(shape1[1], shape2[1]) + @inferred BlockArrays.combine_blockaxes(shape1[1], shape2[1]) + @inferred Base.Broadcast.broadcast_shape(shape1, shape2) + test_allocation!(shape1, shape2) # compile first + p = @allocated test_allocation!(shape1, shape2) + @test p == 0 end @testset "adjtrans" begin From 549a1e1aa0e551b8a03e5e4b4be1fe2a0699530e Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Wed, 11 Oct 2023 10:09:50 -0700 Subject: [PATCH 2/2] Fix upstream errors --- src/blockbroadcast.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/blockbroadcast.jl b/src/blockbroadcast.jl index 9fd1fee9..77afde55 100644 --- a/src/blockbroadcast.jl +++ b/src/blockbroadcast.jl @@ -30,10 +30,12 @@ BroadcastStyle(::PseudoBlockStyle{M}, ::BlockStyle{N}) where {M,N} = BlockStyle( # sortedunion can assume inputs are already sorted so this could be improved sortedunion(a,b) = sort!(union(a,b)) -sortedunion(a::Tuple, b::Tuple) = (a..., b...) +sortedunion_tuple(a::Tuple, b::Tuple) = (a..., b...) sortedunion(a::Base.OneTo, b::Base.OneTo) = Base.OneTo(max(last(a),last(b))) sortedunion(a::AbstractUnitRange, b::AbstractUnitRange) = min(first(a),first(b)):max(last(a),last(b)) -combine_blockaxes(a, b) = _BlockedUnitRange(sortedunion(blocklasts(a), blocklasts(b))) +combine_blockaxes(a, b) = combine_blockaxes(a, b, blocklasts(a), blocklasts(b)) +combine_blockaxes(a, b, abl, bbl) = length(b) == 1 ? a : _BlockedUnitRange(sortedunion(abl, bbl)) +combine_blockaxes(a, b, abl::Tuple, bbl::Tuple) = _BlockedUnitRange(sortedunion_tuple(abl, bbl)) Base.Broadcast.axistype(a::BlockedUnitRange, b::BlockedUnitRange) = combine_blockaxes(a, b) Base.Broadcast.axistype(a::BlockedUnitRange, b) = combine_blockaxes(a, b)