Skip to content

Commit 5f37e6a

Browse files
committed
Use model field instead
1 parent 5dfcaa6 commit 5f37e6a

File tree

5 files changed

+50
-41
lines changed

5 files changed

+50
-41
lines changed

src/DynamicPPL.jl

+1
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ export AbstractVarInfo,
113113
pointwise_loglikelihoods,
114114
condition,
115115
decondition,
116+
set_tracked_varnames,
116117
fix,
117118
unfix,
118119
predict,

src/model.jl

+43-8
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,20 @@
44
args::NamedTuple{argnames,Targs}
55
defaults::NamedTuple{defaultnames,Tdefaults}
66
context::Ctx=DefaultContext()
7+
tracked_varnames::Union{Nothing,Array{<:VarName}}
78
end
89
910
A `Model` struct with model evaluation function of type `F`, arguments of names `argnames`
1011
types `Targs`, default arguments of names `defaultnames` with types `Tdefaults`, missing
1112
arguments `missings`, and evaluation context of type `Ctx`.
1213
14+
`tracked_varnames` is an array of VarNames that should be tracked during sampling. During
15+
model evaluation (with `DynamicPPL.evaluate!!`) all random variables are tracked; however,
16+
at the end of each iteration of MCMC sampling, `DynamicPPL.values_as_in_model` is used to
17+
extract the values of _only_ the tracked variables. This allows the user to control which
18+
variables are ultimately stored in the chain. This field can be set using the
19+
[`set_tracked_varnames`](@ref) function.
20+
1321
Here `argnames`, `defaultargnames`, and `missings` are tuples of symbols, e.g. `(:a, :b)`.
1422
`context` is by default `DefaultContext()`.
1523
@@ -23,14 +31,17 @@ different arguments.
2331
# Examples
2432
2533
```julia
34+
julia> f(x) = x + 1 # Dummy function
35+
f (generic function with 1 method)
36+
2637
julia> Model(f, (x = 1.0, y = 2.0))
27-
Model{typeof(f),(:x, :y),(),(),Tuple{Float64,Float64},Tuple{}}(f, (x = 1.0, y = 2.0), NamedTuple())
38+
Model{typeof(f), (:x, :y), (), (), Tuple{Float64, Float64}, Tuple{}, DefaultContext}(f, (x = 1.0, y = 2.0), NamedTuple(), DefaultContext(), nothing)
2839
2940
julia> Model(f, (x = 1.0, y = 2.0), (x = 42,))
30-
Model{typeof(f),(:x, :y),(:x,),(),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,))
41+
Model{typeof(f), (:x, :y), (:x,), (), Tuple{Float64, Float64}, Tuple{Int64}, DefaultContext}(f, (x = 1.0, y = 2.0), (x = 42,), DefaultContext(), nothing)
3142
3243
julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition of missings
33-
Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,))
44+
Model{typeof(f), (:x, :y), (:x,), (:y,), Tuple{Float64, Float64}, Tuple{Int64}, DefaultContext}(f, (x = 1.0, y = 2.0), (x = 42,), DefaultContext(), nothing)
3445
```
3546
"""
3647
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <:
@@ -39,6 +50,7 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractConte
3950
args::NamedTuple{argnames,Targs}
4051
defaults::NamedTuple{defaultnames,Tdefaults}
4152
context::Ctx
53+
tracked_varnames::Union{Nothing,Array{<:VarName}}
4254

4355
@doc """
4456
Model{missings}(f, args::NamedTuple, defaults::NamedTuple)
@@ -51,9 +63,10 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractConte
5163
args::NamedTuple{argnames,Targs},
5264
defaults::NamedTuple{defaultnames,Tdefaults},
5365
context::Ctx=DefaultContext(),
66+
tracked_varnames::Union{Nothing,Array{<:VarName}}=nothing,
5467
) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx}
5568
return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}(
56-
f, args, defaults, context
69+
f, args, defaults, context, tracked_varnames
5770
)
5871
end
5972
end
@@ -71,22 +84,44 @@ model with different arguments.
7184
args::NamedTuple{argnames,Targs},
7285
defaults::NamedTuple{kwargnames,Tkwargs},
7386
context::AbstractContext=DefaultContext(),
87+
tracked_varnames::Union{Nothing,Array{<:VarName}}=nothing,
7488
) where {F,argnames,Targs,kwargnames,Tkwargs}
7589
missing_args = Tuple(
7690
name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing
7791
)
7892
missing_kwargs = Tuple(
7993
name for (name, typ) in zip(kwargnames, Tkwargs.types) if typ <: Missing
8094
)
81-
return :(Model{$(missing_args..., missing_kwargs...)}(f, args, defaults, context))
95+
return :(Model{$(missing_args..., missing_kwargs...)}(
96+
f, args, defaults, context, tracked_varnames
97+
))
8298
end
8399

84-
function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs...)
85-
return Model(f, args, NamedTuple(kwargs), context)
100+
function Model(
101+
f,
102+
args::NamedTuple,
103+
context::AbstractContext=DefaultContext(),
104+
tracked_varnames::Union{Nothing,Array{<:VarName}}=nothing;
105+
kwargs...,
106+
)
107+
return Model(f, args, NamedTuple(kwargs), context, tracked_varnames)
86108
end
87109

88110
function contextualize(model::Model, context::AbstractContext)
89-
return Model(model.f, model.args, model.defaults, context)
111+
return Model(model.f, model.args, model.defaults, context, model.tracked_varnames)
112+
end
113+
114+
"""
115+
set_tracked_varnames(model::Model, varnames::Union{Nothing,Array{<:VarName}})
116+
117+
Return a new `Model` which only tracks a subset of variables during sampling.
118+
119+
If `varnames` is `nothing`, then all variables will be tracked. Otherwise, only
120+
the variables subsumed by `varnames` are tracked. For example, if `varnames =
121+
[@varname(x)]`, then any variable `x`, `x[1]`, `x.a`, ... will be tracked.
122+
"""
123+
function set_tracked_varnames(model::Model, varnames::Union{Nothing,Array{<:VarName}})
124+
return Model(model.f, model.args, model.defaults, model.context, varnames)
90125
end
91126

92127
"""

src/values_as_in_model.jl

+2-29
Original file line numberDiff line numberDiff line change
@@ -195,39 +195,12 @@ function values_as_in_model(
195195
model::Model,
196196
include_colon_eq::Bool,
197197
varinfo::AbstractVarInfo,
198-
tracked_varnames=tracked_varnames(model),
198+
tracked_varnames=model.tracked_varnames,
199199
context::AbstractContext=DefaultContext(),
200200
)
201+
@show tracked_varnames
201202
tracked_varnames = isnothing(tracked_varnames) ? nothing : collect(tracked_varnames)
202203
context = ValuesAsInModelContext(include_colon_eq, tracked_varnames, context)
203204
evaluate!!(model, varinfo, context)
204205
return context.values
205206
end
206-
207-
"""
208-
tracked_varnames(model::Model)
209-
210-
Returns a set of `VarName`s that the model should track.
211-
212-
By default, this returns `nothing`, which means that all `VarName`s should be
213-
tracked.
214-
215-
If you want to track only a subset of `VarName`s, you can override this method
216-
in your model definition:
217-
218-
```julia
219-
@model function mymodel()
220-
x ~ Normal()
221-
y ~ Normal(x, 1)
222-
end
223-
224-
DynamicPPL.tracked_varnames(::Model{typeof(mymodel)}) = [@varname(y)]
225-
```
226-
227-
Then, when you sample from `mymodel()`, only the value of `y` will be tracked
228-
(and not `x`).
229-
230-
Note that quantities on the left-hand side of `:=` are always tracked, and will
231-
ignore the varnames specified in this method.
232-
"""
233-
tracked_varnames(::Model) = nothing

test/contexts.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ end
167167
ctx1 = PrefixContext{:a}(DefaultContext())
168168
ctx2 = SamplingContext(ctx1)
169169
ctx3 = PrefixContext{:b}(ctx2)
170-
ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3)
170+
ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, nothing, ctx3)
171171
vn_prefixed1 = prefix(ctx1, vn)
172172
vn_prefixed2 = prefix(ctx2, vn)
173173
vn_prefixed3 = prefix(ctx3, vn)

test/values_as_in_model.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@
8282
@test !haskey(values, @varname(y))
8383
@test haskey(values, @varname(z)) # := is always included
8484

85-
# Specify instead using `tracked_varnames` method
86-
DynamicPPL.tracked_varnames(::Model{typeof(track_specific)}) = [@varname(y)]
87-
values = values_as_in_model(model, true, vi)
85+
# Specify instead using `set_tracked_varnames` method
86+
model2 = DynamicPPL.set_tracked_varnames(model, [@varname(y)])
87+
values = values_as_in_model(model2, true, vi)
8888
@test !haskey(values, @varname(x[1]))
8989
@test !haskey(values, @varname(x[2]))
9090
@test haskey(values, @varname(y))

0 commit comments

Comments
 (0)