Skip to content

Commit

Permalink
Merge pull request #39 from CDCgov/38-modular-inference
Browse files Browse the repository at this point in the history
Implement a proposed modular API to specify a data (i.e. case) generating process
  • Loading branch information
seabbs authored Feb 15, 2024
2 parents e786bfb + d8eeded commit 22a5f5d
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 201 deletions.
11 changes: 8 additions & 3 deletions EpiAware/src/EpiAware.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,18 @@ using Distributions,
Parameters,
QuadGK

# Exported utilities
export scan,
create_discrete_pmf,
growth_rate_to_reproductive_ratio,
generate_observation_kernel,
EpiModel,
log_infections,
random_walk
default_rw_priors

# Exported types
export EpiData, Renewal, ExpGrowthRate, DirectInfections

# Exported Turing model constructors
export make_epi_inference_model, random_walk

include("utilities.jl")
include("epimodel.jl")
Expand Down
91 changes: 39 additions & 52 deletions EpiAware/src/epimodel.jl
Original file line number Diff line number Diff line change
@@ -1,96 +1,83 @@
abstract type AbstractEpiModel end



"""
struct EpiModel{T<:Real} <: AbstractEpiModel
EpiModel represents an epidemiological model with generation intervals, delay intervals, and observation delay kernel.
# Fields
- `gen_int::Vector{T}`: Discrete generation inteval, runs from 1, 2, ... to the end of the vector.
- `delay_int::Vector{T}`: Discrete delay distribution runs from 0, 1, ... to the end of the vector less 1.
- `delay_kernel::SparseMatrixCSC{T,Integer}`: Sparse matrix representing the observation delay kernel.
- `cluster_coeff::T`: Cluster coefficient for negative binomial observations.
- `len_gen_int::Integer`: Length of `gen_int`.
- `len_delay_int::Integer`: Length of `delay_int`.
- `time_horizon::Integer`: Length of the generated data.
# Constructors
- `EpiModel(gen_int, delay_int, cluster_coeff, time_horizon::Integer)`: Constructs an EpiModel object with given generation intervals, delay intervals, cluster coefficient, and time horizon.
- `EpiModel(gen_distribution::ContinuousDistribution, delay_distribution::ContinuousDistribution, cluster_coeff, time_horizon::Integer; Δd = 1.0, D_gen, D_delay)`: Constructs an EpiModel object with generation and delay distributions, cluster coefficient, time horizon, and optional parameters.
"""
struct EpiModel{T<:Real} <: AbstractEpiModel
struct EpiData{T<:Real,F<:Function}
gen_int::Vector{T}
delay_int::Vector{T}
delay_kernel::SparseMatrixCSC{T,Integer}
cluster_coeff::T
len_gen_int::Integer #length(gen_int) just to save recalc
len_delay_int::Integer #length(delay_int) just to save recalc
len_gen_int::Integer
len_delay_int::Integer
time_horizon::Integer
transformation::F

#Inner constructors for EpiModel object
function EpiModel(gen_int, delay_int, cluster_coeff, time_horizon::Integer)
#Inner constructors for EpiData object
function EpiData(
gen_int,
delay_int,
cluster_coeff,
time_horizon::Integer,
transformation::Function,
)
@assert all(gen_int .>= 0) "Generation interval must be non-negative"
@assert all(delay_int .>= 0) "Delay interval must be non-negative"
@assert sum(gen_int) 1 "Generation interval must sum to 1"
@assert sum(delay_int) 1 "Delay interval must sum to 1"

K = generate_observation_kernel(delay_int, time_horizon)

new{eltype(gen_int)}(
new{eltype(gen_int),typeof(transformation)}(
gen_int,
delay_int,
K,
cluster_coeff,
length(gen_int),
length(delay_int),
time_horizon,
transformation,
)
end

function EpiModel(
function EpiData(
gen_distribution::ContinuousDistribution,
delay_distribution::ContinuousDistribution,
cluster_coeff,
time_horizon::Integer;
Δd = 1.0,
D_gen,
D_delay,
Δd = 1.0,
transformation::Function = exp,
)
gen_int =
create_discrete_pmf(gen_distribution, Δd = Δd, D = D_gen) |>
p -> p[2:end] ./ sum(p[2:end])
delay_int = create_discrete_pmf(delay_distribution, Δd = Δd, D = D_delay)

K = generate_observation_kernel(delay_int, time_horizon)

new{eltype(gen_int)}(
gen_int,
delay_int,
K,
cluster_coeff,
length(gen_int),
length(delay_int),
time_horizon,
)
return EpiData(gen_int, delay_int, cluster_coeff, time_horizon, transformation)
end
end

"""
(epi_model::EpiModel)(recent_incidence, Rt)
struct DirectInfections <: AbstractEpiModel
data::EpiData
end

Apply the EpiModel to calculate new incidence based on recent incidence and Rt.
function (epi_model::DirectInfections)(recent_incidence, unc_I_t)
nothing, epi_model.data.transformation(unc_I_t)
end

struct ExpGrowthRate <: AbstractEpiModel
data::EpiData
end

# Arguments
- `recent_incidence`: Array of recent incidence values.
- `Rt`: Reproduction number.
function (epi_model::ExpGrowthRate)(unc_recent_incidence, rt)
new_unc_recent_incidence = unc_recent_incidence + rt
new_unc_recent_incidence, epi_model.data.transformation(new_unc_recent_incidence)
end

struct Renewal <: AbstractEpiModel
data::EpiData
end

# Returns
- `new_incidence`: Array of new incidence values.
"""
function (epi_model::EpiModel)(recent_incidence, Rt)
new_incidence = Rt * dot(recent_incidence, epi_model.gen_int)
[new_incidence; recent_incidence[1:(epi_model.len_gen_int-1)]], new_incidence
function (epi_model::Renewal)(recent_incidence, Rt)
new_incidence = Rt * dot(recent_incidence, epi_model.data.gen_int)
[new_incidence; recent_incidence[1:(epi_model.data.len_gen_int-1)]], new_incidence
end
35 changes: 14 additions & 21 deletions EpiAware/src/latent-processes.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,25 @@
const STANDARD_RW_PRIORS =
(var_RW_dist = truncated(Normal(0.0, 0.05), 0.0, Inf), init_rw_value_dist = Normal())


"""
random_walk(n, ϵ_t = missing; latent_process_priors = (var_RW_dist = truncated(Normal(0., 0.05), 0., Inf),), ::Type{T} = Float64) where {T <: Real}
Constructs a random walk model.
# Arguments
- `n`: The number of time steps.
- `ϵ_t`: The random noise vector. Defaults to `missing`, in which case it is sampled from the standard multivariate normal distribution.
- `latent_process_priors`: The prior distribution for the latent process parameters. Defaults to `(var_RW_dist = truncated(Normal(0., 0.05), 0., Inf),)`.
function default_rw_priors()
return (
var_RW_dist = truncated(Normal(0.0, 0.05), 0.0, Inf),
init_rw_value_dist = Normal(),
)
end

# Returns
- `rw`: The random walk process.
- `σ_RW`: The standard deviation of the random walk process.
"""
@model function random_walk(
n,
ϵ_t = missing,
::Type{T} = Float64;
latent_process_priors = STANDARD_RW_PRIORS,
) where {T<:Real}
latent_process_priors = default_rw_priors(),
) where {T<:AbstractFloat}
rw = Vector{T}(undef, n)
ϵ_t ~ MvNormal(ones(n))
σ²_RW ~ latent_process_priors.var_RW_dist
init_rw_value ~ latent_process_priors.init_rw_value_dist
σ_RW = sqrt(σ²_RW)
rw .= init_rw_value .+ cumsum(σ_RW * ϵ_t)
return rw, (; σ_RW, init_rw_value)

rw[1] = init_rw_value + σ_RW * ϵ_t[1]
for t = 2:n
rw[t] = rw[t-1] + σ_RW * ϵ_t[t]
end
return rw, (; σ_RW, init_rw_value, init = rw[1])
end
43 changes: 6 additions & 37 deletions EpiAware/src/models.jl
Original file line number Diff line number Diff line change
@@ -1,39 +1,8 @@
"""
log_infections(y_t, epimodel::EpiModel, latent_process;
latent_process_priors,
transform_function = exp,
n_generate_ahead = 0,
pos_shift = 1e-6,
neg_bin_cluster_factor = missing,
neg_bin_cluster_factor_prior = Gamma(3, 0.05 / 3))
A Turing model for Log-infections undelying observed epidemiological data.
This function defines a log-infections model for epidemiological data.
It takes the observed data `y_t`, an `EpiModel` object `epimodel`, and a `latent_process`
model. It also accepts optional arguments for the `latent_process_priors`, `transform_function`,
`n_generate_ahead`, `pos_shift`, `neg_bin_cluster_factor`, and `neg_bin_cluster_factor_prior`.
## Arguments
- `y_t`: Observed data.
- `epimodel`: Epidemiological model.
- `latent_process`: Latent process model.
- `latent_process_priors`: Priors for the latent process model.
- `transform_function`: Function to transform the latent process into infections. Default is `exp`.
- `n_generate_ahead`: Number of time steps to generate ahead. Default is `0`.
- `pos_shift`: Positive shift to avoid zero values. Default is `1e-6`.
- `neg_bin_cluster_factor`: Missing value for the negative binomial cluster factor. Default is `missing`.
- `neg_bin_cluster_factor_prior`: Prior distribution for the negative binomial cluster factor. Default is `Gamma(3, 0.05 / 3)`.
## Returns
A named tuple containing the generated quantities `I_t` and `latent_process_parameters`.
"""
@model function log_infections(
@model function make_epi_inference_model(
y_t,
epimodel::EpiModel,
epimodel::AbstractEpiModel,
latent_process;
latent_process_priors,
transform_function = exp,
pos_shift = 1e-6,
neg_bin_cluster_factor = missing,
neg_bin_cluster_factor_prior = Gamma(3, 0.05 / 3),
Expand All @@ -42,16 +11,16 @@ A named tuple containing the generated quantities `I_t` and `latent_process_para
neg_bin_cluster_factor ~ neg_bin_cluster_factor_prior

#Latent process
time_steps = epimodel.time_horizon
@submodel _I_t, latent_process_parameters =
time_steps = epimodel.data.time_horizon
@submodel latent_process, latent_process_parameters =
latent_process(time_steps; latent_process_priors = latent_process_priors)

#Transform into infections
I_t = transform_function.(_I_t)
I_t, _ = scan(epimodel, latent_process_parameters.init, latent_process)

#Predictive distribution
case_pred_dists =
(epimodel.delay_kernel * I_t) .+ pos_shift .|>
(epimodel.data.delay_kernel * I_t) .+ pos_shift .|>
μ -> mean_cc_neg_bin(μ, neg_bin_cluster_factor)

#Likelihood
Expand Down
59 changes: 35 additions & 24 deletions EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using Pkg: generate
#=
# Toy model for running analysis:
Expand Down Expand Up @@ -47,18 +46,16 @@ r &\sim \text{Gamma}(3, 0.05/3).
\end{align}
```
## Load dependencies in `TestEnv`
## Load dependencies
This script should be run from the root folder of `EpiAware` and with the active environment.
=#

split(pwd(), "/")[end] != "EpiAware" && begin
cd("./EpiAware")
using Pkg
Pkg.activate(".")

using TestEnv
TestEnv.activate()
end

using TestEnv # Run in Test environment mode
TestEnv.activate()

using EpiAware
using Turing
Expand All @@ -70,16 +67,18 @@ Random.seed!(0)

#=
## Create an `EpiModel` struct
Somewhat randomly chosen parameters for the `EpiModel` struct.
- Medium length generation interval distribution.
- Median 2 day, std 4.3 day delay distribution.
- 100 days of simulations
=#

truth_GI = Gamma(1, 2)
truth_delay = Uniform(0.0, 5.0)
truth_GI = Gamma(2, 5)
truth_delay = LogNormal(2.0, 1.0)
neg_bin_cluster_factor = 0.05
time_horizon = 100

epimodel = EpiModel(
model_data = EpiData(
truth_GI,
truth_delay,
neg_bin_cluster_factor,
Expand All @@ -89,29 +88,41 @@ epimodel = EpiModel(
)

#=
## Define a log-infections model
The log-infections model is defined by a Turing model `log_infections`.
## Define the data generating process
In this case we don't have observed data, so we use `missing` value for `y_t`.
In this case we use the `DirectInfections` model.
=#
toy_log_infs = log_infections(

toy_log_infs = DirectInfections(model_data)

#=
## Generate a `Turing` `Model`
We don't have observed data, so we use `missing` value for `y_t`.
=#

log_infs_model = make_epi_inference_model(
missing,
epimodel,
random_walk;
latent_process_priors = EpiAware.STANDARD_RW_PRIORS,
toy_log_infs,
random_walk,
latent_process_priors = default_rw_priors(),
pos_shift = 1e-6,
neg_bin_cluster_factor = 0.5,
neg_bin_cluster_factor_prior = Gamma(3, 0.05 / 3),
)



#=
## Sample from the model
I define a fixed version of the model with initial infections set to 10 and variance of the random walk process set to 0.1.
We can sample from the model using the `rand` function, and plot the generated infections against generated cases.
=#
cond_toy = fix(toy_log_infs, (init_rw_value = log(10.0), σ²_RW = 0.1))
random_epidemic = rand(cond_toy)

# We can get the generated infections using `generated_quantities` function. Because the observed
# cases are "defined" with a `~` operator they can be accessed directly from the randomly sampled
# process.

cond_toy = fix(log_infs_model, (init_rw_value = log(10.0), σ²_RW = 0.1))
random_epidemic = rand(cond_toy)
gen = generated_quantities(cond_toy, random_epidemic)
plot(
gen.I_t,
Expand All @@ -120,4 +131,4 @@ plot(
ylabel = "Infections",
title = "Generated Infections",
)
scatter!(X.y_t, lab = "generated cases")
scatter!(random_epidemic.y_t, lab = "generated cases")
Loading

0 comments on commit 22a5f5d

Please sign in to comment.