Skip to content

Commit 5601b2d

Browse files
authored
Clean up old code (#2574)
* Remove src/essential and deprecated function stubs * Fix imports * Export `@addlogprob!` * Fix more tests * Clean up more stuff
1 parent d29f810 commit 5601b2d

34 files changed

+148
-1131
lines changed

HISTORY.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
# Release 0.39.0
2+
3+
## Removal of Turing.Essential
4+
5+
The Turing.Essential module has been removed.
6+
Anything exported from there can be imported from either `Turing` or `DynamicPPL`.
7+
8+
## `@addlogprob!`
9+
10+
The `@addlogprob!` macro is now exported from Turing, making it officially part of the public interface.
11+
112
# Release 0.38.4
213

314
The minimum Julia version was increased to 1.10.2 (from 1.10.0).

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu
4242
| `to_submodel` | [`DynamicPPL.to_submodel`](@extref) | Define a submodel |
4343
| `prefix` | [`DynamicPPL.prefix`](@extref) | Prefix all variable names in a model with a given VarName |
4444
| `LogDensityFunction` | [`DynamicPPL.LogDensityFunction`](@extref) | A struct containing all information about how to evaluate a model. Mostly for advanced users |
45+
| `@addlogprob!` | [`DynamicPPL.@addlogprob!`](@extref) | Add arbitrary log-probability terms during model evaluation |
4546

4647
### Inference
4748

src/Turing.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using Printf: Printf
2323
using Random: Random
2424
using LinearAlgebra: I
2525

26-
using ADTypes: ADTypes
26+
using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoMooncake
2727

2828
const DEFAULT_ADTYPE = ADTypes.AutoForwardDiff()
2929

@@ -47,8 +47,6 @@ end
4747
# Random probability measures.
4848
include("stdlib/distributions.jl")
4949
include("stdlib/RandomMeasures.jl")
50-
include("essential/Essential.jl")
51-
using .Essential
5250
include("mcmc/Inference.jl") # inference algorithms
5351
using .Inference
5452
include("variational/VariationalInference.jl")
@@ -57,13 +55,13 @@ using .Variational
5755
include("optimisation/Optimisation.jl")
5856
using .Optimisation
5957

60-
include("deprecated.jl") # to be removed in the next minor version release
61-
6258
###########
6359
# Exports #
6460
###########
6561
# `using` statements for stuff to re-export
6662
using DynamicPPL:
63+
@model,
64+
@varname,
6765
pointwise_loglikelihoods,
6866
generated_quantities,
6967
returned,
@@ -73,9 +71,12 @@ using DynamicPPL:
7371
decondition,
7472
fix,
7573
unfix,
74+
prefix,
7675
conditioned,
76+
@submodel,
7777
to_submodel,
78-
LogDensityFunction
78+
LogDensityFunction,
79+
@addlogprob!
7980
using StatsBase: predict
8081
using OrderedCollections: OrderedDict
8182

@@ -90,6 +91,7 @@ export
9091
to_submodel,
9192
prefix,
9293
LogDensityFunction,
94+
@addlogprob!,
9395
# Sampling - AbstractMCMC
9496
sample,
9597
MCMCThreads,

src/deprecated.jl

Lines changed: 0 additions & 39 deletions
This file was deleted.

src/essential/Essential.jl

Lines changed: 0 additions & 24 deletions
This file was deleted.

src/essential/container.jl

Lines changed: 0 additions & 70 deletions
This file was deleted.

src/mcmc/Inference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module Inference
22

3-
using ..Essential
43
using DynamicPPL:
4+
@model,
55
Metadata,
66
VarInfo,
77
# TODO(mhauru) all_varnames_grouped_by_symbol isn't exported by DPPL, because it is only

src/mcmc/particle_mcmc.jl

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,79 @@
22
### Particle Filtering and Particle MCMC Samplers.
33
###
44

5+
### AdvancedPS models and interface
6+
7+
struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} <:
8+
AdvancedPS.AbstractGenericModel
9+
model::M
10+
sampler::S
11+
varinfo::V
12+
evaluator::E
13+
end
14+
15+
function TracedModel(
16+
model::Model,
17+
sampler::AbstractSampler,
18+
varinfo::AbstractVarInfo,
19+
rng::Random.AbstractRNG,
20+
)
21+
context = SamplingContext(rng, sampler, DefaultContext())
22+
args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context)
23+
if kwargs !== nothing && !isempty(kwargs)
24+
error(
25+
"Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.",
26+
)
27+
end
28+
return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(
29+
model, sampler, varinfo, (model.f, args...)
30+
)
31+
end
32+
33+
function AdvancedPS.advance!(
34+
trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}, isref::Bool=false
35+
)
36+
# Make sure we load/reset the rng in the new replaying mechanism
37+
DynamicPPL.increment_num_produce!(trace.model.f.varinfo)
38+
isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng)
39+
score = consume(trace.model.ctask)
40+
if score === nothing
41+
return nothing
42+
else
43+
return score + DynamicPPL.getlogp(trace.model.f.varinfo)
44+
end
45+
end
46+
47+
function AdvancedPS.delete_retained!(trace::TracedModel)
48+
DynamicPPL.set_retained_vns_del!(trace.varinfo)
49+
return trace
50+
end
51+
52+
function AdvancedPS.reset_model(trace::TracedModel)
53+
DynamicPPL.reset_num_produce!(trace.varinfo)
54+
return trace
55+
end
56+
57+
function AdvancedPS.reset_logprob!(trace::TracedModel)
58+
DynamicPPL.resetlogp!!(trace.model.varinfo)
59+
return trace
60+
end
61+
62+
function AdvancedPS.update_rng!(
63+
trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}
64+
)
65+
# Extract the `args`.
66+
args = trace.model.ctask.args
67+
# From `args`, extract the `SamplingContext`, which contains the RNG.
68+
sampling_context = args[3]
69+
rng = sampling_context.rng
70+
trace.rng = rng
71+
return trace
72+
end
73+
74+
function Libtask.TapedTask(model::TracedModel, ::Random.AbstractRNG, args...; kwargs...) # RNG ?
75+
return Libtask.TapedTask(model.evaluator[1], model.evaluator[2:end]...; kwargs...)
76+
end
77+
578
####
679
#### Generic Sequential Monte Carlo sampler.
780
####
@@ -408,7 +481,7 @@ function AdvancedPS.Trace(
408481
newvarinfo = deepcopy(varinfo)
409482
DynamicPPL.reset_num_produce!(newvarinfo)
410483

411-
tmodel = Turing.Essential.TracedModel(model, sampler, newvarinfo, rng)
484+
tmodel = TracedModel(model, sampler, newvarinfo, rng)
412485
newtrace = AdvancedPS.Trace(tmodel, rng)
413486
AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace)
414487
return newtrace

test/ad.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Ration
2525

2626
"""A dictionary mapping ADTypes to the element types they use."""
2727
eltypes_by_adtype = Dict(
28-
Turing.AutoForwardDiff => (ForwardDiff.Dual,),
29-
Turing.AutoReverseDiff => (
28+
AutoForwardDiff => (ForwardDiff.Dual,),
29+
AutoReverseDiff => (
3030
ReverseDiff.TrackedArray,
3131
ReverseDiff.TrackedMatrix,
3232
ReverseDiff.TrackedReal,
@@ -37,7 +37,7 @@ eltypes_by_adtype = Dict(
3737
),
3838
)
3939
if INCLUDE_MOONCAKE
40-
eltypes_by_adtype[Turing.AutoMooncake] = (Mooncake.CoDual,)
40+
eltypes_by_adtype[AutoMooncake] = (Mooncake.CoDual,)
4141
end
4242

4343
"""
@@ -189,32 +189,32 @@ end
189189
"""
190190
All the ADTypes on which we want to run the tests.
191191
"""
192-
ADTYPES = [Turing.AutoForwardDiff(), Turing.AutoReverseDiff(; compile=false)]
192+
ADTYPES = [AutoForwardDiff(), AutoReverseDiff(; compile=false)]
193193
if INCLUDE_MOONCAKE
194-
push!(ADTYPES, Turing.AutoMooncake(; config=nothing))
194+
push!(ADTYPES, AutoMooncake(; config=nothing))
195195
end
196196

197197
# Check that ADTypeCheckContext itself works as expected.
198198
@testset "ADTypeCheckContext" begin
199199
@model test_model() = x ~ Normal(0, 1)
200200
tm = test_model()
201201
adtypes = (
202-
Turing.AutoForwardDiff(),
203-
Turing.AutoReverseDiff(),
202+
AutoForwardDiff(),
203+
AutoReverseDiff(),
204204
# Don't need to test Mooncake as it doesn't use tracer types
205205
)
206206
for actual_adtype in adtypes
207-
sampler = Turing.HMC(0.1, 5; adtype=actual_adtype)
207+
sampler = HMC(0.1, 5; adtype=actual_adtype)
208208
for expected_adtype in adtypes
209209
contextualised_tm = DynamicPPL.contextualize(
210210
tm, ADTypeCheckContext(expected_adtype, tm.context)
211211
)
212212
@testset "Expected: $expected_adtype, Actual: $actual_adtype" begin
213213
if actual_adtype == expected_adtype
214214
# Check that this does not throw an error.
215-
Turing.sample(contextualised_tm, sampler, 2)
215+
sample(contextualised_tm, sampler, 2)
216216
else
217-
@test_throws AbstractWrongADBackendError Turing.sample(
217+
@test_throws AbstractWrongADBackendError sample(
218218
contextualised_tm, sampler, 2
219219
)
220220
end

0 commit comments

Comments
 (0)