Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 559: Diagnostic analysis over all inference runs #560

Merged
merged 10 commits into from
Dec 19, 2024
1 change: 1 addition & 0 deletions manuscript/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964"
DrWatson = "634d3b9d-ee7a-5ddf-bec9-22491ea816e1"
Expand Down
43 changes: 24 additions & 19 deletions manuscript/index.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ index_location = @__DIR__()
Pkg.activate(index_location)
Pkg.resolve()
Pkg.instantiate()
Pkg.add(["CairoMakie", "JLD2", "DataFramesMeta", "DrWatson"])
Pkg.add(["CairoMakie", "JLD2", "DataFramesMeta", "DrWatson", "CSV"])

using DataFramesMeta, JLD2
using DataFramesMeta, JLD2, CSV

```

Expand Down Expand Up @@ -66,23 +66,28 @@ We noted that for a substantial number of the model configurations there were mo
priorpred_dir = joinpath(@__DIR__(),"..", "pipeline/data/priorpredictive/")
priorpred_datafiles = readdir(priorpred_dir) |>
fns -> filter(fn -> contains(fn, ".jld2"), fns) #filter for .jld2 files

priorpred_outcomes_df = mapreduce(vcat, priorpred_datafiles) do fn
D = load(joinpath(priorpred_dir, fn))
igp = D["inference_config"]["igp"]
latent_model = D["inference_config"]["latent_model"]
gi_mean = D["inference_config"]["gi_mean"]
T1, T2 = split(D["inference_config"]["tspan"], "_")
runsuccess = D["priorpredictive"] .== "Pass"
df = DataFrame(
infection_gen_proc = igp,
latent_model = latent_model,
gi_mean = gi_mean,
T1 = T1,
T2 = T2,
T_diff = parse(Int, T2) - parse(Int, T1),
runsuccess = runsuccess,
)
priorpred_outcomes_df = DataFrame()
if !isfile(joinpath(index_location, "pass_fail_rdn1.csv"))
priorpred_outcomes_df = mapreduce(vcat, priorpred_datafiles) do fn
D = load(joinpath(priorpred_dir, fn))
igp = D["inference_config"]["igp"]
latent_model = D["inference_config"]["latent_model"]
gi_mean = D["inference_config"]["gi_mean"]
T1, T2 = split(D["inference_config"]["tspan"], "_")
runsuccess = D["priorpredictive"] .== "Pass"
df = DataFrame(
infection_gen_proc = igp,
latent_model = latent_model,
gi_mean = gi_mean,
T1 = T1,
T2 = T2,
T_diff = parse(Int, T2) - parse(Int, T1),
runsuccess = runsuccess,
)
end
CSV.write(joinpath(index_location, "pass_fail_rdn1.csv"), priorpred_outcomes_df)
else
priorpred_outcomes_df = CSV.File(joinpath(index_location, "pass_fail_rdn1.csv")) |> DataFrame
end
```

Expand Down
20 changes: 20 additions & 0 deletions pipeline/scripts/create_mcmc_diagonostic_script.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
## Analysis of the prediction dataframes for mcmc diagnostics
diagnostic_df = mapreduce(vcat, scenarios) do scenario
mapreduce(vcat, true_gi_means) do true_gi_mean
target_str = "truth_gi_mean_" * string(true_gi_mean) * "_"
files = readdir(datadir("epiaware_observables/" * scenario)) |>
strs -> filter(s -> occursin("jld2", s), strs) |>
strs -> filter(s -> occursin(target_str, s), strs)

mapreduce(vcat, files) do filename
output = load(joinpath(datadir("epiaware_observables"), scenario, filename))
try
make_mcmc_diagnostic_dataframe(output, true_gi_mean, scenario)
catch e
end
end
end
end

## Save the mcmc diagnostics
CSV.write("manuscript/inference_diagnostics_rnd2.csv", diagnostic_df)
10 changes: 9 additions & 1 deletion pipeline/scripts/create_postprocessing_dataframes.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
using EpiAwarePipeline, EpiAware, JLD2, DrWatson, DataFramesMeta, CSV
using EpiAwarePipeline, EpiAware, JLD2, DrWatson, DataFramesMeta, CSV, MCMCChains

## Define scenarios
scenarios = ["measures_outbreak", "smooth_outbreak", "smooth_endemic", "rough_endemic"]

## Define true GI means
true_gi_means = [2.0, 10.0, 20.0]

if !isfile(plotsdir("plotting_data/predictions.csv"))
@info "Prediction dataframe does not exist, generating now"
include("create_prediction_dataframe.jl")
Expand All @@ -12,3 +15,8 @@ if !isfile(plotsdir("plotting_data/truthdata.csv"))
@info "Truth dataframe does not exist, generating now"
include("create_truth_dataframe.jl")
end

if !isfile("manuscript/inference_pass_fail_rnd2.csv")
@info "Diagnostic dataframe does not exist, generating now"
include("create_mcmc_diagonostic_script.jl")
end
26 changes: 11 additions & 15 deletions pipeline/scripts/create_prediction_dataframe.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
## Define true GI means
true_gi_means = [2.0, 10.0, 20.0]

## Load the prediction dataframes or record fails
## Structure to record success/failure
success_configs = Dict[]
failed_configs = Dict[]

## Analysis of the prediction dataframes
dfs = mapreduce(vcat, scenarios) do scenario
mapreduce(vcat, true_gi_means) do true_gi_mean
target_str = "truth_gi_mean_" * string(true_gi_mean) * "_"
Expand All @@ -14,39 +13,36 @@ dfs = mapreduce(vcat, scenarios) do scenario
mapreduce(vcat, files) do filename
output = load(joinpath(datadir("epiaware_observables"), scenario, filename))
try
push!(success_configs,
merge(output["inference_config"], Dict("runsuccess" => true)))
make_prediction_dataframe_from_output(output, true_gi_mean, scenario)
catch e
@warn "Error in $filename"
push!(failed_configs, output["inference_config"])
push!(failed_configs,
merge(output["inference_config"], Dict("runsuccess" => false)))
return DataFrame()
end
end
end
end

## Gather the failed data
failed_df = mapreduce(vcat, failed_configs) do D
## Gather the pass/failed data
pass_fail_df = mapreduce(vcat, [success_configs; failed_configs]) do D
igp = D["igp"] |> str -> split(str, ".")[end]
latent_model = D["latent_model"]
gi_mean = D["gi_mean"]
T1, T2 = split(D["tspan"], "_")
runsuccess = D["priorpredictive"] .== "Pass"
df = DataFrame(
infection_gen_proc = igp,
latent_model = latent_model,
gi_mean = gi_mean,
T1 = T1,
T2 = T2,
T_diff = parse(Int, T2) - parse(Int, T1),
runsuccess = runsuccess
runsuccess = D["runsuccess"]
)
end

##
grped_failed_df = failed_df |>
df -> @groupby(df, :infection_gen_proc, :latent_model) |>
gd -> @combine(gd, :n_fail=sum(1 .- :runsuccess))

## Save the prediction and failed dataframes
CSV.write(plotsdir("plotting_data/predictions.csv"), dfs)
CSV.write(plotsdir("plotting_data/failed_preds.csv"), failed_df)
CSV.write("manuscript/inference_pass_fail_rnd2.csv", pass_fail_df)
2 changes: 1 addition & 1 deletion pipeline/src/EpiAwarePipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ export score_parameters, simple_crps, summarise_crps

# Exported functions: Analysis functions for constructing dataframes
export make_prediction_dataframe_from_output, make_truthdata_dataframe,
make_scoring_dataframe_from_output
make_scoring_dataframe_from_output, make_mcmc_diagnostic_dataframe

# Exported functions: Make main plots
export figureone, figuretwo
Expand Down
2 changes: 2 additions & 0 deletions pipeline/src/analysis/analysis.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
include("config_mappings.jl")
include("make_truthdata_dataframe.jl")
include("make_prediction_dataframe_from_output.jl")
include("make_scoring_dataframe_from_output.jl")
include("make_mcmc_diagnostic_dataframe.jl")
33 changes: 33 additions & 0 deletions pipeline/src/analysis/config_mappings.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Extracts and returns relevant information from the given inference configuration dictionary.

# Returns
- `NamedTuple`: A named tuple containing the following fields:
- `igp_model::String`: The IGP model name extracted from the configuration.
- `latent_model::String`: The latent model name from the configuration.
- `used_gi_mean::Float64`: The mean generation interval (GI) used in the configuration.
- `used_gi_std::Float64`: The standard deviation of the generation interval (GI) used in the configuration.
- `start_time::Int`: The start time parsed from the configuration's time span.
- `reference_time::Int`: The reference time parsed from the configuration's time span.
- `used_gi_means::Vector{Float64}`: A vector of GI means, either a single value if the IGP model is "Renewal" or a list of values generated by `make_gi_params` otherwise.
"""
function _get_info_from_config(inference_config)
#Get the scenario, IGP model, latent model and true mean GI
igp_model = inference_config["igp"] |> igp_name -> split(igp_name, ".")[end]
latent_model = inference_config["latent_model"]
used_gi_mean = inference_config["gi_mean"]
used_gi_std = inference_config["gi_std"]
(start_time, reference_time) = inference_config["tspan"] |>
tspan -> split(tspan, "_") |>
tspan -> (
parse(Int, tspan[1]), parse(Int, tspan[2]))

#Get the quantiles for the targets across the gi mean scenarios
#if Renewal model, then we use the underlying epi model
#otherwise we use the epi datas to loop over different gi mean implications
used_gi_means = igp_model == "Renewal" ?
[used_gi_mean] :
make_gi_params(EpiAwareExamplePipeline())["gi_means"]
return (; igp_model, latent_model, used_gi_mean, used_gi_std,
start_time, reference_time, used_gi_means)
end
48 changes: 48 additions & 0 deletions pipeline/src/analysis/make_mcmc_diagnostic_dataframe.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
Generate a DataFrame containing MCMC diagnostic metrics. The metrics are the proportion of
parameters that pass the bulk effective sample size (ESS) threshold, the proportion of
parameters that pass the tail ESS threshold, the proportion of parameters that pass the R-hat
absolute difference from 1 threshold, whether the model has a cluster factor parameter, and the tail ESS
of the cluster factor parameter.

# Arguments
- `output::Dict`: A dictionary containing the inference results.
- `bulk_ess_threshold::Int`: The threshold for bulk effective sample size (ESS). Default is 500.
- `tail_ess_threshold::Int`: The threshold for tail effective sample size (ESS). Default is 100.
- `rhat_diff_threshold::Float64`: The threshold for the difference of R-hat from 1. Default is 0.02.
"""
function make_mcmc_diagnostic_dataframe(
output, true_mean_gi, scenario; bulk_ess_threshold = 500,
tail_ess_threshold = 100, rhat_diff_threshold = 0.02)
#Get the scenario, IGP model, latent model and true mean GI
inference_config = output["inference_config"]
info = _get_info_from_config(inference_config)
#Get the convergence diagnostics
chn_nt = output["inference_results"].samples |> summarize |> summary -> summary.nt
cluster_factor_idxs = chn_nt.parameters .== Symbol("obs.cluster_factor")
has_cluster_factor = any(cluster_factor_idxs)
not_cluster_factor = .~cluster_factor_idxs
cluster_factor_tail = chn_nt.ess_tail[cluster_factor_idxs][1]

#Create the dataframe
df = mapreduce(vcat, info.used_gi_means) do used_gi_mean
DataFrame(
Scenario = scenario,
igp_model = info.igp_model,
latent_model = info.latent_model,
True_GI_Mean = true_mean_gi,
used_gi_mean = used_gi_mean,
reference_time = info.reference_time,
bulk_ess_threshold = (chn_nt.ess_bulk[not_cluster_factor] .>
bulk_ess_threshold) |>
mean,
tail_ess_threshold = (chn_nt.ess_tail[not_cluster_factor] .>
tail_ess_threshold) |>
mean,
rhat_diff_threshold = (abs.(chn_nt.rhat[not_cluster_factor] .- 1) .<
rhat_diff_threshold) |> mean,
has_cluster_factor = has_cluster_factor,
cluster_factor_tail = has_cluster_factor ? cluster_factor_tail : missing)
end
return df
end
36 changes: 11 additions & 25 deletions pipeline/src/analysis/make_prediction_dataframe_from_output.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,42 +20,28 @@ function make_prediction_dataframe_from_output(
inference_config = output["inference_config"]
forecasts = output["forecast_results"]
#Get the scenario, IGP model, latent model and true mean GI
igp_model = inference_config["igp"] |> igp_name -> split(igp_name, ".")[end]
latent_model = inference_config["latent_model"]
used_gi_mean = inference_config["gi_mean"]
used_gi_std = inference_config["gi_std"]
(start_time, reference_time) = inference_config["tspan"] |>
tspan -> split(tspan, "_") |>
tspan -> (
parse(Int, tspan[1]), parse(Int, tspan[2]))

#Get the quantiles for the targets across the gi mean scenarios
#if Renewal model, then we use the underlying epi model
#otherwise we use the epi datas to loop over different gi mean implications
used_gi_means = igp_model == "Renewal" ?
[used_gi_mean] :
make_gi_params(EpiAwareExamplePipeline())["gi_means"]

used_epidatas = map(used_gi_means) do ḡ
_make_epidata(ḡ, used_gi_std; transformation = transformation)
info = _get_info_from_config(inference_config)
#Get the epi datas
used_epidatas = map(info.used_gi_means) do ḡ
_make_epidata(ḡ, info.used_gi_std; transformation = transformation)
end

#Generate the quantiles for the targets
preds = map(used_epidatas) do epi_data
generate_quantiles_for_targets(forecasts, epi_data, qs)
end

#Create the dataframe columnwise
df = mapreduce(vcat, preds, used_gi_means) do pred, used_gi_mean
df = mapreduce(vcat, preds, info.used_gi_means) do pred, used_gi_mean
mapreduce(vcat, keys(pred)) do target
target_mat = pred[target]
target_times = collect(1:size(target_mat, 1)) .+ (start_time - 1)
target_times = collect(1:size(target_mat, 1)) .+ (info.start_time - 1)
_df = DataFrame(target_times = target_times)
_df[!, "Scenario"] .= scenario
_df[!, "IGP_Model"] .= igp_model
_df[!, "Latent_Model"] .= latent_model
_df[!, "igp_model"] .= info.igp_model
_df[!, "latent_model"] .= info.latent_model
_df[!, "True_GI_Mean"] .= true_mean_gi
_df[!, "Used_GI_Mean"] .= used_gi_mean
_df[!, "Reference_Time"] .= reference_time
_df[!, "used_gi_mean"] .= used_gi_mean
_df[!, "reference_time"] .= info.reference_time
_df[!, "Target"] .= string(target)
# quantile predictions
for (j, q) in enumerate(qs)
Expand Down
Loading