Skip to content

Commit ee2b148

Browse files
authored
remove LogDensityProblemsAD (#2490)
* Remove LogDensityProblemsAD, part 1 * update Optimisation code to not use LogDensityProblemsAD * Fix field name change * Don't put chunksize=0 * Remove LogDensityProblemsAD dep * Improve OptimLogDensity docstring * Remove unneeded model argument to _optimize * Fix more tests * Remove essential/ad from the list of CI groups * Fix HMC function
1 parent 52b2105 commit ee2b148

File tree

17 files changed

+191
-485
lines changed

17 files changed

+191
-485
lines changed

.github/workflows/Tests.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ jobs:
2424
test:
2525
# Run some of the slower test files individually. The last one catches everything
2626
# not included in the others.
27-
- name: "essential/ad"
28-
args: "essential/ad.jl"
2927
- name: "mcmc/gibbs"
3028
args: "mcmc/gibbs.jl"
3129
- name: "mcmc/hmc"
@@ -37,7 +35,7 @@ jobs:
3735
- name: "mcmc/ess"
3836
args: "mcmc/ess.jl"
3937
- name: "everything else"
40-
args: "--skip essential/ad.jl mcmc/gibbs.jl mcmc/hmc.jl mcmc/abstractmcmc.jl mcmc/Inference.jl mcmc/ess.jl"
38+
args: "--skip mcmc/gibbs.jl mcmc/hmc.jl mcmc/abstractmcmc.jl mcmc/Inference.jl mcmc/ess.jl"
4139
runner:
4240
# Default
4341
- version: '1'

Manifest.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.11.3"
44
manifest_format = "2.0"
5-
project_hash = "83ec9face19bc568fc30cc287161517dc49f6c5c"
5+
project_hash = "afdf28a30966aaa4af542a30879dd92074661565"
66

77
[[deps.ADTypes]]
88
git-tree-sha1 = "fb97701c117c8162e84dfcf80215caa904aef44f"

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2323
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
2424
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2525
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
26-
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
2726
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
2827
NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673"
2928
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
@@ -69,7 +68,6 @@ ForwardDiff = "0.10.3"
6968
Libtask = "0.8.8"
7069
LinearAlgebra = "1"
7170
LogDensityProblems = "2"
72-
LogDensityProblemsAD = "1.7.0"
7371
MCMCChains = "5, 6"
7472
NamedArrays = "0.9, 0.10"
7573
Optim = "1"

ext/TuringDynamicHMCExt.jl

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,10 @@ module TuringDynamicHMCExt
33
### DynamicHMC backend - https://github.com/tpapp/DynamicHMC.jl
44
###
55

6-
if isdefined(Base, :get_extension)
7-
using DynamicHMC: DynamicHMC
8-
using Turing
9-
using Turing: AbstractMCMC, Random, LogDensityProblems, DynamicPPL
10-
using Turing.Inference: ADTypes, LogDensityProblemsAD, TYPEDFIELDS
11-
else
12-
import ..DynamicHMC
13-
using ..Turing
14-
using ..Turing: AbstractMCMC, Random, LogDensityProblems, DynamicPPL
15-
using ..Turing.Inference: ADTypes, LogDensityProblemsAD, TYPEDFIELDS
16-
end
6+
using DynamicHMC: DynamicHMC
7+
using Turing
8+
using Turing: AbstractMCMC, Random, LogDensityProblems, DynamicPPL
9+
using Turing.Inference: ADTypes, TYPEDFIELDS
1710

1811
"""
1912
DynamicNUTS
@@ -69,10 +62,11 @@ function DynamicPPL.initialstep(
6962
end
7063

7164
# Define log-density function.
72-
= LogDensityProblemsAD.ADgradient(
73-
Turing.LogDensityFunction(
74-
model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext())
75-
),
65+
= DynamicPPL.LogDensityFunction(
66+
model,
67+
vi,
68+
DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext());
69+
adtype=spl.alg.adtype,
7670
)
7771

7872
# Perform initial step.

ext/TuringOptimExt.jl

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
11
module TuringOptimExt
22

3-
if isdefined(Base, :get_extension)
4-
using Turing: Turing
5-
import Turing: DynamicPPL, NamedArrays, Accessors, Optimisation
6-
using Optim: Optim
7-
else
8-
import ..Turing
9-
import ..Turing: DynamicPPL, NamedArrays, Accessors, Optimisation
10-
import ..Optim
11-
end
3+
using Turing: Turing
4+
import Turing: DynamicPPL, NamedArrays, Accessors, Optimisation
5+
using Optim: Optim
126

137
####################
148
# Optim.jl methods #
@@ -42,7 +36,7 @@ function Optim.optimize(
4236
)
4337
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
4438
f = Optimisation.OptimLogDensity(model, ctx)
45-
init_vals = DynamicPPL.getparams(f)
39+
init_vals = DynamicPPL.getparams(f.ldf)
4640
optimizer = Optim.LBFGS()
4741
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
4842
end
@@ -65,7 +59,7 @@ function Optim.optimize(
6559
)
6660
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
6761
f = Optimisation.OptimLogDensity(model, ctx)
68-
init_vals = DynamicPPL.getparams(f)
62+
init_vals = DynamicPPL.getparams(f.ldf)
6963
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
7064
end
7165
function Optim.optimize(
@@ -81,7 +75,7 @@ end
8175

8276
function _mle_optimize(model::DynamicPPL.Model, args...; kwargs...)
8377
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
84-
return _optimize(model, Optimisation.OptimLogDensity(model, ctx), args...; kwargs...)
78+
return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...)
8579
end
8680

8781
"""
@@ -112,7 +106,7 @@ function Optim.optimize(
112106
)
113107
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
114108
f = Optimisation.OptimLogDensity(model, ctx)
115-
init_vals = DynamicPPL.getparams(f)
109+
init_vals = DynamicPPL.getparams(f.ldf)
116110
optimizer = Optim.LBFGS()
117111
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
118112
end
@@ -135,7 +129,7 @@ function Optim.optimize(
135129
)
136130
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
137131
f = Optimisation.OptimLogDensity(model, ctx)
138-
init_vals = DynamicPPL.getparams(f)
132+
init_vals = DynamicPPL.getparams(f.ldf)
139133
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
140134
end
141135
function Optim.optimize(
@@ -151,28 +145,29 @@ end
151145

152146
function _map_optimize(model::DynamicPPL.Model, args...; kwargs...)
153147
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
154-
return _optimize(model, Optimisation.OptimLogDensity(model, ctx), args...; kwargs...)
148+
return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...)
155149
end
156-
157150
"""
158-
_optimize(model::Model, f::OptimLogDensity, optimizer=Optim.LBFGS(), args...; kwargs...)
151+
_optimize(f::OptimLogDensity, optimizer=Optim.LBFGS(), args...; kwargs...)
159152
160153
Estimate a mode, i.e., compute a MLE or MAP estimate.
161154
"""
162155
function _optimize(
163-
model::DynamicPPL.Model,
164156
f::Optimisation.OptimLogDensity,
165-
init_vals::AbstractArray=DynamicPPL.getparams(f),
157+
init_vals::AbstractArray=DynamicPPL.getparams(f.ldf),
166158
optimizer::Optim.AbstractOptimizer=Optim.LBFGS(),
167159
options::Optim.Options=Optim.Options(),
168160
args...;
169161
kwargs...,
170162
)
171163
# Convert the initial values, since it is assumed that users provide them
172164
# in the constrained space.
173-
f = Accessors.@set f.varinfo = DynamicPPL.unflatten(f.varinfo, init_vals)
174-
f = Accessors.@set f.varinfo = DynamicPPL.link(f.varinfo, model)
175-
init_vals = DynamicPPL.getparams(f)
165+
# TODO(penelopeysm): As with in src/optimisation/Optimisation.jl, unclear
166+
# whether initialisation is really necessary at all
167+
vi = DynamicPPL.unflatten(f.ldf.varinfo, init_vals)
168+
vi = DynamicPPL.link(vi, f.ldf.model)
169+
f = Optimisation.OptimLogDensity(f.ldf.model, vi, f.ldf.context; adtype=f.ldf.adtype)
170+
init_vals = DynamicPPL.getparams(f.ldf)
176171

177172
# Optimize!
178173
M = Optim.optimize(Optim.only_fg!(f), init_vals, optimizer, options, args...; kwargs...)
@@ -186,12 +181,16 @@ function _optimize(
186181
end
187182

188183
# Get the optimum in unconstrained space. `getparams` does the invlinking.
189-
f = Accessors.@set f.varinfo = DynamicPPL.unflatten(f.varinfo, M.minimizer)
190-
vns_vals_iter = Turing.Inference.getparams(model, f.varinfo)
184+
vi = f.ldf.varinfo
185+
vi_optimum = DynamicPPL.unflatten(vi, M.minimizer)
186+
logdensity_optimum = Optimisation.OptimLogDensity(
187+
f.ldf.model, vi_optimum, f.ldf.context
188+
)
189+
vns_vals_iter = Turing.Inference.getparams(f.ldf.model, vi_optimum)
191190
varnames = map(Symbol first, vns_vals_iter)
192191
vals = map(last, vns_vals_iter)
193192
vmat = NamedArrays.NamedArray(vals, varnames)
194-
return Optimisation.ModeResult(vmat, M, -M.minimum, f)
193+
return Optimisation.ModeResult(vmat, M, -M.minimum, logdensity_optimum)
195194
end
196195

197196
end # module

src/mcmc/Inference.jl

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ import AdvancedPS
4848
import Accessors
4949
import EllipticalSliceSampling
5050
import LogDensityProblems
51-
import LogDensityProblemsAD
5251
import Random
5352
import MCMCChains
5453
import StatsBase: predict
@@ -160,29 +159,6 @@ function externalsampler(
160159
return ExternalSampler(sampler, adtype, Val(unconstrained))
161160
end
162161

163-
getADType(spl::Sampler) = getADType(spl.alg)
164-
getADType(::SampleFromPrior) = Turing.DEFAULT_ADTYPE
165-
166-
getADType(ctx::DynamicPPL.SamplingContext) = getADType(ctx.sampler)
167-
getADType(ctx::DynamicPPL.AbstractContext) = getADType(DynamicPPL.NodeTrait(ctx), ctx)
168-
getADType(::DynamicPPL.IsLeaf, ctx::DynamicPPL.AbstractContext) = Turing.DEFAULT_ADTYPE
169-
function getADType(::DynamicPPL.IsParent, ctx::DynamicPPL.AbstractContext)
170-
return getADType(DynamicPPL.childcontext(ctx))
171-
end
172-
173-
getADType(alg::Hamiltonian) = alg.adtype
174-
175-
function LogDensityProblemsAD.ADgradient(ℓ::DynamicPPL.LogDensityFunction)
176-
return LogDensityProblemsAD.ADgradient(getADType(ℓ.context), ℓ)
177-
end
178-
179-
function LogDensityProblems.logdensity(
180-
f::Turing.LogDensityFunction{<:AbstractVarInfo,<:Model,<:DynamicPPL.DefaultContext},
181-
x::NamedTuple,
182-
)
183-
return DynamicPPL.logjoint(f.model, DynamicPPL.unflatten(f.varinfo, x))
184-
end
185-
186162
# TODO: make a nicer `set_namedtuple!` and move these functions to DynamicPPL.
187163
function DynamicPPL.unflatten(vi::TypedVarInfo, θ::NamedTuple)
188164
set_namedtuple!(deepcopy(vi), θ)

src/mcmc/abstractmcmc.jl

Lines changed: 11 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
struct TuringState{S,F}
1+
struct TuringState{S,M,V,C}
22
state::S
3-
logdensity::F
3+
ldf::DynamicPPL.LogDensityFunction{M,V,C}
44
end
55

66
state_to_turing(f::DynamicPPL.LogDensityFunction, state) = TuringState(state, f)
@@ -12,20 +12,10 @@ function transition_to_turing(f::DynamicPPL.LogDensityFunction, transition)
1212
return Transition(f.model, varinfo, transition)
1313
end
1414

15-
state_to_turing(f::LogDensityProblemsAD.ADGradientWrapper, state) = TuringState(state, f)
16-
function transition_to_turing(f::LogDensityProblemsAD.ADGradientWrapper, transition)
17-
return transition_to_turing(parent(f), transition)
18-
end
19-
20-
function varinfo_from_logdensityfn(f::LogDensityProblemsAD.ADGradientWrapper)
21-
return varinfo_from_logdensityfn(parent(f))
22-
end
23-
varinfo_from_logdensityfn(f::DynamicPPL.LogDensityFunction) = f.varinfo
24-
2515
function varinfo(state::TuringState)
26-
θ = getparams(DynamicPPL.getmodel(state.logdensity), state.state)
16+
θ = getparams(state.ldf.model, state.state)
2717
# TODO: Do we need to link here first?
28-
return DynamicPPL.unflatten(varinfo_from_logdensityfn(state.logdensity), θ)
18+
return DynamicPPL.unflatten(state.ldf.varinfo, θ)
2919
end
3020
varinfo(state::AbstractVarInfo) = state
3121

@@ -40,23 +30,6 @@ getstats(transition::AdvancedHMC.Transition) = transition.stat
4030

4131
getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params
4232

43-
getvarinfo(f::DynamicPPL.LogDensityFunction) = f.varinfo
44-
function getvarinfo(f::LogDensityProblemsAD.ADGradientWrapper)
45-
return getvarinfo(LogDensityProblemsAD.parent(f))
46-
end
47-
48-
function setvarinfo(f::DynamicPPL.LogDensityFunction, varinfo)
49-
return DynamicPPL.LogDensityFunction(f.model, varinfo, f.context; adtype=f.adtype)
50-
end
51-
52-
function setvarinfo(
53-
f::LogDensityProblemsAD.ADGradientWrapper, varinfo, adtype::ADTypes.AbstractADType
54-
)
55-
return LogDensityProblemsAD.ADgradient(
56-
adtype, setvarinfo(LogDensityProblemsAD.parent(f), varinfo)
57-
)
58-
end
59-
6033
# TODO: Do we also support `resume`, etc?
6134
function AbstractMCMC.step(
6235
rng::Random.AbstractRNG,
@@ -69,12 +42,8 @@ function AbstractMCMC.step(
6942
alg = sampler_wrapper.alg
7043
sampler = alg.sampler
7144

72-
# Create a log-density function with an implementation of the
73-
# gradient so we ensure that we're using the same AD backend as in Turing.
74-
f = LogDensityProblemsAD.ADgradient(alg.adtype, DynamicPPL.LogDensityFunction(model))
75-
76-
# Link the varinfo if needed.
77-
varinfo = getvarinfo(f)
45+
# Initialise varinfo with initial params and link the varinfo if needed.
46+
varinfo = DynamicPPL.VarInfo(model)
7847
if requires_unconstrained_space(alg)
7948
if initial_params !== nothing
8049
# If we have initial parameters, we need to set the varinfo before linking.
@@ -85,9 +54,11 @@ function AbstractMCMC.step(
8554
varinfo = DynamicPPL.link(varinfo, model)
8655
end
8756
end
88-
f = setvarinfo(f, varinfo, alg.adtype)
8957

90-
# Then just call `AdvancedHMC.step` with the right arguments.
58+
# Construct LogDensityFunction
59+
f = DynamicPPL.LogDensityFunction(model, varinfo; adtype=alg.adtype)
60+
61+
# Then just call `AbstractMCMC.step` with the right arguments.
9162
if initial_state === nothing
9263
transition_inner, state_inner = AbstractMCMC.step(
9364
rng, AbstractMCMC.LogDensityModel(f), sampler; initial_params, kwargs...
@@ -114,7 +85,7 @@ function AbstractMCMC.step(
11485
kwargs...,
11586
)
11687
sampler = sampler_wrapper.alg.sampler
117-
f = state.logdensity
88+
f = state.ldf
11889

11990
# Then just call `AdvancedHMC.step` with the right arguments.
12091
transition_inner, state_inner = AbstractMCMC.step(

src/mcmc/gibbs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ function setparams_varinfo!!(
438438
state::TuringState,
439439
params::AbstractVarInfo,
440440
)
441-
logdensity = DynamicPPL.setmodel(state.logdensity, model, sampler.alg.adtype)
441+
logdensity = DynamicPPL.setmodel(state.ldf, model, sampler.alg.adtype)
442442
new_inner_state = setparams_varinfo!!(
443443
AbstractMCMC.LogDensityModel(logdensity), sampler, state.state, params
444444
)

0 commit comments

Comments
 (0)