Skip to content

Commit

Permalink
Merge branch 'main' into issue246
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs authored Jul 16, 2024
2 parents e8d67ef + c01b139 commit df7f8dc
Show file tree
Hide file tree
Showing 14 changed files with 489 additions and 16 deletions.
2 changes: 2 additions & 0 deletions pipeline/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ authors = ["Sam Abbott", "Sam Brand", "Zach Susswein"]
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AlgebraOfGraphics = "cbdf2221-f076-402e-a563-3d30da359d67"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand Down
35 changes: 35 additions & 0 deletions pipeline/scripts/create_figure1.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
## Script to make figure 1
using Pkg
Pkg.activate(joinpath(@__DIR__(), ".."))

using EpiAwarePipeline, EpiAware, AlgebraOfGraphics, JLD2, DrWatson, Plots, DataFramesMeta,
Statistics, Distributions, CSV

##
pipelines = [
SmoothOutbreakPipeline(), MeasuresOutbreakPipeline(),
SmoothEndemicPipeline(), RoughEndemicPipeline()]

## load some data and create a dataframe for the plot
truth_data_files = readdir(datadir("truth_data")) |>
strs -> filter(s -> occursin("jld2", s), strs)
analysis_df = CSV.File(plotsdir("analysis_df.csv")) |> DataFrame
truth_df = mapreduce(vcat, truth_data_files) do filename
D = load(joinpath(datadir("truth_data"), filename))
make_truthdata_dataframe(filename, D, pipelines)
end

## Make mainfigure plots

# Define scenario titles and reference times for figure 1
scenario_dict = Dict(
"measures_outbreak" => (title = "Outbreak with measures", T = 28),
"smooth_outbreak" => (title = "Outbreak no measures", T = 35),
"smooth_endemic" => (title = "Smooth endemic", T = 35),
"rough_endemic" => (title = "Rough endemic", T = 35)
)

fig1 = figureone(truth_df, analysis_df, scenario_dict)

## Save the figure
save(plotsdir("figure1.png"), fig1)
33 changes: 33 additions & 0 deletions pipeline/scripts/create_prediction_dataframe.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using Pkg
Pkg.activate(joinpath(@__DIR__(), ".."))

using EpiAwarePipeline, EpiAware, AlgebraOfGraphics, JLD2, DrWatson, Plots, DataFramesMeta,
Statistics, Distributions, DrWatson

## load some data and create a dataframe for the plot
files = readdir(datadir("epiaware_observables")) |>
strs -> filter(s -> occursin("jld2", s), strs)

## Define scenarios
pipelines = [
SmoothOutbreakPipeline(), MeasuresOutbreakPipeline(),
SmoothEndemicPipeline(), RoughEndemicPipeline()]

## Set up EpiData objects: Used in the prediction dataframe for infection generating
## processes that don't use directly in simulation.
gi_params = make_gi_params(pipelines[1])
epi_datas = map(gi_params["gi_means"]) do μ
σ = gi_params["gi_stds"][1]
shape =/ σ)^2
scale = σ^2 / μ
Gamma(shape, scale)
end .|> gen_dist -> EpiData(gen_distribution = gen_dist)

## Calculate the prediction dataframe
prediction_df = mapreduce(vcat, files) do filename
output = load(joinpath(datadir("epiaware_observables"), filename))
make_prediction_dataframe_from_output(filename, output, epi_datas, pipelines)
end

## Save the prediction dataframe
CSV.write(plotsdir("analysis_df.csv"), prediction_df)
14 changes: 12 additions & 2 deletions pipeline/src/EpiAwarePipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@ module EpiAwarePipeline

using CSV, Dagger, DataFramesMeta, Dates, Distributions, DocStringExtensions, DrWatson,
EpiAware, Plots, Statistics, ADTypes, AbstractMCMC, Plots, JLD2, MCMCChains, Turing,
DynamicPPL, LogExpFunctions, RCall, LinearAlgebra, Random
DynamicPPL, LogExpFunctions, RCall, LinearAlgebra, Random, AlgebraOfGraphics,
CairoMakie

# Exported pipeline types
export AbstractEpiAwarePipeline, EpiAwarePipeline, AbstractRtwithoutRenewalPipeline,
RtwithoutRenewalPriorPipeline, EpiAwareExamplePipeline, SmoothOutbreakPipeline,
MeasuresOutbreakPipeline, SmoothEndemicPipeline, RoughEndemicPipeline

# Exported utility functions
export calculate_processes
export calculate_processes, generate_quantiles_for_targets,
timeseries_samples_into_quantiles

# Exported configuration types
export TruthSimulationConfig, InferenceConfig
Expand All @@ -46,6 +48,12 @@ export define_forecast_epiprob, generate_forecasts
# Exported functions: scoring functions
export score_parameters

# Exported functions: Analysis functions for constructing dataframes
export make_prediction_dataframe_from_output, make_truthdata_dataframe

# Exported functions: Make main plots
export figureone

# Exported functions: plot functions
export plot_truth_data, plot_Rt

Expand All @@ -57,5 +65,7 @@ include("simulate/simulate.jl")
include("infer/infer.jl")
include("forecast/forecast.jl")
include("scoring/score_parameters.jl")
include("analysis/analysis.jl")
include("mainplots/mainplots.jl")
include("plot_functions.jl")
end
2 changes: 2 additions & 0 deletions pipeline/src/analysis/analysis.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
include("make_truthdata_dataframe.jl")
include("make_prediction_dataframe_from_output.jl")
64 changes: 64 additions & 0 deletions pipeline/src/analysis/make_prediction_dataframe_from_output.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Create a dataframe containing prediction results based on the given output and input data.
# Arguments
- `filename`: The name of the file.
- `output`: The output data containing inference configuration, IGP model, and other information.
- `epi_datas`: The input data for the epidemiological model.
- `qs`: An optional array of quantiles to calculate. Default is `[0.025, 0.5, 0.975]`.
# Returns
A dataframe containing the prediction results.
"""
function make_prediction_dataframe_from_output(
filename, output, epi_datas, pipelines; qs = [0.025, 0.5, 0.975])
#Get the scenario, IGP model, latent model and true mean GI
inference_config = output["inference_config"]
igp_model = output["inference_config"].igp |> string
scenario = EpiAwarePipeline._get_scenario_from_filename(filename, pipelines)
latent_model = EpiAwarePipeline._get_latent_model_from_filename(filename)
true_mean_gi = EpiAwarePipeline._get_true_gi_mean_from_filename(filename)

#Get the quantiles for the targets across the gi mean scenarios
#if Renewal model, then we use the underlying epi model
#otherwise we use the epi datas to loop over different gi mean implications
used_epi_datas = igp_model == "Renewal" ? [output["epiprob"].epi_model.data] : epi_datas

preds = nothing
try
preds = map(used_epi_datas) do epi_data
generate_quantiles_for_targets(output, epi_data, qs)
end
used_gi_means = igp_model == "Renewal" ?
[EpiAwarePipeline._get_used_gi_mean_from_filename(filename)] :
make_gi_params(EpiAwareExamplePipeline())["gi_means"]

#Create the dataframe columnwise
df = mapreduce(vcat, preds, used_gi_means) do pred, used_gi_mean
mapreduce(vcat, keys(pred)) do target
target_mat = pred[target]
target_times = collect(1:size(target_mat, 1)) .+
(inference_config.tspan[1] - 1)
_df = DataFrame(target_times = target_times)
_df[!, "Scenario"] .= scenario
_df[!, "IGP_Model"] .= igp_model
_df[!, "Latent_Model"] .= latent_model
_df[!, "True_GI_Mean"] .= true_mean_gi
_df[!, "Used_GI_Mean"] .= used_gi_mean
_df[!, "Reference_Time"] .= inference_config.tspan[2]
_df[!, "Target"] .= string(target)
# quantile predictions
for (j, q) in enumerate(qs)
q_str = split(string(q), ".")[end]
_df[!, "q_$(q_str)"] = target_mat[:, j]
end
return _df
end
end
return df
catch
@warn "Error in generating quantiles for targets in file $filename"
return nothing
end
end
38 changes: 38 additions & 0 deletions pipeline/src/analysis/make_truthdata_dataframe.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@

"""
make_truthdata_dataframe(filename, truth_data, pipelines; I_0 = 100.0)
Create a DataFrame containing truth data for analysis.
# Arguments
- `filename::String`: The name of the file.
- `truth_data::Dict`: A dictionary containing truth data.
- `pipelines::Array`: An array of pipelines.
- `I_0::Float64`: Initial value for I_t (default: 100.0).
# Returns
- `df::DataFrame`: A DataFrame containing the truth data.
"""
function make_truthdata_dataframe(filename, truth_data, pipelines; I_0 = 100.0)
I_t = truth_data["I_t"]
true_mean_gi = truth_data["truth_gi_mean"]
log_It = _calc_log_infections(I_t)
rt = _calc_rt(I_t, I_0)
scenario = _get_scenario_from_filename(filename, pipelines)
truth_procs = (; log_I_t = log_It, rt, Rt = truth_data["truth_process"])

df = mapreduce(vcat, keys(truth_procs)) do target
proc = truth_procs[target]
_df = DataFrame(
target_times = 1:length(proc),
target_values = proc
)
_df[!, "Scenario"] .= scenario
_df[!, "True_GI_Mean"] .= true_mean_gi
_df[!, "Target"] .= string(target)
return _df
end

return df
end
165 changes: 165 additions & 0 deletions pipeline/src/mainplots/figureone.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""
Internal method to check if the required columns are present in the truth dataframe.
# Arguments
- `truth_df`: The truth dataframe to be checked.
"""
function _figure_one_truth_dataframe_checks(truth_df)
@assert "True_GI_Mean" names(truth_df) "True_GI_Mean col not in truth data"
@assert "Scenario" names(truth_df) "Scenario col not in truth data"
@assert "target_times" names(truth_df) "target_times col not in truth data"
@assert "target_values" names(truth_df) "target_values col not in truth data"
end

"""
Internal method to perform checks on the analysis dataframe to ensure that it contains the required columns.
# Arguments
- `analysis_df`: The analysis dataframe to be checked.
# Raises
- `AssertionError`: If any of the required columns are missing in the analysis dataframe.
"""
function _figure_one_analysis_dataframe_checks(analysis_df)
@assert "True_GI_Mean" names(analysis_df) "True_GI_Mean col not in analysis data"
@assert "Used_GI_Mean" names(analysis_df) "Used_GI_Mean col not in analysis data"
@assert "Reference_Time" names(analysis_df) "Reference_Time col not in analysis data"
@assert "Scenario" names(analysis_df) "Scenario col not in analysis data"
@assert "IGP_Model" names(analysis_df) "IGP_Model col not in analysis data"
@assert "Latent_Model" names(analysis_df) "Latent_Model col not in analysis data"
@assert "target_times" names(analysis_df) "target_times col not in analysis data"
end

"""
Internal method to perform checks on the truth and analysis dataframes for Figure One.
# Arguments
- `truth_df::DataFrame`: The truth dataframe.
- `analysis_df::DataFrame`: The analysis dataframe.
- `scenario_dict::Dict{String, Any}`: A dictionary containing scenario information.
# Raises
- `AssertionError`: If the scenarios in the truth and analysis dataframes do not match, or if the scenarios in the truth dataframe do not match the keys in the scenario dictionary.
"""
function _figure_one_dataframe_checks(truth_df, analysis_df, scenario_dict)
@assert issetequal(unique(truth_df.Scenario), unique(analysis_df.Scenario)) "Truth and analysis data scenarios do not match"
@assert issetequal(unique(truth_df.Scenario), keys(scenario_dict)) "Truth and analysis data True_GI_Mean do not match"
_figure_one_truth_dataframe_checks(truth_df)
_figure_one_analysis_dataframe_checks(analysis_df)
end

"""
Internal method for creating a figure of model inference for a specific scenario
using the given analysis data.
# Arguments
- `analysis_df`: The analysis data frame.
- `scenario`: The scenario to plot.
- `reference_time`: The reference time.
- `true_gi_choice`: The true GI choice.
- `used_gi_choice`: The used GI choice.
- `lower_sym`: The symbol for the lower quantile (default is `:q_025`).
- `upper_sym`: The symbol for the upper quantile (default is `:q_975`).
# Returns
- `plt_model`: The plot object.
"""
function _figure_one_scenario(analysis_df, scenario; reference_time, true_gi_choice,
used_gi_choice, lower_sym = :q_025, upper_sym = :q_975)
model_plotting_data = analysis_df |>
df -> @subset(df, :True_GI_Mean.==true_gi_choice) |>
df -> @subset(df, :Used_GI_Mean.==used_gi_choice) |>
df -> @subset(df, :Reference_Time.==reference_time) |>
df -> @subset(df, :Scenario.==scenario) |>
data

plt_model = model_plotting_data *
mapping(:target_times => "T", :q_5 => "Process values",
col = :Target, row = :IGP_Model => "IGP model",
color = :Latent_Model => "Latent model") *
mapping(lower = lower_sym, upper = upper_sym) * visual(LinesFill)

return plt_model
end

"""
Internal method that generates a plot of the truth data for a specific scenario.
## Arguments
- `truth_df`: The truth data DataFrame.
- `scenario`: The scenario for which the truth data should be plotted.
- `true_gi_choice`: The choice of true GI mean.
## Returns
- `plt_truth`: The plot of the truth data.
"""
function _figure_one_scenario_truth_data(truth_df, scenario; true_gi_choice)
truth_plotting_data = truth_df |>
df -> @subset(df, :True_GI_Mean.==true_gi_choice) |>
df -> @subset(df, :Scenario.==scenario) |> data
plt_truth = truth_plotting_data *
mapping(:target_times => "T", :target_values => "values",
col = :Target, color = :Latent_Model) *
visual(Lines)
return plt_truth
end

"""
Generate figure 1 showing the analysis and truth data for different scenarios.
## Arguments
- `truth_df`: DataFrame containing the truth data.
- `analysis_df`: DataFrame containing the analysis data.
- `scenario_dict`: Dictionary containing information about the scenarios.
## Keyword Arguments
- `fig_kws`: Keyword arguments for the Figure object. Default is `(; size = (1000, 2000))`.
- `true_gi_choice`: Value for the true generation interval choice. Default is `10.0`.
- `used_gi_choice`: Value for the used generation interval choice. Default is `10.0`.
- `legend_title`: Title for the legend. Default is `"Process type"`.
## Returns
- `fig`: Figure object containing the generated figure.
"""
function figureone(
truth_df, analysis_df, scenario_dict; fig_kws = (; size = (1000, 2000)),
true_gi_choice = 10.0, used_gi_choice = 10.0, legend_title = "Process type")
# Perform checks on the dataframes
_figure_one_dataframe_checks(truth_df, analysis_df, scenario_dict)
# Treat the truth data as a Latent model option
truth_df[!, "Latent_Model"] .= "Truth data"

scenarios = analysis_df.Scenario |> unique
plt_truth_vect = map(scenarios) do scenario
_figure_one_scenario_truth_data(truth_df, scenario; true_gi_choice)
end
plt_analysis_vect = map(scenarios) do scenario
_figure_one_scenario(
analysis_df, scenario; reference_time = scenario_dict[scenario].T,
true_gi_choice, used_gi_choice)
end

fig = Figure(; fig_kws...)
leg = nothing
for (i, scenario) in enumerate(scenarios)
sf = fig[i, :]
ag = draw!(
sf, plt_analysis_vect[i] + plt_truth_vect[i], facet = (; linkyaxes = :none))
leg = AlgebraOfGraphics.compute_legend(ag)
Label(sf[0, :], scenario_dict[scenario].title, fontsize = 24, font = :bold)
end

Label(fig[:, 0], "Process values", fontsize = 28, font = :bold, rotation = pi / 2)
Label(fig[:, 2], "Infection generating process",
fontsize = 24, font = :bold, rotation = -pi / 2)
_leg = (leg[1], leg[2], [legend_title])
Legend(fig[:, 3], _leg...)

return fig
end
1 change: 1 addition & 0 deletions pipeline/src/mainplots/mainplots.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include("figureone.jl")
Loading

0 comments on commit df7f8dc

Please sign in to comment.