Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the prior interface for make epi model inference #61

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions EpiAware/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
- Solid lines indicate implemented features/analysis.
- Dashed lines indicate planned features/analysis.

## Proposed `EpiAware` model diagram
```mermaid
flowchart TD
flowchart LR

A["Underlying dists.
and specify length of sims
Expand All @@ -29,24 +30,40 @@ C["Observational Data
Obs. cases y_t"]
D["Latent processes
---------------------
Random Walk"]
E[Turing model constructor]
F["Latent Process priors"]
random_walk"]
E["Turing model constructor
---------------------
make_epi_inference_model"]
F["Latent Process priors
---------------------
default_rw_priors"]
G[Posterior draws]
H[Posterior checking]
I[Post-processing]
DataW[Data wrangling and QC]
J["Observation Model
J["Observation models
---------------------
delay_observations"]
K["Observation model priors
---------------------
default_delay_obs_priors"]
ObservationModel["ObservationModel
---------------------
delay_observations_model"]
LatentProcess["LatentProcess
---------------------
random_walk_process"]

A --> EpiModel
B --> EpiModel
EpiModel -->E
C-->E
D-->|random_walk| E
J-->E
F-->|default_rw_priors|E
D-->LatentProcess
F-->LatentProcess
J-->ObservationModel
K-->ObservationModel
LatentProcess-->E
ObservationModel-->E
E-->|sample...NUTS...| G
G-.->H
H-.->I
Expand Down
4 changes: 2 additions & 2 deletions EpiAware/src/EpiAware.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ export create_discrete_pmf, default_rw_priors, default_delay_obs_priors, spread_
export EpiData, Renewal, ExpGrowthRate, DirectInfections

# Exported Turing model constructors
export make_epi_inference_model, random_walk, delay_observations
export make_epi_inference_model, delay_observations_model, random_walk_process

include("epimodel.jl")
include("utilities.jl")
include("models.jl")
include("latent-processes.jl")
include("observation-processes.jl")
include("models.jl")

end
31 changes: 31 additions & 0 deletions EpiAware/src/latent-processes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,34 @@ end
end
return rw, init, (; σ_RW,)
end

"""
struct LatentProcess{F<:Function}

A struct representing a latent process with its priors.

# Fields
- `latent_process`: The latent process function for a `Turing` model.
- `latent_process_priors`: NamedTuple containing the priors for the latent process.

"""
struct LatentProcess{F <: Function}
latent_process::F
latent_process_priors::NamedTuple
end

"""
random_walk_process(; latent_process_priors = default_rw_priors())

Create a `LatentProcess` struct reflecting a random walk process with optional priors.

# Arguments
- `latent_process_priors`: Optional priors for the random walk process.

# Returns
- `LatentProcess`: A random walk process.

"""
function random_walk_process(; latent_process_priors = default_rw_priors())
LatentProcess(random_walk, latent_process_priors)
end
15 changes: 8 additions & 7 deletions EpiAware/src/models.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
@model function make_epi_inference_model(
y_t,
epimodel::AbstractEpiModel,
latent_process,
observation_process;
process_priors,
latent_process_obj::LatentProcess,
observation_process_obj::ObservationModel;
pos_shift = 1e-6
)
#Latent process
time_steps = epimodel.data.time_horizon
@submodel latent_process, init, latent_process_aux = latent_process(
time_steps; latent_process_priors = process_priors)
@submodel latent_process, init, latent_process_aux = latent_process_obj.latent_process(
time_steps;
latent_process_priors = latent_process_obj.latent_process_priors
)

#Transform into infections
I_t = epimodel(latent_process, init)

#Predictive distribution of ascerted cases
@submodel generated_y_t, generated_y_t_aux = observation_process(
@submodel generated_y_t, generated_y_t_aux = observation_process_obj.observation_model(
y_t,
I_t,
epimodel::AbstractEpiModel;
observation_process_priors = process_priors,
observation_process_priors = observation_process_obj.observation_model_priors,
pos_shift = pos_shift
)

Expand Down
31 changes: 31 additions & 0 deletions EpiAware/src/observation-processes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,34 @@ end

return y_t, (; neg_bin_cluster_factor,)
end

"""
struct ObservationModel{F<:Function}

A struct representing an observation model with its priors.

# Fields
- `observation_model`: The observation model function for a `Turing` model.
- `observation_model_priors`: NamedTuple containing the priors for the observation model.

"""
struct ObservationModel{F <: Function}
observation_model::F
observation_model_priors::NamedTuple
end

"""
delay_observations_model(; latent_process_priors = default_rw_priors())

Create an `ObservationModel` struct reflecting a delayed observation process with optional priors.

# Arguments
- `latent_process_priors`: Optional priors for the delayed observation process.

# Returns
- `ObservationModel`: An observation model with delayed observations.

"""
function delay_observations_model(; observation_model_priors = default_delay_obs_priors())
ObservationModel(delay_observations, observation_model_priors)
end
19 changes: 4 additions & 15 deletions EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,20 +101,16 @@ In this case we use the `DirectInfections` model.
=#

toy_log_infs = DirectInfections(model_data)
rwp = random_walk_process()
obs_mdl = delay_observations_model()

#=
## Generate a `Turing` `Model`
We don't have observed data, so we use `missing` value for `y_t`.
=#

log_infs_model = make_epi_inference_model(
missing,
toy_log_infs,
random_walk,
delay_observations;
process_priors = merge(default_rw_priors(), default_delay_obs_priors()),
pos_shift = 1e-6
)
missing, toy_log_infs, rwp, obs_mdl; pos_shift = 1e-6)

#=
## Sample from the model
Expand Down Expand Up @@ -147,14 +143,7 @@ We treat the generated data as observed data and attempt to infer underlying inf

truth_data = random_epidemic.y_t

model = make_epi_inference_model(
truth_data,
toy_log_infs,
random_walk,
delay_observations;
process_priors = merge(default_rw_priors(), default_delay_obs_priors()),
pos_shift = 1e-6
)
model = make_epi_inference_model(truth_data, toy_log_infs, rwp, obs_mdl; pos_shift = 1e-6)

@time chn = sample(
model,
Expand Down
2 changes: 1 addition & 1 deletion EpiAware/test/test_latent-processes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
@testitem "Testing random_walk against theoretical properties" begin
using DynamicPPL, Turing
n = 5
model = random_walk(n)
model = EpiAware.random_walk(n)
fixed_model = fix(model, (σ²_RW = 1.0, init_rw_value = 0.0)) #Fixing the standard deviation of the random walk process
n_samples = 1000
samples_day_5 = sample(fixed_model, Prior(), n_samples) |>
Expand Down
37 changes: 9 additions & 28 deletions EpiAware/test/test_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,13 @@
# Define test inputs
y_t = missing # Data will be generated from the model
data = EpiData([0.2, 0.3, 0.5], [0.1, 0.4, 0.5], 0.8, 10, exp)
process_priors = merge(default_rw_priors(), default_delay_obs_priors())
pos_shift = 1e-6

epimodel = DirectInfections(data)

rwp = random_walk_process()
obs_mdl = delay_observations_model()
# Call the function
test_mdl = make_epi_inference_model(
y_t,
epimodel,
random_walk,
delay_observations;
process_priors,
pos_shift
)
test_mdl = make_epi_inference_model(y_t, epimodel, rwp, obs_mdl; pos_shift)

# Define expected outputs for a conditional model
# Underlying log-infections are const value 1 for all time steps and
Expand All @@ -38,20 +31,14 @@ end
# Define test inputs
y_t = missing # Data will be generated from the model
data = EpiData([0.2, 0.3, 0.5], [0.1, 0.4, 0.5], 0.8, 10, exp)
process_priors = merge(default_rw_priors(), default_delay_obs_priors())
pos_shift = 1e-6

epimodel = ExpGrowthRate(data)
rwp = random_walk_process()
obs_mdl = delay_observations_model()

# Call the function
test_mdl = make_epi_inference_model(
y_t,
epimodel,
random_walk,
delay_observations;
process_priors,
pos_shift
)
test_mdl = make_epi_inference_model(y_t, epimodel, rwp, obs_mdl; pos_shift)

# Define expected outputs for a conditional model
# Underlying log-infections are const value 1 for all time steps and
Expand All @@ -76,16 +63,10 @@ end
pos_shift = 1e-6

epimodel = Renewal(data)

rwp = random_walk_process()
obs_mdl = delay_observations_model()
# Call the function
test_mdl = make_epi_inference_model(
y_t,
epimodel,
random_walk,
delay_observations;
process_priors,
pos_shift
)
test_mdl = make_epi_inference_model(y_t, epimodel, rwp, obs_mdl; pos_shift)

# Define expected outputs for a conditional model
# Underlying log-infections are const value 1 for all time steps and
Expand Down
2 changes: 1 addition & 1 deletion EpiAware/test/test_observation-processes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
observation_process_priors = default_delay_obs_priors()

# Call the function
mdl = delay_observations(
mdl = EpiAware.delay_observations(
missing,
I_t,
epimodel;
Expand Down
Loading