Skip to content

Commit ef34dcd

Browse files
torfjeldemhaurudevmotionTor Fjelde
authored
Implementation of Robust Adaptive Metropolis (#106)
* added an initial implementation of `RAM` * added proper docs for RAM * fixed doctest for `RAM` + added impls of `getparams` and `setparams!!` * added DocStringExtensions as a dep * bump patch version * attempt at making the dcotest a bit more consistent * a * added checks for eigenvalues according to p. 13 in Vihola (2012) (in preivous commit) * fixed default value for `eigenvalue_lower_bound` * applied suggestions from @mhauru * more doctesting of RAM + improved docstrings * added docstring for `RAMState` * added proper testing of RAM * Update src/RobustAdaptiveMetropolis.jl Co-authored-by: Markus Hauru <[email protected]> * added compat entries to docs * apply suggestions from @devmotion Co-authored-by: David Widmann <[email protected]> * renamed `RAM` to `RobostMetropolisHastings` + removed the separate module for this * formatting * made the docstring for RAM a bit nicer * fixed doctest * formatting * minor improvement to docstring of RAM * fused scalar operations * added dimensionality check of the provided `S` matrix * fixed typo * Update docs/src/api.md Co-authored-by: David Widmann <[email protected]> * use `randn` instead of `rand` for initialisation * added an explanation of the `min` * Update test/RobustAdaptiveMetropolis.jl Co-authored-by: David Widmann <[email protected]> * use explicit `Cholesky` constructor for backwards compat * Fix typo: ```` -> ``` * formatted according to `blue` * Update src/RobustAdaptiveMetropolis.jl Co-authored-by: Markus Hauru <[email protected]> --------- Co-authored-by: Markus Hauru <[email protected]> Co-authored-by: David Widmann <[email protected]> Co-authored-by: Tor Fjelde <[email protected]> Co-authored-by: Markus Hauru <[email protected]>
1 parent 98a1041 commit ef34dcd

7 files changed

+376
-3
lines changed

Project.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "AdvancedMH"
22
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
3-
version = "0.8.4"
3+
version = "0.8.5"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
8+
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
89
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
@@ -26,6 +27,7 @@ AdvancedMHStructArraysExt = "StructArrays"
2627
AbstractMCMC = "5.6"
2728
DiffResults = "1"
2829
Distributions = "0.25"
30+
DocStringExtensions = "0.9"
2931
FillArrays = "1"
3032
ForwardDiff = "0.10"
3133
LinearAlgebra = "1.6"

docs/Project.toml

+10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
[deps]
2+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
23
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
4+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5+
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
6+
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
7+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
38

49
[compat]
510
Documenter = "1"
11+
Distributions = "0.25"
12+
LinearAlgebra = "1.6"
13+
LogDensityProblems = "2"
14+
MCMCChains = "6.0.4"
15+
Random = "1.6"

docs/src/api.md

+6
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,9 @@ MetropolisHastings
1212
```@docs
1313
DensityModel
1414
```
15+
16+
## Samplers
17+
18+
```@docs
19+
RobustAdaptiveMetropolis
20+
```

src/AdvancedMH.jl

+5-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ module AdvancedMH
33
# Import the relevant libraries.
44
using AbstractMCMC
55
using Distributions
6-
using LinearAlgebra: I
6+
using LinearAlgebra: LinearAlgebra, I
77
using FillArrays: Zeros
8+
using DocStringExtensions: FIELDS
89

910
using LogDensityProblems: LogDensityProblems
1011

@@ -22,7 +23,8 @@ export
2223
SymmetricRandomWalkProposal,
2324
Ensemble,
2425
StretchProposal,
25-
MALA
26+
MALA,
27+
RobustAdaptiveMetropolis
2628

2729
# Reexports
2830
export sample, MCMCThreads, MCMCDistributed, MCMCSerial
@@ -159,5 +161,6 @@ include("proposal.jl")
159161
include("mh-core.jl")
160162
include("emcee.jl")
161163
include("MALA.jl")
164+
include("RobustAdaptiveMetropolis.jl")
162165

163166
end # module AdvancedMH

src/RobustAdaptiveMetropolis.jl

+278
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
# TODO: Should we generalise this arbitrary symmetric proposals?
2+
"""
3+
RobustAdaptiveMetropolis
4+
5+
Robust Adaptive Metropolis-Hastings (RAM).
6+
7+
This is a simple implementation of the RAM algorithm described in [^VIH12].
8+
9+
# Fields
10+
11+
$(FIELDS)
12+
13+
# Examples
14+
15+
The following demonstrates how to implement a simple Gaussian model and sample from it using the RAM algorithm.
16+
17+
```jldoctest ram-gaussian; setup=:(using Random; Random.seed!(1234);)
18+
julia> using AdvancedMH, Distributions, MCMCChains, LogDensityProblems, LinearAlgebra
19+
20+
julia> # Define a Gaussian with zero mean and some covariance.
21+
struct Gaussian{A}
22+
Σ::A
23+
end
24+
25+
julia> # Implement the LogDensityProblems interface.
26+
LogDensityProblems.dimension(model::Gaussian) = size(model.Σ, 1)
27+
28+
julia> function LogDensityProblems.logdensity(model::Gaussian, x)
29+
d = LogDensityProblems.dimension(model)
30+
return logpdf(MvNormal(zeros(d),model.Σ), x)
31+
end
32+
33+
julia> LogDensityProblems.capabilities(::Gaussian) = LogDensityProblems.LogDensityOrder{0}()
34+
35+
julia> # Construct the model. We'll use a correlation of 0.5.
36+
model = Gaussian([1.0 0.5; 0.5 1.0]);
37+
38+
julia> # Number of samples we want in the resulting chain.
39+
num_samples = 10_000;
40+
41+
julia> # Number of warmup steps, i.e. the number of steps to adapt the covariance of the proposal.
42+
# Note that these are not included in the resulting chain, as `discard_initial=num_warmup`
43+
# by default in the `sample` call. To include them, pass `discard_initial=0` to `sample`.
44+
num_warmup = 10_000;
45+
46+
julia> # Sample!
47+
chain = sample(
48+
model,
49+
RobustAdaptiveMetropolis(),
50+
num_samples;
51+
chain_type=Chains, num_warmup, progress=false, initial_params=zeros(2)
52+
);
53+
54+
julia> isapprox(cov(Array(chain)), model.Σ; rtol = 0.2)
55+
true
56+
```
57+
58+
It's also possible to restrict the eigenvalues to avoid either too small or too large values. See p. 13 in [^VIH12].
59+
60+
```jldoctest ram-gaussian
61+
julia> chain = sample(
62+
model,
63+
RobustAdaptiveMetropolis(eigenvalue_lower_bound=0.1, eigenvalue_upper_bound=2.0),
64+
num_samples;
65+
chain_type=Chains, num_warmup, progress=false, initial_params=zeros(2)
66+
);
67+
68+
julia> norm(cov(Array(chain)) - [1.0 0.5; 0.5 1.0]) < 0.2
69+
true
70+
```
71+
72+
# References
73+
[^VIH12]: Vihola (2012) Robust adaptive Metropolis algorithm with coerced acceptance rate, Statistics and computing.
74+
"""
75+
Base.@kwdef struct RobustAdaptiveMetropolis{T,A<:Union{Nothing,AbstractMatrix{T}}} <:
76+
AdvancedMH.MHSampler
77+
"target acceptance rate. Default: 0.234."
78+
α::T = 0.234
79+
"negative exponent of the adaptation decay rate. Default: `0.6`."
80+
γ::T = 0.6
81+
"initial lower-triangular Cholesky factor of the covariance matrix. If specified, should be convertible into a `LowerTriangular`. Default: `nothing`, which is interpreted as the identity matrix."
82+
S::A = nothing
83+
"lower bound on eigenvalues of the adapted Cholesky factor. Default: `0.0`."
84+
eigenvalue_lower_bound::T = 0.0
85+
"upper bound on eigenvalues of the adapted Cholesky factor. Default: `Inf`."
86+
eigenvalue_upper_bound::T = Inf
87+
end
88+
89+
"""
90+
RobustAdaptiveMetropolisState
91+
92+
State of the Robust Adaptive Metropolis-Hastings (RAM) algorithm.
93+
94+
See also: [`RobustAdaptiveMetropolis`](@ref).
95+
96+
# Fields
97+
$(FIELDS)
98+
"""
99+
struct RobustAdaptiveMetropolisState{T1,L,A,T2,T3}
100+
"current realization of the chain."
101+
x::T1
102+
"log density of `x` under the target model."
103+
logprob::L
104+
"current lower-triangular Cholesky factor."
105+
S::A
106+
"log acceptance ratio of the previous iteration (not necessarily of `x`)."
107+
logα::T2
108+
"current step size for adaptation of `S`."
109+
η::T3
110+
"current iteration."
111+
iteration::Int
112+
"whether the previous iteration was accepted."
113+
isaccept::Bool
114+
end
115+
116+
AbstractMCMC.getparams(state::RobustAdaptiveMetropolisState) = state.x
117+
function AbstractMCMC.setparams!!(state::RobustAdaptiveMetropolisState, x)
118+
return RobustAdaptiveMetropolisState(
119+
x, state.logprob, state.S, state.logα, state.η, state.iteration, state.isaccept
120+
)
121+
end
122+
123+
function ram_step_inner(
124+
rng::Random.AbstractRNG,
125+
model::AbstractMCMC.LogDensityModel,
126+
sampler::RobustAdaptiveMetropolis,
127+
state::RobustAdaptiveMetropolisState,
128+
)
129+
# This is the initial state.
130+
f = model.logdensity
131+
d = LogDensityProblems.dimension(f)
132+
133+
# Sample the proposal.
134+
x = state.x
135+
U = randn(rng, eltype(x), d)
136+
x_new = muladd(state.S, U, x)
137+
138+
# Compute the acceptance probability.
139+
lp = state.logprob
140+
lp_new = LogDensityProblems.logdensity(f, x_new)
141+
# Technically, the `min` here is unnecessary for sampling according to `min(..., 1)`.
142+
# However, `ram_adapt` assumes that `logα` actually represents the log acceptance probability
143+
# and is thus bounded at 0. Moreover, users might be interested in inspecting the average
144+
# acceptance rate to check that the algorithm achieves something similar to the target rate.
145+
# Hence, it's a bit more convenient for the user if we just perform the `min` here
146+
# so they can just take an average of (`exp` of) the `logα` values.
147+
logα = min(lp_new - lp, zero(lp))
148+
isaccept = Random.randexp(rng) > -logα
149+
150+
return x_new, lp_new, U, logα, isaccept
151+
end
152+
153+
function ram_adapt(
154+
sampler::RobustAdaptiveMetropolis,
155+
state::RobustAdaptiveMetropolisState,
156+
logα::Real,
157+
U::AbstractVector,
158+
)
159+
Δα = exp(logα) - sampler.α
160+
S = state.S
161+
# TODO: Make this configurable by defining a more general path.
162+
η = state.iteration^(-sampler.γ)
163+
ΔS =* abs(Δα)) * S * U / LinearAlgebra.norm(U)
164+
# TODO: Maybe do in-place and then have the user extract it with a callback if they really want it.
165+
S_new = if sign(Δα) == 1
166+
# One rank update.
167+
LinearAlgebra.lowrankupdate(LinearAlgebra.Cholesky(S.data, :L, 0), ΔS).L
168+
else
169+
# One rank downdate.
170+
LinearAlgebra.lowrankdowndate(LinearAlgebra.Cholesky(S.data, :L, 0), ΔS).L
171+
end
172+
return S_new, η
173+
end
174+
175+
function AbstractMCMC.step(
176+
rng::Random.AbstractRNG,
177+
model::AbstractMCMC.LogDensityModel,
178+
sampler::RobustAdaptiveMetropolis;
179+
initial_params=nothing,
180+
kwargs...,
181+
)
182+
# This is the initial state.
183+
f = model.logdensity
184+
d = LogDensityProblems.dimension(f)
185+
186+
# Initial parameter state.
187+
T = if initial_params === nothing
188+
eltype(sampler.γ)
189+
else
190+
Base.promote_type(eltype(sampler.γ), eltype(initial_params))
191+
end
192+
x = if initial_params === nothing
193+
randn(rng, T, d)
194+
else
195+
convert(AbstractVector{T}, initial_params)
196+
end
197+
# Initialize the Cholesky factor of the covariance matrix.
198+
S_data = if sampler.S === nothing
199+
LinearAlgebra.diagm(0 => ones(T, d))
200+
else
201+
# Check the dimensionality of the provided `S`.
202+
if size(sampler.S) != (d, d)
203+
throw(ArgumentError("The provided `S` has the wrong dimensionality."))
204+
end
205+
convert(AbstractMatrix{T}, sampler.S)
206+
end
207+
S = LinearAlgebra.LowerTriangular(S_data)
208+
209+
# Construct the initial state.
210+
lp = LogDensityProblems.logdensity(f, x)
211+
state = RobustAdaptiveMetropolisState(x, lp, S, zero(T), 0, 1, true)
212+
213+
return AdvancedMH.Transition(x, lp, true), state
214+
end
215+
216+
function AbstractMCMC.step(
217+
rng::Random.AbstractRNG,
218+
model::AbstractMCMC.LogDensityModel,
219+
sampler::RobustAdaptiveMetropolis,
220+
state::RobustAdaptiveMetropolisState;
221+
kwargs...,
222+
)
223+
# Take the inner step.
224+
x_new, lp_new, U, logα, isaccept = ram_step_inner(rng, model, sampler, state)
225+
# Accept / reject the proposal.
226+
state_new = RobustAdaptiveMetropolisState(
227+
isaccept ? x_new : state.x,
228+
isaccept ? lp_new : state.logprob,
229+
state.S,
230+
logα,
231+
state.η,
232+
state.iteration + 1,
233+
isaccept,
234+
)
235+
return AdvancedMH.Transition(state_new.x, state_new.logprob, state_new.isaccept),
236+
state_new
237+
end
238+
239+
function valid_eigenvalues(S, lower_bound, upper_bound)
240+
# Short-circuit if the bounds are the default.
241+
(lower_bound == 0 && upper_bound == Inf) && return true
242+
# Note that this is just the diagonal when `S` is triangular.
243+
eigenvals = LinearAlgebra.eigvals(S)
244+
return all(x -> lower_bound <= x <= upper_bound, eigenvals)
245+
end
246+
247+
function AbstractMCMC.step_warmup(
248+
rng::Random.AbstractRNG,
249+
model::AbstractMCMC.LogDensityModel,
250+
sampler::RobustAdaptiveMetropolis,
251+
state::RobustAdaptiveMetropolisState;
252+
kwargs...,
253+
)
254+
# Take the inner step.
255+
x_new, lp_new, U, logα, isaccept = ram_step_inner(rng, model, sampler, state)
256+
# Adapt the proposal.
257+
S_new, η = ram_adapt(sampler, state, logα, U)
258+
# Check that `S_new` has eigenvalues in the desired range.
259+
if !valid_eigenvalues(
260+
S_new, sampler.eigenvalue_lower_bound, sampler.eigenvalue_upper_bound
261+
)
262+
# In this case, we just keep the old `S` (p. 13 in Vihola, 2012).
263+
S_new = state.S
264+
end
265+
266+
# Update state.
267+
state_new = RobustAdaptiveMetropolisState(
268+
isaccept ? x_new : state.x,
269+
isaccept ? lp_new : state.logprob,
270+
S_new,
271+
logα,
272+
η,
273+
state.iteration + 1,
274+
isaccept,
275+
)
276+
return AdvancedMH.Transition(state_new.x, state_new.logprob, state_new.isaccept),
277+
state_new
278+
end

0 commit comments

Comments
 (0)