From 0b9a40e7fa84da419a16086e1d2f0d4db994630d Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Fri, 3 Jun 2022 10:15:10 +0200 Subject: [PATCH 1/2] provide type-inferred length resolves #138 Need to check correct general computation of length of Axis. --- src/axis.jl | 1 + src/componentarray.jl | 4 ++++ test/runtests.jl | 9 +++++++++ 3 files changed, 14 insertions(+) diff --git a/src/axis.jl b/src/axis.jl index 29ea4222..47bb8d5c 100644 --- a/src/axis.jl +++ b/src/axis.jl @@ -142,6 +142,7 @@ Base.merge(axs::Axis...) = Axis(merge(indexmap.(axs)...)) Base.firstindex(ax::AbstractAxis) = first(viewindex(first(indexmap(ax)))) Base.lastindex(ax::AbstractAxis) = last(viewindex(last(indexmap(ax)))) +Base.length(ax::AbstractAxis) = lastindex(ax) - firstindex(ax) + 1 Base.keys(ax::AbstractAxis) = keys(indexmap(ax)) diff --git a/src/componentarray.jl b/src/componentarray.jl index ea09fa84..66c6e562 100644 --- a/src/componentarray.jl +++ b/src/componentarray.jl @@ -235,6 +235,10 @@ last_index(x) = last(x) last_index(x::ViewAxis) = last_index(viewindex(x)) last_index(x::AbstractAxis) = last_index(last(indexmap(x))) +# length information is in Axis, use it to make SVector creation type stable +Base.length(ca::ComponentArray) = prod(length.(getaxes(ca))) +Base.size(ca::ComponentArray) = map(length, getaxes(ca)) + # Reduce singleton dimensions remove_nulls() = () remove_nulls(x1, args...) = (x1, remove_nulls(args...)...) diff --git a/test/runtests.jl b/test/runtests.jl index a863af41..6c6b49a4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -598,6 +598,15 @@ end @test convert(Cholesky{Float32,Matrix{Float32}}, chol).factors isa Matrix{Float32} end +@testset "length typestable" begin + # function boundary, so that cv is type-inferred + test_create_svector = (cv) -> SVector{length(cv)}(cv) + @inferred test_create_svector(ComponentVector(a=1:3)); + @inferred test_create_svector(cmat); + test_create_smatrix = (cmat) -> SMatrix{size(cmat)...}(cmat) + @test (@inferred test_create_smatrix(cmat)) isa SMatrix +end; + @testset "Autodiff" begin include("autodiff_tests.jl") end \ No newline at end of file From c998de7e467f26e93379fa220bebcf64c39f063a Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Fri, 3 Jun 2022 10:55:41 +0200 Subject: [PATCH 2/2] handle special case of top-level flat axis --- src/axis.jl | 2 ++ src/componentarray.jl | 13 +++++++++++-- test/runtests.jl | 1 - 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/axis.jl b/src/axis.jl index 47bb8d5c..3e7e27cc 100644 --- a/src/axis.jl +++ b/src/axis.jl @@ -140,9 +140,11 @@ const NotShapedOrPartitionedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis} where Base.merge(axs::Axis...) = Axis(merge(indexmap.(axs)...)) +# TODO: broken for FlatAxis Base.firstindex(ax::AbstractAxis) = first(viewindex(first(indexmap(ax)))) Base.lastindex(ax::AbstractAxis) = last(viewindex(last(indexmap(ax)))) Base.length(ax::AbstractAxis) = lastindex(ax) - firstindex(ax) + 1 +Base.length(ax::NullorFlatAxis) = error("NullorFlatAxis has no length") Base.keys(ax::AbstractAxis) = keys(indexmap(ax)) diff --git a/src/componentarray.jl b/src/componentarray.jl index 66c6e562..955ed230 100644 --- a/src/componentarray.jl +++ b/src/componentarray.jl @@ -236,8 +236,17 @@ last_index(x::ViewAxis) = last_index(viewindex(x)) last_index(x::AbstractAxis) = last_index(last(indexmap(x))) # length information is in Axis, use it to make SVector creation type stable -Base.length(ca::ComponentArray) = prod(length.(getaxes(ca))) -Base.size(ca::ComponentArray) = map(length, getaxes(ca)) +@inline _hasNullOrFlatAxis(ca) = any(map(ax -> ax isa NullorFlatAxis, getaxes(ca))) +function Base.length(ca::ComponentArray) + # vca2 = vcat(ca2', ca2') #has not length - is it a valid ComponentVector + # or rather a Vector + _hasNullOrFlatAxis(ca) && return(length(getdata(ca))) + prod(length.(getaxes(ca))) +end +function Base.size(ca::ComponentArray) + _hasNullOrFlatAxis(ca) && return(size(getdata(ca))) + map(length, getaxes(ca)) +end # Reduce singleton dimensions remove_nulls() = () diff --git a/test/runtests.jl b/test/runtests.jl index 6c6b49a4..7b27cec5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -602,7 +602,6 @@ end # function boundary, so that cv is type-inferred test_create_svector = (cv) -> SVector{length(cv)}(cv) @inferred test_create_svector(ComponentVector(a=1:3)); - @inferred test_create_svector(cmat); test_create_smatrix = (cmat) -> SMatrix{size(cmat)...}(cmat) @test (@inferred test_create_smatrix(cmat)) isa SMatrix end;