Skip to content

Commit c159659

Browse files
committed
Add implicitly mapped measures and kernels
1 parent 5b1f4aa commit c159659

File tree

5 files changed

+351
-0
lines changed

5 files changed

+351
-0
lines changed

src/MeasureBase.jl

+4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module MeasureBase
22

33
using Base: @propagate_inbounds
4+
using Base: OneTo
45

56
using Random
67
import Random: rand!
@@ -144,6 +145,7 @@ include("combinators/restricted.jl")
144145
include("combinators/smart-constructors.jl")
145146
include("combinators/powerweighted.jl")
146147
include("combinators/conditional.jl")
148+
include("combinators/implicitlymapped.jl")
147149

148150
include("standard/stdmeasure.jl")
149151
include("standard/stduniform.jl")
@@ -152,6 +154,8 @@ include("standard/stdlogistic.jl")
152154
include("standard/stdnormal.jl")
153155
include("combinators/half.jl")
154156

157+
#include("implicitmaps.jl")
158+
155159
include("rand.jl")
156160

157161
include("density.jl")

src/combinators/implicitlymapped.jl

+250
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
2+
"""
3+
abstract type ImplicitlyMapped
4+
5+
Supertype for objects that have been mapped in an implicit way.
6+
7+
The explicit map/function can only be determined given some kind of observed
8+
result `obs` using
9+
10+
```julia
11+
f_map = explicit_mapfunc(mapped::ImplicitlyMapped, obs)
12+
```
13+
14+
The original object that has been implicitly mapped
15+
may be retrieved via
16+
17+
```julia
18+
obj = explicit_mapfunc(mapped::ImplicitlyMapped, obs)
19+
```
20+
21+
Note that `obs` is typically *not* the directly result of `f_map(ob)`. Instead,
22+
the relationship between `obj`, `f_map`, and `obs` depends on what `obj` is:
23+
24+
* A measure `mu = obj`: The mapping process is equivalent to
25+
`mapped_mu = pushfwd(f_map, mu, PushfwdRootMeasure())` and `obs` is an
26+
element of the measurable space of `mu`. Implicitly mapped measures support
27+
28+
```julia
29+
DensityInterface.DensityKind(mapped_mu::ImplicitlyMapped)
30+
DensityInterface.logdensityof(mapped_mu::ImplicitlyMapped, obs)
31+
```
32+
33+
and the explicitly mapped measure can be generated via
34+
35+
```julia
36+
explicit_measure(mapped_mu::ImplicitlyMapped, obs)
37+
```
38+
39+
* A transition/Markov kernel `f_kernel = obj`, i.e. a function that maps
40+
points in some space to measures on a (possibly different) space:
41+
The mapping process is equivalent to
42+
`mapped_f_kernel = (p -> pushfwd(f_map, f_kernel(p), PushfwdRootMeasure()))`
43+
and `obs` is an element of the measurable space of the measures generated
44+
by the mapped kernel. Implicitly mapped transition/Markov kernels support
45+
46+
```julia
47+
Likelihood(mapped_f_kernel::ImplicitlyMapped, obs)
48+
```
49+
50+
and the explicitly mapped kernel can be generated via
51+
52+
```julia
53+
explicit_measure(mapped_mu::ImplicitlyMapped, obs)
54+
```
55+
56+
# Implementation
57+
58+
Subtypes of `ImplicitlyMapped` that should support origin measures of
59+
type `SomeRelevantMeasure` and observations of type `SomeRelevantObs`,
60+
resulting in explicit maps/functions of type `SomeMapFunc`, must
61+
implement/specialize
62+
63+
```julia
64+
MeasureBase.implicit_origin(mapped::MyImplicitlyMapped)
65+
MeasureBase.explicit_mapfunc(mapped::MyImplicitlyMapped, obs::SomeRelevantObs)::SomeMapFunc
66+
```
67+
68+
and (except if functions of type `SomeMapFunc` are invertible via
69+
`InverseFunctions.inverse`) must also specialize
70+
71+
```julia
72+
MeasureBase.pushfwd(f::SomeMapFunc, mu::SomeRelevantMeasure, ::PushfwdRootMeasure)
73+
```
74+
75+
Subtypes of `ImplicitlyMapped` may support multiple combinations of
76+
observation and measure types.
77+
"""
78+
abstract type ImplicitlyMapped end
79+
export ImplicitlyMapped
80+
81+
"""
82+
implicit_origin(mapped::ImplicitlyMapped)
83+
84+
Get the original object (a measure or transition/Markov kernel) that was
85+
implicitly mapped.
86+
87+
See [ImplicitlyMapped](@ref) for detailed semantics.
88+
89+
# Implementation
90+
91+
`implicit_origin` must be implemented for subtypes of `ImplicitlyMapped`,
92+
there is no default implementation.
93+
"""
94+
function implicit_origin end
95+
export implicit_origin
96+
97+
"""
98+
explicit_mapfunc(mapped::ImplicitlyMapped, obs)
99+
100+
Get an explicit map/function based on an implicitly mapped object and an
101+
observation.
102+
103+
See [ImplicitlyMapped](@ref) for detailed semantics.
104+
105+
# Implementation
106+
107+
`explicit_mapfunc` must be implemented for subtypes of `ImplicitlyMapped`,
108+
there is no default implementation.
109+
"""
110+
function explicit_mapfunc end
111+
export explicit_mapfunc
112+
113+
"""
114+
explicit_measure(mapped::ImplicitlyMapped, obs)
115+
116+
Get an explicitly mapped measure based on an implicitly mapped measure and an
117+
observation that provides context on which pushforward to use on the onmapped
118+
original measure `implicit_origin(mapped)`.
119+
120+
Used [`explicit_mapfunc`](@ref) to get the function to use in the pushforward.
121+
122+
# Implementation
123+
124+
`explicit_measure` does not need to be specialized for subtypes of
125+
`ImplicitlyMapped`.
126+
"""
127+
function explicit_measure(mapped_measure::ImplicitlyMapped, obs)
128+
f_map = explicit_mapfunc(mapped_measure, obs)
129+
mu = implicit_origin(mapped_measure)
130+
return pushfwd(f_map, mu, PushfwdRootMeasure())
131+
end
132+
export explicit_measure
133+
134+
function DensityInterface.logdensityof(mapped_measure::ImplicitlyMapped, obs)
135+
return logdensityof(explicit_measure(mapped_measure, obs), obs)
136+
end
137+
138+
function DensityInterface.DensityKind(mapped::ImplicitlyMapped)
139+
DensityKind(implicit_origin(mapped))
140+
end
141+
142+
"""
143+
explicit_kernel(mapped::ImplicitlyMapped, obs)
144+
145+
Get an expliclity mapped transition/Markov kernel, based on an implicitly
146+
mapped kernel and an observation that provides context on which pushforward
147+
to add to the unmapped original kernel `implicit_origin(mapped)`.
148+
149+
Used [`explicit_mapfunc`](@ref) to get the function to use in the pushforward.
150+
151+
# Implementation
152+
153+
`explicit_kernel` does not need to be specialized for subtypes of
154+
`ImplicitlyMapped`.
155+
"""
156+
function explicit_kernel(mapped_kernel::ImplicitlyMapped, obs)
157+
f_map = explicit_mapfunc(mapped_kernel, obs)
158+
f_kernel = implicit_origin(mapped_kernel)
159+
return (p -> pushfwd(f_map, f_kernel(p), PushfwdRootMeasure()))
160+
end
161+
export explicit_kernel
162+
163+
function Likelihood(mapped_kernel::ImplicitlyMapped, obs)
164+
return Likelihood(explicit_kernel(mapped_kernel, obs), obs)
165+
end
166+
167+
"""
168+
struct MeasureBase.TakeAny{T} <: Function
169+
170+
Represents a function that takes n values from a collection.
171+
172+
`f = TakeAny(n)` treats all collections as unordered: `f(xs) may take the
173+
first `n` elements of `xs`, but there is no guarantee. It must, however,
174+
always take take the same elements from collections that are identical.
175+
176+
Constructor: `TakeAny(n::Union{Integer,Static.StaticInteger})`.
177+
"""
178+
struct TakeAny{T<:IntegerLike}
179+
n::T
180+
end
181+
182+
_takeany_range(f::TakeAny, idxs) = first(idxs):first(idxs)+dynamic(f.n)-1
183+
@inline _takeany_range(f::TakeAny, ::OneTo) = OneTo(dynamic(f.n))
184+
185+
@inline _takeany_range(::TakeAny{<:Static.StaticInteger{N}}, ::OneTo) where {N} = SOneTo(N)
186+
@inline _takeany_range(::TakeAny{<:Static.StaticInteger{N}}, ::SOneTo) where {N} = SOneTo(N)
187+
188+
@inline (f::TakeAny)(xs::Tuple) = xs[begin:begin+f.n-1]
189+
@inline (f::TakeAny)(xs::AbstractVector) = xs[_takeany_range(f, eachindex(xs))]
190+
191+
function (f::TakeAny)(xs)
192+
n = dynamic(f.n)
193+
ys = collect(Iterators.take(xs, n))
194+
length(ys) != n &&
195+
throw(ArgumentError("Can't take $n elements from a sequence shorter than $n"))
196+
return typeof(xs)(ys)
197+
end
198+
199+
"""
200+
struct Marginalized{T} <: ImplicitlyMapped
201+
202+
Represents an implicitly marginalized measure or transition kernel.
203+
204+
Constructors:
205+
206+
* `Marginalized(mu)`
207+
* `Marginalized(f_kernel)`
208+
209+
See [ImplicitlyMapped](@ref) for detailed semantics.
210+
211+
Example:
212+
213+
```julia
214+
mu = productmeasure((a = StdUniform(), b = StdNormal(), c = StdExponential()))
215+
obs = (a = 0.7, c = 1.2)
216+
217+
marg_mu_equiv = productmeasure((a = StdUniform(), c = StdExponential()))
218+
219+
logdensityof(Marginalized(mu), obs) ≈ logdensityof(marg_mu_equiv, obs)
220+
```
221+
"""
222+
struct Marginalized{T} <: ImplicitlyMapped
223+
obj::T
224+
end
225+
export Marginalized
226+
227+
implicit_origin(mapped::Marginalized) = mapped.obj
228+
229+
function explicit_mapfunc(::Marginalized, obs::NamedTuple{names}) where {names}
230+
PropSelFunction{names,names}()
231+
end
232+
function pushfwd(f::PropSelFunction, mu::ProductMeasure{<:NamedTuple}, ::PushfwdRootMeasure)
233+
productmeasure(f(marginals(mu)))
234+
end
235+
236+
explicit_mapfunc(::Marginalized, obs::AbstractVector) = TakeAny(length(obs))
237+
explicit_mapfunc(::Marginalized, obs::StaticArray{Tuple{N}}) where {N} = TakeAny(static(N))
238+
239+
function pushfwd(
240+
f::TakeAny,
241+
mu::PowerMeasure{<:Any,<:Tuple{<:AbstractUnitRange}},
242+
::PushfwdRootMeasure,
243+
)
244+
n = f.n
245+
n_mu = length(mu)
246+
n_mu < n && throw(
247+
ArgumentError("Can't marginalize $n_mu dimensional power measure to $n dimensions"),
248+
)
249+
mu.parent^f.n
250+
end

test/Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
AffineMaps = "2c83c9a8-abf5-4329-a0d7-deffaf474661"
23
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
34
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
45
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"

test/combinators/implicitlymapped.jl

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
using Test
2+
3+
using MeasureBase
4+
5+
using StaticArrays: SVector
6+
using Static: static
7+
using AffineMaps, PropertyFunctions
8+
9+
@testset "implicitlymapped" begin
10+
@testset "TakeAny" begin
11+
V = [3, 2, 4, 2, 7, 5, 6]
12+
mV = [3, 2, 4, 2]
13+
S = Set(V)
14+
mS = Set(mV)
15+
SV = SVector(V...)
16+
mSV = SVector(mV...)
17+
18+
@test @inferred(MeasureBase.TakeAny(4)(V)) == mV
19+
@test @inferred(MeasureBase.TakeAny(static(4))(V)) == mV
20+
tS = @inferred(MeasureBase.TakeAny(4)(S))
21+
@test tS isa Set && length(tS) == 4 && all(x -> x in S, tS)
22+
@test @inferred(MeasureBase.TakeAny(static(4))(S)) == MeasureBase.TakeAny(4)(S)
23+
@test @inferred(MeasureBase.TakeAny(static(4))(SV)) === mSV
24+
@test @inferred(MeasureBase.TakeAny(4)(SV)) == mV
25+
@test @inferred(MeasureBase.TakeAny(4)(V)) == mV
26+
end
27+
28+
function test_implicitly_mapped(
29+
label,
30+
f_kernel,
31+
ref_mapfunc,
32+
ref_mappedkernel,
33+
par,
34+
orig_obs,
35+
obs,
36+
)
37+
@testset "$label" begin
38+
im_measure = @inferred Marginalized(f_kernel(par))
39+
im_kernel = @inferred Marginalized(f_kernel)
40+
mapfunc = @inferred explicit_mapfunc(im_measure, obs)
41+
mapped_measure = @inferred explicit_measure(im_measure, obs)
42+
mapped_likelihood = @inferred Likelihood(im_kernel, obs)
43+
44+
@test mapfunc == ref_mapfunc
45+
@test @inferred(mapfunc(orig_obs)) == obs
46+
@test mapped_measure == ref_mappedkernel(par)
47+
48+
@test @inferred(logdensityof(im_measure, obs))
49+
logdensityof(mapped_measure, obs)
50+
@test @inferred(logdensityof(mapped_likelihood, par))
51+
logdensityof(Likelihood(ref_mappedkernel, obs), par)
52+
end
53+
end
54+
55+
f_kernel =
56+
par -> productmeasure(
57+
map(
58+
m -> pushfwd(Mul(par), m),
59+
(a = StdUniform(), b = StdNormal(), c = StdExponential()),
60+
),
61+
)
62+
ref_mapfunc = @pf (; $a, $c)
63+
ref_mappedkernel =
64+
par -> productmeasure(
65+
map(m -> pushfwd(Mul(par), m), (a = StdUniform(), c = StdExponential())),
66+
)
67+
par = 4.2
68+
orig_obs = (a = 0.7, b = 2.1, c = 1.2)
69+
obs = (a = 0.7, c = 1.2)
70+
test_implicitly_mapped(
71+
"marginalized nt",
72+
f_kernel,
73+
ref_mapfunc,
74+
ref_mappedkernel,
75+
par,
76+
orig_obs,
77+
obs,
78+
)
79+
80+
f_kernel = par -> pushfwd(Mul(par), StdNormal())^7
81+
ref_mapfunc = MeasureBase.TakeAny(3)
82+
ref_mappedkernel = par -> pushfwd(Mul(par), StdNormal())^3
83+
par = 4.2
84+
orig_obs = [9.4, -7.3, 1.0, -2.9, 1.9, 4.7, 0.5]
85+
obs = [9.4, -7.3, 1.0]
86+
test_implicitly_mapped(
87+
"marginalized nt",
88+
f_kernel,
89+
ref_mapfunc,
90+
ref_mappedkernel,
91+
par,
92+
orig_obs,
93+
obs,
94+
)
95+
end

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,6 @@ include("smf.jl")
2121

2222
include("combinators/weighted.jl")
2323
include("combinators/transformedmeasure.jl")
24+
include("combinators/implicitlymapped.jl")
2425

2526
include("test_docs.jl")

0 commit comments

Comments
 (0)