-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
489 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
## Script to make figure 1 | ||
using Pkg | ||
Pkg.activate(joinpath(@__DIR__(), "..")) | ||
|
||
using EpiAwarePipeline, EpiAware, AlgebraOfGraphics, JLD2, DrWatson, Plots, DataFramesMeta, | ||
Statistics, Distributions, CSV | ||
|
||
## | ||
pipelines = [ | ||
SmoothOutbreakPipeline(), MeasuresOutbreakPipeline(), | ||
SmoothEndemicPipeline(), RoughEndemicPipeline()] | ||
|
||
## load some data and create a dataframe for the plot | ||
truth_data_files = readdir(datadir("truth_data")) |> | ||
strs -> filter(s -> occursin("jld2", s), strs) | ||
analysis_df = CSV.File(plotsdir("analysis_df.csv")) |> DataFrame | ||
truth_df = mapreduce(vcat, truth_data_files) do filename | ||
D = load(joinpath(datadir("truth_data"), filename)) | ||
make_truthdata_dataframe(filename, D, pipelines) | ||
end | ||
|
||
## Make mainfigure plots | ||
|
||
# Define scenario titles and reference times for figure 1 | ||
scenario_dict = Dict( | ||
"measures_outbreak" => (title = "Outbreak with measures", T = 28), | ||
"smooth_outbreak" => (title = "Outbreak no measures", T = 35), | ||
"smooth_endemic" => (title = "Smooth endemic", T = 35), | ||
"rough_endemic" => (title = "Rough endemic", T = 35) | ||
) | ||
|
||
fig1 = figureone(truth_df, analysis_df, scenario_dict) | ||
|
||
## Save the figure | ||
save(plotsdir("figure1.png"), fig1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
using Pkg | ||
Pkg.activate(joinpath(@__DIR__(), "..")) | ||
|
||
using EpiAwarePipeline, EpiAware, AlgebraOfGraphics, JLD2, DrWatson, Plots, DataFramesMeta, | ||
Statistics, Distributions, DrWatson | ||
|
||
## load some data and create a dataframe for the plot | ||
files = readdir(datadir("epiaware_observables")) |> | ||
strs -> filter(s -> occursin("jld2", s), strs) | ||
|
||
## Define scenarios | ||
pipelines = [ | ||
SmoothOutbreakPipeline(), MeasuresOutbreakPipeline(), | ||
SmoothEndemicPipeline(), RoughEndemicPipeline()] | ||
|
||
## Set up EpiData objects: Used in the prediction dataframe for infection generating | ||
## processes that don't use directly in simulation. | ||
gi_params = make_gi_params(pipelines[1]) | ||
epi_datas = map(gi_params["gi_means"]) do μ | ||
σ = gi_params["gi_stds"][1] | ||
shape = (μ / σ)^2 | ||
scale = σ^2 / μ | ||
Gamma(shape, scale) | ||
end .|> gen_dist -> EpiData(gen_distribution = gen_dist) | ||
|
||
## Calculate the prediction dataframe | ||
prediction_df = mapreduce(vcat, files) do filename | ||
output = load(joinpath(datadir("epiaware_observables"), filename)) | ||
make_prediction_dataframe_from_output(filename, output, epi_datas, pipelines) | ||
end | ||
|
||
## Save the prediction dataframe | ||
CSV.write(plotsdir("analysis_df.csv"), prediction_df) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
include("make_truthdata_dataframe.jl") | ||
include("make_prediction_dataframe_from_output.jl") |
64 changes: 64 additions & 0 deletions
64
pipeline/src/analysis/make_prediction_dataframe_from_output.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
""" | ||
Create a dataframe containing prediction results based on the given output and input data. | ||
# Arguments | ||
- `filename`: The name of the file. | ||
- `output`: The output data containing inference configuration, IGP model, and other information. | ||
- `epi_datas`: The input data for the epidemiological model. | ||
- `qs`: An optional array of quantiles to calculate. Default is `[0.025, 0.5, 0.975]`. | ||
# Returns | ||
A dataframe containing the prediction results. | ||
""" | ||
function make_prediction_dataframe_from_output( | ||
filename, output, epi_datas, pipelines; qs = [0.025, 0.5, 0.975]) | ||
#Get the scenario, IGP model, latent model and true mean GI | ||
inference_config = output["inference_config"] | ||
igp_model = output["inference_config"].igp |> string | ||
scenario = EpiAwarePipeline._get_scenario_from_filename(filename, pipelines) | ||
latent_model = EpiAwarePipeline._get_latent_model_from_filename(filename) | ||
true_mean_gi = EpiAwarePipeline._get_true_gi_mean_from_filename(filename) | ||
|
||
#Get the quantiles for the targets across the gi mean scenarios | ||
#if Renewal model, then we use the underlying epi model | ||
#otherwise we use the epi datas to loop over different gi mean implications | ||
used_epi_datas = igp_model == "Renewal" ? [output["epiprob"].epi_model.data] : epi_datas | ||
|
||
preds = nothing | ||
try | ||
preds = map(used_epi_datas) do epi_data | ||
generate_quantiles_for_targets(output, epi_data, qs) | ||
end | ||
used_gi_means = igp_model == "Renewal" ? | ||
[EpiAwarePipeline._get_used_gi_mean_from_filename(filename)] : | ||
make_gi_params(EpiAwareExamplePipeline())["gi_means"] | ||
|
||
#Create the dataframe columnwise | ||
df = mapreduce(vcat, preds, used_gi_means) do pred, used_gi_mean | ||
mapreduce(vcat, keys(pred)) do target | ||
target_mat = pred[target] | ||
target_times = collect(1:size(target_mat, 1)) .+ | ||
(inference_config.tspan[1] - 1) | ||
_df = DataFrame(target_times = target_times) | ||
_df[!, "Scenario"] .= scenario | ||
_df[!, "IGP_Model"] .= igp_model | ||
_df[!, "Latent_Model"] .= latent_model | ||
_df[!, "True_GI_Mean"] .= true_mean_gi | ||
_df[!, "Used_GI_Mean"] .= used_gi_mean | ||
_df[!, "Reference_Time"] .= inference_config.tspan[2] | ||
_df[!, "Target"] .= string(target) | ||
# quantile predictions | ||
for (j, q) in enumerate(qs) | ||
q_str = split(string(q), ".")[end] | ||
_df[!, "q_$(q_str)"] = target_mat[:, j] | ||
end | ||
return _df | ||
end | ||
end | ||
return df | ||
catch | ||
@warn "Error in generating quantiles for targets in file $filename" | ||
return nothing | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
|
||
""" | ||
make_truthdata_dataframe(filename, truth_data, pipelines; I_0 = 100.0) | ||
Create a DataFrame containing truth data for analysis. | ||
# Arguments | ||
- `filename::String`: The name of the file. | ||
- `truth_data::Dict`: A dictionary containing truth data. | ||
- `pipelines::Array`: An array of pipelines. | ||
- `I_0::Float64`: Initial value for I_t (default: 100.0). | ||
# Returns | ||
- `df::DataFrame`: A DataFrame containing the truth data. | ||
""" | ||
function make_truthdata_dataframe(filename, truth_data, pipelines; I_0 = 100.0) | ||
I_t = truth_data["I_t"] | ||
true_mean_gi = truth_data["truth_gi_mean"] | ||
log_It = _calc_log_infections(I_t) | ||
rt = _calc_rt(I_t, I_0) | ||
scenario = _get_scenario_from_filename(filename, pipelines) | ||
truth_procs = (; log_I_t = log_It, rt, Rt = truth_data["truth_process"]) | ||
|
||
df = mapreduce(vcat, keys(truth_procs)) do target | ||
proc = truth_procs[target] | ||
_df = DataFrame( | ||
target_times = 1:length(proc), | ||
target_values = proc | ||
) | ||
_df[!, "Scenario"] .= scenario | ||
_df[!, "True_GI_Mean"] .= true_mean_gi | ||
_df[!, "Target"] .= string(target) | ||
return _df | ||
end | ||
|
||
return df | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
""" | ||
Internal method to check if the required columns are present in the truth dataframe. | ||
# Arguments | ||
- `truth_df`: The truth dataframe to be checked. | ||
""" | ||
function _figure_one_truth_dataframe_checks(truth_df) | ||
@assert "True_GI_Mean" ∈ names(truth_df) "True_GI_Mean col not in truth data" | ||
@assert "Scenario" ∈ names(truth_df) "Scenario col not in truth data" | ||
@assert "target_times" ∈ names(truth_df) "target_times col not in truth data" | ||
@assert "target_values" ∈ names(truth_df) "target_values col not in truth data" | ||
end | ||
|
||
""" | ||
Internal method to perform checks on the analysis dataframe to ensure that it contains the required columns. | ||
# Arguments | ||
- `analysis_df`: The analysis dataframe to be checked. | ||
# Raises | ||
- `AssertionError`: If any of the required columns are missing in the analysis dataframe. | ||
""" | ||
function _figure_one_analysis_dataframe_checks(analysis_df) | ||
@assert "True_GI_Mean" ∈ names(analysis_df) "True_GI_Mean col not in analysis data" | ||
@assert "Used_GI_Mean" ∈ names(analysis_df) "Used_GI_Mean col not in analysis data" | ||
@assert "Reference_Time" ∈ names(analysis_df) "Reference_Time col not in analysis data" | ||
@assert "Scenario" ∈ names(analysis_df) "Scenario col not in analysis data" | ||
@assert "IGP_Model" ∈ names(analysis_df) "IGP_Model col not in analysis data" | ||
@assert "Latent_Model" ∈ names(analysis_df) "Latent_Model col not in analysis data" | ||
@assert "target_times" ∈ names(analysis_df) "target_times col not in analysis data" | ||
end | ||
|
||
""" | ||
Internal method to perform checks on the truth and analysis dataframes for Figure One. | ||
# Arguments | ||
- `truth_df::DataFrame`: The truth dataframe. | ||
- `analysis_df::DataFrame`: The analysis dataframe. | ||
- `scenario_dict::Dict{String, Any}`: A dictionary containing scenario information. | ||
# Raises | ||
- `AssertionError`: If the scenarios in the truth and analysis dataframes do not match, or if the scenarios in the truth dataframe do not match the keys in the scenario dictionary. | ||
""" | ||
function _figure_one_dataframe_checks(truth_df, analysis_df, scenario_dict) | ||
@assert issetequal(unique(truth_df.Scenario), unique(analysis_df.Scenario)) "Truth and analysis data scenarios do not match" | ||
@assert issetequal(unique(truth_df.Scenario), keys(scenario_dict)) "Truth and analysis data True_GI_Mean do not match" | ||
_figure_one_truth_dataframe_checks(truth_df) | ||
_figure_one_analysis_dataframe_checks(analysis_df) | ||
end | ||
|
||
""" | ||
Internal method for creating a figure of model inference for a specific scenario | ||
using the given analysis data. | ||
# Arguments | ||
- `analysis_df`: The analysis data frame. | ||
- `scenario`: The scenario to plot. | ||
- `reference_time`: The reference time. | ||
- `true_gi_choice`: The true GI choice. | ||
- `used_gi_choice`: The used GI choice. | ||
- `lower_sym`: The symbol for the lower quantile (default is `:q_025`). | ||
- `upper_sym`: The symbol for the upper quantile (default is `:q_975`). | ||
# Returns | ||
- `plt_model`: The plot object. | ||
""" | ||
function _figure_one_scenario(analysis_df, scenario; reference_time, true_gi_choice, | ||
used_gi_choice, lower_sym = :q_025, upper_sym = :q_975) | ||
model_plotting_data = analysis_df |> | ||
df -> @subset(df, :True_GI_Mean.==true_gi_choice) |> | ||
df -> @subset(df, :Used_GI_Mean.==used_gi_choice) |> | ||
df -> @subset(df, :Reference_Time.==reference_time) |> | ||
df -> @subset(df, :Scenario.==scenario) |> | ||
data | ||
|
||
plt_model = model_plotting_data * | ||
mapping(:target_times => "T", :q_5 => "Process values", | ||
col = :Target, row = :IGP_Model => "IGP model", | ||
color = :Latent_Model => "Latent model") * | ||
mapping(lower = lower_sym, upper = upper_sym) * visual(LinesFill) | ||
|
||
return plt_model | ||
end | ||
|
||
""" | ||
Internal method that generates a plot of the truth data for a specific scenario. | ||
## Arguments | ||
- `truth_df`: The truth data DataFrame. | ||
- `scenario`: The scenario for which the truth data should be plotted. | ||
- `true_gi_choice`: The choice of true GI mean. | ||
## Returns | ||
- `plt_truth`: The plot of the truth data. | ||
""" | ||
function _figure_one_scenario_truth_data(truth_df, scenario; true_gi_choice) | ||
truth_plotting_data = truth_df |> | ||
df -> @subset(df, :True_GI_Mean.==true_gi_choice) |> | ||
df -> @subset(df, :Scenario.==scenario) |> data | ||
plt_truth = truth_plotting_data * | ||
mapping(:target_times => "T", :target_values => "values", | ||
col = :Target, color = :Latent_Model) * | ||
visual(Lines) | ||
return plt_truth | ||
end | ||
|
||
""" | ||
Generate figure 1 showing the analysis and truth data for different scenarios. | ||
## Arguments | ||
- `truth_df`: DataFrame containing the truth data. | ||
- `analysis_df`: DataFrame containing the analysis data. | ||
- `scenario_dict`: Dictionary containing information about the scenarios. | ||
## Keyword Arguments | ||
- `fig_kws`: Keyword arguments for the Figure object. Default is `(; size = (1000, 2000))`. | ||
- `true_gi_choice`: Value for the true generation interval choice. Default is `10.0`. | ||
- `used_gi_choice`: Value for the used generation interval choice. Default is `10.0`. | ||
- `legend_title`: Title for the legend. Default is `"Process type"`. | ||
## Returns | ||
- `fig`: Figure object containing the generated figure. | ||
""" | ||
function figureone( | ||
truth_df, analysis_df, scenario_dict; fig_kws = (; size = (1000, 2000)), | ||
true_gi_choice = 10.0, used_gi_choice = 10.0, legend_title = "Process type") | ||
# Perform checks on the dataframes | ||
_figure_one_dataframe_checks(truth_df, analysis_df, scenario_dict) | ||
# Treat the truth data as a Latent model option | ||
truth_df[!, "Latent_Model"] .= "Truth data" | ||
|
||
scenarios = analysis_df.Scenario |> unique | ||
plt_truth_vect = map(scenarios) do scenario | ||
_figure_one_scenario_truth_data(truth_df, scenario; true_gi_choice) | ||
end | ||
plt_analysis_vect = map(scenarios) do scenario | ||
_figure_one_scenario( | ||
analysis_df, scenario; reference_time = scenario_dict[scenario].T, | ||
true_gi_choice, used_gi_choice) | ||
end | ||
|
||
fig = Figure(; fig_kws...) | ||
leg = nothing | ||
for (i, scenario) in enumerate(scenarios) | ||
sf = fig[i, :] | ||
ag = draw!( | ||
sf, plt_analysis_vect[i] + plt_truth_vect[i], facet = (; linkyaxes = :none)) | ||
leg = AlgebraOfGraphics.compute_legend(ag) | ||
Label(sf[0, :], scenario_dict[scenario].title, fontsize = 24, font = :bold) | ||
end | ||
|
||
Label(fig[:, 0], "Process values", fontsize = 28, font = :bold, rotation = pi / 2) | ||
Label(fig[:, 2], "Infection generating process", | ||
fontsize = 24, font = :bold, rotation = -pi / 2) | ||
_leg = (leg[1], leg[2], [legend_title]) | ||
Legend(fig[:, 3], _leg...) | ||
|
||
return fig | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
include("figureone.jl") |
Oops, something went wrong.