Skip to content

Commit 3624ca0

Browse files
authored
Add calculate processes function to pipeline (#314)
* calculate_processes function * unit tests * add expected_Rt so not to repeat * change to using first element of `rt` due possible weekly variation * calculate_processes function * unit tests * add expected_Rt so not to repeat * change to using first element of `rt` due possible weekly variation * move infection seeding to a new function
1 parent 788ddce commit 3624ca0

File tree

7 files changed

+121
-1
lines changed

7 files changed

+121
-1
lines changed

pipeline/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@ DrWatson = "634d3b9d-ee7a-5ddf-bec9-22491ea816e1"
1515
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
1616
EpiAware = "b2eeebe4-5992-4301-9193-7ebc9f62c855"
1717
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
18+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1819
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1920
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
2021
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
2122
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
23+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2224
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2325
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2426
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

pipeline/src/EpiAwarePipeline.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,15 @@ module EpiAwarePipeline
1212

1313
using CSV, Dagger, DataFramesMeta, Dates, Distributions, DocStringExtensions, DrWatson,
1414
EpiAware, Plots, Statistics, ADTypes, AbstractMCMC, Plots, JLD2, MCMCChains, Turing,
15-
DynamicPPL, LogExpFunctions, RCall
15+
DynamicPPL, LogExpFunctions, RCall, LinearAlgebra
1616

1717
# Exported pipeline types
1818
export AbstractEpiAwarePipeline, EpiAwarePipeline, RtwithoutRenewalPipeline,
1919
RtwithoutRenewalPriorPipeline, EpiAwareExamplePipeline
2020

21+
# Exported utility functions
22+
export calculate_processes
23+
2124
# Exported configuration types
2225
export TruthSimulationConfig, InferenceConfig
2326

@@ -47,6 +50,7 @@ export plot_truth_data, plot_Rt
4750

4851
include("docstrings.jl")
4952
include("pipeline/pipeline.jl")
53+
include("utils/utils.jl")
5054
include("constructors/constructors.jl")
5155
include("simulate/simulate.jl")
5256
include("infer/infer.jl")
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""
2+
Internal function for calculating the log of the infections with an informative
3+
error message if the infections are not positive definite.
4+
"""
5+
function _calc_log_infections(I_t)
6+
@assert all(I_t .> 0) "Infections must be positive definite."
7+
log.(I_t)
8+
end
9+
10+
"""
11+
Internal function for calculating the exponential growth rate with an informative
12+
error message if the infections are not positive definite.
13+
"""
14+
function _calc_rt(I_t, I0)
15+
@assert all(I_t .> 0) "Infections must be positive definite."
16+
@assert I0>0 "Initial infections must be positive definite."
17+
log.([I0; I_t]) .- log(I0) |> diff
18+
end
19+
20+
"""
21+
Internal function for seeding the infections. Method dispatches on the pipeline
22+
type to determine the seeding method. This is the default seeding method which
23+
assumes backward exponential growth with initial infections `I0` from initial
24+
estimate of `rt`.
25+
26+
"""
27+
function _infection_seeding(I_t, I0, data::EpiData, pipeline::AbstractEpiAwarePipeline)
28+
n = length(data.gen_int)
29+
init_rt = _calc_rt(I_t[1:2], I0) |> x -> x[1]
30+
[I0 * exp(-init_rt * (n - i)) for i in 1:n]
31+
end
32+
33+
"""
34+
Internal function for calculating the _instantaneous_ reproduction number `Rt`
35+
using the method of [Fraser (2007)](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0000758).
36+
Left truncation handling method is determined by the pipeline type. The default
37+
left truncation in `I_t` is handled by extending the series with backwards exponential
38+
growth from the initial infections `I0` and the exponential growth rate `init_rt`.
39+
40+
# Arguments
41+
- `I_t`: Incident infections.
42+
- `I0`: Initial infections at time zero.
43+
- `init_rt`: Initial exponential growth rate.
44+
- `data::EpiData`: An instance of the `EpiData` type containing generation interval data.
45+
- `pipeline::AbstractEpiAwarePipeline`: An instance of the `AbstractEpiAwarePipeline` type.
46+
"""
47+
function _calc_Rt(I_t, I0, data::EpiData, pipeline::AbstractEpiAwarePipeline)
48+
@assert all(I_t .> 0) "Log infections must be positive definite."
49+
@assert I0>0 "Initial infections must be positive definite."
50+
51+
aug_I_t = vcat(_infection_seeding(I_t, I0, data, pipeline), I_t)
52+
53+
Rt = expected_Rt(data, aug_I_t)
54+
55+
return Rt
56+
end
57+
58+
"""
59+
Calculate the log of infections `log_I_t`, exponential growth values `rt`, and
60+
instaneous reproductive number `Rt` for a given time series of infections. The
61+
reproductive number calculation deals with left truncation in `I_t` by extending
62+
`I_t` with backwards exponential growth using the mean exponential growth rate
63+
from the first 7 time steps of `rt`.
64+
65+
# Arguments
66+
- `I_t`: An array representing the time series of infections.
67+
- `I0`: The initial number of infections.
68+
- `pmf`: The probability mass function used to calculate Rt.
69+
70+
# Returns
71+
A named tuple containing the calculated values for `log_I_t`, `rt`, and `Rt`.
72+
73+
"""
74+
function calculate_processes(I_t, I0, data::EpiData, pipeline::AbstractEpiAwarePipeline)
75+
log_I_t = _calc_log_infections(I_t)
76+
rt = _calc_rt(I_t, I0)
77+
Rt = _calc_Rt(I_t, I0, data, pipeline)
78+
return (; log_I_t, rt, Rt)
79+
end

pipeline/src/utils/utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
include("calculate_processes.jl")

pipeline/test/end-to-end/test_prior_predictive.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@ using Test
1515

1616
inference_results = generate_inference_results(
1717
truthdata, inference_config, pipeline; tspan, inference_method)
18+
1819
@test inference_results["inference_results"] isa EpiAwareObservables
1920
end

pipeline/test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ quickactivate(@__DIR__(), "EpiAwarePipeline")
44
# Run tests
55
include("pipeline/test_pipelinetypes.jl");
66
include("pipeline/test_pipelinefunctions.jl");
7+
include("utils/test_calculate_processes.jl");
78
include("constructors/test_constructors.jl");
89
include("simulate/test_TruthSimulationConfig.jl");
910
include("simulate/test_SimulationConfig.jl");
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
using Test
2+
3+
@testset "calculate_processes" begin
4+
using EpiAware, EpiAwarePipeline
5+
using Random
6+
rng = MersenneTwister(1234)
7+
I0 = 10.0
8+
rt = randn(rng, 20)
9+
I_t = cumsum(rt) .+ log(I0) .|> exp
10+
pmf = [1.0]
11+
pipeline = EpiAwareExamplePipeline()
12+
13+
data = EpiData(pmf, exp)
14+
15+
result = calculate_processes(I_t, I0, data, pipeline)
16+
17+
# Check if the log of infections is calculated correctly
18+
@testset "Log of infections" begin
19+
expected_log_I_t = log.(I_t)
20+
@test result.log_I_t expected_log_I_t
21+
end
22+
23+
# Check if the exponential growth rate is calculated correctly
24+
@testset "Exponential growth rate" begin
25+
@test result.rt rt
26+
end
27+
28+
# In this special case (pmf = [1.0]), the Rt = exp(rt)
29+
@testset "Instantaneous reproduction number" begin
30+
@test result.Rt exp.(rt)
31+
end
32+
end

0 commit comments

Comments
 (0)