1
- # TODO : Compare with ChangesOfVariables.jl
1
+ """
2
+ abstract type PushFwdStyle
3
+
4
+ Provides the behavior of a measure's [`rootmeasure`](@ref) under a
5
+ pushforward. Either [`AdaptRootMeasure()`](@ref) or
6
+ [`PushfwdRootMeasure()`](@ref)
7
+ """
8
+ abstract type PushFwdStyle end
9
+ export PushFwdStyle
10
+
11
+ const TransformVolCorr = PushFwdStyle
12
+
13
+ """
14
+ AdaptRootMeasure()
15
+
16
+ Indicates that when applying a pushforward to a measure, it's
17
+ [`rootmeasure`](@ref) not not be pushed forward. Instead, the root measure
18
+ should be kept just "reshaped" to the new measurable space if necessary.
19
+
20
+ Density calculations for pushforward measures constructed with
21
+ `AdaptRootMeasure()` will take take the volume element of variate
22
+ transform (typically via the log-abs-det-Jacobian of the transform) into
23
+ account.
24
+ """
25
+ struct AdaptRootMeasure <: TransformVolCorr end
26
+ export AdaptRootMeasure
27
+
28
+ const WithVolCorr = AdaptRootMeasure
2
29
3
- using InverseFunctions: FunctionWithInverse
30
+ """
31
+ PushfwdRootMeasure()
32
+
33
+ Indicates than when applying a pushforward to a measure, it's
34
+ [`rootmeasure`](@ref) should be pushed forward with the same function.
35
+
36
+ Density calculations for pushforward measures constructed with
37
+ `PushfwdRootMeasure()` will ignore the volume element of the variate
38
+ transform.
39
+ """
40
+ struct PushfwdRootMeasure <: TransformVolCorr end
41
+ export PushfwdRootMeasure
42
+
43
+ const NoVolCorr = PushfwdRootMeasure
4
44
5
45
abstract type AbstractTransformedMeasure <: AbstractMeasure end
6
46
@@ -19,23 +59,42 @@ function parent(::AbstractTransformedMeasure) end
19
59
export PushforwardMeasure
20
60
21
61
"""
22
- struct PushforwardMeasure{F,I,M,VC<:TransformVolCorr } <: AbstractPushforward
62
+ struct PushforwardMeasure{F,I,M,S<:PushFwdStyle } <: AbstractPushforward
23
63
f :: F
24
64
finv :: I
25
65
origin :: M
26
- volcorr :: VC
66
+ style :: S
27
67
end
28
68
29
69
Users should not call `PushforwardMeasure` directly. Instead call or add
30
70
methods to `pushfwd`.
31
71
"""
32
- struct PushforwardMeasure{F,I,M,VC <: TransformVolCorr } <: AbstractPushforward
72
+ struct PushforwardMeasure{F,I,M,S <: PushFwdStyle } <: AbstractPushforward
33
73
f:: F
34
74
finv:: I
35
75
origin:: M
36
- volcorr:: VC
76
+ style:: S
77
+
78
+ function PushforwardMeasure {F,I,M,S} (
79
+ f:: F ,
80
+ finv:: I ,
81
+ origin:: M ,
82
+ style:: S ,
83
+ ) where {F,I,M,S<: PushFwdStyle }
84
+ new {F,I,M,S} (f, finv, origin, style)
85
+ end
86
+
87
+ function PushforwardMeasure (f, finv, origin:: M , style:: S ) where {M,S<: PushFwdStyle }
88
+ new {Core.Typeof(f),Core.Typeof(finv),M,S} (f, finv, origin, style)
89
+ end
37
90
end
38
91
92
+ const _NonBijectivePusfwdMeasure{M<: PushforwardMeasure ,S<: PushFwdStyle } = Union{
93
+ PushforwardMeasure{<: Any ,<: NoInverse ,M,S},
94
+ PushforwardMeasure{<: NoInverse ,<: Any ,M,S},
95
+ PushforwardMeasure{<: NoInverse ,<: NoInverse ,M,S},
96
+ }
97
+
39
98
gettransform (ν:: PushforwardMeasure ) = ν. f
40
99
parent (ν:: PushforwardMeasure ) = ν. origin
41
100
45
104
46
105
# TODO : THIS IS ALMOST CERTAINLY WRONG
47
106
# @inline function logdensity_rel(
48
- # ν::PushforwardMeasure{FF1,IF1,M1,<:WithVolCorr },
49
- # β::PushforwardMeasure{FF2,IF2,M2,<:WithVolCorr },
107
+ # ν::PushforwardMeasure{FF1,IF1,M1,<:AdaptRootMeasure },
108
+ # β::PushforwardMeasure{FF2,IF2,M2,<:AdaptRootMeasure },
50
109
# y,
51
110
# ) where {FF1,IF1,M1,FF2,IF2,M2}
52
111
# x = β.inv_f(y)
53
112
# f = ν.inv_f ∘ β.f
54
113
# inv_f = β.inv_f ∘ ν.f
55
- # logdensity_rel(pushfwd(f, inv_f, ν.origin, WithVolCorr ()), β.origin, x)
114
+ # logdensity_rel(pushfwd(f, inv_f, ν.origin, AdaptRootMeasure ()), β.origin, x)
56
115
# end
57
116
117
+ # TODO : Would profit from custom pullback:
118
+ function _combine_logd_with_ladj (logd_orig:: Real , ladj:: Real )
119
+ logd_result = logd_orig + ladj
120
+ R = typeof (logd_result)
121
+
122
+ if isnan (logd_result) && isneginf (logd_orig) && isposinf (ladj)
123
+ # Zero μ wins against infinite volume:
124
+ R (- Inf ):: R
125
+ elseif isfinite (logd_orig) && isneginf (ladj)
126
+ # Maybe also for isneginf(logd_orig) && isfinite(ladj) ?
127
+ # Return constant -Inf to prevent problems with ForwardDiff:
128
+ # R(-Inf)
129
+ near_neg_inf (R):: R # Avoids AdvancedHMC warnings
130
+ else
131
+ logd_result:: R
132
+ end
133
+ end
134
+
135
+ function logdensityof (
136
+ @nospecialize (μ:: _NonBijectivePusfwdMeasure{M,<:PushfwdRootMeasure} ),
137
+ @nospecialize (v:: Any )
138
+ ) where {M}
139
+ throw (
140
+ ArgumentError (
141
+ " Can't calculate densities for non-bijective pushforward measure $(nameof (M)) " ,
142
+ ),
143
+ )
144
+ end
145
+
146
+ function logdensityof (
147
+ @nospecialize (μ:: _NonBijectivePusfwdMeasure{M,<:AdaptRootMeasure} ),
148
+ @nospecialize (v:: Any )
149
+ ) where {M}
150
+ throw (
151
+ ArgumentError (
152
+ " Can't calculate densities for non-bijective pushforward measure $(nameof (M)) " ,
153
+ ),
154
+ )
155
+ end
156
+
58
157
for func in [:logdensityof , :logdensity_def ]
59
- @eval @inline function $func (ν:: PushforwardMeasure{F,I,M,<:WithVolCorr} , y) where {F,I,M}
60
- f = ν. f
61
- finv = ν. finv
62
- x_orig, inv_ladj = with_logabsdet_jacobian (unwrap (finv), y)
63
- logd_orig = $ func (ν. origin, x_orig)
64
- logd = float (logd_orig + inv_ladj)
65
- neginf = oftype (logd, - Inf )
66
- return ifelse (
67
- # Zero density wins against infinite volume:
68
- (isnan (logd) && logd_orig == - Inf && inv_ladj == + Inf ) ||
69
- # Maybe also for (logd_orig == -Inf) && isfinite(inv_ladj) ?
70
- # Return constant -Inf to prevent problems with ForwardDiff:
71
- (isfinite (logd_orig) && (inv_ladj == - Inf )),
72
- neginf,
73
- logd,
74
- )
158
+ @eval function $func (ν:: PushforwardMeasure{F,I,M,<:AdaptRootMeasure} , y) where {F,I,M}
159
+ f_inv = unwrap (ν. finv)
160
+ x, inv_ladj = with_logabsdet_jacobian (f_inv, y)
161
+ logd_orig = $ func (ν. origin, x)
162
+ return _combine_logd_with_ladj (logd_orig, inv_ladj)
75
163
end
76
164
77
- @eval @inline function $func (ν:: PushforwardMeasure{F,I,M,<:NoVolCorr} , y) where {F,I,M}
78
- x = ν. finv (y)
79
- return $ func (ν. origin, x)
165
+ @eval function $func (ν:: PushforwardMeasure{F,I,M,<:PushfwdRootMeasure} , y) where {F,I,M}
166
+ f_inv = unwrap (ν. finv)
167
+ x = f_inv (y)
168
+ logd_orig = $ func (ν. origin, x)
169
+ return logd_orig
80
170
end
81
171
end
82
172
83
- insupport (ν :: PushforwardMeasure , y ) = insupport (ν . origin, ν . finv (y ))
173
+ insupport (m :: PushforwardMeasure , x ) = insupport (transport_origin (m), to_origin (m, x ))
84
174
85
175
function testvalue (:: Type{T} , ν:: PushforwardMeasure ) where {T}
86
176
ν. f (testvalue (T, parent (ν)))
87
177
end
88
178
89
179
@inline function basemeasure (ν:: PushforwardMeasure )
90
- pushfwd (ν. f, basemeasure (parent (ν)), NoVolCorr ())
180
+ pushfwd (ν. f, basemeasure (parent (ν)), PushfwdRootMeasure ())
181
+ end
182
+
183
+ function rootmeasure (m:: PushforwardMeasure{F,I,M,PushfwdRootMeasure} ) where {F,I,M}
184
+ pushfwd (m. f, rootmeasure (m. origin))
185
+ end
186
+ function rootmeasure (m:: PushforwardMeasure{F,I,M,AdaptRootMeasure} ) where {F,I,M}
187
+ rootmeasure (m. origin)
91
188
end
92
189
93
190
_pushfwd_dof (:: Type{MU} , :: Type , dof) where {MU} = NoDOF {MU} ()
94
191
_pushfwd_dof (:: Type{MU} , :: Type{<:Tuple{Any,Real}} , dof) where {MU} = dof
95
192
96
193
@inline getdof (ν:: MU ) where {MU<: PushforwardMeasure } = getdof (ν. origin)
194
+ @inline getdof (m:: _NonBijectivePusfwdMeasure ) = MeasureBase. NoDOF {typeof(m)} ()
97
195
98
196
# Bypass `checked_arg`, would require potentially costly transformation:
99
197
@inline checked_arg (:: PushforwardMeasure , x) = x
@@ -102,47 +200,53 @@ _pushfwd_dof(::Type{MU}, ::Type{<:Tuple{Any,Real}}, dof) where {MU} = dof
102
200
@inline from_origin (ν:: PushforwardMeasure , x) = ν. f (x)
103
201
@inline to_origin (ν:: PushforwardMeasure , y) = ν. finv (y)
104
202
203
+ massof (m:: PushforwardMeasure ) = massof (transport_origin (m))
204
+
105
205
function Base. rand (rng:: AbstractRNG , :: Type{T} , ν:: PushforwardMeasure ) where {T}
106
- return ν. f (rand (rng, T, parent (ν) ))
206
+ return ν. f (rand (rng, T, ν . origin ))
107
207
end
108
208
109
209
# ##############################################################################
110
210
# pushfwd
111
211
112
- export pushfwd
113
-
114
212
"""
115
- pushfwd(f, μ, volcorr = WithVolCorr ())
213
+ pushfwd(f, μ, style = AdaptRootMeasure ())
116
214
117
215
Return the [pushforward
118
216
measure](https://en.wikipedia.org/wiki/Pushforward_measure) from `μ` the
119
217
[measurable function](https://en.wikipedia.org/wiki/Measurable_function) `f`.
120
218
121
219
To manually specify an inverse, call
122
- `pushfwd(InverseFunctions.setinverse(f, finv), μ, volcorr )`.
220
+ `pushfwd(InverseFunctions.setinverse(f, finv), μ, style )`.
123
221
"""
124
- function pushfwd (f, μ, volcorr:: TransformVolCorr = WithVolCorr ())
125
- PushforwardMeasure (f, inverse (f), μ, volcorr)
126
- end
127
-
128
- function pushfwd (f, μ:: PushforwardMeasure , volcorr:: TransformVolCorr = WithVolCorr ())
129
- _pushfwd_of_pushfwd (f, μ, μ. volcorr, volcorr)
130
- end
222
+ function pushfwd end
223
+ export pushfwd
131
224
132
- # Either both WithVolCorr or both NoVolCorr, so we can merge them
133
- function _pushfwd_of_pushfwd (f, μ:: PushforwardMeasure , :: V , v:: V ) where {V}
134
- pushfwd (fchain ((μ. f, f)), μ. origin, v)
225
+ @inline pushfwd (f, μ) = _pushfwd_impl (f, μ, AdaptRootMeasure ())
226
+ @inline pushfwd (f, μ, style:: AdaptRootMeasure ) = _pushfwd_impl (f, μ, style)
227
+ @inline pushfwd (f, μ, style:: PushfwdRootMeasure ) = _pushfwd_impl (f, μ, style)
228
+
229
+ _pushfwd_impl (f, μ, style) = PushforwardMeasure (f, inverse (f), μ, style)
230
+
231
+ function _pushfwd_impl (
232
+ f,
233
+ μ:: PushforwardMeasure{F,I,M,S} ,
234
+ style:: S ,
235
+ ) where {F,I,M,S<: PushFwdStyle }
236
+ orig_μ = μ. origin
237
+ new_f = fcomp (f, μ. f)
238
+ new_f_inv = fcomp (μ. finv, inverse (f))
239
+ PushforwardMeasure (new_f, new_f_inv, orig_μ, style)
135
240
end
136
241
137
- function _pushfwd_of_pushfwd (f, μ:: PushforwardMeasure , _, v)
138
- PushforwardMeasure (f, inverse (f), μ, v)
139
- end
242
+ _pushfwd_impl (:: typeof (identity), μ, :: AdaptRootMeasure ) = μ
243
+ _pushfwd_impl (:: typeof (identity), μ, :: PushfwdRootMeasure ) = μ
140
244
141
245
# ##############################################################################
142
246
# pullback
143
247
144
248
"""
145
- pullback (f, μ, volcorr = WithVolCorr ())
249
+ pullbck (f, μ, style = AdaptRootMeasure ())
146
250
147
251
A _pullback_ is a dual concept to a _pushforward_. While a pushforward needs a
148
252
map _from_ the support of a measure, a pullback requires a map _into_ the
@@ -154,8 +258,17 @@ in terms of the inverse function; the "forward" function is not used at all. In
154
258
some cases, we may be focusing on log-density (and not, for example, sampling).
155
259
156
260
To manually specify an inverse, call
157
- `pullback (InverseFunctions.setinverse(f, finv), μ, volcorr )`.
261
+ `pullbck (InverseFunctions.setinverse(f, finv), μ, style )`.
158
262
"""
159
- function pullback (f, μ, volcorr:: TransformVolCorr = WithVolCorr ())
160
- pushfwd (setinverse (inverse (f), f), μ, volcorr)
263
+ function pullbck end
264
+ export pullbck
265
+
266
+ @inline pullbck (f, μ) = _pullback_impl (f, μ, AdaptRootMeasure ())
267
+ @inline pullbck (f, μ, style:: AdaptRootMeasure ) = _pullback_impl (f, μ, style)
268
+ @inline pullbck (f, μ, style:: PushfwdRootMeasure ) = _pullback_impl (f, μ, style)
269
+
270
+ function _pullback_impl (f, μ, style = AdaptRootMeasure ())
271
+ pushfwd (setinverse (inverse (f), f), μ, style)
161
272
end
273
+
274
+ @deprecate pullback (f, μ, style:: PushFwdStyle = AdaptRootMeasure ()) pullbck (f, μ, style)
0 commit comments