Skip to content

Commit af37402

Browse files
authored
Issue 252: Composing complex models (#296)
* add prefixing to ConcatLatentModels * start working on concatlatentmodels * fix concat constructor * add tests for ConcatLatentModels * add draft ascertainment method * work on tests * use latent_model not latentmodel * use latent_model not latentmodel * fix Ascertainment constructors * add tests for prefix_submodel * start to add helper strucs * add tests and constructors for Prefix wrappers around prefix_submodel * reduce custom code by using Prefix constructors * switch stackobservation models to using new prefixwrapper * fix ascertaiment to use latent models * fix testing issues related to prefix_submodel * fix final tests
1 parent 20298b2 commit af37402

27 files changed

+358
-113
lines changed

EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ module EpiAwareUtils
66
using ..EpiAwareBase
77

88
using DataFramesMeta: DataFrame, @rename!
9-
using DynamicPPL: Model, fix, condition
9+
using DynamicPPL: Model, fix, condition, @submodel, @model
1010
using MCMCChains: Chains
1111
using Random: AbstractRNG
1212
using Tables: rowtable
@@ -17,12 +17,13 @@ using Distributions, DocStringExtensions, QuadGK, Statistics, Turing
1717
export HalfNormal, DirectSample
1818

1919
#Export functions
20-
export scan, spread_draws, censored_pmf, get_param_array
20+
export scan, spread_draws, censored_pmf, get_param_array, prefix_submodel
2121

2222
include("docstrings.jl")
2323
include("censored_pmf.jl")
2424
include("HalfNormal.jl")
2525
include("scan.jl")
26+
include("prefix_submodel.jl")
2627
include("turing-methods.jl")
2728
include("DirectSample.jl")
2829
include("post-inference.jl")
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
@doc raw"
2+
Generate a submodel with an optional prefix. A lightweight wrapper around the `@submodel` macro from DynamicPPL.jl.
3+
4+
# Arguments
5+
6+
- `model::AbstractModel`: The model to be used.
7+
- `fn::Function`: The Turing @model function to be applied to the model.
8+
- `prefix::String`: The prefix to be used. If the prefix is an empty string, the submodel is created without a prefix.
9+
10+
# Returns
11+
12+
- `submodel`: The returns from the submodel are passed through.
13+
14+
# Examples
15+
16+
```julia
17+
using EpiAware
18+
submodel = prefix_submodel(CombineLatentModels([FixedIntercept(0.1), AR()]), generate_latent, \"Test\", 10)
19+
rand(submodel)
20+
```
21+
"
22+
@model function prefix_submodel(
23+
model::AbstractModel, fn::Function, prefix::String, kwargs...)
24+
if prefix == ""
25+
@submodel submodel = fn(model, kwargs...)
26+
else
27+
@submodel prefix=eval(prefix) submodel=fn(model, kwargs...)
28+
end
29+
return submodel
30+
end

EpiAware/src/EpiLatentModels/EpiLatentModels.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ module EpiLatentModels
55

66
using ..EpiAwareBase
77

8-
using ..EpiAwareUtils: HalfNormal
8+
using ..EpiAwareUtils: HalfNormal, prefix_submodel
99

1010
using LogExpFunctions: softmax
1111

@@ -26,7 +26,7 @@ export RepeatEach, RepeatBlock
2626
export broadcast_dayofweek, broadcast_weekly, equal_dimensions
2727

2828
# Export tools for modifying latent models
29-
export DiffLatentModel, TransformLatentModel
29+
export DiffLatentModel, TransformLatentModel, PrefixLatentModel
3030

3131
include("docstrings.jl")
3232
include("models/Intercept.jl")
@@ -35,9 +35,9 @@ include("models/AR.jl")
3535
include("models/HierarchicalNormal.jl")
3636
include("modifiers/DiffLatentModel.jl")
3737
include("modifiers/TransformLatentModel.jl")
38+
include("modifiers/PrefixLatentModel.jl")
3839
include("manipulators/CombineLatentModels.jl")
3940
include("manipulators/ConcatLatentModels.jl")
40-
4141
include("manipulators/broadcast/LatentModel.jl")
4242
include("manipulators/broadcast/rules.jl")
4343
include("manipulators/broadcast/helpers.jl")

EpiAware/src/EpiLatentModels/manipulators/CombineLatentModels.jl

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
@doc raw"
22
The `CombineLatentModels` struct.
33
4-
This struct is used to combine multiple latent models into a single latent model.
4+
This struct is used to combine multiple latent models into a single latent model. If a prefix is supplied wraps each model with `PrefixLatentModel`.
55
66
# Constructors
7-
8-
- `CombineLatentModels(models::M) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `CombineLatentModels` instance with specified models, ensuring that there are at least two models.
9-
- `CombineLatentModels(; models::M) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `CombineLatentModels` instance with specified models, ensuring that there are at least two models.
7+
- `CombineLatentModels(models::M, prefixes::P) where {M <: AbstractVector{<:AbstractTuringLatentModel}, P <: AbstractVector{<:String}}`: Constructs a `CombineLatentModels` instance with specified models and prefixes, ensuring that there are at least two models and the number of models and prefixes are equal.
8+
- `CombineLatentModels(models::M) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `CombineLatentModels` instance with specified models, automatically generating prefixes for each model. The
9+
automatic prefixes are of the form `Combine.1`, `Combine.2`, etc.
1010
1111
# Examples
1212
@@ -17,15 +17,33 @@ latent_model = generate_latent(combined_model, 10)
1717
latent_model()
1818
```
1919
"
20-
@kwdef struct CombineLatentModels{M <: AbstractVector{<:AbstractTuringLatentModel}} <:
20+
@kwdef struct CombineLatentModels{
21+
M <: AbstractVector{<:AbstractTuringLatentModel}, P <: AbstractVector{<:String}} <:
2122
AbstractTuringLatentModel
2223
"A vector of latent models"
2324
models::M
25+
"A vector of prefixes for the latent models"
26+
prefixes::P
2427

25-
function CombineLatentModels(models::M) where {M <:
26-
AbstractVector{<:AbstractTuringLatentModel}}
28+
function CombineLatentModels(models::M,
29+
prefixes::P) where {
30+
M <: AbstractVector{<:AbstractTuringLatentModel},
31+
P <: AbstractVector{<:String}}
2732
@assert length(models)>1 "At least two models are required"
28-
return new{AbstractVector{<:AbstractTuringLatentModel}}(models)
33+
@assert length(models)==length(prefixes) "The number of models and prefixes must be equal"
34+
for i in eachindex(models)
35+
if (prefixes[i] != "")
36+
models[i] = PrefixLatentModel(models[i], prefixes[i])
37+
end
38+
end
39+
return new{AbstractVector{<:AbstractTuringLatentModel}, AbstractVector{<:String}}(
40+
models, prefixes)
41+
end
42+
43+
function CombineLatentModels(models::M) where {
44+
M <: AbstractVector{<:AbstractTuringLatentModel}}
45+
prefixes = "Combine." .* string.(1:length(models))
46+
return CombineLatentModels(models, prefixes)
2947
end
3048
end
3149

@@ -49,7 +67,8 @@ Generate latent variables using a combination of multiple latent models.
4967
return final_latent, (; latent_aux...)
5068
end
5169

52-
@model function _accumulate_latents(models, index, acc_latent, acc_aux, n, n_models)
70+
@model function _accumulate_latents(
71+
models, index, acc_latent, acc_aux, n, n_models)
5372
if index > n_models
5473
return acc_latent, (; acc_aux...)
5574
else

EpiAware/src/EpiLatentModels/manipulators/ConcatLatentModels.jl

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ This struct is used to concatenate multiple latent models into a single latent m
55
66
# Constructors
77
8-
- `ConcatLatentModels(models::M, no_models::Int, dimension_adaptor::Function) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `ConcatLatentModels` instance with specified models, number of models, and dimension adaptor.
9-
- `ConcatLatentModels(models::M, dimension_adaptor::Function) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `ConcatLatentModels` instance with specified models and dimension adaptor, ensuring that there are at least two models. The default dimension adaptor is `equal_dimensions`.
10-
- `ConcatLatentModels(; models::M, dimension_adaptor::Function) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `ConcatLatentModels` instance with specified models and dimension adaptor, ensuring that there are at least two models. The default dimension adaptor is `equal_dimensions`.
8+
- `ConcatLatentModels(models::M, no_models::I, dimension_adaptor::F, prefixes::P) where {M <: AbstractVector{<:AbstractTuringLatentModel}, I <: Int, F <: Function, P <: AbstractVector{String}}`: Constructs a `ConcatLatentModels` instance with specified models, number of models, dimension adaptor, and prefixes.
9+
- `ConcatLatentModels(models::M, dimension_adaptor::F; prefixes::P = \"Concat.\" * string.(1:length(models))) where {M <: AbstractVector{<:AbstractTuringLatentModel}, F <: Function}`: Constructs a `ConcatLatentModels` instance with specified models and dimension adaptor. The number of models is automatically determined as are the prefixes (of the form `Concat.1`, `Concat.2`, etc.) by default.
10+
- `ConcatLatentModels(models::M; dimension_adaptor::Function, prefixes::P) where {M <: AbstractVector{<:AbstractTuringLatentModel}, P <: AbstractVector{String}}`: Constructs a `ConcatLatentModels` instance with specified models, dimension adaptor, prefixes, and automatically determines the number of models.The default dimension adaptor is `equal_dimensions`. The default prefixes are of the form `Concat.1`, `Concat.2`, etc.
11+
- `ConcatLatentModels(; models::M, dimension_adaptor::Function, prefixes::P) where {M <: AbstractVector{<:AbstractTuringLatentModel}, P <: AbstractVector{String}}`: Constructs a `ConcatLatentModels` instance with specified models, dimension adaptor, prefixes, and automatically determines the number of models. The default dimension adaptor is `equal_dimensions`. The default prefixes are of the form `Concat.1`, `Concat.2`, etc.
1112
1213
# Examples
1314
@@ -19,46 +20,62 @@ latent_model()
1920
```
2021
"
2122
struct ConcatLatentModels{
22-
M <: AbstractVector{<:AbstractTuringLatentModel}, N <: Int, F <: Function} <:
23+
M <: AbstractVector{<:AbstractTuringLatentModel}, N <: Int, F <: Function, P <:
24+
AbstractVector{<:String}} <:
2325
AbstractTuringLatentModel
2426
"A vector of latent models"
2527
models::M
2628
"The number of models in the collection"
2729
no_models::N
2830
"The dimension function for the latent variables. By default this divides the number of latent variables by the number of models and returns a vector of dimensions rounding up the first element and rounding down the rest."
2931
dimension_adaptor::F
32+
"A vector of prefixes for the latent models"
33+
prefixes::P
3034

3135
function ConcatLatentModels(models::M,
3236
no_models::I,
33-
dimension_adaptor::F) where {
37+
dimension_adaptor::F, prefixes::P) where {
3438
M <: AbstractVector{<:AbstractTuringLatentModel}, I <: Int,
35-
F <: Function}
39+
F <: Function, P <: AbstractVector{<:String}}
3640
@assert length(models)>1 "At least two models are required"
3741
@assert length(models)==no_models "no_models must be equal to the number of models"
3842
# check all dimension functions take a single n and return an integer
3943
check_dim = dimension_adaptor(no_models, no_models)
4044
@assert typeof(check_dim)<:AbstractVector{Int} "Output of dimension_adaptor must be a vector of integers"
4145
@assert length(check_dim)==no_models "The vector of dimensions must have the same length as the number of models"
42-
return new{AbstractVector{<:AbstractTuringLatentModel}, Int, Function}(
43-
models, no_models, dimension_adaptor)
46+
@assert length(prefixes)==no_models "The number of models and prefixes must be equal"
47+
for i in eachindex(models)
48+
if (prefixes[i] != "")
49+
models[i] = PrefixLatentModel(models[i], prefixes[i])
50+
end
51+
end
52+
return new{
53+
AbstractVector{<:AbstractTuringLatentModel}, Int, Function,
54+
AbstractVector{<:String}}(
55+
models, no_models, dimension_adaptor, prefixes)
4456
end
4557

46-
function ConcatLatentModels(models::M,
47-
dimension_adaptor::Function) where {
58+
function ConcatLatentModels(models::M, dimension_adaptor::Function;
59+
prefixes = nothing) where {
4860
M <: AbstractVector{<:AbstractTuringLatentModel}}
49-
return ConcatLatentModels(models, length(models), dimension_adaptor)
61+
no_models = length(models)
62+
if isnothing(prefixes)
63+
prefixes = "Concat." .* string.(1:no_models)
64+
end
65+
return ConcatLatentModels(models, no_models, dimension_adaptor, prefixes)
5066
end
5167

5268
function ConcatLatentModels(models::M;
53-
dimension_adaptor::Function = equal_dimensions) where {
69+
dimension_adaptor::Function = equal_dimensions,
70+
prefixes = nothing) where {
5471
M <: AbstractVector{<:AbstractTuringLatentModel}}
55-
return ConcatLatentModels(models, dimension_adaptor)
72+
return ConcatLatentModels(models, dimension_adaptor; prefixes = prefixes)
5673
end
5774

5875
function ConcatLatentModels(; models::M,
59-
dimension_adaptor::Function = equal_dimensions) where {
76+
dimension_adaptor::Function = equal_dimensions, prefixes = nothing) where {
6077
M <: AbstractVector{<:AbstractTuringLatentModel}}
61-
return ConcatLatentModels(models, dimension_adaptor)
78+
return ConcatLatentModels(models, dimension_adaptor; prefixes = prefixes)
6279
end
6380
end
6481

@@ -102,7 +119,8 @@ Generate latent variables by concatenating multiple latent models.
102119
end
103120

104121
@model function _concat_latents(
105-
models, index::Int, acc_latent, acc_aux, dims::AbstractVector{<:Int}, n_models::Int)
122+
models, index::Int, acc_latent, acc_aux,
123+
dims::AbstractVector{<:Int}, n_models::Int)
106124
if index > n_models
107125
return acc_latent, (; acc_aux...)
108126
else

EpiAware/src/EpiLatentModels/modifiers/DiffLatentModel.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ terms ``Z_1, \ldots, Z_d`` are inferred.
1818
1919
## Constructors
2020
21-
- `DiffLatentModel(latentmodel, init_prior_distribution::Distribution; d::Int)`
22-
Constructs a `DiffLatentModel` for `d`-fold differencing with `latentmodel` as the
21+
- `DiffLatentModel(latent_model, init_prior_distribution::Distribution; d::Int)`
22+
Constructs a `DiffLatentModel` for `d`-fold differencing with `latent_model` as the
2323
undifferenced latent process. All initial terms have common prior
2424
`init_prior_distribution`.
2525
- `DiffLatentModel(;model, init_priors::Vector{D} where {D <: Distribution})`
26-
Constructs a `DiffLatentModel` for `d`-fold differencing with `latentmodel` as the
26+
Constructs a `DiffLatentModel` for `d`-fold differencing with `latent_model` as the
2727
undifferenced latent process. The `d` initial terms have priors given by the vector
2828
`init_priors`, therefore `length(init_priors)` sets `d`.
2929
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
@doc raw"
2+
Generate a latent model with a prefix. A lightweight wrapper around `EpiAwareUtils.prefix_submodel`.
3+
4+
# Constructors
5+
- `PrefixLatentModel(model::M, prefix::P)`: Create a `PrefixLatentModel` with the latent model `model` and the prefix `prefix`.
6+
- `PrefixLatentModel(; model::M, prefix::P)`: Create a `PrefixLatentModel` with the latent model `model` and the prefix `prefix`.
7+
8+
# Examples
9+
```julia
10+
using EpiAware
11+
latent_model = PrefixLatentModel(model = HierarchicalNormal(), prefix = \"Test\")
12+
mdl = generate_latent(latent_model, 10)
13+
rand(mdl)
14+
```
15+
"
16+
@kwdef struct PrefixLatentModel{M <: AbstractTuringLatentModel, P <: String} <:
17+
AbstractTuringLatentModel
18+
"The latent model"
19+
model::M
20+
"The prefix for the latent model"
21+
prefix::P
22+
end
23+
24+
@model function EpiAwareBase.generate_latent(latent_model::PrefixLatentModel, n)
25+
@submodel submodel = prefix_submodel(
26+
latent_model.model, generate_latent, latent_model.prefix, n)
27+
return submodel
28+
end

EpiAware/src/EpiObsModels/EpiObsModels.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ module EpiObsModels
55

66
using ..EpiAwareBase
77

8-
using ..EpiAwareUtils: censored_pmf, HalfNormal
8+
using ..EpiAwareUtils: censored_pmf, HalfNormal, prefix_submodel
99

10-
using ..EpiLatentModels: HierarchicalNormal, broadcast_dayofweek
10+
using ..EpiLatentModels: HierarchicalNormal, broadcast_dayofweek, PrefixLatentModel
1111

1212
using Turing, Distributions, DocStringExtensions, SparseArrays
1313

@@ -18,15 +18,19 @@ export PoissonError, NegativeBinomialError
1818
export generate_observation_error_priors, observation_error
1919

2020
# Observation model modifiers
21-
export LatentDelay, Ascertainment, StackObservationModels
21+
export LatentDelay, Ascertainment, PrefixObservationModel
22+
23+
# Observation model manipulators
24+
export StackObservationModels
2225

2326
# helper functions
2427
export ascertainment_dayofweek
2528

2629
include("docstrings.jl")
27-
include("LatentDelay.jl")
28-
include("ascertainment/Ascertainment.jl")
29-
include("ascertainment/helpers.jl")
30+
include("modifiers/LatentDelay.jl")
31+
include("modifiers/ascertainment/Ascertainment.jl")
32+
include("modifiers/ascertainment/helpers.jl")
33+
include("modifiers/PrefixObservationModel.jl")
3034
include("StackObservationModels.jl")
3135
include("ObservationErrorModels/methods.jl")
3236
include("ObservationErrorModels/NegativeBinomialError.jl")

EpiAware/src/EpiObsModels/StackObservationModels.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
@doc raw"
22
33
A stack of observation models that are looped over to generate observations for
4-
each model in the stack. Note that the model names are used to prefix the parameters in each model (so if I have a model named `cases` and a parameter `y_t`, the parameter in the model will be `cases.y_t`).
4+
each model in the stack. Note that the model names are used to prefix the parameters in each model (so if I have a model named `cases` and a parameter `y_t`, the parameter in the model will be `cases.y_t`). Inside the constructor `PrefixObservationModel` is wrapped around each observation model.
55
66
## Constructors
77
@@ -48,7 +48,10 @@ deaths_y_t
4848
N <: AbstractString
4949
}
5050
@assert length(models)==length(model_names) "The number of models and model names must be equal."
51-
new{typeof(models), typeof(model_names)}(models, model_names)
51+
wrapped_models = [PrefixObservationModel(models[i], model_names[i])
52+
for i in eachindex(models)]
53+
new{AbstractVector{<:AbstractTuringObservationModel}, typeof(model_names)}(
54+
wrapped_models, model_names)
5255
end
5356

5457
function StackObservationModels(models::NamedTuple{
@@ -77,7 +80,7 @@ Generate observations from a stack of observation models. Assumes a 1 to 1 mappi
7780

7881
obs = ()
7982
for (model, model_name) in zip(obs_model.models, obs_model.model_names)
80-
@submodel prefix=eval(model_name) obs_tmp=generate_observations(
83+
@submodel obs_tmp = generate_observations(
8184
model, y_t[Symbol(model_name)], Y_t[Symbol(model_name)])
8285
obs = obs..., obs_tmp...
8386
end

0 commit comments

Comments
 (0)