Skip to content

Commit 98a1041

Browse files
sunxd3penelopeysm
andauthored
Add getparams and setparams!! following AbstractMCMC v5.5 and v5.6 (#103)
* add `getparams` and `setparams!!` * add BangBang as dep, add functions for `GradientTransition` * increase atol * Update test/runtests.jl Co-authored-by: Penelope Yong <[email protected]> * update functions with `model` arguments * fix test errors * remove BangBang dep --------- Co-authored-by: Penelope Yong <[email protected]>
1 parent 4442783 commit 98a1041

File tree

4 files changed

+55
-7
lines changed

4 files changed

+55
-7
lines changed

Project.toml

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedMH"
22
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
3-
version = "0.8.3"
3+
version = "0.8.4"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -23,18 +23,18 @@ AdvancedMHMCMCChainsExt = "MCMCChains"
2323
AdvancedMHStructArraysExt = "StructArrays"
2424

2525
[compat]
26-
AbstractMCMC = "5"
26+
AbstractMCMC = "5.6"
2727
DiffResults = "1"
2828
Distributions = "0.25"
2929
FillArrays = "1"
3030
ForwardDiff = "0.10"
31+
LinearAlgebra = "1.6"
3132
LogDensityProblems = "2"
3233
MCMCChains = "6.0.4"
34+
Random = "1.6"
3335
Requires = "1"
3436
StructArrays = "0.6"
3537
julia = "1.6"
36-
LinearAlgebra = "1.6"
37-
Random = "1.6"
3838

3939
[extras]
4040
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"

src/AdvancedMH.jl

+14
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,20 @@ function __init__()
140140
end
141141
end
142142

143+
# AbstractMCMC.jl interface
144+
function AbstractMCMC.getparams(t::Transition)
145+
return t.params
146+
end
147+
148+
# TODO (sunxd): remove `DensityModel` in favor of `AbstractMCMC.LogDensityModel`
149+
function AbstractMCMC.setparams!!(model::DensityModelOrLogDensityModel, t::Transition, params)
150+
return Transition(
151+
params,
152+
logdensity(model, params),
153+
t.accepted
154+
)
155+
end
156+
143157
# Include inference methods.
144158
include("proposal.jl")
145159
include("mh-core.jl")

src/MALA.jl

+16-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ MALA(d::RandomWalkProposal) = MALA{typeof(d)}(d)
1111
MALA(d) = MALA(RandomWalkProposal(d))
1212

1313

14-
struct GradientTransition{T<:Union{Vector, Real, NamedTuple}, L<:Real, G<:Union{Vector, Real, NamedTuple}} <: AbstractTransition
14+
struct GradientTransition{T<:Union{Vector,Real,NamedTuple},L<:Real,G<:Union{Vector,Real,NamedTuple}} <: AbstractTransition
1515
params::T
1616
lp::L
1717
gradient::G
@@ -20,6 +20,20 @@ end
2020

2121
logdensity(model::DensityModelOrLogDensityModel, t::GradientTransition) = t.lp
2222

23+
function AbstractMCMC.getparams(t::GradientTransition)
24+
return t.params
25+
end
26+
27+
function AbstractMCMC.setparams!!(model::DensityModelOrLogDensityModel, t::GradientTransition, params)
28+
lp, gradient = logdensity_and_gradient(model, params)
29+
return GradientTransition(
30+
params,
31+
lp,
32+
gradient,
33+
t.accepted
34+
)
35+
end
36+
2337
propose(::Random.AbstractRNG, ::MALA, ::DensityModelOrLogDensityModel) = error("please specify initial parameters")
2438
function transition(sampler::MALA, model::DensityModelOrLogDensityModel, params, accepted)
2539
return GradientTransition(params, logdensity_and_gradient(model, params)..., accepted)
@@ -88,6 +102,6 @@ logdensity_and_gradient(::DensityModelOrLogDensityModel, ::Any)
88102
function logdensity_and_gradient(model::AbstractMCMC.LogDensityModel, params)
89103
check_capabilities(model)
90104
return LogDensityProblems.logdensity_and_gradient(model.logdensity, params)
91-
end
105+
end
92106

93107

test/runtests.jl

+21-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using AdvancedMH
2+
using AbstractMCMC
23
using DiffResults
34
using Distributions
45
using ForwardDiff
@@ -33,6 +34,25 @@ include("util.jl")
3334
LogDensityProblems.logdensity(::typeof(density), θ) = density(θ)
3435
LogDensityProblems.dimension(::typeof(density)) = 2
3536

37+
@testset "getparams/setparams!! (AbstractMCMC interface)" begin
38+
t1, _ = AbstractMCMC.step(Random.default_rng(), model, StaticMH([Normal(0, 1), Normal(0, 1)]))
39+
t2, _ = AbstractMCMC.step(Random.default_rng(), model, MALA(x -> MvNormal(x, I)); initial_params=ones(2))
40+
for t in [t1, t2]
41+
@test AbstractMCMC.getparams(model, t) == t.params
42+
43+
new_transition = AbstractMCMC.setparams!!(model, t, AbstractMCMC.getparams(model, t))
44+
@test new_transition.lp == t.lp
45+
@test new_transition.accepted == t.accepted
46+
@test new_transition.params == t.params
47+
if hasfield(typeof(t), :gradient)
48+
@test new_transition.gradient == t.gradient
49+
end
50+
51+
t_replaced = AbstractMCMC.setparams!!(model, t, [1.0, 2.0])
52+
@test t_replaced.params == [1.0, 2.0]
53+
end
54+
end
55+
3656
@testset "StaticMH" begin
3757
# Set up our sampler with initial parameters.
3858
spl1 = StaticMH([Normal(0,1), Normal(0, 1)])
@@ -69,7 +89,7 @@ include("util.jl")
6989
@test mean(chain1.σ) 1.0 atol=0.1
7090
@test mean(chain2.μ) 0.0 atol=0.1
7191
@test mean(chain2.σ) 1.0 atol=0.1
72-
@test mean(chain3.μ) 0.0 atol=0.1
92+
@test mean(chain3.μ) 0.0 atol=0.15
7393
@test mean(chain3.σ) 1.0 atol=0.1
7494
end
7595

0 commit comments

Comments
 (0)