-
Notifications
You must be signed in to change notification settings - Fork 31
Description
I got this working, sort of:
julia> d = For(j -> Normal(j, 2.0), 1:3)
For{Normal{(:μ, :σ), Tuple{Int64, Float64}}}(j->Main.Normal(j, 2.0), (1:3,))
julia> test_transport(d, Normal() ^ 3)
Test Summary: | Pass Total Time
transport_to Normal() ^ 3 to For{Normal{(:μ, :σ), Tuple{Int64, Float64}}}(j->Main.Normal(j, 2.0), (1:3,)) | 8 8 0.0s
DefaultTestSet("transport_to Normal() ^ 3 to For{Normal{(:μ, :σ), Tuple{Int64, Float64}}}(j->Main.Normal(j, 2.0), (1:3,))", Any[], 8, false, false, true, 1.66725e9, 1.66725e9)To do this, I added for_constructor that's like For, but a little smarter - it might sometimes collapse to a power measure:
for_constructor(f, x) = for_constructor(f, (x,))
@generated function for_constructor(f::F, inds::I) where {F,I<:Tuple}
eltypes = Tuple{eltype.(I.types)...}
quote
T = Core.Compiler.return_type(f, $eltypes)
_for(T, f, inds, static(Base.issingletontype(T)))
end
end
function _for(::Type{T}, f::F, inds::I, ::True) where {T,F,I}
instance(T) ^ size(first(inds))
end
function _for(::Type{T}, f::F, inds::I, ::False) where {T,F,I}
For{T,F,I}(f, inds)
endThen we just need the standard stuff:
function MeasureBase.transport_origin(d::AbstractProductMeasure)
for_constructor(MeasureBase.transport_origin, marginals(d))
end
function MeasureBase.to_origin(d::AbstractProductMeasure, x)
map(MeasureBase.to_origin, marginals(d), x)
end
function MeasureBase.from_origin(d::AbstractProductMeasure, x)
map(MeasureBase.from_origin, marginals(d), x)
endWell, almost. There's also this bug:
julia> MeasureBase._origin_depth(Normal() ^ 3)
ERROR: MethodError: no method matching ^(::MeasureBase.NoTransportOrigin{StdNormal}, ::Tuple{Int64})
Closest candidates are:
^(::AbstractMeasure, ::Tuple) at ~/git/MeasureBase.jl/src/combinators/power.jl:55
^(::AbstractMeasure, ::Any) at ~/git/MeasureBase.jl/src/combinators/power.jl:56
Stacktrace:
[1] _for(#unused#::Type{MeasureBase.NoTransportOrigin{StdNormal}}, f::typeof(MeasureBase.transport_origin), inds::Tuple{FillArrays.Fill{StdNormal, 1, Tuple{Base.OneTo{Int64}}}}, #unused#::Static.True)
@ MeasureTheory ~/git/MeasureTheory.jl/src/combinators/for.jl:37
[2] macro expansion
@ ~/git/MeasureTheory.jl/src/combinators/for.jl:32 [inlined]
[3] for_constructor(f::typeof(MeasureBase.transport_origin), inds::Tuple{FillArrays.Fill{StdNormal, 1, Tuple{Base.OneTo{Int64}}}})
@ MeasureTheory ~/git/MeasureTheory.jl/src/combinators/for.jl:28
[4] for_constructor(f::Function, x::FillArrays.Fill{StdNormal, 1, Tuple{Base.OneTo{Int64}}})
@ MeasureTheory ~/git/MeasureTheory.jl/src/combinators/for.jl:26
[5] transport_origin(d::PowerMeasure{StdNormal, Tuple{Base.OneTo{Int64}}})
@ MeasureTheory ~/git/MeasureTheory.jl/src/combinators/for.jl:305
[6] _origin_depth(ν::PowerMeasure{Normal{(), Tuple{}}, Tuple{Base.OneTo{Int64}}})
@ MeasureBase ~/git/MeasureBase.jl/src/transport.jl:130
[7] top-level scope
@ REPL[60]:1We end up taking a power of a NoTransportOrigin, which makes no sense. As a quick fix, I temporarily changed MeasureBase._origin_depth to
@inline function _origin_depth(ν::NU) where {NU}
ν_0 = ν
Base.Cartesian.@nexprs 10 i -> begin # 10 is just some "big enough" number
ν_{i} = transport_origin(ν_{i - 1})
if ν_{i} isa PowerMeasure
ν_{i} = ν_{i}.parent
else
if ν_{i} isa NoTransportOrigin
return static(i - 1)
end
end
return static(10)
endThis last part feels kind of hacky. Also, we have the problem that map forces allocation. It would be nice to use mappedarray instead, but that doesn't infer properly. Maybe a modification of it could?
Also, it seems like a problem if we have a product with different "origin depths". A fixpoint approach would handle this, but I think the current approach will break. Any ideas for this @oschulz ?