From 3b8d38e07059035d4c9f6d068d4916daad4a1415 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Wed, 18 Dec 2024 15:56:06 +0000 Subject: [PATCH 1/9] Create make_mcmc_diagnostic_dataframe.jl --- pipeline/src/analysis/make_mcmc_diagnostic_dataframe.jl | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 pipeline/src/analysis/make_mcmc_diagnostic_dataframe.jl diff --git a/pipeline/src/analysis/make_mcmc_diagnostic_dataframe.jl b/pipeline/src/analysis/make_mcmc_diagnostic_dataframe.jl new file mode 100644 index 000000000..01e3f4cfc --- /dev/null +++ b/pipeline/src/analysis/make_mcmc_diagnostic_dataframe.jl @@ -0,0 +1,2 @@ +function make_mcmc_diagnostic_dataframe(output) +end From 3b21e096b6e43de852629b1d501a1913320991d1 Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Wed, 18 Dec 2024 18:05:33 +0000 Subject: [PATCH 2/9] reorg scripts and add more success/fail analysis --- .../scripts/create_mcmc_diagonostic_script.jl | 20 ++++++++++++++ .../create_postprocessing_dataframes.jl | 10 ++++++- .../scripts/create_prediction_dataframe.jl | 26 ++++++++----------- 3 files changed, 40 insertions(+), 16 deletions(-) create mode 100644 pipeline/scripts/create_mcmc_diagonostic_script.jl diff --git a/pipeline/scripts/create_mcmc_diagonostic_script.jl b/pipeline/scripts/create_mcmc_diagonostic_script.jl new file mode 100644 index 000000000..493fb2914 --- /dev/null +++ b/pipeline/scripts/create_mcmc_diagonostic_script.jl @@ -0,0 +1,20 @@ +## Analysis of the prediction dataframes for mcmc diagnostics +diagnostic_df = mapreduce(vcat, scenarios) do scenario + mapreduce(vcat, true_gi_means) do true_gi_mean + target_str = "truth_gi_mean_" * string(true_gi_mean) * "_" + files = readdir(datadir("epiaware_observables/" * scenario)) |> + strs -> filter(s -> occursin("jld2", s), strs) |> + strs -> filter(s -> occursin(target_str, s), strs) + + mapreduce(vcat, files) do filename + output = load(joinpath(datadir("epiaware_observables"), scenario, filename)) + try + make_mcmc_diagnostic_dataframe(output, true_gi_mean, scenario) + catch e + end + end + end +end + +## Save the mcmc diagnostics +CSV.write("manuscript/inference_diagnostics_rnd2.csv", diagnostic_df) diff --git a/pipeline/scripts/create_postprocessing_dataframes.jl b/pipeline/scripts/create_postprocessing_dataframes.jl index 4b3ea1568..eb8a4fd4d 100644 --- a/pipeline/scripts/create_postprocessing_dataframes.jl +++ b/pipeline/scripts/create_postprocessing_dataframes.jl @@ -1,8 +1,11 @@ -using EpiAwarePipeline, EpiAware, JLD2, DrWatson, DataFramesMeta, CSV +using EpiAwarePipeline, EpiAware, JLD2, DrWatson, DataFramesMeta, CSV, MCMCChains ## Define scenarios scenarios = ["measures_outbreak", "smooth_outbreak", "smooth_endemic", "rough_endemic"] +## Define true GI means +true_gi_means = [2.0, 10.0, 20.0] + if !isfile(plotsdir("plotting_data/predictions.csv")) @info "Prediction dataframe does not exist, generating now" include("create_prediction_dataframe.jl") @@ -12,3 +15,8 @@ if !isfile(plotsdir("plotting_data/truthdata.csv")) @info "Truth dataframe does not exist, generating now" include("create_truth_dataframe.jl") end + +if !isfile("manuscript/inference_pass_fail_rnd2.csv") + @info "Diagnostic dataframe does not exist, generating now" + include("create_mcmc_diagonostic_script.jl") +end diff --git a/pipeline/scripts/create_prediction_dataframe.jl b/pipeline/scripts/create_prediction_dataframe.jl index c0e61409d..9c0d6dad1 100644 --- a/pipeline/scripts/create_prediction_dataframe.jl +++ b/pipeline/scripts/create_prediction_dataframe.jl @@ -1,9 +1,8 @@ -## Define true GI means -true_gi_means = [2.0, 10.0, 20.0] - -## Load the prediction dataframes or record fails +## Structure to record success/failure +success_configs = Dict[] failed_configs = Dict[] +## Analysis of the prediction dataframes dfs = mapreduce(vcat, scenarios) do scenario mapreduce(vcat, true_gi_means) do true_gi_mean target_str = "truth_gi_mean_" * string(true_gi_mean) * "_" @@ -14,23 +13,25 @@ dfs = mapreduce(vcat, scenarios) do scenario mapreduce(vcat, files) do filename output = load(joinpath(datadir("epiaware_observables"), scenario, filename)) try + push!(success_configs, + merge(output["inference_config"], Dict("runsuccess" => true))) make_prediction_dataframe_from_output(output, true_gi_mean, scenario) catch e @warn "Error in $filename" - push!(failed_configs, output["inference_config"]) + push!(failed_configs, + merge(output["inference_config"], Dict("runsuccess" => false))) return DataFrame() end end end end -## Gather the failed data -failed_df = mapreduce(vcat, failed_configs) do D +## Gather the pass/failed data +pass_fail_df = mapreduce(vcat, [success_configs; failed_configs]) do D igp = D["igp"] |> str -> split(str, ".")[end] latent_model = D["latent_model"] gi_mean = D["gi_mean"] T1, T2 = split(D["tspan"], "_") - runsuccess = D["priorpredictive"] .== "Pass" df = DataFrame( infection_gen_proc = igp, latent_model = latent_model, @@ -38,15 +39,10 @@ failed_df = mapreduce(vcat, failed_configs) do D T1 = T1, T2 = T2, T_diff = parse(Int, T2) - parse(Int, T1), - runsuccess = runsuccess + runsuccess = D["runsuccess"] ) end -## -grped_failed_df = failed_df |> - df -> @groupby(df, :infection_gen_proc, :latent_model) |> - gd -> @combine(gd, :n_fail=sum(1 .- :runsuccess)) - ## Save the prediction and failed dataframes CSV.write(plotsdir("plotting_data/predictions.csv"), dfs) -CSV.write(plotsdir("plotting_data/failed_preds.csv"), failed_df) +CSV.write("manuscript/inference_pass_fail_rnd2.csv", pass_fail_df) From c900fe8f7c22fab726bcc4a89541c32d17e32d95 Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Wed, 18 Dec 2024 18:06:04 +0000 Subject: [PATCH 3/9] Add function to get run info to avoid DRY --- pipeline/src/analysis/config_mappings.jl | 33 +++++++++++++++++ .../make_prediction_dataframe_from_output.jl | 36 ++++++------------- 2 files changed, 44 insertions(+), 25 deletions(-) create mode 100644 pipeline/src/analysis/config_mappings.jl diff --git a/pipeline/src/analysis/config_mappings.jl b/pipeline/src/analysis/config_mappings.jl new file mode 100644 index 000000000..057258fdb --- /dev/null +++ b/pipeline/src/analysis/config_mappings.jl @@ -0,0 +1,33 @@ +""" +Extracts and returns relevant information from the given inference configuration dictionary. + +# Returns +- `NamedTuple`: A named tuple containing the following fields: + - `igp_model::String`: The IGP model name extracted from the configuration. + - `latent_model::String`: The latent model name from the configuration. + - `used_gi_mean::Float64`: The mean generation interval (GI) used in the configuration. + - `used_gi_std::Float64`: The standard deviation of the generation interval (GI) used in the configuration. + - `start_time::Int`: The start time parsed from the configuration's time span. + - `reference_time::Int`: The reference time parsed from the configuration's time span. + - `used_gi_means::Vector{Float64}`: A vector of GI means, either a single value if the IGP model is "Renewal" or a list of values generated by `make_gi_params` otherwise. +""" +function _get_info_from_config(inference_config) + #Get the scenario, IGP model, latent model and true mean GI + igp_model = inference_config["igp"] |> igp_name -> split(igp_name, ".")[end] + latent_model = inference_config["latent_model"] + used_gi_mean = inference_config["gi_mean"] + used_gi_std = inference_config["gi_std"] + (start_time, reference_time) = inference_config["tspan"] |> + tspan -> split(tspan, "_") |> + tspan -> ( + parse(Int, tspan[1]), parse(Int, tspan[2])) + + #Get the quantiles for the targets across the gi mean scenarios + #if Renewal model, then we use the underlying epi model + #otherwise we use the epi datas to loop over different gi mean implications + used_gi_means = igp_model == "Renewal" ? + [used_gi_mean] : + make_gi_params(EpiAwareExamplePipeline())["gi_means"] + return (; igp_model, latent_model, used_gi_mean, used_gi_std, + start_time, reference_time, used_gi_means) +end diff --git a/pipeline/src/analysis/make_prediction_dataframe_from_output.jl b/pipeline/src/analysis/make_prediction_dataframe_from_output.jl index cc68012ba..7cad4e4f2 100644 --- a/pipeline/src/analysis/make_prediction_dataframe_from_output.jl +++ b/pipeline/src/analysis/make_prediction_dataframe_from_output.jl @@ -20,42 +20,28 @@ function make_prediction_dataframe_from_output( inference_config = output["inference_config"] forecasts = output["forecast_results"] #Get the scenario, IGP model, latent model and true mean GI - igp_model = inference_config["igp"] |> igp_name -> split(igp_name, ".")[end] - latent_model = inference_config["latent_model"] - used_gi_mean = inference_config["gi_mean"] - used_gi_std = inference_config["gi_std"] - (start_time, reference_time) = inference_config["tspan"] |> - tspan -> split(tspan, "_") |> - tspan -> ( - parse(Int, tspan[1]), parse(Int, tspan[2])) - - #Get the quantiles for the targets across the gi mean scenarios - #if Renewal model, then we use the underlying epi model - #otherwise we use the epi datas to loop over different gi mean implications - used_gi_means = igp_model == "Renewal" ? - [used_gi_mean] : - make_gi_params(EpiAwareExamplePipeline())["gi_means"] - - used_epidatas = map(used_gi_means) do ḡ - _make_epidata(ḡ, used_gi_std; transformation = transformation) + info = _get_info_from_config(inference_config) + #Get the epi datas + used_epidatas = map(info.used_gi_means) do ḡ + _make_epidata(ḡ, info.used_gi_std; transformation = transformation) end - + #Generate the quantiles for the targets preds = map(used_epidatas) do epi_data generate_quantiles_for_targets(forecasts, epi_data, qs) end #Create the dataframe columnwise - df = mapreduce(vcat, preds, used_gi_means) do pred, used_gi_mean + df = mapreduce(vcat, preds, info.used_gi_means) do pred, used_gi_mean mapreduce(vcat, keys(pred)) do target target_mat = pred[target] - target_times = collect(1:size(target_mat, 1)) .+ (start_time - 1) + target_times = collect(1:size(target_mat, 1)) .+ (info.start_time - 1) _df = DataFrame(target_times = target_times) _df[!, "Scenario"] .= scenario - _df[!, "IGP_Model"] .= igp_model - _df[!, "Latent_Model"] .= latent_model + _df[!, "igp_model"] .= info.igp_model + _df[!, "latent_model"] .= info.latent_model _df[!, "True_GI_Mean"] .= true_mean_gi - _df[!, "Used_GI_Mean"] .= used_gi_mean - _df[!, "Reference_Time"] .= reference_time + _df[!, "used_gi_mean"] .= used_gi_mean + _df[!, "reference_time"] .= info.reference_time _df[!, "Target"] .= string(target) # quantile predictions for (j, q) in enumerate(qs) From 60a2295bd6151bbb0a0bf3ec93ae9b6abcaac48a Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Wed, 18 Dec 2024 18:06:21 +0000 Subject: [PATCH 4/9] Add function to do diagnostics --- .../make_mcmc_diagnostic_dataframe.jl | 48 ++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/pipeline/src/analysis/make_mcmc_diagnostic_dataframe.jl b/pipeline/src/analysis/make_mcmc_diagnostic_dataframe.jl index 01e3f4cfc..80ef7d27a 100644 --- a/pipeline/src/analysis/make_mcmc_diagnostic_dataframe.jl +++ b/pipeline/src/analysis/make_mcmc_diagnostic_dataframe.jl @@ -1,2 +1,48 @@ -function make_mcmc_diagnostic_dataframe(output) +""" +Generate a DataFrame containing MCMC diagnostic metrics. The metrics are the proportion of +parameters that pass the bulk effective sample size (ESS) threshold, the proportion of +parameters that pass the tail ESS threshold, the proportion of parameters that pass the R-hat +absolute difference from 1 threshold, whether the model has a cluster factor parameter, and the tail ESS +of the cluster factor parameter. + +# Arguments +- `output::Dict`: A dictionary containing the inference results. +- `bulk_ess_threshold::Int`: The threshold for bulk effective sample size (ESS). Default is 500. +- `tail_ess_threshold::Int`: The threshold for tail effective sample size (ESS). Default is 100. +- `rhat_diff_threshold::Float64`: The threshold for the difference of R-hat from 1. Default is 0.02. +""" +function make_mcmc_diagnostic_dataframe( + output, true_mean_gi, scenario; bulk_ess_threshold = 500, + tail_ess_threshold = 100, rhat_diff_threshold = 0.02) + #Get the scenario, IGP model, latent model and true mean GI + inference_config = output["inference_config"] + info = _get_info_from_config(inference_config) + #Get the convergence diagnostics + chn_nt = output["inference_results"].samples |> summarize |> summary -> summary.nt + cluster_factor_idxs = chn_nt.parameters .== Symbol("obs.cluster_factor") + has_cluster_factor = any(cluster_factor_idxs) + not_cluster_factor = .~cluster_factor_idxs + cluster_factor_tail = chn_nt.ess_tail[cluster_factor_idxs][1] + + #Create the dataframe + df = mapreduce(vcat, info.used_gi_means) do used_gi_mean + DataFrame( + Scenario = scenario, + igp_model = info.igp_model, + latent_model = info.latent_model, + True_GI_Mean = true_mean_gi, + used_gi_mean = used_gi_mean, + reference_time = info.reference_time, + bulk_ess_threshold = (chn_nt.ess_bulk[not_cluster_factor] .> + bulk_ess_threshold) |> + mean, + tail_ess_threshold = (chn_nt.ess_tail[not_cluster_factor] .> + tail_ess_threshold) |> + mean, + rhat_diff_threshold = (abs.(chn_nt.rhat[not_cluster_factor] .- 1) .< + rhat_diff_threshold) |> mean, + has_cluster_factor = has_cluster_factor, + cluster_factor_tail = has_cluster_factor ? cluster_factor_tail : missing) + end + return df end From 677d222597f86ce7d12725436bc125baa82b9247 Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Wed, 18 Dec 2024 18:06:29 +0000 Subject: [PATCH 5/9] export new func --- pipeline/src/EpiAwarePipeline.jl | 2 +- pipeline/src/analysis/analysis.jl | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pipeline/src/EpiAwarePipeline.jl b/pipeline/src/EpiAwarePipeline.jl index ed328d28f..e2771f3c1 100644 --- a/pipeline/src/EpiAwarePipeline.jl +++ b/pipeline/src/EpiAwarePipeline.jl @@ -46,7 +46,7 @@ export score_parameters, simple_crps, summarise_crps # Exported functions: Analysis functions for constructing dataframes export make_prediction_dataframe_from_output, make_truthdata_dataframe, - make_scoring_dataframe_from_output + make_scoring_dataframe_from_output, make_mcmc_diagnostic_dataframe # Exported functions: Make main plots export figureone, figuretwo diff --git a/pipeline/src/analysis/analysis.jl b/pipeline/src/analysis/analysis.jl index 559f15b13..e0c83aa34 100644 --- a/pipeline/src/analysis/analysis.jl +++ b/pipeline/src/analysis/analysis.jl @@ -1,3 +1,5 @@ +include("config_mappings.jl") include("make_truthdata_dataframe.jl") include("make_prediction_dataframe_from_output.jl") include("make_scoring_dataframe_from_output.jl") +include("make_mcmc_diagnostic_dataframe.jl") From 3d624dca1d3ed1a077eab06737f9610bb2cf1524 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Wed, 18 Dec 2024 23:03:43 +0000 Subject: [PATCH 6/9] update SI --- manuscript/Project.toml | 1 + manuscript/index.qmd | 43 +++++++++++++++++++++++------------------ 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/manuscript/Project.toml b/manuscript/Project.toml index bc7ef347d..f3a5ee9ca 100644 --- a/manuscript/Project.toml +++ b/manuscript/Project.toml @@ -1,4 +1,5 @@ [deps] +CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964" DrWatson = "634d3b9d-ee7a-5ddf-bec9-22491ea816e1" diff --git a/manuscript/index.qmd b/manuscript/index.qmd index 05306e4ca..743fd5da5 100644 --- a/manuscript/index.qmd +++ b/manuscript/index.qmd @@ -24,9 +24,9 @@ index_location = @__DIR__() Pkg.activate(index_location) Pkg.resolve() Pkg.instantiate() -Pkg.add(["CairoMakie", "JLD2", "DataFramesMeta", "DrWatson"]) +Pkg.add(["CairoMakie", "JLD2", "DataFramesMeta", "DrWatson", "CSV"]) -using DataFramesMeta, JLD2 +using DataFramesMeta, JLD2, CSV ``` @@ -66,23 +66,28 @@ We noted that for a substantial number of the model configurations there were mo priorpred_dir = joinpath(@__DIR__(),"..", "pipeline/data/priorpredictive/") priorpred_datafiles = readdir(priorpred_dir) |> fns -> filter(fn -> contains(fn, ".jld2"), fns) #filter for .jld2 files - -priorpred_outcomes_df = mapreduce(vcat, priorpred_datafiles) do fn - D = load(joinpath(priorpred_dir, fn)) - igp = D["inference_config"]["igp"] - latent_model = D["inference_config"]["latent_model"] - gi_mean = D["inference_config"]["gi_mean"] - T1, T2 = split(D["inference_config"]["tspan"], "_") - runsuccess = D["priorpredictive"] .== "Pass" - df = DataFrame( - infection_gen_proc = igp, - latent_model = latent_model, - gi_mean = gi_mean, - T1 = T1, - T2 = T2, - T_diff = parse(Int, T2) - parse(Int, T1), - runsuccess = runsuccess, - ) +priorpred_outcomes_df = DataFrame() +if !isfile(joinpath(index_location, "pass_fail_rdn1.csv")) + priorpred_outcomes_df = mapreduce(vcat, priorpred_datafiles) do fn + D = load(joinpath(priorpred_dir, fn)) + igp = D["inference_config"]["igp"] + latent_model = D["inference_config"]["latent_model"] + gi_mean = D["inference_config"]["gi_mean"] + T1, T2 = split(D["inference_config"]["tspan"], "_") + runsuccess = D["priorpredictive"] .== "Pass" + df = DataFrame( + infection_gen_proc = igp, + latent_model = latent_model, + gi_mean = gi_mean, + T1 = T1, + T2 = T2, + T_diff = parse(Int, T2) - parse(Int, T1), + runsuccess = runsuccess, + ) + end + CSV.write(joinpath(index_location, "pass_fail_rdn1.csv"), priorpred_outcomes_df) +else + priorpred_outcomes_df = CSV.File(joinpath(index_location, "pass_fail_rdn1.csv")) |> DataFrame end ``` From e894504f9812d9440ff98df371efbf030d1bd2c1 Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Thu, 19 Dec 2024 11:32:21 +0000 Subject: [PATCH 7/9] Issue 561: Soft min transformation (#562) Also removed unnecessary call to `fetch` --- .../constructors/make_observation_model.jl | 22 ++++++++++++++++++- .../test/pipeline/test_pipelinefunctions.jl | 8 +++---- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/pipeline/src/constructors/make_observation_model.jl b/pipeline/src/constructors/make_observation_model.jl index 1fb1255aa..9ad4d6020 100644 --- a/pipeline/src/constructors/make_observation_model.jl +++ b/pipeline/src/constructors/make_observation_model.jl @@ -1,5 +1,5 @@ """ -Constructs an observation model for the given pipeline. This is the defualt method. +Constructs an observation model for the given pipeline. This is the default method. # Arguments - `pipeline::AbstractEpiAwarePipeline`: The pipeline for which the observation model is constructed. @@ -18,3 +18,23 @@ function make_observation_model(pipeline::AbstractEpiAwarePipeline) obs = LatentDelay(dayofweek_logit_ascert, delay_distribution) return obs end + +const negC = -1e15 +""" +Soft minimum function for a smooth transition from `x -> x` to a maximum value of 1e15. +""" +_softmin(x) = -logaddexp(negC, -x) + +function make_observation_model(pipeline::AbstractRtwithoutRenewalPipeline) + default_params = make_default_params(pipeline) + #Model for ascertainment based on day of the week + dayofweek_logit_ascert = ascertainment_dayofweek( + NegativeBinomialError(cluster_factor_prior = HalfNormal(default_params["cluster_factor"])); + transform = (x, y) -> _softmin.(x .* y)) + + #Default continuous-time model for latent delay in observations + delay_distribution = make_delay_distribution(pipeline) + #Model for latent delay in observations + obs = LatentDelay(dayofweek_logit_ascert, delay_distribution) + return obs +end diff --git a/pipeline/test/pipeline/test_pipelinefunctions.jl b/pipeline/test/pipeline/test_pipelinefunctions.jl index af96cb2cb..6a27c88b3 100644 --- a/pipeline/test/pipeline/test_pipelinefunctions.jl +++ b/pipeline/test/pipeline/test_pipelinefunctions.jl @@ -13,16 +13,14 @@ end @testset "do_inference tests" begin function make_inference(pipeline) - truthdata_dg_task = do_truthdata(pipeline) - truthdata = fetch.(truthdata_dg_task) + truthdata = do_truthdata(pipeline) do_inference(truthdata[1], pipeline) end for pipetype in [SmoothOutbreakPipeline, MeasuresOutbreakPipeline, SmoothEndemicPipeline, RoughEndemicPipeline] - pipeline = pipetype(; ndraws = 20, nchains = 1, testmode = true) - inference_results_tsk = make_inference(pipeline) - inference_results = fetch.(inference_results_tsk) + pipeline = pipetype(; ndraws = 1000, nchains = 1, testmode = true) + inference_results = make_inference(pipeline) @test length(inference_results) == 1 @test all([result["inference_results"] isa EpiAwareObservables for result in inference_results]) From 73e80c174eb9f944766223026218c163944c14c0 Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Thu, 19 Dec 2024 11:56:25 +0000 Subject: [PATCH 8/9] base values on pipeline types --- .../create_postprocessing_dataframes.jl | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/pipeline/scripts/create_postprocessing_dataframes.jl b/pipeline/scripts/create_postprocessing_dataframes.jl index eb8a4fd4d..2cce5ea6c 100644 --- a/pipeline/scripts/create_postprocessing_dataframes.jl +++ b/pipeline/scripts/create_postprocessing_dataframes.jl @@ -1,10 +1,24 @@ using EpiAwarePipeline, EpiAware, JLD2, DrWatson, DataFramesMeta, CSV, MCMCChains +pipelinetypes = [ + MeasuresOutbreakPipeline, + SmoothOutbreakPipeline, + SmoothEndemicPipeline, + RoughEndemicPipeline +] ## Define scenarios -scenarios = ["measures_outbreak", "smooth_outbreak", "smooth_endemic", "rough_endemic"] + +scenarios = pipelinetypes .|> pipetype -> pipetype().prefix ## Define true GI means -true_gi_means = [2.0, 10.0, 20.0] +# Errors if not the same for all pipeline types +true_gi_means = map(pipelinetypes) do pipetype + make_gi_params(pipetype())["gi_means"] +end |> + ensemble_gi_means -> all([gi_means == ensemble_gi_means[1] + for gi_means in ensemble_gi_means]) ? + ensemble_gi_means[1] : + error("GI means are not the same") if !isfile(plotsdir("plotting_data/predictions.csv")) @info "Prediction dataframe does not exist, generating now" From 727a827ea557fb1e9a4aef91749be64500f80895 Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Thu, 19 Dec 2024 12:46:02 +0000 Subject: [PATCH 9/9] breakdown mcmc convergence test function Adds more stats and a unit test --- .../make_mcmc_diagnostic_dataframe.jl | 44 +++++++++++++++---- .../make_mcmc_diagnostic_dataframe.jl | 37 ++++++++++++++++ pipeline/test/analysis/test_analysis.jl | 1 + 3 files changed, 74 insertions(+), 8 deletions(-) create mode 100644 pipeline/test/analysis/make_mcmc_diagnostic_dataframe.jl diff --git a/pipeline/src/analysis/make_mcmc_diagnostic_dataframe.jl b/pipeline/src/analysis/make_mcmc_diagnostic_dataframe.jl index 80ef7d27a..b51a5c706 100644 --- a/pipeline/src/analysis/make_mcmc_diagnostic_dataframe.jl +++ b/pipeline/src/analysis/make_mcmc_diagnostic_dataframe.jl @@ -1,3 +1,28 @@ +""" +Collects the statistics of a vector `x` that are relevant for MCMC diagnostics. +""" +function _get_stats(x, threshold; pass_above = true) + if pass_above + return (; x_mean = mean(x), prop_pass = mean(x .>= threshold), + x_min = minimum(x), x_max = maximum(x)) + else + return (; x_mean = mean(x), prop_pass = mean(x .<= threshold), + x_min = minimum(x), x_max = maximum(x)) + end +end + +""" +Collects the convergence statistics over the parameters that are not cluster factor. +""" +function _collect_stats(chn_nt, not_cluster_factor; bulk_ess_threshold, + tail_ess_threshold, rhat_diff_threshold) + ess_bulk = chn_nt.ess_bulk[not_cluster_factor] |> x -> _get_stats(x, bulk_ess_threshold) + ess_tail = chn_nt.ess_tail[not_cluster_factor] |> x -> _get_stats(x, tail_ess_threshold) + rhat_diff = abs.(chn_nt.rhat[not_cluster_factor] .- 1) |> + x -> _get_stats(x, rhat_diff_threshold; pass_above = false) + return (; ess_bulk, ess_tail, rhat_diff) +end + """ Generate a DataFrame containing MCMC diagnostic metrics. The metrics are the proportion of parameters that pass the bulk effective sample size (ESS) threshold, the proportion of @@ -23,6 +48,9 @@ function make_mcmc_diagnostic_dataframe( has_cluster_factor = any(cluster_factor_idxs) not_cluster_factor = .~cluster_factor_idxs cluster_factor_tail = chn_nt.ess_tail[cluster_factor_idxs][1] + #Collect the statistics + stats_for_targets = _collect_stats(chn_nt, not_cluster_factor; bulk_ess_threshold, + tail_ess_threshold, rhat_diff_threshold) #Create the dataframe df = mapreduce(vcat, info.used_gi_means) do used_gi_mean @@ -33,16 +61,16 @@ function make_mcmc_diagnostic_dataframe( True_GI_Mean = true_mean_gi, used_gi_mean = used_gi_mean, reference_time = info.reference_time, - bulk_ess_threshold = (chn_nt.ess_bulk[not_cluster_factor] .> - bulk_ess_threshold) |> - mean, - tail_ess_threshold = (chn_nt.ess_tail[not_cluster_factor] .> - tail_ess_threshold) |> - mean, - rhat_diff_threshold = (abs.(chn_nt.rhat[not_cluster_factor] .- 1) .< - rhat_diff_threshold) |> mean, has_cluster_factor = has_cluster_factor, cluster_factor_tail = has_cluster_factor ? cluster_factor_tail : missing) end + #Add stats columns + for key in keys(stats_for_targets) + stats = getfield(stats_for_targets, key) + df[!, string(key) * "_" * "mean"] .= stats.x_mean + df[!, string(key) * "_" * "prop_pass"] .= stats.prop_pass + df[!, string(key) * "_" * "min"] .= stats.x_min + df[!, string(key) * "_" * "max"] .= stats.x_max + end return df end diff --git a/pipeline/test/analysis/make_mcmc_diagnostic_dataframe.jl b/pipeline/test/analysis/make_mcmc_diagnostic_dataframe.jl new file mode 100644 index 000000000..b2473d918 --- /dev/null +++ b/pipeline/test/analysis/make_mcmc_diagnostic_dataframe.jl @@ -0,0 +1,37 @@ +@testset "test MCMC convergence analysis on toy obs model" begin + using JLD2, DataFramesMeta, Turing, EpiAware + # Reuse the local config + _output = load(joinpath(@__DIR__(), "test_data.jld2")) + inference_config = _output["inference_config"] + # Create a simple test model to test mcmc diagnostics via prior sampling + obs = make_observation_model(SmoothEndemicPipeline()) + @model function test_model() + x ~ filldist(Normal(0, 1), 20) + @submodel prefix="obs" y_t=generate_observations(obs, missing, exp.(x)) + end + n = 1000 + samples = sample(test_model(), Prior(), n) + + # Create a simple output to test the function + output = Dict( + "inference_config" => inference_config, + "inference_results" => (; samples,) + ) + + true_mean_gi = 10.0 + scenario = "rough_endemic" + df = make_mcmc_diagnostic_dataframe( + output, true_mean_gi, "rough_endemic") + # Check pass throughs + @test typeof(df) == DataFrame + @test size(df, 1) == 3 # Number of rows should match the length of used_gi_means + @test df[1, :Scenario] == scenario + @test df[1, :latent_model] == inference_config["latent_model"] + @test df[1, :True_GI_Mean] == true_mean_gi + # Prior sampling should be uncorrelated and meet all the convergence criteria + @test all(df[:, :ess_bulk_prop_pass] .== 1.0) + @test all(df[:, :ess_tail_prop_pass] .== 1.0) + @test all(df[:, :rhat_diff_prop_pass] .== 1.0) + @test all(df[:, :has_cluster_factor] .== true) + @test all(df[1, :cluster_factor_tail] .> n / 2) +end diff --git a/pipeline/test/analysis/test_analysis.jl b/pipeline/test/analysis/test_analysis.jl index 1e2608aea..6a45d2214 100644 --- a/pipeline/test/analysis/test_analysis.jl +++ b/pipeline/test/analysis/test_analysis.jl @@ -1,2 +1,3 @@ include("make_prediction_dataframe_from_output.jl") include("make_truthdata_dataframe.jl") +include("make_mcmc_diagnostic_dataframe.jl")