Skip to content

Commit d155042

Browse files
committed
Improve pushforward implementation
Rename TransformVolCorr and subtypes (with backward compatibility).
1 parent 7fcbec4 commit d155042

File tree

4 files changed

+184
-79
lines changed

4 files changed

+184
-79
lines changed

src/MeasureBase.jl

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ using DensityInterface: FuncDensity, LogFuncDensity
2121
using DensityInterface
2222

2323
using InverseFunctions
24+
using InverseFunctions: FunctionWithInverse
2425
using ChangesOfVariables
2526
using ConstantRNGs
2627

src/combinators/transformedmeasure.jl

+165-52
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,46 @@
1-
# TODO: Compare with ChangesOfVariables.jl
1+
"""
2+
abstract type PushFwdStyle
3+
4+
Provides the behavior of a measure's [`rootmeasure`](@ref) under a
5+
pushforward. Either [`AdaptRootMeasure()`](@ref) or
6+
[`PushfwdRootMeasure()`](@ref)
7+
"""
8+
abstract type PushFwdStyle end
9+
export PushFwdStyle
10+
11+
const TransformVolCorr = PushFwdStyle
12+
13+
"""
14+
AdaptRootMeasure()
15+
16+
Indicates that when applying a pushforward to a measure, it's
17+
[`rootmeasure`](@ref) not not be pushed forward. Instead, the root measure
18+
should be kept just "reshaped" to the new measurable space if necessary.
19+
20+
Density calculations for pushforward measures constructed with
21+
`AdaptRootMeasure()` will take take the volume element of variate
22+
transform (typically via the log-abs-det-Jacobian of the transform) into
23+
account.
24+
"""
25+
struct AdaptRootMeasure <: TransformVolCorr end
26+
export AdaptRootMeasure
27+
28+
const WithVolCorr = AdaptRootMeasure
229

3-
using InverseFunctions: FunctionWithInverse
30+
"""
31+
PushfwdRootMeasure()
32+
33+
Indicates than when applying a pushforward to a measure, it's
34+
[`rootmeasure`](@ref) should be pushed forward with the same function.
35+
36+
Density calculations for pushforward measures constructed with
37+
`PushfwdRootMeasure()` will ignore the volume element of the variate
38+
transform.
39+
"""
40+
struct PushfwdRootMeasure <: TransformVolCorr end
41+
export PushfwdRootMeasure
42+
43+
const NoVolCorr = PushfwdRootMeasure
444

545
abstract type AbstractTransformedMeasure <: AbstractMeasure end
646

@@ -19,23 +59,42 @@ function parent(::AbstractTransformedMeasure) end
1959
export PushforwardMeasure
2060

2161
"""
22-
struct PushforwardMeasure{F,I,M,VC<:TransformVolCorr} <: AbstractPushforward
62+
struct PushforwardMeasure{F,I,M,S<:PushFwdStyle} <: AbstractPushforward
2363
f :: F
2464
finv :: I
2565
origin :: M
26-
volcorr :: VC
66+
style :: S
2767
end
2868
2969
Users should not call `PushforwardMeasure` directly. Instead call or add
3070
methods to `pushfwd`.
3171
"""
32-
struct PushforwardMeasure{F,I,M,VC<:TransformVolCorr} <: AbstractPushforward
72+
struct PushforwardMeasure{F,I,M,S<:PushFwdStyle} <: AbstractPushforward
3373
f::F
3474
finv::I
3575
origin::M
36-
volcorr::VC
76+
style::S
77+
78+
function PushforwardMeasure{F,I,M,S}(
79+
f::F,
80+
finv::I,
81+
origin::M,
82+
style::S,
83+
) where {F,I,M,S<:PushFwdStyle}
84+
new{F,I,M,S}(f, finv, origin, style)
85+
end
86+
87+
function PushforwardMeasure(f, finv, origin::M, style::S) where {M,S<:PushFwdStyle}
88+
new{Core.Typeof(f),Core.Typeof(finv),M,S}(f, finv, origin, style)
89+
end
3790
end
3891

92+
const _NonBijectivePusfwdMeasure{M<:PushforwardMeasure,S<:PushFwdStyle} = Union{
93+
PushforwardMeasure{<:Any,<:NoInverse,M,S},
94+
PushforwardMeasure{<:NoInverse,<:Any,M,S},
95+
PushforwardMeasure{<:NoInverse,<:NoInverse,M,S},
96+
}
97+
3998
gettransform::PushforwardMeasure) = ν.f
4099
parent::PushforwardMeasure) = ν.origin
41100

@@ -45,55 +104,94 @@ end
45104

46105
# TODO: THIS IS ALMOST CERTAINLY WRONG
47106
# @inline function logdensity_rel(
48-
# ν::PushforwardMeasure{FF1,IF1,M1,<:WithVolCorr},
49-
# β::PushforwardMeasure{FF2,IF2,M2,<:WithVolCorr},
107+
# ν::PushforwardMeasure{FF1,IF1,M1,<:AdaptRootMeasure},
108+
# β::PushforwardMeasure{FF2,IF2,M2,<:AdaptRootMeasure},
50109
# y,
51110
# ) where {FF1,IF1,M1,FF2,IF2,M2}
52111
# x = β.inv_f(y)
53112
# f = ν.inv_f ∘ β.f
54113
# inv_f = β.inv_f ∘ ν.f
55-
# logdensity_rel(pushfwd(f, inv_f, ν.origin, WithVolCorr()), β.origin, x)
114+
# logdensity_rel(pushfwd(f, inv_f, ν.origin, AdaptRootMeasure()), β.origin, x)
56115
# end
57116

117+
# TODO: Would profit from custom pullback:
118+
function _combine_logd_with_ladj(logd_orig::Real, ladj::Real)
119+
logd_result = logd_orig + ladj
120+
R = typeof(logd_result)
121+
122+
if isnan(logd_result) && isneginf(logd_orig) && isposinf(ladj)
123+
# Zero μ wins against infinite volume:
124+
R(-Inf)::R
125+
elseif isfinite(logd_orig) && isneginf(ladj)
126+
# Maybe also for isneginf(logd_orig) && isfinite(ladj) ?
127+
# Return constant -Inf to prevent problems with ForwardDiff:
128+
#R(-Inf)
129+
near_neg_inf(R)::R # Avoids AdvancedHMC warnings
130+
else
131+
logd_result::R
132+
end
133+
end
134+
135+
function logdensityof(
136+
@nospecialize::_NonBijectivePusfwdMeasure{M,<:PushfwdRootMeasure}),
137+
@nospecialize(v::Any)
138+
) where {M}
139+
throw(
140+
ArgumentError(
141+
"Can't calculate densities for non-bijective pushforward measure $(nameof(M))",
142+
),
143+
)
144+
end
145+
146+
function logdensityof(
147+
@nospecialize::_NonBijectivePusfwdMeasure{M,<:AdaptRootMeasure}),
148+
@nospecialize(v::Any)
149+
) where {M}
150+
throw(
151+
ArgumentError(
152+
"Can't calculate densities for non-bijective pushforward measure $(nameof(M))",
153+
),
154+
)
155+
end
156+
58157
for func in [:logdensityof, :logdensity_def]
59-
@eval @inline function $func::PushforwardMeasure{F,I,M,<:WithVolCorr}, y) where {F,I,M}
60-
f = ν.f
61-
finv = ν.finv
62-
x_orig, inv_ladj = with_logabsdet_jacobian(unwrap(finv), y)
63-
logd_orig = $func.origin, x_orig)
64-
logd = float(logd_orig + inv_ladj)
65-
neginf = oftype(logd, -Inf)
66-
return ifelse(
67-
# Zero density wins against infinite volume:
68-
(isnan(logd) && logd_orig == -Inf && inv_ladj == +Inf) ||
69-
# Maybe also for (logd_orig == -Inf) && isfinite(inv_ladj) ?
70-
# Return constant -Inf to prevent problems with ForwardDiff:
71-
(isfinite(logd_orig) && (inv_ladj == -Inf)),
72-
neginf,
73-
logd,
74-
)
158+
@eval function $func::PushforwardMeasure{F,I,M,<:AdaptRootMeasure}, y) where {F,I,M}
159+
f_inv = unwrap.finv)
160+
x, inv_ladj = with_logabsdet_jacobian(f_inv, y)
161+
logd_orig = $func.origin, x)
162+
return _combine_logd_with_ladj(logd_orig, inv_ladj)
75163
end
76164

77-
@eval @inline function $func::PushforwardMeasure{F,I,M,<:NoVolCorr}, y) where {F,I,M}
78-
x = ν.finv(y)
79-
return $func.origin, x)
165+
@eval function $func::PushforwardMeasure{F,I,M,<:PushfwdRootMeasure}, y) where {F,I,M}
166+
f_inv = unwrap.finv)
167+
x = f_inv(y)
168+
logd_orig = $func.origin, x)
169+
return logd_orig
80170
end
81171
end
82172

83-
insupport(ν::PushforwardMeasure, y) = insupport(ν.origin, ν.finv(y))
173+
insupport(m::PushforwardMeasure, x) = insupport(transport_origin(m), to_origin(m, x))
84174

85175
function testvalue(::Type{T}, ν::PushforwardMeasure) where {T}
86176
ν.f(testvalue(T, parent(ν)))
87177
end
88178

89179
@inline function basemeasure::PushforwardMeasure)
90-
pushfwd.f, basemeasure(parent(ν)), NoVolCorr())
180+
pushfwd.f, basemeasure(parent(ν)), PushfwdRootMeasure())
181+
end
182+
183+
function rootmeasure(m::PushforwardMeasure{F,I,M,PushfwdRootMeasure}) where {F,I,M}
184+
pushfwd(m.f, rootmeasure(m.origin))
185+
end
186+
function rootmeasure(m::PushforwardMeasure{F,I,M,AdaptRootMeasure}) where {F,I,M}
187+
rootmeasure(m.origin)
91188
end
92189

93190
_pushfwd_dof(::Type{MU}, ::Type, dof) where {MU} = NoDOF{MU}()
94191
_pushfwd_dof(::Type{MU}, ::Type{<:Tuple{Any,Real}}, dof) where {MU} = dof
95192

96193
@inline getdof::MU) where {MU<:PushforwardMeasure} = getdof.origin)
194+
@inline getdof(m::_NonBijectivePusfwdMeasure) = MeasureBase.NoDOF{typeof(m)}()
97195

98196
# Bypass `checked_arg`, would require potentially costly transformation:
99197
@inline checked_arg(::PushforwardMeasure, x) = x
@@ -102,47 +200,53 @@ _pushfwd_dof(::Type{MU}, ::Type{<:Tuple{Any,Real}}, dof) where {MU} = dof
102200
@inline from_origin::PushforwardMeasure, x) = ν.f(x)
103201
@inline to_origin::PushforwardMeasure, y) = ν.finv(y)
104202

203+
massof(m::PushforwardMeasure) = massof(transport_origin(m))
204+
105205
function Base.rand(rng::AbstractRNG, ::Type{T}, ν::PushforwardMeasure) where {T}
106-
return ν.f(rand(rng, T, parent(ν)))
206+
return ν.f(rand(rng, T, ν.origin))
107207
end
108208

109209
###############################################################################
110210
# pushfwd
111211

112-
export pushfwd
113-
114212
"""
115-
pushfwd(f, μ, volcorr = WithVolCorr())
213+
pushfwd(f, μ, style = AdaptRootMeasure())
116214
117215
Return the [pushforward
118216
measure](https://en.wikipedia.org/wiki/Pushforward_measure) from `μ` the
119217
[measurable function](https://en.wikipedia.org/wiki/Measurable_function) `f`.
120218
121219
To manually specify an inverse, call
122-
`pushfwd(InverseFunctions.setinverse(f, finv), μ, volcorr)`.
220+
`pushfwd(InverseFunctions.setinverse(f, finv), μ, style)`.
123221
"""
124-
function pushfwd(f, μ, volcorr::TransformVolCorr = WithVolCorr())
125-
PushforwardMeasure(f, inverse(f), μ, volcorr)
126-
end
127-
128-
function pushfwd(f, μ::PushforwardMeasure, volcorr::TransformVolCorr = WithVolCorr())
129-
_pushfwd_of_pushfwd(f, μ, μ.volcorr, volcorr)
130-
end
222+
function pushfwd end
223+
export pushfwd
131224

132-
# Either both WithVolCorr or both NoVolCorr, so we can merge them
133-
function _pushfwd_of_pushfwd(f, μ::PushforwardMeasure, ::V, v::V) where {V}
134-
pushfwd(fchain((μ.f, f)), μ.origin, v)
225+
@inline pushfwd(f, μ) = _pushfwd_impl(f, μ, AdaptRootMeasure())
226+
@inline pushfwd(f, μ, style::AdaptRootMeasure) = _pushfwd_impl(f, μ, style)
227+
@inline pushfwd(f, μ, style::PushfwdRootMeasure) = _pushfwd_impl(f, μ, style)
228+
229+
_pushfwd_impl(f, μ, style) = PushforwardMeasure(f, inverse(f), μ, style)
230+
231+
function _pushfwd_impl(
232+
f,
233+
μ::PushforwardMeasure{F,I,M,S},
234+
style::S,
235+
) where {F,I,M,S<:PushFwdStyle}
236+
orig_μ = μ.origin
237+
new_f = fcomp(f, μ.f)
238+
new_f_inv = fcomp.finv, inverse(f))
239+
PushforwardMeasure(new_f, new_f_inv, orig_μ, style)
135240
end
136241

137-
function _pushfwd_of_pushfwd(f, μ::PushforwardMeasure, _, v)
138-
PushforwardMeasure(f, inverse(f), μ, v)
139-
end
242+
_pushfwd_impl(::typeof(identity), μ, ::AdaptRootMeasure) = μ
243+
_pushfwd_impl(::typeof(identity), μ, ::PushfwdRootMeasure) = μ
140244

141245
###############################################################################
142246
# pullback
143247

144248
"""
145-
pullback(f, μ, volcorr = WithVolCorr())
249+
pullbck(f, μ, style = AdaptRootMeasure())
146250
147251
A _pullback_ is a dual concept to a _pushforward_. While a pushforward needs a
148252
map _from_ the support of a measure, a pullback requires a map _into_ the
@@ -154,8 +258,17 @@ in terms of the inverse function; the "forward" function is not used at all. In
154258
some cases, we may be focusing on log-density (and not, for example, sampling).
155259
156260
To manually specify an inverse, call
157-
`pullback(InverseFunctions.setinverse(f, finv), μ, volcorr)`.
261+
`pullbck(InverseFunctions.setinverse(f, finv), μ, style)`.
158262
"""
159-
function pullback(f, μ, volcorr::TransformVolCorr = WithVolCorr())
160-
pushfwd(setinverse(inverse(f), f), μ, volcorr)
263+
function pullbck end
264+
export pullbck
265+
266+
@inline pullbck(f, μ) = _pullback_impl(f, μ, AdaptRootMeasure())
267+
@inline pullbck(f, μ, style::AdaptRootMeasure) = _pullback_impl(f, μ, style)
268+
@inline pullbck(f, μ, style::PushfwdRootMeasure) = _pullback_impl(f, μ, style)
269+
270+
function _pullback_impl(f, μ, style = AdaptRootMeasure())
271+
pushfwd(setinverse(inverse(f), f), μ, style)
161272
end
273+
274+
@deprecate pullback(f, μ, style::PushFwdStyle = AdaptRootMeasure()) pullbck(f, μ, style)

src/transport.jl

-27
Original file line numberDiff line numberDiff line change
@@ -274,30 +274,3 @@ function Base.show(io::IO, f::TransportFunction)
274274
end
275275

276276
Base.show(io::IO, M::MIME"text/plain", f::TransportFunction) = show(io, f)
277-
278-
"""
279-
abstract type TransformVolCorr
280-
281-
Provides control over density correction by transform volume element.
282-
Either [`NoVolCorr()`](@ref) or [`WithVolCorr()`](@ref)
283-
"""
284-
abstract type TransformVolCorr end
285-
286-
"""
287-
NoVolCorr()
288-
289-
Indicate that density calculations should ignore the volume element of
290-
variate transformations. Should only be used in special cases in which
291-
the volume element has already been taken into account in a different
292-
way.
293-
"""
294-
struct NoVolCorr <: TransformVolCorr end
295-
296-
"""
297-
WithVolCorr()
298-
299-
Indicate that density calculations should take the volume element of
300-
variate transformations into account (typically via the
301-
log-abs-det-Jacobian of the transform).
302-
"""
303-
struct WithVolCorr <: TransformVolCorr end

src/utils.jl

+18
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,21 @@ using InverseFunctions: FunctionWithInverse
164164

165165
unwrap(f) = f
166166
unwrap(f::FunctionWithInverse) = f.f
167+
168+
169+
fcomp(f, g) = fchain(g, f)
170+
fcomp(::typeof(identity), g) = g
171+
fcomp(f, ::typeof(identity)) = f
172+
fcomp(::typeof(identity), ::typeof(identity)) = identity
173+
174+
175+
near_neg_inf(::Type{T}) where T<:Real = T(-1E38) # Still fits into Float32
176+
177+
isneginf(x) = isinf(x) && x < 0
178+
isposinf(x) = isinf(x) && x > 0
179+
180+
isapproxzero(x::T) where T<:Real = x zero(T)
181+
isapproxzero(A::AbstractArray) = all(isapproxzero, A)
182+
183+
isapproxone(x::T) where T<:Real = x one(T)
184+
isapproxone(A::AbstractArray) = all(isapproxone, A)

0 commit comments

Comments
 (0)