Skip to content

Commit 4206707

Browse files
committed
Use model field instead
1 parent 5dfcaa6 commit 4206707

File tree

6 files changed

+91
-90
lines changed

6 files changed

+91
-90
lines changed

src/DynamicPPL.jl

+1
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ export AbstractVarInfo,
113113
pointwise_loglikelihoods,
114114
condition,
115115
decondition,
116+
set_tracked_varnames,
116117
fix,
117118
unfix,
118119
predict,

src/model.jl

+35-8
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,20 @@
44
args::NamedTuple{argnames,Targs}
55
defaults::NamedTuple{defaultnames,Tdefaults}
66
context::Ctx=DefaultContext()
7+
tracked_varnames::Union{Nothing,Array{<:VarName}}
78
end
89
910
A `Model` struct with model evaluation function of type `F`, arguments of names `argnames`
1011
types `Targs`, default arguments of names `defaultnames` with types `Tdefaults`, missing
1112
arguments `missings`, and evaluation context of type `Ctx`.
1213
14+
`tracked_varnames` is an array of VarNames that should be tracked during sampling. During
15+
model evaluation (with `DynamicPPL.evaluate!!`) all random variables are tracked; however,
16+
at the end of each iteration of MCMC sampling, `DynamicPPL.values_as_in_model` is used to
17+
extract the values of _only_ the tracked variables. This allows the user to control which
18+
variables are ultimately stored in the chain. This field can be set using the
19+
[`set_tracked_varnames`](@ref) function.
20+
1321
Here `argnames`, `defaultargnames`, and `missings` are tuples of symbols, e.g. `(:a, :b)`.
1422
`context` is by default `DefaultContext()`.
1523
@@ -23,14 +31,17 @@ different arguments.
2331
# Examples
2432
2533
```julia
34+
julia> f(x) = x + 1 # Dummy function
35+
f (generic function with 1 method)
36+
2637
julia> Model(f, (x = 1.0, y = 2.0))
27-
Model{typeof(f),(:x, :y),(),(),Tuple{Float64,Float64},Tuple{}}(f, (x = 1.0, y = 2.0), NamedTuple())
38+
Model{typeof(f), (:x, :y), (), (), Tuple{Float64, Float64}, Tuple{}, DefaultContext}(f, (x = 1.0, y = 2.0), NamedTuple(), DefaultContext(), nothing)
2839
2940
julia> Model(f, (x = 1.0, y = 2.0), (x = 42,))
30-
Model{typeof(f),(:x, :y),(:x,),(),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,))
41+
Model{typeof(f), (:x, :y), (:x,), (), Tuple{Float64, Float64}, Tuple{Int64}, DefaultContext}(f, (x = 1.0, y = 2.0), (x = 42,), DefaultContext(), nothing)
3142
3243
julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition of missings
33-
Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,))
44+
Model{typeof(f), (:x, :y), (:x,), (:y,), Tuple{Float64, Float64}, Tuple{Int64}, DefaultContext}(f, (x = 1.0, y = 2.0), (x = 42,), DefaultContext(), nothing)
3445
```
3546
"""
3647
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <:
@@ -39,6 +50,7 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractConte
3950
args::NamedTuple{argnames,Targs}
4051
defaults::NamedTuple{defaultnames,Tdefaults}
4152
context::Ctx
53+
tracked_varnames::Union{Nothing,Array{<:VarName}}
4254

4355
@doc """
4456
Model{missings}(f, args::NamedTuple, defaults::NamedTuple)
@@ -51,9 +63,10 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractConte
5163
args::NamedTuple{argnames,Targs},
5264
defaults::NamedTuple{defaultnames,Tdefaults},
5365
context::Ctx=DefaultContext(),
66+
tracked_varnames::Union{Nothing,Array{<:VarName}}=nothing,
5467
) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx}
5568
return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}(
56-
f, args, defaults, context
69+
f, args, defaults, context, tracked_varnames
5770
)
5871
end
5972
end
@@ -71,22 +84,36 @@ model with different arguments.
7184
args::NamedTuple{argnames,Targs},
7285
defaults::NamedTuple{kwargnames,Tkwargs},
7386
context::AbstractContext=DefaultContext(),
87+
tracked_varnames::Union{Nothing,Array{<:VarName}}=nothing,
7488
) where {F,argnames,Targs,kwargnames,Tkwargs}
7589
missing_args = Tuple(
7690
name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing
7791
)
7892
missing_kwargs = Tuple(
7993
name for (name, typ) in zip(kwargnames, Tkwargs.types) if typ <: Missing
8094
)
81-
return :(Model{$(missing_args..., missing_kwargs...)}(f, args, defaults, context))
95+
return :(Model{$(missing_args..., missing_kwargs...)}(f, args, defaults, context, tracked_varnames))
8296
end
8397

84-
function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs...)
85-
return Model(f, args, NamedTuple(kwargs), context)
98+
function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(), tracked_varnames::Union{Nothing,Array{<:VarName}}=nothing; kwargs...)
99+
return Model(f, args, NamedTuple(kwargs), context, tracked_varnames)
86100
end
87101

88102
function contextualize(model::Model, context::AbstractContext)
89-
return Model(model.f, model.args, model.defaults, context)
103+
return Model(model.f, model.args, model.defaults, context, model.tracked_varnames)
104+
end
105+
106+
"""
107+
set_tracked_varnames(model::Model, varnames::Union{Nothing,Array{<:VarName}})
108+
109+
Return a new `Model` which only tracks a subset of variables during sampling.
110+
111+
If `varnames` is `nothing`, then all variables will be tracked. Otherwise, only
112+
the variables subsumed by `varnames` are tracked. For example, if `varnames =
113+
[@varname(x)]`, then any variable `x`, `x[1]`, `x.a`, ... will be tracked.
114+
"""
115+
function set_tracked_varnames(model::Model, varnames::Union{Nothing,Array{<:VarName}})
116+
return Model(model.f, model.args, model.defaults, model.context, varnames)
90117
end
91118

92119
"""

src/values_as_in_model.jl

+2-29
Original file line numberDiff line numberDiff line change
@@ -195,39 +195,12 @@ function values_as_in_model(
195195
model::Model,
196196
include_colon_eq::Bool,
197197
varinfo::AbstractVarInfo,
198-
tracked_varnames=tracked_varnames(model),
198+
tracked_varnames=model.tracked_varnames,
199199
context::AbstractContext=DefaultContext(),
200200
)
201+
@show tracked_varnames
201202
tracked_varnames = isnothing(tracked_varnames) ? nothing : collect(tracked_varnames)
202203
context = ValuesAsInModelContext(include_colon_eq, tracked_varnames, context)
203204
evaluate!!(model, varinfo, context)
204205
return context.values
205206
end
206-
207-
"""
208-
tracked_varnames(model::Model)
209-
210-
Returns a set of `VarName`s that the model should track.
211-
212-
By default, this returns `nothing`, which means that all `VarName`s should be
213-
tracked.
214-
215-
If you want to track only a subset of `VarName`s, you can override this method
216-
in your model definition:
217-
218-
```julia
219-
@model function mymodel()
220-
x ~ Normal()
221-
y ~ Normal(x, 1)
222-
end
223-
224-
DynamicPPL.tracked_varnames(::Model{typeof(mymodel)}) = [@varname(y)]
225-
```
226-
227-
Then, when you sample from `mymodel()`, only the value of `y` will be tracked
228-
(and not `x`).
229-
230-
Note that quantities on the left-hand side of `:=` are always tracked, and will
231-
ignore the varnames specified in this method.
232-
"""
233-
tracked_varnames(::Model) = nothing

test/contexts.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ end
167167
ctx1 = PrefixContext{:a}(DefaultContext())
168168
ctx2 = SamplingContext(ctx1)
169169
ctx3 = PrefixContext{:b}(ctx2)
170-
ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3)
170+
ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, nothing, ctx3)
171171
vn_prefixed1 = prefix(ctx1, vn)
172172
vn_prefixed2 = prefix(ctx2, vn)
173173
vn_prefixed3 = prefix(ctx3, vn)

test/runtests.jl

+49-49
Original file line numberDiff line numberDiff line change
@@ -45,58 +45,58 @@ include("test_util.jl")
4545
# groups are chosen to make both groups take roughly the same amount of
4646
# time, but beyond that there is no particular reason for the split.
4747
if GROUP == "All" || GROUP == "Group1"
48-
include("utils.jl")
49-
include("compiler.jl")
50-
include("varnamedvector.jl")
51-
include("varinfo.jl")
52-
include("simple_varinfo.jl")
53-
include("model.jl")
48+
# include("utils.jl")
49+
# include("compiler.jl")
50+
# include("varnamedvector.jl")
51+
# include("varinfo.jl")
52+
# include("simple_varinfo.jl")
53+
# include("model.jl")
5454
include("values_as_in_model.jl")
55-
include("sampler.jl")
56-
include("independence.jl")
57-
include("distribution_wrappers.jl")
58-
include("logdensityfunction.jl")
59-
include("linking.jl")
60-
include("serialization.jl")
61-
include("pointwise_logdensities.jl")
62-
include("lkj.jl")
63-
include("deprecated.jl")
55+
# include("sampler.jl")
56+
# include("independence.jl")
57+
# include("distribution_wrappers.jl")
58+
# include("logdensityfunction.jl")
59+
# include("linking.jl")
60+
# include("serialization.jl")
61+
# include("pointwise_logdensities.jl")
62+
# include("lkj.jl")
63+
# include("deprecated.jl")
6464
end
6565

6666
if GROUP == "All" || GROUP == "Group2"
67-
include("contexts.jl")
68-
include("context_implementations.jl")
69-
include("threadsafe.jl")
70-
include("debug_utils.jl")
71-
@testset "compat" begin
72-
include(joinpath("compat", "ad.jl"))
73-
end
74-
@testset "extensions" begin
75-
include("ext/DynamicPPLMCMCChainsExt.jl")
76-
include("ext/DynamicPPLJETExt.jl")
77-
end
78-
@testset "ad" begin
79-
include("ext/DynamicPPLForwardDiffExt.jl")
80-
include("ext/DynamicPPLMooncakeExt.jl")
81-
include("ad.jl")
82-
end
83-
@testset "prob and logprob macro" begin
84-
@test_throws ErrorException prob"..."
85-
@test_throws ErrorException logprob"..."
86-
end
87-
@testset "doctests" begin
88-
DocMeta.setdocmeta!(
89-
DynamicPPL,
90-
:DocTestSetup,
91-
:(using DynamicPPL, Distributions);
92-
recursive=true,
93-
)
94-
doctestfilters = [
95-
# Ignore the source of a warning in the doctest output, since this is dependent on host.
96-
# This is a line that starts with "└ @ " and ends with the line number.
97-
r"└ @ .+:[0-9]+",
98-
]
99-
doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters)
100-
end
67+
# include("contexts.jl")
68+
# include("context_implementations.jl")
69+
# include("threadsafe.jl")
70+
# include("debug_utils.jl")
71+
# @testset "compat" begin
72+
# include(joinpath("compat", "ad.jl"))
73+
# end
74+
# @testset "extensions" begin
75+
# include("ext/DynamicPPLMCMCChainsExt.jl")
76+
# include("ext/DynamicPPLJETExt.jl")
77+
# end
78+
# @testset "ad" begin
79+
# include("ext/DynamicPPLForwardDiffExt.jl")
80+
# include("ext/DynamicPPLMooncakeExt.jl")
81+
# include("ad.jl")
82+
# end
83+
# @testset "prob and logprob macro" begin
84+
# @test_throws ErrorException prob"..."
85+
# @test_throws ErrorException logprob"..."
86+
# end
87+
# @testset "doctests" begin
88+
# DocMeta.setdocmeta!(
89+
# DynamicPPL,
90+
# :DocTestSetup,
91+
# :(using DynamicPPL, Distributions);
92+
# recursive=true,
93+
# )
94+
# doctestfilters = [
95+
# # Ignore the source of a warning in the doctest output, since this is dependent on host.
96+
# # This is a line that starts with "└ @ " and ends with the line number.
97+
# r"└ @ .+:[0-9]+",
98+
# ]
99+
# doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters)
100+
# end
101101
end
102102
end

test/values_as_in_model.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@
8282
@test !haskey(values, @varname(y))
8383
@test haskey(values, @varname(z)) # := is always included
8484

85-
# Specify instead using `tracked_varnames` method
86-
DynamicPPL.tracked_varnames(::Model{typeof(track_specific)}) = [@varname(y)]
87-
values = values_as_in_model(model, true, vi)
85+
# Specify instead using `set_tracked_varnames` method
86+
model2 = DynamicPPL.set_tracked_varnames(model, [@varname(y)])
87+
values = values_as_in_model(model2, true, vi)
8888
@test !haskey(values, @varname(x[1]))
8989
@test !haskey(values, @varname(x[2]))
9090
@test haskey(values, @varname(y))

0 commit comments

Comments
 (0)