From df8522795deb824c71add1ccf81ac3cb073d5259 Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Mon, 15 Jul 2024 22:12:29 +0100 Subject: [PATCH 1/3] Gathering samples into interpretable analysis and fig1 code (#370) * Makie deps * Extra plotting util functions * More utilities * prediction dataframe function * fix underlying utils * Create create_prediction_dataframe.jl * Update EpiAwarePipeline.jl * Analysis functions for creating data frames from underlying samples * Refactor utility functions * figure 1 code * scripts for gathering samples into dataframe * Create create_figure1.jl * fix pre-commit --- pipeline/Project.toml | 2 + pipeline/scripts/create_figure1.jl | 35 ++++ .../scripts/create_prediction_dataframe.jl | 33 ++++ pipeline/src/EpiAwarePipeline.jl | 14 +- pipeline/src/analysis/analysis.jl | 2 + .../make_prediction_dataframe_from_output.jl | 64 +++++++ .../src/analysis/make_truthdata_dataframe.jl | 38 ++++ pipeline/src/mainplots/figureone.jl | 165 ++++++++++++++++++ pipeline/src/mainplots/mainplots.jl | 1 + pipeline/src/utils/calculate_processes.jl | 26 ++- pipeline/src/utils/scenario_names_utils.jl | 54 ++++++ pipeline/src/utils/timeseries_utils.jl | 48 +++++ pipeline/src/utils/utils.jl | 2 + pipeline/test/mainplots/test_utils.jl | 21 +++ 14 files changed, 489 insertions(+), 16 deletions(-) create mode 100644 pipeline/scripts/create_figure1.jl create mode 100644 pipeline/scripts/create_prediction_dataframe.jl create mode 100644 pipeline/src/analysis/analysis.jl create mode 100644 pipeline/src/analysis/make_prediction_dataframe_from_output.jl create mode 100644 pipeline/src/analysis/make_truthdata_dataframe.jl create mode 100644 pipeline/src/mainplots/figureone.jl create mode 100644 pipeline/src/mainplots/mainplots.jl create mode 100644 pipeline/src/utils/scenario_names_utils.jl create mode 100644 pipeline/src/utils/timeseries_utils.jl create mode 100644 pipeline/test/mainplots/test_utils.jl diff --git a/pipeline/Project.toml b/pipeline/Project.toml index 063410b33..ed4e050d8 100644 --- a/pipeline/Project.toml +++ b/pipeline/Project.toml @@ -5,7 +5,9 @@ authors = ["Sam Abbott", "Sam Brand", "Zach Susswein"] [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" +AlgebraOfGraphics = "cbdf2221-f076-402e-a563-3d30da359d67" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54" DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" diff --git a/pipeline/scripts/create_figure1.jl b/pipeline/scripts/create_figure1.jl new file mode 100644 index 000000000..81c0c29b0 --- /dev/null +++ b/pipeline/scripts/create_figure1.jl @@ -0,0 +1,35 @@ +## Script to make figure 1 +using Pkg +Pkg.activate(joinpath(@__DIR__(), "..")) + +using EpiAwarePipeline, EpiAware, AlgebraOfGraphics, JLD2, DrWatson, Plots, DataFramesMeta, + Statistics, Distributions, CSV + +## +pipelines = [ + SmoothOutbreakPipeline(), MeasuresOutbreakPipeline(), + SmoothEndemicPipeline(), RoughEndemicPipeline()] + +## load some data and create a dataframe for the plot +truth_data_files = readdir(datadir("truth_data")) |> + strs -> filter(s -> occursin("jld2", s), strs) +analysis_df = CSV.File(plotsdir("analysis_df.csv")) |> DataFrame +truth_df = mapreduce(vcat, truth_data_files) do filename + D = load(joinpath(datadir("truth_data"), filename)) + make_truthdata_dataframe(filename, D, pipelines) +end + +## Make mainfigure plots + +# Define scenario titles and reference times for figure 1 +scenario_dict = Dict( + "measures_outbreak" => (title = "Outbreak with measures", T = 28), + "smooth_outbreak" => (title = "Outbreak no measures", T = 35), + "smooth_endemic" => (title = "Smooth endemic", T = 35), + "rough_endemic" => (title = "Rough endemic", T = 35) +) + +fig1 = figureone(truth_df, analysis_df, scenario_dict) + +## Save the figure +save(plotsdir("figure1.png"), fig1) diff --git a/pipeline/scripts/create_prediction_dataframe.jl b/pipeline/scripts/create_prediction_dataframe.jl new file mode 100644 index 000000000..c3faab056 --- /dev/null +++ b/pipeline/scripts/create_prediction_dataframe.jl @@ -0,0 +1,33 @@ +using Pkg +Pkg.activate(joinpath(@__DIR__(), "..")) + +using EpiAwarePipeline, EpiAware, AlgebraOfGraphics, JLD2, DrWatson, Plots, DataFramesMeta, + Statistics, Distributions, DrWatson + +## load some data and create a dataframe for the plot +files = readdir(datadir("epiaware_observables")) |> + strs -> filter(s -> occursin("jld2", s), strs) + +## Define scenarios +pipelines = [ + SmoothOutbreakPipeline(), MeasuresOutbreakPipeline(), + SmoothEndemicPipeline(), RoughEndemicPipeline()] + +## Set up EpiData objects: Used in the prediction dataframe for infection generating +## processes that don't use directly in simulation. +gi_params = make_gi_params(pipelines[1]) +epi_datas = map(gi_params["gi_means"]) do μ + σ = gi_params["gi_stds"][1] + shape = (μ / σ)^2 + scale = σ^2 / μ + Gamma(shape, scale) +end .|> gen_dist -> EpiData(gen_distribution = gen_dist) + +## Calculate the prediction dataframe +prediction_df = mapreduce(vcat, files) do filename + output = load(joinpath(datadir("epiaware_observables"), filename)) + make_prediction_dataframe_from_output(filename, output, epi_datas, pipelines) +end + +## Save the prediction dataframe +CSV.write(plotsdir("analysis_df.csv"), prediction_df) diff --git a/pipeline/src/EpiAwarePipeline.jl b/pipeline/src/EpiAwarePipeline.jl index dc1ba8709..9b7face17 100644 --- a/pipeline/src/EpiAwarePipeline.jl +++ b/pipeline/src/EpiAwarePipeline.jl @@ -12,7 +12,8 @@ module EpiAwarePipeline using CSV, Dagger, DataFramesMeta, Dates, Distributions, DocStringExtensions, DrWatson, EpiAware, Plots, Statistics, ADTypes, AbstractMCMC, Plots, JLD2, MCMCChains, Turing, - DynamicPPL, LogExpFunctions, RCall, LinearAlgebra, Random + DynamicPPL, LogExpFunctions, RCall, LinearAlgebra, Random, AlgebraOfGraphics, + CairoMakie # Exported pipeline types export AbstractEpiAwarePipeline, EpiAwarePipeline, AbstractRtwithoutRenewalPipeline, @@ -20,7 +21,8 @@ export AbstractEpiAwarePipeline, EpiAwarePipeline, AbstractRtwithoutRenewalPipel MeasuresOutbreakPipeline, SmoothEndemicPipeline, RoughEndemicPipeline # Exported utility functions -export calculate_processes +export calculate_processes, generate_quantiles_for_targets, + timeseries_samples_into_quantiles # Exported configuration types export TruthSimulationConfig, InferenceConfig @@ -46,6 +48,12 @@ export define_forecast_epiprob, generate_forecasts # Exported functions: scoring functions export score_parameters +# Exported functions: Analysis functions for constructing dataframes +export make_prediction_dataframe_from_output, make_truthdata_dataframe + +# Exported functions: Make main plots +export figureone + # Exported functions: plot functions export plot_truth_data, plot_Rt @@ -57,5 +65,7 @@ include("simulate/simulate.jl") include("infer/infer.jl") include("forecast/forecast.jl") include("scoring/score_parameters.jl") +include("analysis/analysis.jl") +include("mainplots/mainplots.jl") include("plot_functions.jl") end diff --git a/pipeline/src/analysis/analysis.jl b/pipeline/src/analysis/analysis.jl new file mode 100644 index 000000000..fffa72dca --- /dev/null +++ b/pipeline/src/analysis/analysis.jl @@ -0,0 +1,2 @@ +include("make_truthdata_dataframe.jl") +include("make_prediction_dataframe_from_output.jl") diff --git a/pipeline/src/analysis/make_prediction_dataframe_from_output.jl b/pipeline/src/analysis/make_prediction_dataframe_from_output.jl new file mode 100644 index 000000000..8abbed5c5 --- /dev/null +++ b/pipeline/src/analysis/make_prediction_dataframe_from_output.jl @@ -0,0 +1,64 @@ +""" +Create a dataframe containing prediction results based on the given output and input data. + +# Arguments +- `filename`: The name of the file. +- `output`: The output data containing inference configuration, IGP model, and other information. +- `epi_datas`: The input data for the epidemiological model. +- `qs`: An optional array of quantiles to calculate. Default is `[0.025, 0.5, 0.975]`. + +# Returns +A dataframe containing the prediction results. + +""" +function make_prediction_dataframe_from_output( + filename, output, epi_datas, pipelines; qs = [0.025, 0.5, 0.975]) + #Get the scenario, IGP model, latent model and true mean GI + inference_config = output["inference_config"] + igp_model = output["inference_config"].igp |> string + scenario = EpiAwarePipeline._get_scenario_from_filename(filename, pipelines) + latent_model = EpiAwarePipeline._get_latent_model_from_filename(filename) + true_mean_gi = EpiAwarePipeline._get_true_gi_mean_from_filename(filename) + + #Get the quantiles for the targets across the gi mean scenarios + #if Renewal model, then we use the underlying epi model + #otherwise we use the epi datas to loop over different gi mean implications + used_epi_datas = igp_model == "Renewal" ? [output["epiprob"].epi_model.data] : epi_datas + + preds = nothing + try + preds = map(used_epi_datas) do epi_data + generate_quantiles_for_targets(output, epi_data, qs) + end + used_gi_means = igp_model == "Renewal" ? + [EpiAwarePipeline._get_used_gi_mean_from_filename(filename)] : + make_gi_params(EpiAwareExamplePipeline())["gi_means"] + + #Create the dataframe columnwise + df = mapreduce(vcat, preds, used_gi_means) do pred, used_gi_mean + mapreduce(vcat, keys(pred)) do target + target_mat = pred[target] + target_times = collect(1:size(target_mat, 1)) .+ + (inference_config.tspan[1] - 1) + _df = DataFrame(target_times = target_times) + _df[!, "Scenario"] .= scenario + _df[!, "IGP_Model"] .= igp_model + _df[!, "Latent_Model"] .= latent_model + _df[!, "True_GI_Mean"] .= true_mean_gi + _df[!, "Used_GI_Mean"] .= used_gi_mean + _df[!, "Reference_Time"] .= inference_config.tspan[2] + _df[!, "Target"] .= string(target) + # quantile predictions + for (j, q) in enumerate(qs) + q_str = split(string(q), ".")[end] + _df[!, "q_$(q_str)"] = target_mat[:, j] + end + return _df + end + end + return df + catch + @warn "Error in generating quantiles for targets in file $filename" + return nothing + end +end diff --git a/pipeline/src/analysis/make_truthdata_dataframe.jl b/pipeline/src/analysis/make_truthdata_dataframe.jl new file mode 100644 index 000000000..1adedeb9b --- /dev/null +++ b/pipeline/src/analysis/make_truthdata_dataframe.jl @@ -0,0 +1,38 @@ + +""" + make_truthdata_dataframe(filename, truth_data, pipelines; I_0 = 100.0) + +Create a DataFrame containing truth data for analysis. + +# Arguments +- `filename::String`: The name of the file. +- `truth_data::Dict`: A dictionary containing truth data. +- `pipelines::Array`: An array of pipelines. +- `I_0::Float64`: Initial value for I_t (default: 100.0). + +# Returns +- `df::DataFrame`: A DataFrame containing the truth data. + +""" +function make_truthdata_dataframe(filename, truth_data, pipelines; I_0 = 100.0) + I_t = truth_data["I_t"] + true_mean_gi = truth_data["truth_gi_mean"] + log_It = _calc_log_infections(I_t) + rt = _calc_rt(I_t, I_0) + scenario = _get_scenario_from_filename(filename, pipelines) + truth_procs = (; log_I_t = log_It, rt, Rt = truth_data["truth_process"]) + + df = mapreduce(vcat, keys(truth_procs)) do target + proc = truth_procs[target] + _df = DataFrame( + target_times = 1:length(proc), + target_values = proc + ) + _df[!, "Scenario"] .= scenario + _df[!, "True_GI_Mean"] .= true_mean_gi + _df[!, "Target"] .= string(target) + return _df + end + + return df +end diff --git a/pipeline/src/mainplots/figureone.jl b/pipeline/src/mainplots/figureone.jl new file mode 100644 index 000000000..13410afe0 --- /dev/null +++ b/pipeline/src/mainplots/figureone.jl @@ -0,0 +1,165 @@ +""" +Internal method to check if the required columns are present in the truth dataframe. + +# Arguments +- `truth_df`: The truth dataframe to be checked. + +""" +function _figure_one_truth_dataframe_checks(truth_df) + @assert "True_GI_Mean" ∈ names(truth_df) "True_GI_Mean col not in truth data" + @assert "Scenario" ∈ names(truth_df) "Scenario col not in truth data" + @assert "target_times" ∈ names(truth_df) "target_times col not in truth data" + @assert "target_values" ∈ names(truth_df) "target_values col not in truth data" +end + +""" +Internal method to perform checks on the analysis dataframe to ensure that it contains the required columns. + +# Arguments +- `analysis_df`: The analysis dataframe to be checked. + +# Raises +- `AssertionError`: If any of the required columns are missing in the analysis dataframe. + +""" +function _figure_one_analysis_dataframe_checks(analysis_df) + @assert "True_GI_Mean" ∈ names(analysis_df) "True_GI_Mean col not in analysis data" + @assert "Used_GI_Mean" ∈ names(analysis_df) "Used_GI_Mean col not in analysis data" + @assert "Reference_Time" ∈ names(analysis_df) "Reference_Time col not in analysis data" + @assert "Scenario" ∈ names(analysis_df) "Scenario col not in analysis data" + @assert "IGP_Model" ∈ names(analysis_df) "IGP_Model col not in analysis data" + @assert "Latent_Model" ∈ names(analysis_df) "Latent_Model col not in analysis data" + @assert "target_times" ∈ names(analysis_df) "target_times col not in analysis data" +end + +""" +Internal method to perform checks on the truth and analysis dataframes for Figure One. + +# Arguments +- `truth_df::DataFrame`: The truth dataframe. +- `analysis_df::DataFrame`: The analysis dataframe. +- `scenario_dict::Dict{String, Any}`: A dictionary containing scenario information. + +# Raises +- `AssertionError`: If the scenarios in the truth and analysis dataframes do not match, or if the scenarios in the truth dataframe do not match the keys in the scenario dictionary. + +""" +function _figure_one_dataframe_checks(truth_df, analysis_df, scenario_dict) + @assert issetequal(unique(truth_df.Scenario), unique(analysis_df.Scenario)) "Truth and analysis data scenarios do not match" + @assert issetequal(unique(truth_df.Scenario), keys(scenario_dict)) "Truth and analysis data True_GI_Mean do not match" + _figure_one_truth_dataframe_checks(truth_df) + _figure_one_analysis_dataframe_checks(analysis_df) +end + +""" +Internal method for creating a figure of model inference for a specific scenario + using the given analysis data. + +# Arguments +- `analysis_df`: The analysis data frame. +- `scenario`: The scenario to plot. +- `reference_time`: The reference time. +- `true_gi_choice`: The true GI choice. +- `used_gi_choice`: The used GI choice. +- `lower_sym`: The symbol for the lower quantile (default is `:q_025`). +- `upper_sym`: The symbol for the upper quantile (default is `:q_975`). + +# Returns +- `plt_model`: The plot object. + +""" +function _figure_one_scenario(analysis_df, scenario; reference_time, true_gi_choice, + used_gi_choice, lower_sym = :q_025, upper_sym = :q_975) + model_plotting_data = analysis_df |> + df -> @subset(df, :True_GI_Mean.==true_gi_choice) |> + df -> @subset(df, :Used_GI_Mean.==used_gi_choice) |> + df -> @subset(df, :Reference_Time.==reference_time) |> + df -> @subset(df, :Scenario.==scenario) |> + data + + plt_model = model_plotting_data * + mapping(:target_times => "T", :q_5 => "Process values", + col = :Target, row = :IGP_Model => "IGP model", + color = :Latent_Model => "Latent model") * + mapping(lower = lower_sym, upper = upper_sym) * visual(LinesFill) + + return plt_model +end + +""" +Internal method that generates a plot of the truth data for a specific scenario. + +## Arguments +- `truth_df`: The truth data DataFrame. +- `scenario`: The scenario for which the truth data should be plotted. +- `true_gi_choice`: The choice of true GI mean. + +## Returns +- `plt_truth`: The plot of the truth data. + +""" +function _figure_one_scenario_truth_data(truth_df, scenario; true_gi_choice) + truth_plotting_data = truth_df |> + df -> @subset(df, :True_GI_Mean.==true_gi_choice) |> + df -> @subset(df, :Scenario.==scenario) |> data + plt_truth = truth_plotting_data * + mapping(:target_times => "T", :target_values => "values", + col = :Target, color = :Latent_Model) * + visual(Lines) + return plt_truth +end + +""" +Generate figure 1 showing the analysis and truth data for different scenarios. + +## Arguments +- `truth_df`: DataFrame containing the truth data. +- `analysis_df`: DataFrame containing the analysis data. +- `scenario_dict`: Dictionary containing information about the scenarios. + +## Keyword Arguments +- `fig_kws`: Keyword arguments for the Figure object. Default is `(; size = (1000, 2000))`. +- `true_gi_choice`: Value for the true generation interval choice. Default is `10.0`. +- `used_gi_choice`: Value for the used generation interval choice. Default is `10.0`. +- `legend_title`: Title for the legend. Default is `"Process type"`. + +## Returns +- `fig`: Figure object containing the generated figure. + +""" +function figureone( + truth_df, analysis_df, scenario_dict; fig_kws = (; size = (1000, 2000)), + true_gi_choice = 10.0, used_gi_choice = 10.0, legend_title = "Process type") + # Perform checks on the dataframes + _figure_one_dataframe_checks(truth_df, analysis_df, scenario_dict) + # Treat the truth data as a Latent model option + truth_df[!, "Latent_Model"] .= "Truth data" + + scenarios = analysis_df.Scenario |> unique + plt_truth_vect = map(scenarios) do scenario + _figure_one_scenario_truth_data(truth_df, scenario; true_gi_choice) + end + plt_analysis_vect = map(scenarios) do scenario + _figure_one_scenario( + analysis_df, scenario; reference_time = scenario_dict[scenario].T, + true_gi_choice, used_gi_choice) + end + + fig = Figure(; fig_kws...) + leg = nothing + for (i, scenario) in enumerate(scenarios) + sf = fig[i, :] + ag = draw!( + sf, plt_analysis_vect[i] + plt_truth_vect[i], facet = (; linkyaxes = :none)) + leg = AlgebraOfGraphics.compute_legend(ag) + Label(sf[0, :], scenario_dict[scenario].title, fontsize = 24, font = :bold) + end + + Label(fig[:, 0], "Process values", fontsize = 28, font = :bold, rotation = pi / 2) + Label(fig[:, 2], "Infection generating process", + fontsize = 24, font = :bold, rotation = -pi / 2) + _leg = (leg[1], leg[2], [legend_title]) + Legend(fig[:, 3], _leg...) + + return fig +end diff --git a/pipeline/src/mainplots/mainplots.jl b/pipeline/src/mainplots/mainplots.jl new file mode 100644 index 000000000..213944413 --- /dev/null +++ b/pipeline/src/mainplots/mainplots.jl @@ -0,0 +1 @@ +include("figureone.jl") diff --git a/pipeline/src/utils/calculate_processes.jl b/pipeline/src/utils/calculate_processes.jl index 873204187..33ca9bc65 100644 --- a/pipeline/src/utils/calculate_processes.jl +++ b/pipeline/src/utils/calculate_processes.jl @@ -2,19 +2,17 @@ Internal function for calculating the log of the infections with an informative error message if the infections are not positive definite. """ -function _calc_log_infections(I_t) - @assert all(I_t .> 0) "Infections must be positive definite." - log.(I_t) +function _calc_log_infections(I_t; jitter = 1e-6) + log.(I_t .+ jitter) end """ Internal function for calculating the exponential growth rate with an informative error message if the infections are not positive definite. """ -function _calc_rt(I_t, I0) - @assert all(I_t .> 0) "Infections must be positive definite." - @assert I0>0 "Initial infections must be positive definite." - log.([I0; I_t]) .- log(I0) |> diff +function _calc_rt(I_t, I0; jitter = 1e-6) + @assert I0 + jitter>0 "Initial infections must be positive definite." + log.([I0 + jitter; I_t .+ jitter]) .- log(I0 + jitter) |> diff end """ @@ -24,10 +22,11 @@ assumes backward exponential growth with initial infections `I0` from initial estimate of `rt`. """ -function _infection_seeding(I_t, I0, data::EpiData, pipeline::AbstractEpiAwarePipeline) +function _infection_seeding( + I_t, I0, data::EpiData, pipeline::AbstractEpiAwarePipeline; jitter = 1e-6) n = length(data.gen_int) - init_rt = _calc_rt(I_t[1:2], I0) |> x -> x[1] - [I0 * exp(-init_rt * (n - i)) for i in 1:n] + init_rt = _calc_rt(I_t[1:2] .+ jitter, I0 + jitter) |> x -> x[1] + [(I0 + jitter) * exp(-init_rt * (n - i)) for i in 1:n] end """ @@ -44,11 +43,10 @@ growth from the initial infections `I0` and the exponential growth rate `init_rt - `data::EpiData`: An instance of the `EpiData` type containing generation interval data. - `pipeline::AbstractEpiAwarePipeline`: An instance of the `AbstractEpiAwarePipeline` type. """ -function _calc_Rt(I_t, I0, data::EpiData, pipeline::AbstractEpiAwarePipeline) - @assert all(I_t .> 0) "Log infections must be positive definite." - @assert I0>0 "Initial infections must be positive definite." +function _calc_Rt(I_t, I0, data::EpiData, pipeline::AbstractEpiAwarePipeline; jitter = 1e-6) + @assert I0 + jitter>0 "Initial infections must be positive definite." - aug_I_t = vcat(_infection_seeding(I_t, I0, data, pipeline), I_t) + aug_I_t = vcat(_infection_seeding(I_t .+ jitter, I0 + jitter, data, pipeline), I_t) Rt = expected_Rt(data, aug_I_t) diff --git a/pipeline/src/utils/scenario_names_utils.jl b/pipeline/src/utils/scenario_names_utils.jl new file mode 100644 index 000000000..6c0e5b5f9 --- /dev/null +++ b/pipeline/src/utils/scenario_names_utils.jl @@ -0,0 +1,54 @@ +""" +Internal method to get a list of truth data scenarios from a vector of pipelines. +""" +function _get_scenario_list(pipelines::Vector{<:AbstractEpiAwarePipeline}) + map(pipelines) do pipeline + _prefix_from_pipeline(pipeline) + end +end + +""" +Internal method to get a list of truth data mean generation intervals +""" +function _get_truth_mean_gi_string_list(pipelines::Vector{<:AbstractEpiAwarePipeline}) + gi_params = make_gi_params(pipelines[1]) + gi_mean_strs = gi_params["gi_means"] .|> x -> "gi_mean=$(x)" + gi_mean_strs +end + +""" +Internal method to get the scenario from a filename. +""" +function _get_scenario_from_filename( + filename, pipelines::Vector{<:AbstractEpiAwarePipeline}) + sc_list = _get_scenario_list(pipelines) + scenario = filter(sc -> occursin(sc, filename), sc_list)[1] +end + +""" +Internal method to get the latent model from a filename. +""" +function _get_latent_model_from_filename(filename; + latent_namemodels::Vector{String} = ["wkly_ar", "wkly_diff_ar", "wkly_rw"]) + latent_model = filter(lm -> occursin(lm, filename), latent_namemodels)[1] +end + +""" +Internal method to get the true GI mean from a filename. +""" +function _get_true_gi_mean_from_filename(filename; + truth_gi_mean_strings::Vector{String} = "truth_gi_mean_" .* + string.([2.0, 10.0, 20.0])) + true_gi = filter(tgi -> occursin(tgi, filename), truth_gi_mean_strings)[1] |> + str -> split(str, "_")[end] |> s -> parse(Float64, s) +end + +""" +Internal method to get the used GI mean from a filename. +""" +function _get_used_gi_mean_from_filename(filename; + used_gi_mean_strings::Vector{String} = "_gi_mean=" .* + string.([2.0, 10.0, 20.0])) + used_gi = filter(ugi -> occursin(ugi, filename), used_gi_mean_strings)[1] |> + str -> split(str, "=")[end] |> s -> parse(Float64, s) +end diff --git a/pipeline/src/utils/timeseries_utils.jl b/pipeline/src/utils/timeseries_utils.jl new file mode 100644 index 000000000..a24e85628 --- /dev/null +++ b/pipeline/src/utils/timeseries_utils.jl @@ -0,0 +1,48 @@ + +""" +Transforms a matrix of time series samples into quantiles. + +This function takes a matrix `X` where each row represents a time series and transforms + it into a matrix of quantiles. The `qs` argument specifies the quantiles to compute. + +# Arguments +- `X`: A matrix of time series samples in shape (num_time_points, num_samples). +- `qs`: A vector of quantiles to compute. + +# Returns +A matrix where each row represents the quantiles of a time series. +""" +function timeseries_samples_into_quantiles(X, qs) + mapreduce(vcat, eachrow(X)) do row + _row = filter(x -> !isnan(x), row) + quantile(_row, qs)' + end +end + +""" +Internal function for reducing a sequence of results from calls to `calculate_processes`. +""" +_process_reduction(procs_1, procs_2) = (; log_I_t = hcat(procs_1.log_I_t, procs_2.log_I_t), + Rt = hcat(procs_1.Rt, procs_2.Rt), rt = hcat(procs_1.rt, procs_2.rt)) + +""" +Generate quantiles for targets based on the output and EpiData. + +# Arguments +- `output`: The output containing inference results. +- `D::EpiData`: The `EpiData` object containing data about underlying infection process, + e.g. the generation interval distribution. +- `qs`: The quantiles to generate. + +# Returns +An array of quantiles for each target. + +""" +function generate_quantiles_for_targets(output, D::EpiData, qs) + mapreduce(_process_reduction, output["forecast_results"].generated, + output["inference_results"].samples[:init_incidence]) do gen, logI0 + calculate_processes(gen.I_t, exp(logI0), D, EpiAwareExamplePipeline()) + end |> res -> map(res) do X + timeseries_samples_into_quantiles(X, qs) + end +end diff --git a/pipeline/src/utils/utils.jl b/pipeline/src/utils/utils.jl index c6c24a28e..cba27e2b7 100644 --- a/pipeline/src/utils/utils.jl +++ b/pipeline/src/utils/utils.jl @@ -1 +1,3 @@ include("calculate_processes.jl") +include("scenario_names_utils.jl") +include("timeseries_utils.jl") diff --git a/pipeline/test/mainplots/test_utils.jl b/pipeline/test/mainplots/test_utils.jl new file mode 100644 index 000000000..f2c694f2f --- /dev/null +++ b/pipeline/test/mainplots/test_utils.jl @@ -0,0 +1,21 @@ + +@testset "timeseries_samples_into_quantiles tests" begin + using EpiAwarePipeline + X = [1 2 3; 4 5 6; 7 8 9] + qs = [0.25, 0.5, 0.75] + expected_output = [1.5 2.0 2.5 + 4.5 5.0 5.5 + 7.5 8.0 8.5] + @test timeseries_samples_into_quantiles(X, qs) == expected_output +end + +@testset "generate_quantiles_for_targets tests" begin + output = Dict("inference_results" => Dict( + "generated" => [1, 2, 3], "samples" => Dict(:init_incidence => [0.1, 0.2, 0.3]))) + D = EpiData() + qs = [0.25, 0.5, 0.75] + expected_output = [ + [0.1 0.2 0.3; 0.1 0.2 0.3; 0.1 0.2 0.3], [0.1 0.2 0.3; 0.1 0.2 0.3; 0.1 0.2 0.3], + [0.1 0.2 0.3; 0.1 0.2 0.3; 0.1 0.2 0.3]] + @test generate_quantiles_for_targets(output, D, qs) == expected_output +end From a16f8e3eaa32e19fd10af4cf5700f996939e7ba1 Mon Sep 17 00:00:00 2001 From: Sam Abbott Date: Tue, 16 Jul 2024 09:58:15 +0100 Subject: [PATCH 2/3] Hotfix: Benchmark retuning --- benchmark/runbenchmarks.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/benchmark/runbenchmarks.jl b/benchmark/runbenchmarks.jl index 70fe712c5..4a52c0cb5 100644 --- a/benchmark/runbenchmarks.jl +++ b/benchmark/runbenchmarks.jl @@ -6,6 +6,7 @@ benchmarkpkg( "JULIA_NUM_THREADS" => "2", "OMP_NUM_THREADS" => "2" ), - ), - resultfile = joinpath(@__DIR__, "result.json") + ); + resultfile = joinpath(@__DIR__, "result.json"). + retune = true ) From c01b13909dd2d84c58ffd64d0ab76a4217ddf471 Mon Sep 17 00:00:00 2001 From: Sam Abbott Date: Tue, 16 Jul 2024 09:58:43 +0100 Subject: [PATCH 3/3] Hotfix - Benchmark retuning --- benchmark/runbenchmarks.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/runbenchmarks.jl b/benchmark/runbenchmarks.jl index 4a52c0cb5..e5a15f174 100644 --- a/benchmark/runbenchmarks.jl +++ b/benchmark/runbenchmarks.jl @@ -7,6 +7,6 @@ benchmarkpkg( "OMP_NUM_THREADS" => "2" ), ); - resultfile = joinpath(@__DIR__, "result.json"). + resultfile = joinpath(@__DIR__, "result.json"), retune = true )