Skip to content

Commit

Permalink
Issue 252: Composing complex models (#296)
Browse files Browse the repository at this point in the history
* add prefixing to ConcatLatentModels

* start working on concatlatentmodels

* fix concat constructor

* add tests for ConcatLatentModels

* add draft ascertainment method

* work on tests

* use latent_model not latentmodel

* use latent_model not latentmodel

* fix Ascertainment constructors

* add tests for prefix_submodel

* start to add helper strucs

* add tests and constructors for Prefix wrappers around prefix_submodel

* reduce custom code by using Prefix constructors

* switch stackobservation models to using new prefixwrapper

* fix ascertaiment to use latent models

* fix testing issues related to prefix_submodel

* fix final tests
  • Loading branch information
seabbs authored Jun 27, 2024
1 parent 20298b2 commit af37402
Show file tree
Hide file tree
Showing 27 changed files with 358 additions and 113 deletions.
5 changes: 3 additions & 2 deletions EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ module EpiAwareUtils
using ..EpiAwareBase

using DataFramesMeta: DataFrame, @rename!
using DynamicPPL: Model, fix, condition
using DynamicPPL: Model, fix, condition, @submodel, @model
using MCMCChains: Chains
using Random: AbstractRNG
using Tables: rowtable
Expand All @@ -17,12 +17,13 @@ using Distributions, DocStringExtensions, QuadGK, Statistics, Turing
export HalfNormal, DirectSample

#Export functions
export scan, spread_draws, censored_pmf, get_param_array
export scan, spread_draws, censored_pmf, get_param_array, prefix_submodel

include("docstrings.jl")
include("censored_pmf.jl")
include("HalfNormal.jl")
include("scan.jl")
include("prefix_submodel.jl")
include("turing-methods.jl")
include("DirectSample.jl")
include("post-inference.jl")
Expand Down
30 changes: 30 additions & 0 deletions EpiAware/src/EpiAwareUtils/prefix_submodel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
@doc raw"
Generate a submodel with an optional prefix. A lightweight wrapper around the `@submodel` macro from DynamicPPL.jl.
# Arguments
- `model::AbstractModel`: The model to be used.
- `fn::Function`: The Turing @model function to be applied to the model.
- `prefix::String`: The prefix to be used. If the prefix is an empty string, the submodel is created without a prefix.
# Returns
- `submodel`: The returns from the submodel are passed through.
# Examples
```julia
using EpiAware
submodel = prefix_submodel(CombineLatentModels([FixedIntercept(0.1), AR()]), generate_latent, \"Test\", 10)
rand(submodel)
```
"
@model function prefix_submodel(
model::AbstractModel, fn::Function, prefix::String, kwargs...)
if prefix == ""
@submodel submodel = fn(model, kwargs...)
else
@submodel prefix=eval(prefix) submodel=fn(model, kwargs...)
end
return submodel
end
6 changes: 3 additions & 3 deletions EpiAware/src/EpiLatentModels/EpiLatentModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module EpiLatentModels

using ..EpiAwareBase

using ..EpiAwareUtils: HalfNormal
using ..EpiAwareUtils: HalfNormal, prefix_submodel

using LogExpFunctions: softmax

Expand All @@ -26,7 +26,7 @@ export RepeatEach, RepeatBlock
export broadcast_dayofweek, broadcast_weekly, equal_dimensions

# Export tools for modifying latent models
export DiffLatentModel, TransformLatentModel
export DiffLatentModel, TransformLatentModel, PrefixLatentModel

include("docstrings.jl")
include("models/Intercept.jl")
Expand All @@ -35,9 +35,9 @@ include("models/AR.jl")
include("models/HierarchicalNormal.jl")
include("modifiers/DiffLatentModel.jl")
include("modifiers/TransformLatentModel.jl")
include("modifiers/PrefixLatentModel.jl")
include("manipulators/CombineLatentModels.jl")
include("manipulators/ConcatLatentModels.jl")

include("manipulators/broadcast/LatentModel.jl")
include("manipulators/broadcast/rules.jl")
include("manipulators/broadcast/helpers.jl")
Expand Down
37 changes: 28 additions & 9 deletions EpiAware/src/EpiLatentModels/manipulators/CombineLatentModels.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
@doc raw"
The `CombineLatentModels` struct.
This struct is used to combine multiple latent models into a single latent model.
This struct is used to combine multiple latent models into a single latent model. If a prefix is supplied wraps each model with `PrefixLatentModel`.
# Constructors
- `CombineLatentModels(models::M) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `CombineLatentModels` instance with specified models, ensuring that there are at least two models.
- `CombineLatentModels(; models::M) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `CombineLatentModels` instance with specified models, ensuring that there are at least two models.
- `CombineLatentModels(models::M, prefixes::P) where {M <: AbstractVector{<:AbstractTuringLatentModel}, P <: AbstractVector{<:String}}`: Constructs a `CombineLatentModels` instance with specified models and prefixes, ensuring that there are at least two models and the number of models and prefixes are equal.
- `CombineLatentModels(models::M) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `CombineLatentModels` instance with specified models, automatically generating prefixes for each model. The
automatic prefixes are of the form `Combine.1`, `Combine.2`, etc.
# Examples
Expand All @@ -17,15 +17,33 @@ latent_model = generate_latent(combined_model, 10)
latent_model()
```
"
@kwdef struct CombineLatentModels{M <: AbstractVector{<:AbstractTuringLatentModel}} <:
@kwdef struct CombineLatentModels{
M <: AbstractVector{<:AbstractTuringLatentModel}, P <: AbstractVector{<:String}} <:
AbstractTuringLatentModel
"A vector of latent models"
models::M
"A vector of prefixes for the latent models"
prefixes::P

function CombineLatentModels(models::M) where {M <:
AbstractVector{<:AbstractTuringLatentModel}}
function CombineLatentModels(models::M,
prefixes::P) where {
M <: AbstractVector{<:AbstractTuringLatentModel},
P <: AbstractVector{<:String}}
@assert length(models)>1 "At least two models are required"
return new{AbstractVector{<:AbstractTuringLatentModel}}(models)
@assert length(models)==length(prefixes) "The number of models and prefixes must be equal"
for i in eachindex(models)
if (prefixes[i] != "")
models[i] = PrefixLatentModel(models[i], prefixes[i])
end
end
return new{AbstractVector{<:AbstractTuringLatentModel}, AbstractVector{<:String}}(
models, prefixes)
end

function CombineLatentModels(models::M) where {
M <: AbstractVector{<:AbstractTuringLatentModel}}
prefixes = "Combine." .* string.(1:length(models))
return CombineLatentModels(models, prefixes)
end
end

Expand All @@ -49,7 +67,8 @@ Generate latent variables using a combination of multiple latent models.
return final_latent, (; latent_aux...)
end

@model function _accumulate_latents(models, index, acc_latent, acc_aux, n, n_models)
@model function _accumulate_latents(
models, index, acc_latent, acc_aux, n, n_models)
if index > n_models
return acc_latent, (; acc_aux...)
else
Expand Down
50 changes: 34 additions & 16 deletions EpiAware/src/EpiLatentModels/manipulators/ConcatLatentModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ This struct is used to concatenate multiple latent models into a single latent m
# Constructors
- `ConcatLatentModels(models::M, no_models::Int, dimension_adaptor::Function) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `ConcatLatentModels` instance with specified models, number of models, and dimension adaptor.
- `ConcatLatentModels(models::M, dimension_adaptor::Function) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `ConcatLatentModels` instance with specified models and dimension adaptor, ensuring that there are at least two models. The default dimension adaptor is `equal_dimensions`.
- `ConcatLatentModels(; models::M, dimension_adaptor::Function) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `ConcatLatentModels` instance with specified models and dimension adaptor, ensuring that there are at least two models. The default dimension adaptor is `equal_dimensions`.
- `ConcatLatentModels(models::M, no_models::I, dimension_adaptor::F, prefixes::P) where {M <: AbstractVector{<:AbstractTuringLatentModel}, I <: Int, F <: Function, P <: AbstractVector{String}}`: Constructs a `ConcatLatentModels` instance with specified models, number of models, dimension adaptor, and prefixes.
- `ConcatLatentModels(models::M, dimension_adaptor::F; prefixes::P = \"Concat.\" * string.(1:length(models))) where {M <: AbstractVector{<:AbstractTuringLatentModel}, F <: Function}`: Constructs a `ConcatLatentModels` instance with specified models and dimension adaptor. The number of models is automatically determined as are the prefixes (of the form `Concat.1`, `Concat.2`, etc.) by default.
- `ConcatLatentModels(models::M; dimension_adaptor::Function, prefixes::P) where {M <: AbstractVector{<:AbstractTuringLatentModel}, P <: AbstractVector{String}}`: Constructs a `ConcatLatentModels` instance with specified models, dimension adaptor, prefixes, and automatically determines the number of models.The default dimension adaptor is `equal_dimensions`. The default prefixes are of the form `Concat.1`, `Concat.2`, etc.
- `ConcatLatentModels(; models::M, dimension_adaptor::Function, prefixes::P) where {M <: AbstractVector{<:AbstractTuringLatentModel}, P <: AbstractVector{String}}`: Constructs a `ConcatLatentModels` instance with specified models, dimension adaptor, prefixes, and automatically determines the number of models. The default dimension adaptor is `equal_dimensions`. The default prefixes are of the form `Concat.1`, `Concat.2`, etc.
# Examples
Expand All @@ -19,46 +20,62 @@ latent_model()
```
"
struct ConcatLatentModels{
M <: AbstractVector{<:AbstractTuringLatentModel}, N <: Int, F <: Function} <:
M <: AbstractVector{<:AbstractTuringLatentModel}, N <: Int, F <: Function, P <:
AbstractVector{<:String}} <:
AbstractTuringLatentModel
"A vector of latent models"
models::M
"The number of models in the collection"
no_models::N
"The dimension function for the latent variables. By default this divides the number of latent variables by the number of models and returns a vector of dimensions rounding up the first element and rounding down the rest."
dimension_adaptor::F
"A vector of prefixes for the latent models"
prefixes::P

function ConcatLatentModels(models::M,
no_models::I,
dimension_adaptor::F) where {
dimension_adaptor::F, prefixes::P) where {
M <: AbstractVector{<:AbstractTuringLatentModel}, I <: Int,
F <: Function}
F <: Function, P <: AbstractVector{<:String}}
@assert length(models)>1 "At least two models are required"
@assert length(models)==no_models "no_models must be equal to the number of models"
# check all dimension functions take a single n and return an integer
check_dim = dimension_adaptor(no_models, no_models)
@assert typeof(check_dim)<:AbstractVector{Int} "Output of dimension_adaptor must be a vector of integers"
@assert length(check_dim)==no_models "The vector of dimensions must have the same length as the number of models"
return new{AbstractVector{<:AbstractTuringLatentModel}, Int, Function}(
models, no_models, dimension_adaptor)
@assert length(prefixes)==no_models "The number of models and prefixes must be equal"
for i in eachindex(models)
if (prefixes[i] != "")
models[i] = PrefixLatentModel(models[i], prefixes[i])
end
end
return new{
AbstractVector{<:AbstractTuringLatentModel}, Int, Function,
AbstractVector{<:String}}(
models, no_models, dimension_adaptor, prefixes)
end

function ConcatLatentModels(models::M,
dimension_adaptor::Function) where {
function ConcatLatentModels(models::M, dimension_adaptor::Function;
prefixes = nothing) where {
M <: AbstractVector{<:AbstractTuringLatentModel}}
return ConcatLatentModels(models, length(models), dimension_adaptor)
no_models = length(models)
if isnothing(prefixes)
prefixes = "Concat." .* string.(1:no_models)
end
return ConcatLatentModels(models, no_models, dimension_adaptor, prefixes)
end

function ConcatLatentModels(models::M;
dimension_adaptor::Function = equal_dimensions) where {
dimension_adaptor::Function = equal_dimensions,
prefixes = nothing) where {
M <: AbstractVector{<:AbstractTuringLatentModel}}
return ConcatLatentModels(models, dimension_adaptor)
return ConcatLatentModels(models, dimension_adaptor; prefixes = prefixes)
end

function ConcatLatentModels(; models::M,
dimension_adaptor::Function = equal_dimensions) where {
dimension_adaptor::Function = equal_dimensions, prefixes = nothing) where {
M <: AbstractVector{<:AbstractTuringLatentModel}}
return ConcatLatentModels(models, dimension_adaptor)
return ConcatLatentModels(models, dimension_adaptor; prefixes = prefixes)
end
end

Expand Down Expand Up @@ -102,7 +119,8 @@ Generate latent variables by concatenating multiple latent models.
end

@model function _concat_latents(
models, index::Int, acc_latent, acc_aux, dims::AbstractVector{<:Int}, n_models::Int)
models, index::Int, acc_latent, acc_aux,
dims::AbstractVector{<:Int}, n_models::Int)
if index > n_models
return acc_latent, (; acc_aux...)
else
Expand Down
6 changes: 3 additions & 3 deletions EpiAware/src/EpiLatentModels/modifiers/DiffLatentModel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ terms ``Z_1, \ldots, Z_d`` are inferred.
## Constructors
- `DiffLatentModel(latentmodel, init_prior_distribution::Distribution; d::Int)`
Constructs a `DiffLatentModel` for `d`-fold differencing with `latentmodel` as the
- `DiffLatentModel(latent_model, init_prior_distribution::Distribution; d::Int)`
Constructs a `DiffLatentModel` for `d`-fold differencing with `latent_model` as the
undifferenced latent process. All initial terms have common prior
`init_prior_distribution`.
- `DiffLatentModel(;model, init_priors::Vector{D} where {D <: Distribution})`
Constructs a `DiffLatentModel` for `d`-fold differencing with `latentmodel` as the
Constructs a `DiffLatentModel` for `d`-fold differencing with `latent_model` as the
undifferenced latent process. The `d` initial terms have priors given by the vector
`init_priors`, therefore `length(init_priors)` sets `d`.
Expand Down
28 changes: 28 additions & 0 deletions EpiAware/src/EpiLatentModels/modifiers/PrefixLatentModel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
@doc raw"
Generate a latent model with a prefix. A lightweight wrapper around `EpiAwareUtils.prefix_submodel`.
# Constructors
- `PrefixLatentModel(model::M, prefix::P)`: Create a `PrefixLatentModel` with the latent model `model` and the prefix `prefix`.
- `PrefixLatentModel(; model::M, prefix::P)`: Create a `PrefixLatentModel` with the latent model `model` and the prefix `prefix`.
# Examples
```julia
using EpiAware
latent_model = PrefixLatentModel(model = HierarchicalNormal(), prefix = \"Test\")
mdl = generate_latent(latent_model, 10)
rand(mdl)
```
"
@kwdef struct PrefixLatentModel{M <: AbstractTuringLatentModel, P <: String} <:
AbstractTuringLatentModel
"The latent model"
model::M
"The prefix for the latent model"
prefix::P
end

@model function EpiAwareBase.generate_latent(latent_model::PrefixLatentModel, n)
@submodel submodel = prefix_submodel(
latent_model.model, generate_latent, latent_model.prefix, n)
return submodel
end
16 changes: 10 additions & 6 deletions EpiAware/src/EpiObsModels/EpiObsModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ module EpiObsModels

using ..EpiAwareBase

using ..EpiAwareUtils: censored_pmf, HalfNormal
using ..EpiAwareUtils: censored_pmf, HalfNormal, prefix_submodel

using ..EpiLatentModels: HierarchicalNormal, broadcast_dayofweek
using ..EpiLatentModels: HierarchicalNormal, broadcast_dayofweek, PrefixLatentModel

using Turing, Distributions, DocStringExtensions, SparseArrays

Expand All @@ -18,15 +18,19 @@ export PoissonError, NegativeBinomialError
export generate_observation_error_priors, observation_error

# Observation model modifiers
export LatentDelay, Ascertainment, StackObservationModels
export LatentDelay, Ascertainment, PrefixObservationModel

# Observation model manipulators
export StackObservationModels

# helper functions
export ascertainment_dayofweek

include("docstrings.jl")
include("LatentDelay.jl")
include("ascertainment/Ascertainment.jl")
include("ascertainment/helpers.jl")
include("modifiers/LatentDelay.jl")
include("modifiers/ascertainment/Ascertainment.jl")
include("modifiers/ascertainment/helpers.jl")
include("modifiers/PrefixObservationModel.jl")
include("StackObservationModels.jl")
include("ObservationErrorModels/methods.jl")
include("ObservationErrorModels/NegativeBinomialError.jl")
Expand Down
9 changes: 6 additions & 3 deletions EpiAware/src/EpiObsModels/StackObservationModels.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@doc raw"
A stack of observation models that are looped over to generate observations for
each model in the stack. Note that the model names are used to prefix the parameters in each model (so if I have a model named `cases` and a parameter `y_t`, the parameter in the model will be `cases.y_t`).
each model in the stack. Note that the model names are used to prefix the parameters in each model (so if I have a model named `cases` and a parameter `y_t`, the parameter in the model will be `cases.y_t`). Inside the constructor `PrefixObservationModel` is wrapped around each observation model.
## Constructors
Expand Down Expand Up @@ -48,7 +48,10 @@ deaths_y_t
N <: AbstractString
}
@assert length(models)==length(model_names) "The number of models and model names must be equal."
new{typeof(models), typeof(model_names)}(models, model_names)
wrapped_models = [PrefixObservationModel(models[i], model_names[i])
for i in eachindex(models)]
new{AbstractVector{<:AbstractTuringObservationModel}, typeof(model_names)}(
wrapped_models, model_names)
end

function StackObservationModels(models::NamedTuple{
Expand Down Expand Up @@ -77,7 +80,7 @@ Generate observations from a stack of observation models. Assumes a 1 to 1 mappi

obs = ()
for (model, model_name) in zip(obs_model.models, obs_model.model_names)
@submodel prefix=eval(model_name) obs_tmp=generate_observations(
@submodel obs_tmp = generate_observations(
model, y_t[Symbol(model_name)], Y_t[Symbol(model_name)])
obs = obs..., obs_tmp...
end
Expand Down
Loading

0 comments on commit af37402

Please sign in to comment.