Skip to content

Allow user to track specific varnames #846

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,30 @@
# DynamicPPL Changelog

## 0.36.0

**Breaking changes**

### `set_tracked_varnames`

There is now a new method `set_tracked_varnames(::Model, ::Union{Nothing,Array{<:VarName}})`, which allows you to specify the variables that are collected when `values_as_in_model` is run.
Internally in DynamicPPL this does not have much impact.
However, Turing.jl uses `values_as_in_model` to collect the variable names and values during sampling, and so this method will effectively allow you to control which variables are ultimately stored in a chain.

Example usage:

```julia
@model function f()
x ~ Normal()
y ~ Normal()
return x, y
end

model = f()
model = set_tracked_varnames(model, [@varname(y)])
```

If you then sample from `model`, only the value of `y` will be stored in the chain, and not `x`.

## 0.35.0

**Breaking changes**
Expand Down
7 changes: 7 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,13 @@ Safe extraction of values from a given [`AbstractVarInfo`](@ref) as they are see
values_as_in_model
```

`values_as_in_model` also uses the `tracked_varnames` field on a [`Model`](@ref) to determine which variables are extracted.
To change the value of this field, you can use [`set_tracked_varnames`](@ref).

```@docs
set_tracked_varnames
```

```@docs
NamedDist
```
Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ export AbstractVarInfo,
pointwise_loglikelihoods,
condition,
decondition,
set_tracked_varnames,
fix,
unfix,
predict,
Expand Down
51 changes: 43 additions & 8 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,20 @@
args::NamedTuple{argnames,Targs}
defaults::NamedTuple{defaultnames,Tdefaults}
context::Ctx=DefaultContext()
tracked_varnames::Union{Nothing,Array{<:VarName}}
end

A `Model` struct with model evaluation function of type `F`, arguments of names `argnames`
types `Targs`, default arguments of names `defaultnames` with types `Tdefaults`, missing
arguments `missings`, and evaluation context of type `Ctx`.

`tracked_varnames` is an array of VarNames that should be tracked during sampling. During
model evaluation (with `DynamicPPL.evaluate!!`) all random variables are tracked; however,
at the end of each iteration of MCMC sampling, `DynamicPPL.values_as_in_model` is used to
extract the values of _only_ the tracked variables. This allows the user to control which
variables are ultimately stored in the chain. This field can be set using the
[`set_tracked_varnames`](@ref) function.

Here `argnames`, `defaultargnames`, and `missings` are tuples of symbols, e.g. `(:a, :b)`.
`context` is by default `DefaultContext()`.

Expand All @@ -23,14 +31,17 @@ different arguments.
# Examples

```julia
julia> f(x) = x + 1 # Dummy function
f (generic function with 1 method)

julia> Model(f, (x = 1.0, y = 2.0))
Model{typeof(f),(:x, :y),(),(),Tuple{Float64,Float64},Tuple{}}(f, (x = 1.0, y = 2.0), NamedTuple())
Model{typeof(f), (:x, :y), (), (), Tuple{Float64, Float64}, Tuple{}, DefaultContext}(f, (x = 1.0, y = 2.0), NamedTuple(), DefaultContext(), nothing)

julia> Model(f, (x = 1.0, y = 2.0), (x = 42,))
Model{typeof(f),(:x, :y),(:x,),(),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,))
Model{typeof(f), (:x, :y), (:x,), (), Tuple{Float64, Float64}, Tuple{Int64}, DefaultContext}(f, (x = 1.0, y = 2.0), (x = 42,), DefaultContext(), nothing)

julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition of missings
Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,))
Model{typeof(f), (:x, :y), (:x,), (:y,), Tuple{Float64, Float64}, Tuple{Int64}, DefaultContext}(f, (x = 1.0, y = 2.0), (x = 42,), DefaultContext(), nothing)
```
"""
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <:
Expand All @@ -39,6 +50,7 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractConte
args::NamedTuple{argnames,Targs}
defaults::NamedTuple{defaultnames,Tdefaults}
context::Ctx
tracked_varnames::Union{Nothing,Array{<:VarName}}

@doc """
Model{missings}(f, args::NamedTuple, defaults::NamedTuple)
Expand All @@ -51,9 +63,10 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractConte
args::NamedTuple{argnames,Targs},
defaults::NamedTuple{defaultnames,Tdefaults},
context::Ctx=DefaultContext(),
tracked_varnames::Union{Nothing,Array{<:VarName}}=nothing,
) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx}
return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}(
f, args, defaults, context
f, args, defaults, context, tracked_varnames
)
end
end
Expand All @@ -71,22 +84,44 @@ model with different arguments.
args::NamedTuple{argnames,Targs},
defaults::NamedTuple{kwargnames,Tkwargs},
context::AbstractContext=DefaultContext(),
tracked_varnames::Union{Nothing,Array{<:VarName}}=nothing,
) where {F,argnames,Targs,kwargnames,Tkwargs}
missing_args = Tuple(
name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing
)
missing_kwargs = Tuple(
name for (name, typ) in zip(kwargnames, Tkwargs.types) if typ <: Missing
)
return :(Model{$(missing_args..., missing_kwargs...)}(f, args, defaults, context))
return :(Model{$(missing_args..., missing_kwargs...)}(
f, args, defaults, context, tracked_varnames
))
end

function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs...)
return Model(f, args, NamedTuple(kwargs), context)
function Model(
f,
args::NamedTuple,
context::AbstractContext=DefaultContext(),
tracked_varnames::Union{Nothing,Array{<:VarName}}=nothing;
kwargs...,
)
return Model(f, args, NamedTuple(kwargs), context, tracked_varnames)
end

function contextualize(model::Model, context::AbstractContext)
return Model(model.f, model.args, model.defaults, context)
return Model(model.f, model.args, model.defaults, context, model.tracked_varnames)
end

"""
set_tracked_varnames(model::Model, varnames::Union{Nothing,Array{<:VarName}})

Return a new `Model` which only tracks a subset of variables during sampling.

If `varnames` is `nothing`, then all variables will be tracked. Otherwise, only
the variables subsumed by `varnames` are tracked. For example, if `varnames =
[@varname(x)]`, then any variable `x`, `x[1]`, `x.a`, ... will be tracked.
"""
function set_tracked_varnames(model::Model, varnames::Union{Nothing,Array{<:VarName}})
return Model(model.f, model.args, model.defaults, model.context, varnames)
end

"""
Expand Down
48 changes: 39 additions & 9 deletions src/values_as_in_model.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
"""
TrackedValue{T}

A struct that wraps something on the right-hand side of `:=`. This is needed
because the DynamicPPL compiler actually converts `lhs := rhs` to `lhs ~
TrackedValue(rhs)` (so that we can hit the `tilde_assume` method below). Having
the rhs wrapped in a TrackedValue makes sure that the logpdf of the rhs is not
computed (as it wouldn't make sense).
"""
struct TrackedValue{T}
value::T
end
Expand All @@ -24,17 +33,27 @@
values::OrderedDict
"whether to extract variables on the LHS of :="
include_colon_eq::Bool
"varnames to be tracked; `nothing` means track all varnames"
tracked_varnames::Union{Nothing,Array{<:VarName}}
"child context"
context::C
end
function ValuesAsInModelContext(include_colon_eq, context::AbstractContext)
return ValuesAsInModelContext(OrderedDict(), include_colon_eq, context)
function ValuesAsInModelContext(
include_colon_eq::Bool,
tracked_varnames::Union{Nothing,Array{<:VarName}},
context::AbstractContext,
)
return ValuesAsInModelContext(
OrderedDict(), include_colon_eq, tracked_varnames, context
)
end

NodeTrait(::ValuesAsInModelContext) = IsParent()
childcontext(context::ValuesAsInModelContext) = context.context
function setchildcontext(context::ValuesAsInModelContext, child)
return ValuesAsInModelContext(context.values, context.include_colon_eq, child)
return ValuesAsInModelContext(
context.values, context.include_colon_eq, context.tracked_varnames, child
)
end

is_extracting_values(context::ValuesAsInModelContext) = context.include_colon_eq
Expand Down Expand Up @@ -63,29 +82,38 @@

# `tilde_asssume`
function tilde_assume(context::ValuesAsInModelContext, right, vn, vi)
if is_tracked_value(right)
is_tracked_value_right = is_tracked_value(right)
if is_tracked_value_right
value = right.value
logp = zero(getlogp(vi))
else
value, logp, vi = tilde_assume(childcontext(context), right, vn, vi)
end
# Save the value.
push!(context, vn, value)
# Save the value.
if is_tracked_value_right ||
isnothing(context.tracked_varnames) ||
any(tracked_vn -> subsumes(tracked_vn, vn), context.tracked_varnames)
push!(context, vn, value)
end
# Pass on.
return value, logp, vi
end
function tilde_assume(
rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi
)
if is_tracked_value(right)
is_tracked_value_right = is_tracked_value(right)
if is_tracked_value_right

Check warning on line 105 in src/values_as_in_model.jl

View check run for this annotation

Codecov / codecov/patch

src/values_as_in_model.jl#L104-L105

Added lines #L104 - L105 were not covered by tests
value = right.value
logp = zero(getlogp(vi))
else
value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
end
# Save the value.
push!(context, vn, value)
if is_tracked_value_right ||

Check warning on line 112 in src/values_as_in_model.jl

View check run for this annotation

Codecov / codecov/patch

src/values_as_in_model.jl#L112

Added line #L112 was not covered by tests
isnothing(context.tracked_varnames) ||
any(tracked_vn -> subsumes(tracked_vn, vn), context.tracked_varnames)
push!(context, vn, value)

Check warning on line 115 in src/values_as_in_model.jl

View check run for this annotation

Codecov / codecov/patch

src/values_as_in_model.jl#L114-L115

Added lines #L114 - L115 were not covered by tests
end
# Pass on.
return value, logp, vi
end
Expand Down Expand Up @@ -167,9 +195,11 @@
model::Model,
include_colon_eq::Bool,
varinfo::AbstractVarInfo,
tracked_varnames=model.tracked_varnames,
context::AbstractContext=DefaultContext(),
)
context = ValuesAsInModelContext(include_colon_eq, context)
tracked_varnames = isnothing(tracked_varnames) ? nothing : collect(tracked_varnames)
context = ValuesAsInModelContext(include_colon_eq, tracked_varnames, context)
evaluate!!(model, varinfo, context)
return context.values
end
12 changes: 0 additions & 12 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -728,18 +728,6 @@ module Issue537 end
varinfo = VarInfo(model)
@test haskey(varinfo, @varname(x))
@test !haskey(varinfo, @varname(y))

# While `values_as_in_model` should contain both `x` and `y`, if
# include_colon_eq is set to `true`.
values = values_as_in_model(model, true, deepcopy(varinfo))
@test haskey(values, @varname(x))
@test haskey(values, @varname(y))

# And if include_colon_eq is set to `false`, then `values` should
# only contain `x`.
values = values_as_in_model(model, false, deepcopy(varinfo))
@test haskey(values, @varname(x))
@test !haskey(values, @varname(y))
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ end
ctx1 = PrefixContext{:a}(DefaultContext())
ctx2 = SamplingContext(ctx1)
ctx3 = PrefixContext{:b}(ctx2)
ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3)
ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, nothing, ctx3)
vn_prefixed1 = prefix(ctx1, vn)
vn_prefixed2 = prefix(ctx2, vn)
vn_prefixed3 = prefix(ctx3, vn)
Expand Down
42 changes: 0 additions & 42 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -410,48 +410,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
end
end

@testset "values_as_in_model" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
vns = DynamicPPL.TestUtils.varnames(model)
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
# We can set the include_colon_eq arg to false because none of
# the demo models contain :=. The behaviour when
# include_colon_eq is true is tested in test/compiler.jl
realizations = values_as_in_model(model, false, varinfo)
# Ensure that all variables are found.
vns_found = collect(keys(realizations))
@test vns ∩ vns_found == vns ∪ vns_found
# Ensure that the values are the same.
for vn in vns
@test realizations[vn] == varinfo[vn]
end
end
end

@testset "Prefixing" begin
@model inner() = x ~ Normal()

@model function outer_auto_prefix()
a ~ to_submodel(inner(), true)
b ~ to_submodel(inner(), true)
return nothing
end
@model function outer_manual_prefix()
a ~ to_submodel(prefix(inner(), :a), false)
b ~ to_submodel(prefix(inner(), :b), false)
return nothing
end

for model in (outer_auto_prefix(), outer_manual_prefix())
vi = VarInfo(model)
vns = Set(keys(values_as_in_model(model, false, vi)))
@test vns == Set([@varname(var"a.x"), @varname(var"b.x")])
end
end
end

@testset "Erroneous model call" begin
# Calling a model with the wrong arguments used to lead to infinite recursion, see
# https://github.com/TuringLang/Turing.jl/issues/2182. This guards against it.
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ include("test_util.jl")
include("varinfo.jl")
include("simple_varinfo.jl")
include("model.jl")
include("values_as_in_model.jl")
include("sampler.jl")
include("independence.jl")
include("distribution_wrappers.jl")
Expand Down
Loading
Loading