Skip to content

Commit e62d4e6

Browse files
committed
Fix LogDensityFunction argument order
1 parent 97cf02b commit e62d4e6

File tree

8 files changed

+26
-12
lines changed

8 files changed

+26
-12
lines changed

ext/TuringDynamicHMCExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ function DynamicPPL.initialstep(
7070

7171
# Define log-density function.
7272
= LogDensityProblemsAD.ADgradient(
73-
Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())
73+
Turing.LogDensityFunction(
74+
model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext())
75+
),
7476
)
7577

7678
# Perform initial step.

src/mcmc/ess.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ function AbstractMCMC.step(
5959
rng,
6060
EllipticalSliceSampling.ESSModel(
6161
ESSPrior(model, spl, vi),
62-
Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()),
62+
Turing.LogDensityFunction(
63+
model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext())
64+
),
6365
),
6466
EllipticalSliceSampling.ESS(),
6567
oldstate,

src/mcmc/hmc.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ function DynamicPPL.initialstep(
158158
metric = metricT(length(theta))
159159
= LogDensityProblemsAD.ADgradient(
160160
Turing.LogDensityFunction(
161-
vi,
162161
model,
162+
vi,
163163
# Use the leaf-context from the `model` in case the user has
164164
# contextualized the model with something like `PriorContext`
165165
# to sample from the prior.
@@ -289,8 +289,8 @@ function get_hamiltonian(model, spl, vi, state, n)
289289
metric = gen_metric(n, spl, state)
290290
= LogDensityProblemsAD.ADgradient(
291291
Turing.LogDensityFunction(
292-
vi,
293292
model,
293+
vi,
294294
DynamicPPL.SamplingContext(spl, DynamicPPL.leafcontext(model.context)),
295295
),
296296
)

src/mcmc/mh.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,8 @@ function propose!!(
310310
Base.Fix1(
311311
LogDensityProblems.logdensity,
312312
Turing.LogDensityFunction(
313-
vi,
314313
model,
314+
vi,
315315
DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)),
316316
),
317317
),
@@ -345,8 +345,8 @@ function propose!!(
345345
Base.Fix1(
346346
LogDensityProblems.logdensity,
347347
Turing.LogDensityFunction(
348-
vi,
349348
model,
349+
vi,
350350
DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)),
351351
),
352352
),

src/mcmc/sghmc.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ function DynamicPPL.initialstep(
6767
# Compute initial sample and state.
6868
sample = Transition(model, vi)
6969
= LogDensityProblemsAD.ADgradient(
70-
Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())
70+
Turing.LogDensityFunction(
71+
model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext())
72+
),
7173
)
7274
state = SGHMCState(ℓ, vi, zero(vi[spl]))
7375

@@ -227,7 +229,9 @@ function DynamicPPL.initialstep(
227229
# Create first sample and state.
228230
sample = SGLDTransition(model, vi, zero(spl.alg.stepsize(0)))
229231
= LogDensityProblemsAD.ADgradient(
230-
Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())
232+
Turing.LogDensityFunction(
233+
model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext())
234+
),
231235
)
232236
state = SGLDState(ℓ, vi, 1)
233237

src/optimisation/Optimisation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ Create a callable `OptimLogDensity` struct that evaluates a model using the give
110110
"""
111111
function OptimLogDensity(model::DynamicPPL.Model, context::OptimizationContext)
112112
init = DynamicPPL.VarInfo(model)
113-
return Turing.LogDensityFunction(init, model, context)
113+
return Turing.LogDensityFunction(model, init, context)
114114
end
115115

116116
"""

test/essential/ad.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ function test_model_ad(model, f, syms::Vector{Symbol})
4343
= LogDensityProblemsAD.ADgradient(
4444
Turing.AutoForwardDiff(; chunksize=chunksize, tag=standardtag),
4545
Turing.LogDensityFunction(
46-
vi, model, SampleFromPrior(), DynamicPPL.DefaultContext()
46+
model,
47+
vi,
48+
DynamicPPL.SamplingContext(SampleFromPrior(), DynamicPPL.DefaultContext()),
4749
),
4850
)
4951
l, ∇E = LogDensityProblems.logdensity_and_gradient(ℓ, z)
@@ -84,7 +86,9 @@ end
8486
grad_FWAD = sort(g(_x))
8587

8688
= Turing.LogDensityFunction(
87-
vi, ad_test_f, SampleFromPrior(), DynamicPPL.DefaultContext()
89+
ad_test_f,
90+
vi,
91+
DynamicPPL.SamplingContext(SampleFromPrior(), DynamicPPL.DefaultContext()),
8892
)
8993
x = map(x -> Float64(x), vi[:])
9094

test/skipped/unit_test_helper.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ function test_grad(turing_model, grad_f; trans=Dict())
1111
= LogDensityProblemsAD.ADgradient(
1212
Turing.AutoTracker(),
1313
Turing.LogDensityFunction(
14-
vi, model_f, SampleFromPrior(), DynamicPPL.DefaultContext()
14+
model_f,
15+
vi,
16+
DynamicPPL.SamplingContext(SampleFromPrior(), DynamicPPL.DefaultContext()),
1517
),
1618
)
1719
for _ in 1:10000

0 commit comments

Comments
 (0)