-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #39 from CDCgov/38-modular-inference
Implement a proposed modular API to specify a data (i.e. case) generating process
- Loading branch information
Showing
8 changed files
with
199 additions
and
201 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,96 +1,83 @@ | ||
abstract type AbstractEpiModel end | ||
|
||
|
||
|
||
""" | ||
struct EpiModel{T<:Real} <: AbstractEpiModel | ||
EpiModel represents an epidemiological model with generation intervals, delay intervals, and observation delay kernel. | ||
# Fields | ||
- `gen_int::Vector{T}`: Discrete generation inteval, runs from 1, 2, ... to the end of the vector. | ||
- `delay_int::Vector{T}`: Discrete delay distribution runs from 0, 1, ... to the end of the vector less 1. | ||
- `delay_kernel::SparseMatrixCSC{T,Integer}`: Sparse matrix representing the observation delay kernel. | ||
- `cluster_coeff::T`: Cluster coefficient for negative binomial observations. | ||
- `len_gen_int::Integer`: Length of `gen_int`. | ||
- `len_delay_int::Integer`: Length of `delay_int`. | ||
- `time_horizon::Integer`: Length of the generated data. | ||
# Constructors | ||
- `EpiModel(gen_int, delay_int, cluster_coeff, time_horizon::Integer)`: Constructs an EpiModel object with given generation intervals, delay intervals, cluster coefficient, and time horizon. | ||
- `EpiModel(gen_distribution::ContinuousDistribution, delay_distribution::ContinuousDistribution, cluster_coeff, time_horizon::Integer; Δd = 1.0, D_gen, D_delay)`: Constructs an EpiModel object with generation and delay distributions, cluster coefficient, time horizon, and optional parameters. | ||
""" | ||
struct EpiModel{T<:Real} <: AbstractEpiModel | ||
struct EpiData{T<:Real,F<:Function} | ||
gen_int::Vector{T} | ||
delay_int::Vector{T} | ||
delay_kernel::SparseMatrixCSC{T,Integer} | ||
cluster_coeff::T | ||
len_gen_int::Integer #length(gen_int) just to save recalc | ||
len_delay_int::Integer #length(delay_int) just to save recalc | ||
len_gen_int::Integer | ||
len_delay_int::Integer | ||
time_horizon::Integer | ||
transformation::F | ||
|
||
#Inner constructors for EpiModel object | ||
function EpiModel(gen_int, delay_int, cluster_coeff, time_horizon::Integer) | ||
#Inner constructors for EpiData object | ||
function EpiData( | ||
gen_int, | ||
delay_int, | ||
cluster_coeff, | ||
time_horizon::Integer, | ||
transformation::Function, | ||
) | ||
@assert all(gen_int .>= 0) "Generation interval must be non-negative" | ||
@assert all(delay_int .>= 0) "Delay interval must be non-negative" | ||
@assert sum(gen_int) ≈ 1 "Generation interval must sum to 1" | ||
@assert sum(delay_int) ≈ 1 "Delay interval must sum to 1" | ||
|
||
K = generate_observation_kernel(delay_int, time_horizon) | ||
|
||
new{eltype(gen_int)}( | ||
new{eltype(gen_int),typeof(transformation)}( | ||
gen_int, | ||
delay_int, | ||
K, | ||
cluster_coeff, | ||
length(gen_int), | ||
length(delay_int), | ||
time_horizon, | ||
transformation, | ||
) | ||
end | ||
|
||
function EpiModel( | ||
function EpiData( | ||
gen_distribution::ContinuousDistribution, | ||
delay_distribution::ContinuousDistribution, | ||
cluster_coeff, | ||
time_horizon::Integer; | ||
Δd = 1.0, | ||
D_gen, | ||
D_delay, | ||
Δd = 1.0, | ||
transformation::Function = exp, | ||
) | ||
gen_int = | ||
create_discrete_pmf(gen_distribution, Δd = Δd, D = D_gen) |> | ||
p -> p[2:end] ./ sum(p[2:end]) | ||
delay_int = create_discrete_pmf(delay_distribution, Δd = Δd, D = D_delay) | ||
|
||
K = generate_observation_kernel(delay_int, time_horizon) | ||
|
||
new{eltype(gen_int)}( | ||
gen_int, | ||
delay_int, | ||
K, | ||
cluster_coeff, | ||
length(gen_int), | ||
length(delay_int), | ||
time_horizon, | ||
) | ||
return EpiData(gen_int, delay_int, cluster_coeff, time_horizon, transformation) | ||
end | ||
end | ||
|
||
""" | ||
(epi_model::EpiModel)(recent_incidence, Rt) | ||
struct DirectInfections <: AbstractEpiModel | ||
data::EpiData | ||
end | ||
|
||
Apply the EpiModel to calculate new incidence based on recent incidence and Rt. | ||
function (epi_model::DirectInfections)(recent_incidence, unc_I_t) | ||
nothing, epi_model.data.transformation(unc_I_t) | ||
end | ||
|
||
struct ExpGrowthRate <: AbstractEpiModel | ||
data::EpiData | ||
end | ||
|
||
# Arguments | ||
- `recent_incidence`: Array of recent incidence values. | ||
- `Rt`: Reproduction number. | ||
function (epi_model::ExpGrowthRate)(unc_recent_incidence, rt) | ||
new_unc_recent_incidence = unc_recent_incidence + rt | ||
new_unc_recent_incidence, epi_model.data.transformation(new_unc_recent_incidence) | ||
end | ||
|
||
struct Renewal <: AbstractEpiModel | ||
data::EpiData | ||
end | ||
|
||
# Returns | ||
- `new_incidence`: Array of new incidence values. | ||
""" | ||
function (epi_model::EpiModel)(recent_incidence, Rt) | ||
new_incidence = Rt * dot(recent_incidence, epi_model.gen_int) | ||
[new_incidence; recent_incidence[1:(epi_model.len_gen_int-1)]], new_incidence | ||
function (epi_model::Renewal)(recent_incidence, Rt) | ||
new_incidence = Rt * dot(recent_incidence, epi_model.data.gen_int) | ||
[new_incidence; recent_incidence[1:(epi_model.data.len_gen_int-1)]], new_incidence | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,25 @@ | ||
const STANDARD_RW_PRIORS = | ||
(var_RW_dist = truncated(Normal(0.0, 0.05), 0.0, Inf), init_rw_value_dist = Normal()) | ||
|
||
|
||
""" | ||
random_walk(n, ϵ_t = missing; latent_process_priors = (var_RW_dist = truncated(Normal(0., 0.05), 0., Inf),), ::Type{T} = Float64) where {T <: Real} | ||
Constructs a random walk model. | ||
# Arguments | ||
- `n`: The number of time steps. | ||
- `ϵ_t`: The random noise vector. Defaults to `missing`, in which case it is sampled from the standard multivariate normal distribution. | ||
- `latent_process_priors`: The prior distribution for the latent process parameters. Defaults to `(var_RW_dist = truncated(Normal(0., 0.05), 0., Inf),)`. | ||
function default_rw_priors() | ||
return ( | ||
var_RW_dist = truncated(Normal(0.0, 0.05), 0.0, Inf), | ||
init_rw_value_dist = Normal(), | ||
) | ||
end | ||
|
||
# Returns | ||
- `rw`: The random walk process. | ||
- `σ_RW`: The standard deviation of the random walk process. | ||
""" | ||
@model function random_walk( | ||
n, | ||
ϵ_t = missing, | ||
::Type{T} = Float64; | ||
latent_process_priors = STANDARD_RW_PRIORS, | ||
) where {T<:Real} | ||
latent_process_priors = default_rw_priors(), | ||
) where {T<:AbstractFloat} | ||
rw = Vector{T}(undef, n) | ||
ϵ_t ~ MvNormal(ones(n)) | ||
σ²_RW ~ latent_process_priors.var_RW_dist | ||
init_rw_value ~ latent_process_priors.init_rw_value_dist | ||
σ_RW = sqrt(σ²_RW) | ||
rw .= init_rw_value .+ cumsum(σ_RW * ϵ_t) | ||
return rw, (; σ_RW, init_rw_value) | ||
|
||
rw[1] = init_rw_value + σ_RW * ϵ_t[1] | ||
for t = 2:n | ||
rw[t] = rw[t-1] + σ_RW * ϵ_t[t] | ||
end | ||
return rw, (; σ_RW, init_rw_value, init = rw[1]) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.