162
162
end
163
163
end
164
164
165
- @testitem " LatentDelay parameter recovery with mix of IGP + latent processes: Negative binomial errors + EpiProblem interface" begin
165
+ @testitem " LatentDelay parameter recovery with mix of IGP + latent processes: Poisson errors + EpiProblem interface" begin
166
166
using Random, Turing, Distributions, LinearAlgebra, DynamicPPL, StatsBase, ReverseDiff,
167
167
Suppressor, LogExpFunctions
168
168
# using PairPlots, CairoMakie
178
178
data = EpiData ([0.2 , 0.5 , 0.3 ],
179
179
em_type == Renewal ? softplus : exp
180
180
),
181
- initialisation_prior = Normal (log (100.0 ), 0.25 )
181
+ initialisation_prior = Normal (log (100.0 ), 0.01 )
182
182
)
183
183
184
184
latentprocess_types = [RandomWalk, AR, DiffLatentModel]
190
190
return (; init_prior, std_prior)
191
191
elseif epimodel isa ExpGrowthRate
192
192
init_prior = Normal (0.1 , 0.025 )
193
- std_prior = HalfNormal ( 0.025 )
193
+ std_prior = LogNormal ( log ( 0.025 ), 0.01 )
194
194
return (; init_prior, std_prior)
195
195
elseif epimodel isa DirectInfections
196
196
init_prior = Normal (log (100.0 ), 0.25 )
@@ -204,11 +204,11 @@ end
204
204
if latentprocess_type == RandomWalk
205
205
return RandomWalk (init_prior, std_prior)
206
206
elseif latentprocess_type == AR
207
- return AR (damp_priors = [Beta (8 , 2 ; check_args = false )],
207
+ return AR (damp_priors = [Beta (2 , 8 ; check_args = false )],
208
208
std_prior = std_prior, init_priors = [init_prior])
209
209
elseif latentprocess_type == DiffLatentModel
210
210
return DiffLatentModel (
211
- AR (damp_priors = [Beta (8 , 2 ; check_args = false )],
211
+ AR (damp_priors = [Beta (2 , 8 ; check_args = false )],
212
212
std_prior = std_prior, init_priors = [Normal (0.0 , 0.25 )]),
213
213
init_prior; d = 1 )
214
214
end
@@ -217,15 +217,14 @@ end
217
217
function test_full_process (epimodel, latentprocess, n;
218
218
ad = AutoReverseDiff (; compile = true ), posterior_p_tol = 0.005 )
219
219
# Fix observation model
220
- obs = LatentDelay (
221
- NegativeBinomialError (cluster_factor_prior = HalfNormal (0.05 )), Gamma (3 , 7 / 3 ))
220
+ obs = LatentDelay (PoissonError (), Gamma (3 , 7 / 3 ))
222
221
223
222
# Inference method
224
223
inference_method = EpiMethod (
225
- pre_sampler_steps = [ManyPathfinder (nruns = 4 , maxiters = 100 )],
224
+ pre_sampler_steps = [ManyPathfinder (nruns = 4 , maxiters = 50 )],
226
225
sampler = NUTSampler (adtype = ad,
227
- ndraws = 1000 ,
228
- nchains = 4 ,
226
+ ndraws = 2000 ,
227
+ nchains = 2 ,
229
228
mcmc_parallel = MCMCThreads ())
230
229
)
231
230
@@ -237,15 +236,15 @@ end
237
236
)
238
237
239
238
# Generate data from generative model (i.e. data unconditioned)
240
- generative_mdl = generate_epiaware (
241
- epi_prob, (y_t = Vector {Union{Int, Missing}} (missing , n),))
239
+ generative_mdl = generate_epiaware (epi_prob, (y_t = missing ,))
242
240
θ_true = rand (generative_mdl)
243
241
gen_data = condition (generative_mdl, θ_true)()
244
242
245
243
# Apply inference method to inference model (i.e. generative model conditioned on data)
246
244
inference_results = apply_method (epi_prob,
247
245
inference_method,
248
- (y_t = gen_data. generated_y_t,)
246
+ (y_t = gen_data. generated_y_t,);
247
+ progress = false
249
248
)
250
249
251
250
chn = inference_results. samples
265
264
@testset " Check true parameters are within 99% central post. prob.: " begin
266
265
@testset for latentprocess_type in latentprocess_types, epimodel in epimodels
267
266
latentprocess = set_latent_process (epimodel, latentprocess_type)
268
- @suppress _ = test_full_process (epimodel, latentprocess, 50 )
267
+ @suppress _ = test_full_process (epimodel, latentprocess, 40 )
269
268
end
270
269
end
271
270
end
0 commit comments