-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add calculate processes function to pipeline (#314)
* calculate_processes function * unit tests * add expected_Rt so not to repeat * change to using first element of `rt` due possible weekly variation * calculate_processes function * unit tests * add expected_Rt so not to repeat * change to using first element of `rt` due possible weekly variation * move infection seeding to a new function
- Loading branch information
1 parent
788ddce
commit 3624ca0
Showing
7 changed files
with
121 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
""" | ||
Internal function for calculating the log of the infections with an informative | ||
error message if the infections are not positive definite. | ||
""" | ||
function _calc_log_infections(I_t) | ||
@assert all(I_t .> 0) "Infections must be positive definite." | ||
log.(I_t) | ||
end | ||
|
||
""" | ||
Internal function for calculating the exponential growth rate with an informative | ||
error message if the infections are not positive definite. | ||
""" | ||
function _calc_rt(I_t, I0) | ||
@assert all(I_t .> 0) "Infections must be positive definite." | ||
@assert I0>0 "Initial infections must be positive definite." | ||
log.([I0; I_t]) .- log(I0) |> diff | ||
end | ||
|
||
""" | ||
Internal function for seeding the infections. Method dispatches on the pipeline | ||
type to determine the seeding method. This is the default seeding method which | ||
assumes backward exponential growth with initial infections `I0` from initial | ||
estimate of `rt`. | ||
""" | ||
function _infection_seeding(I_t, I0, data::EpiData, pipeline::AbstractEpiAwarePipeline) | ||
n = length(data.gen_int) | ||
init_rt = _calc_rt(I_t[1:2], I0) |> x -> x[1] | ||
[I0 * exp(-init_rt * (n - i)) for i in 1:n] | ||
end | ||
|
||
""" | ||
Internal function for calculating the _instantaneous_ reproduction number `Rt` | ||
using the method of [Fraser (2007)](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0000758). | ||
Left truncation handling method is determined by the pipeline type. The default | ||
left truncation in `I_t` is handled by extending the series with backwards exponential | ||
growth from the initial infections `I0` and the exponential growth rate `init_rt`. | ||
# Arguments | ||
- `I_t`: Incident infections. | ||
- `I0`: Initial infections at time zero. | ||
- `init_rt`: Initial exponential growth rate. | ||
- `data::EpiData`: An instance of the `EpiData` type containing generation interval data. | ||
- `pipeline::AbstractEpiAwarePipeline`: An instance of the `AbstractEpiAwarePipeline` type. | ||
""" | ||
function _calc_Rt(I_t, I0, data::EpiData, pipeline::AbstractEpiAwarePipeline) | ||
@assert all(I_t .> 0) "Log infections must be positive definite." | ||
@assert I0>0 "Initial infections must be positive definite." | ||
|
||
aug_I_t = vcat(_infection_seeding(I_t, I0, data, pipeline), I_t) | ||
|
||
Rt = expected_Rt(data, aug_I_t) | ||
|
||
return Rt | ||
end | ||
|
||
""" | ||
Calculate the log of infections `log_I_t`, exponential growth values `rt`, and | ||
instaneous reproductive number `Rt` for a given time series of infections. The | ||
reproductive number calculation deals with left truncation in `I_t` by extending | ||
`I_t` with backwards exponential growth using the mean exponential growth rate | ||
from the first 7 time steps of `rt`. | ||
# Arguments | ||
- `I_t`: An array representing the time series of infections. | ||
- `I0`: The initial number of infections. | ||
- `pmf`: The probability mass function used to calculate Rt. | ||
# Returns | ||
A named tuple containing the calculated values for `log_I_t`, `rt`, and `Rt`. | ||
""" | ||
function calculate_processes(I_t, I0, data::EpiData, pipeline::AbstractEpiAwarePipeline) | ||
log_I_t = _calc_log_infections(I_t) | ||
rt = _calc_rt(I_t, I0) | ||
Rt = _calc_Rt(I_t, I0, data, pipeline) | ||
return (; log_I_t, rt, Rt) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
include("calculate_processes.jl") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
using Test | ||
|
||
@testset "calculate_processes" begin | ||
using EpiAware, EpiAwarePipeline | ||
using Random | ||
rng = MersenneTwister(1234) | ||
I0 = 10.0 | ||
rt = randn(rng, 20) | ||
I_t = cumsum(rt) .+ log(I0) .|> exp | ||
pmf = [1.0] | ||
pipeline = EpiAwareExamplePipeline() | ||
|
||
data = EpiData(pmf, exp) | ||
|
||
result = calculate_processes(I_t, I0, data, pipeline) | ||
|
||
# Check if the log of infections is calculated correctly | ||
@testset "Log of infections" begin | ||
expected_log_I_t = log.(I_t) | ||
@test result.log_I_t ≈ expected_log_I_t | ||
end | ||
|
||
# Check if the exponential growth rate is calculated correctly | ||
@testset "Exponential growth rate" begin | ||
@test result.rt ≈ rt | ||
end | ||
|
||
# In this special case (pmf = [1.0]), the Rt = exp(rt) | ||
@testset "Instantaneous reproduction number" begin | ||
@test result.Rt ≈ exp.(rt) | ||
end | ||
end |