Skip to content

Commit a09211a

Browse files
committed
More minor bugfixes
1 parent e62d4e6 commit a09211a

File tree

9 files changed

+22
-30
lines changed

9 files changed

+22
-30
lines changed

src/mcmc/Inference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ end
173173
getADType(alg::Hamiltonian) = alg.adtype
174174

175175
function LogDensityProblemsAD.ADgradient(ℓ::DynamicPPL.LogDensityFunction)
176-
return LogDensityProblemsAD.ADgradient(getADType(DynamicPPL.getcontext(ℓ)), ℓ)
176+
return LogDensityProblemsAD.ADgradient(getADType(.context), ℓ)
177177
end
178178

179179
function LogDensityProblems.logdensity(

src/mcmc/abstractmcmc.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ function getvarinfo(f::LogDensityProblemsAD.ADGradientWrapper)
4545
return getvarinfo(LogDensityProblemsAD.parent(f))
4646
end
4747

48-
setvarinfo(f::DynamicPPL.LogDensityFunction, varinfo) = Accessors.@set f.varinfo = varinfo
48+
function setvarinfo(f::DynamicPPL.LogDensityFunction, varinfo)
49+
return DynamicPPL.LogDensityFunction(f.model, varinfo, f.context; adtype=f.adtype)
50+
end
51+
4952
function setvarinfo(
5053
f::LogDensityProblemsAD.ADGradientWrapper, varinfo, adtype::ADTypes.AbstractADType
5154
)

src/mcmc/emcee.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ function AbstractMCMC.step(
5353
length(initial_params) == n ||
5454
throw(ArgumentError("initial parameters have to be specified for each walker"))
5555
vis = map(vis, initial_params) do vi, init
56-
vi = DynamicPPL.initialize_parameters!!(vi, init, spl, model)
56+
vi = DynamicPPL.initialize_parameters!!(vi, init, model)
5757

5858
# Update log joint probability.
5959
last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromPrior()))

src/mcmc/ess.jl

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -119,16 +119,10 @@ end
119119
Distributions.mean(p::ESSPrior) = p.μ
120120

121121
# Evaluate log-likelihood of proposals
122-
const ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} = Turing.LogDensityFunction{
123-
V,M,<:DynamicPPL.SamplingContext{<:S}
124-
}
125-
126-
function (ℓ::ESSLogLikelihood)(f::AbstractVector)
127-
sampler = DynamicPPL.getsampler(ℓ)
128-
varinfo = DynamicPPL.unflatten(ℓ.varinfo, f)
129-
varinfo = last(DynamicPPL.evaluate!!(ℓ.model, varinfo, sampler))
130-
return getlogp(varinfo)
131-
end
122+
const ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} =
123+
Turing.LogDensityFunction{M,V,<:DynamicPPL.SamplingContext{<:S},AD} where {AD}
124+
125+
(ℓ::ESSLogLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ, f)
132126

133127
function DynamicPPL.tilde_assume(
134128
rng::Random.AbstractRNG, ::DefaultContext, ::Sampler{<:ESS}, right, vn, vi
@@ -138,8 +132,6 @@ function DynamicPPL.tilde_assume(
138132
)
139133
end
140134

141-
function DynamicPPL.tilde_observe(
142-
ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi
143-
)
135+
function DynamicPPL.tilde_observe(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vi)
144136
return DynamicPPL.tilde_observe(ctx, SampleFromPrior(), right, left, vi)
145137
end

src/mcmc/is.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,16 @@ function getlogevidence(samples::Vector{<:Transition}, ::Sampler{<:IS}, state)
4646
return logsumexp(map(x -> x.lp, samples)) - log(length(samples))
4747
end
4848

49-
function DynamicPPL.assume(rng, spl::Sampler{<:IS}, dist::Distribution, vn::VarName, vi)
49+
function DynamicPPL.assume(rng, ::Sampler{<:IS}, dist::Distribution, vn::VarName, vi)
5050
if haskey(vi, vn)
5151
r = vi[vn]
5252
else
5353
r = rand(rng, dist)
54-
vi = push!!(vi, vn, r, dist, spl)
54+
vi = push!!(vi, vn, r, dist)
5555
end
5656
return r, 0, vi
5757
end
5858

59-
function DynamicPPL.observe(spl::Sampler{<:IS}, dist::Distribution, value, vi)
59+
function DynamicPPL.observe(::Sampler{<:IS}, dist::Distribution, value, vi)
6060
return logpdf(dist, value), vi
6161
end

src/mcmc/mh.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,14 +188,13 @@ A log density function for the MH sampler.
188188
189189
This variant uses the `set_namedtuple!` function to update the `VarInfo`.
190190
"""
191-
const MHLogDensityFunction{M<:Model,S<:Sampler{<:MH},V<:AbstractVarInfo} = Turing.LogDensityFunction{
192-
V,M,<:DynamicPPL.SamplingContext{<:S}
193-
}
191+
const MHLogDensityFunction{M<:Model,S<:Sampler{<:MH},V<:AbstractVarInfo} =
192+
Turing.LogDensityFunction{M,V,<:DynamicPPL.SamplingContext{<:S},AD} where {AD}
194193

195194
function LogDensityProblems.logdensity(f::MHLogDensityFunction, x::NamedTuple)
196195
vi = deepcopy(f.varinfo)
197196
set_namedtuple!(vi, x)
198-
vi_new = last(DynamicPPL.evaluate!!(f.model, vi, DynamicPPL.getcontext(f)))
197+
vi_new = last(DynamicPPL.evaluate!!(f.model, vi, f.context))
199198
lj = getlogp(vi_new)
200199
return lj
201200
end

src/mcmc/particle_mcmc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ function DynamicPPL.assume(
365365

366366
if ~haskey(vi, vn)
367367
r = rand(trng, dist)
368-
push!!(vi, vn, r, dist, spl)
368+
push!!(vi, vn, r, dist)
369369
elseif is_flagged(vi, vn, "del")
370370
unset_flag!(vi, vn, "del") # Reference particle parent
371371
r = rand(trng, dist)

src/optimisation/Optimisation.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ end
9999
100100
A struct that stores the negative log density function of a `DynamicPPL` model.
101101
"""
102-
const OptimLogDensity{M<:DynamicPPL.Model,C<:OptimizationContext,V<:DynamicPPL.VarInfo} = Turing.LogDensityFunction{
103-
V,M,C
102+
const OptimLogDensity{M<:DynamicPPL.Model,C<:OptimizationContext,V<:DynamicPPL.VarInfo,AD} = Turing.LogDensityFunction{
103+
M,V,C,AD
104104
}
105105

106106
"""
@@ -125,9 +125,7 @@ required by Optimization.jl.
125125
"""
126126
function (f::OptimLogDensity)(z::AbstractVector)
127127
varinfo = DynamicPPL.unflatten(f.varinfo, z)
128-
return -DynamicPPL.getlogp(
129-
last(DynamicPPL.evaluate!!(f.model, varinfo, DynamicPPL.getcontext(f)))
130-
)
128+
return -DynamicPPL.getlogp(last(DynamicPPL.evaluate!!(f.model, varinfo, f.context)))
131129
end
132130

133131
(f::OptimLogDensity)(z, _) = f(z)

test/mcmc/abstractmcmc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function initialize_nuts(model::Turing.Model)
2626
f = Turing.Inference.setvarinfo(
2727
f,
2828
DynamicPPL.link!!(Turing.Inference.getvarinfo(f), model),
29-
Turing.Inference.getADType(DynamicPPL.getcontext(LogDensityProblemsAD.parent(f))),
29+
Turing.Inference.getADType(LogDensityProblemsAD.parent(f).context),
3030
)
3131

3232
# Choose parameter dimensionality and initial parameter value

0 commit comments

Comments
 (0)