Skip to content

Commit 119f244

Browse files
seabbsSamuelBrand1
andauthored
Tweak parameter recovery testing (#503)
* Pass in missing directly and tweak handling progress and number of chains * aim to catch remaining sammple misisng progress * Update RandomWalk.jl - to Poisson * Update LatentDelay.jl - switch AR priors and use Poisson * Update RandomWalk.jl * Update LatentDelay.jl * Update EpiAware/test/EpiLatentModels/models/RandomWalk.jl Co-authored-by: Samuel Brand <[email protected]> * sharpen priors for r_t (exp growth rate) More samples but shorter generated data * remove debug @info --------- Co-authored-by: Samuel Brand <[email protected]> Co-authored-by: Samuel Brand <[email protected]>
1 parent 822a0aa commit 119f244

File tree

2 files changed

+20
-21
lines changed

2 files changed

+20
-21
lines changed

EpiAware/test/EpiLatentModels/models/RandomWalk.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,23 +44,23 @@ end
4444
ReverseDiff
4545
Random.seed!(1234)
4646

47-
rw_process = RandomWalk()
48-
obs_nb = NegativeBinomialError()
47+
process = RandomWalk()
48+
obs = PoissonError()
4949

50-
@model function test_negbin_errors(rw, obs, y_t)
50+
@model function test_poisson_errors(proc, obs, y_t)
5151
n = length(y_t)
52-
@submodel Z_t = generate_latent(rw, n)
52+
@submodel Z_t = generate_latent(proc, n)
5353
@submodel y_t = generate_observations(obs, y_t, exp.(Z_t))
5454
return Z_t, y_t
5555
end
5656

57-
generative_mdl = test_negbin_errors(rw_process, obs_nb, fill(missing, 40))
57+
generative_mdl = test_poisson_errors(process, obs, fill(missing, 40))
5858
θ_true = rand(generative_mdl)
5959
Z_t_obs, y_t_obs = condition(generative_mdl, θ_true)()
6060

61-
mdl = test_negbin_errors(rw_process, obs_nb, Int.(y_t_obs))
61+
mdl = test_poisson_errors(process, obs, Int.(y_t_obs))
6262
chn = sample(
63-
mdl, NUTS(adtype = AutoReverseDiff(; compile = Val(true))), 1000, progess = false)
63+
mdl, NUTS(adtype = AutoReverseDiff(; compile = Val(true))), 1000; progess = false)
6464

6565
#Check that are in central 99.9% of the posterior predictive distribution
6666
#Therefore, this should be unlikely to fail if the model is correctly implemented

EpiAware/test/EpiObsModels/modifiers/LatentDelay.jl

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ end
162162
end
163163
end
164164

165-
@testitem "LatentDelay parameter recovery with mix of IGP + latent processes: Negative binomial errors + EpiProblem interface" begin
165+
@testitem "LatentDelay parameter recovery with mix of IGP + latent processes: Poisson errors + EpiProblem interface" begin
166166
using Random, Turing, Distributions, LinearAlgebra, DynamicPPL, StatsBase, ReverseDiff,
167167
Suppressor, LogExpFunctions
168168
# using PairPlots, CairoMakie
@@ -178,7 +178,7 @@ end
178178
data = EpiData([0.2, 0.5, 0.3],
179179
em_type == Renewal ? softplus : exp
180180
),
181-
initialisation_prior = Normal(log(100.0), 0.25)
181+
initialisation_prior = Normal(log(100.0), 0.01)
182182
)
183183

184184
latentprocess_types = [RandomWalk, AR, DiffLatentModel]
@@ -190,7 +190,7 @@ end
190190
return (; init_prior, std_prior)
191191
elseif epimodel isa ExpGrowthRate
192192
init_prior = Normal(0.1, 0.025)
193-
std_prior = HalfNormal(0.025)
193+
std_prior = LogNormal(log(0.025), 0.01)
194194
return (; init_prior, std_prior)
195195
elseif epimodel isa DirectInfections
196196
init_prior = Normal(log(100.0), 0.25)
@@ -204,11 +204,11 @@ end
204204
if latentprocess_type == RandomWalk
205205
return RandomWalk(init_prior, std_prior)
206206
elseif latentprocess_type == AR
207-
return AR(damp_priors = [Beta(8, 2; check_args = false)],
207+
return AR(damp_priors = [Beta(2, 8; check_args = false)],
208208
std_prior = std_prior, init_priors = [init_prior])
209209
elseif latentprocess_type == DiffLatentModel
210210
return DiffLatentModel(
211-
AR(damp_priors = [Beta(8, 2; check_args = false)],
211+
AR(damp_priors = [Beta(2, 8; check_args = false)],
212212
std_prior = std_prior, init_priors = [Normal(0.0, 0.25)]),
213213
init_prior; d = 1)
214214
end
@@ -217,15 +217,14 @@ end
217217
function test_full_process(epimodel, latentprocess, n;
218218
ad = AutoReverseDiff(; compile = true), posterior_p_tol = 0.005)
219219
#Fix observation model
220-
obs = LatentDelay(
221-
NegativeBinomialError(cluster_factor_prior = HalfNormal(0.05)), Gamma(3, 7 / 3))
220+
obs = LatentDelay(PoissonError(), Gamma(3, 7 / 3))
222221

223222
#Inference method
224223
inference_method = EpiMethod(
225-
pre_sampler_steps = [ManyPathfinder(nruns = 4, maxiters = 100)],
224+
pre_sampler_steps = [ManyPathfinder(nruns = 4, maxiters = 50)],
226225
sampler = NUTSampler(adtype = ad,
227-
ndraws = 1000,
228-
nchains = 4,
226+
ndraws = 2000,
227+
nchains = 2,
229228
mcmc_parallel = MCMCThreads())
230229
)
231230

@@ -237,15 +236,15 @@ end
237236
)
238237

239238
#Generate data from generative model (i.e. data unconditioned)
240-
generative_mdl = generate_epiaware(
241-
epi_prob, (y_t = Vector{Union{Int, Missing}}(missing, n),))
239+
generative_mdl = generate_epiaware(epi_prob, (y_t = missing,))
242240
θ_true = rand(generative_mdl)
243241
gen_data = condition(generative_mdl, θ_true)()
244242

245243
#Apply inference method to inference model (i.e. generative model conditioned on data)
246244
inference_results = apply_method(epi_prob,
247245
inference_method,
248-
(y_t = gen_data.generated_y_t,)
246+
(y_t = gen_data.generated_y_t,);
247+
progress = false
249248
)
250249

251250
chn = inference_results.samples
@@ -265,7 +264,7 @@ end
265264
@testset "Check true parameters are within 99% central post. prob.: " begin
266265
@testset for latentprocess_type in latentprocess_types, epimodel in epimodels
267266
latentprocess = set_latent_process(epimodel, latentprocess_type)
268-
@suppress _ = test_full_process(epimodel, latentprocess, 50)
267+
@suppress _ = test_full_process(epimodel, latentprocess, 40)
269268
end
270269
end
271270
end

0 commit comments

Comments
 (0)