Skip to content

Commit

Permalink
Tweak parameter recovery testing (#503)
Browse files Browse the repository at this point in the history
* Pass in missing directly and tweak handling progress and number of chains

* aim to catch remaining sammple misisng progress

* Update RandomWalk.jl - to Poisson

* Update LatentDelay.jl - switch AR priors and use Poisson

* Update RandomWalk.jl

* Update LatentDelay.jl

* Update EpiAware/test/EpiLatentModels/models/RandomWalk.jl

Co-authored-by: Samuel Brand <[email protected]>

* sharpen priors for r_t (exp growth rate)

More samples but shorter generated data

* remove debug @info

---------

Co-authored-by: Samuel Brand <[email protected]>
Co-authored-by: Samuel Brand <[email protected]>
  • Loading branch information
3 people authored Oct 24, 2024
1 parent 822a0aa commit 119f244
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 21 deletions.
14 changes: 7 additions & 7 deletions EpiAware/test/EpiLatentModels/models/RandomWalk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,23 @@ end
ReverseDiff
Random.seed!(1234)

rw_process = RandomWalk()
obs_nb = NegativeBinomialError()
process = RandomWalk()
obs = PoissonError()

@model function test_negbin_errors(rw, obs, y_t)
@model function test_poisson_errors(proc, obs, y_t)
n = length(y_t)
@submodel Z_t = generate_latent(rw, n)
@submodel Z_t = generate_latent(proc, n)
@submodel y_t = generate_observations(obs, y_t, exp.(Z_t))
return Z_t, y_t
end

generative_mdl = test_negbin_errors(rw_process, obs_nb, fill(missing, 40))
generative_mdl = test_poisson_errors(process, obs, fill(missing, 40))
θ_true = rand(generative_mdl)
Z_t_obs, y_t_obs = condition(generative_mdl, θ_true)()

mdl = test_negbin_errors(rw_process, obs_nb, Int.(y_t_obs))
mdl = test_poisson_errors(process, obs, Int.(y_t_obs))
chn = sample(
mdl, NUTS(adtype = AutoReverseDiff(; compile = Val(true))), 1000, progess = false)
mdl, NUTS(adtype = AutoReverseDiff(; compile = Val(true))), 1000; progess = false)

#Check that are in central 99.9% of the posterior predictive distribution
#Therefore, this should be unlikely to fail if the model is correctly implemented
Expand Down
27 changes: 13 additions & 14 deletions EpiAware/test/EpiObsModels/modifiers/LatentDelay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ end
end
end

@testitem "LatentDelay parameter recovery with mix of IGP + latent processes: Negative binomial errors + EpiProblem interface" begin
@testitem "LatentDelay parameter recovery with mix of IGP + latent processes: Poisson errors + EpiProblem interface" begin
using Random, Turing, Distributions, LinearAlgebra, DynamicPPL, StatsBase, ReverseDiff,
Suppressor, LogExpFunctions
# using PairPlots, CairoMakie
Expand All @@ -178,7 +178,7 @@ end
data = EpiData([0.2, 0.5, 0.3],
em_type == Renewal ? softplus : exp
),
initialisation_prior = Normal(log(100.0), 0.25)
initialisation_prior = Normal(log(100.0), 0.01)
)

latentprocess_types = [RandomWalk, AR, DiffLatentModel]
Expand All @@ -190,7 +190,7 @@ end
return (; init_prior, std_prior)
elseif epimodel isa ExpGrowthRate
init_prior = Normal(0.1, 0.025)
std_prior = HalfNormal(0.025)
std_prior = LogNormal(log(0.025), 0.01)
return (; init_prior, std_prior)
elseif epimodel isa DirectInfections
init_prior = Normal(log(100.0), 0.25)
Expand All @@ -204,11 +204,11 @@ end
if latentprocess_type == RandomWalk
return RandomWalk(init_prior, std_prior)
elseif latentprocess_type == AR
return AR(damp_priors = [Beta(8, 2; check_args = false)],
return AR(damp_priors = [Beta(2, 8; check_args = false)],
std_prior = std_prior, init_priors = [init_prior])
elseif latentprocess_type == DiffLatentModel
return DiffLatentModel(
AR(damp_priors = [Beta(8, 2; check_args = false)],
AR(damp_priors = [Beta(2, 8; check_args = false)],
std_prior = std_prior, init_priors = [Normal(0.0, 0.25)]),
init_prior; d = 1)
end
Expand All @@ -217,15 +217,14 @@ end
function test_full_process(epimodel, latentprocess, n;
ad = AutoReverseDiff(; compile = true), posterior_p_tol = 0.005)
#Fix observation model
obs = LatentDelay(
NegativeBinomialError(cluster_factor_prior = HalfNormal(0.05)), Gamma(3, 7 / 3))
obs = LatentDelay(PoissonError(), Gamma(3, 7 / 3))

#Inference method
inference_method = EpiMethod(
pre_sampler_steps = [ManyPathfinder(nruns = 4, maxiters = 100)],
pre_sampler_steps = [ManyPathfinder(nruns = 4, maxiters = 50)],
sampler = NUTSampler(adtype = ad,
ndraws = 1000,
nchains = 4,
ndraws = 2000,
nchains = 2,
mcmc_parallel = MCMCThreads())
)

Expand All @@ -237,15 +236,15 @@ end
)

#Generate data from generative model (i.e. data unconditioned)
generative_mdl = generate_epiaware(
epi_prob, (y_t = Vector{Union{Int, Missing}}(missing, n),))
generative_mdl = generate_epiaware(epi_prob, (y_t = missing,))
θ_true = rand(generative_mdl)
gen_data = condition(generative_mdl, θ_true)()

#Apply inference method to inference model (i.e. generative model conditioned on data)
inference_results = apply_method(epi_prob,
inference_method,
(y_t = gen_data.generated_y_t,)
(y_t = gen_data.generated_y_t,);
progress = false
)

chn = inference_results.samples
Expand All @@ -265,7 +264,7 @@ end
@testset "Check true parameters are within 99% central post. prob.: " begin
@testset for latentprocess_type in latentprocess_types, epimodel in epimodels
latentprocess = set_latent_process(epimodel, latentprocess_type)
@suppress _ = test_full_process(epimodel, latentprocess, 50)
@suppress _ = test_full_process(epimodel, latentprocess, 40)
end
end
end

0 comments on commit 119f244

Please sign in to comment.