Skip to content

Commit

Permalink
Add calculate processes function to pipeline (#314)
Browse files Browse the repository at this point in the history
* calculate_processes function

* unit tests

* add expected_Rt so not to repeat

* change to using first element of `rt` due possible weekly variation

* calculate_processes function

* unit tests

* add expected_Rt so not to repeat

* change to using first element of `rt` due possible weekly variation

* move infection seeding to a new function
  • Loading branch information
SamuelBrand1 authored Jul 1, 2024
1 parent 788ddce commit 3624ca0
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 1 deletion.
2 changes: 2 additions & 0 deletions pipeline/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ DrWatson = "634d3b9d-ee7a-5ddf-bec9-22491ea816e1"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
EpiAware = "b2eeebe4-5992-4301-9193-7ebc9f62c855"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
Expand Down
6 changes: 5 additions & 1 deletion pipeline/src/EpiAwarePipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@ module EpiAwarePipeline

using CSV, Dagger, DataFramesMeta, Dates, Distributions, DocStringExtensions, DrWatson,
EpiAware, Plots, Statistics, ADTypes, AbstractMCMC, Plots, JLD2, MCMCChains, Turing,
DynamicPPL, LogExpFunctions, RCall
DynamicPPL, LogExpFunctions, RCall, LinearAlgebra

# Exported pipeline types
export AbstractEpiAwarePipeline, EpiAwarePipeline, RtwithoutRenewalPipeline,
RtwithoutRenewalPriorPipeline, EpiAwareExamplePipeline

# Exported utility functions
export calculate_processes

# Exported configuration types
export TruthSimulationConfig, InferenceConfig

Expand Down Expand Up @@ -47,6 +50,7 @@ export plot_truth_data, plot_Rt

include("docstrings.jl")
include("pipeline/pipeline.jl")
include("utils/utils.jl")
include("constructors/constructors.jl")
include("simulate/simulate.jl")
include("infer/infer.jl")
Expand Down
79 changes: 79 additions & 0 deletions pipeline/src/utils/calculate_processes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Internal function for calculating the log of the infections with an informative
error message if the infections are not positive definite.
"""
function _calc_log_infections(I_t)
@assert all(I_t .> 0) "Infections must be positive definite."
log.(I_t)
end

"""
Internal function for calculating the exponential growth rate with an informative
error message if the infections are not positive definite.
"""
function _calc_rt(I_t, I0)
@assert all(I_t .> 0) "Infections must be positive definite."
@assert I0>0 "Initial infections must be positive definite."
log.([I0; I_t]) .- log(I0) |> diff
end

"""
Internal function for seeding the infections. Method dispatches on the pipeline
type to determine the seeding method. This is the default seeding method which
assumes backward exponential growth with initial infections `I0` from initial
estimate of `rt`.
"""
function _infection_seeding(I_t, I0, data::EpiData, pipeline::AbstractEpiAwarePipeline)
n = length(data.gen_int)
init_rt = _calc_rt(I_t[1:2], I0) |> x -> x[1]
[I0 * exp(-init_rt * (n - i)) for i in 1:n]
end

"""
Internal function for calculating the _instantaneous_ reproduction number `Rt`
using the method of [Fraser (2007)](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0000758).
Left truncation handling method is determined by the pipeline type. The default
left truncation in `I_t` is handled by extending the series with backwards exponential
growth from the initial infections `I0` and the exponential growth rate `init_rt`.
# Arguments
- `I_t`: Incident infections.
- `I0`: Initial infections at time zero.
- `init_rt`: Initial exponential growth rate.
- `data::EpiData`: An instance of the `EpiData` type containing generation interval data.
- `pipeline::AbstractEpiAwarePipeline`: An instance of the `AbstractEpiAwarePipeline` type.
"""
function _calc_Rt(I_t, I0, data::EpiData, pipeline::AbstractEpiAwarePipeline)
@assert all(I_t .> 0) "Log infections must be positive definite."
@assert I0>0 "Initial infections must be positive definite."

aug_I_t = vcat(_infection_seeding(I_t, I0, data, pipeline), I_t)

Rt = expected_Rt(data, aug_I_t)

return Rt
end

"""
Calculate the log of infections `log_I_t`, exponential growth values `rt`, and
instaneous reproductive number `Rt` for a given time series of infections. The
reproductive number calculation deals with left truncation in `I_t` by extending
`I_t` with backwards exponential growth using the mean exponential growth rate
from the first 7 time steps of `rt`.
# Arguments
- `I_t`: An array representing the time series of infections.
- `I0`: The initial number of infections.
- `pmf`: The probability mass function used to calculate Rt.
# Returns
A named tuple containing the calculated values for `log_I_t`, `rt`, and `Rt`.
"""
function calculate_processes(I_t, I0, data::EpiData, pipeline::AbstractEpiAwarePipeline)
log_I_t = _calc_log_infections(I_t)
rt = _calc_rt(I_t, I0)
Rt = _calc_Rt(I_t, I0, data, pipeline)
return (; log_I_t, rt, Rt)
end
1 change: 1 addition & 0 deletions pipeline/src/utils/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include("calculate_processes.jl")
1 change: 1 addition & 0 deletions pipeline/test/end-to-end/test_prior_predictive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ using Test

inference_results = generate_inference_results(
truthdata, inference_config, pipeline; tspan, inference_method)

@test inference_results["inference_results"] isa EpiAwareObservables
end
1 change: 1 addition & 0 deletions pipeline/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ quickactivate(@__DIR__(), "EpiAwarePipeline")
# Run tests
include("pipeline/test_pipelinetypes.jl");
include("pipeline/test_pipelinefunctions.jl");
include("utils/test_calculate_processes.jl");
include("constructors/test_constructors.jl");
include("simulate/test_TruthSimulationConfig.jl");
include("simulate/test_SimulationConfig.jl");
Expand Down
32 changes: 32 additions & 0 deletions pipeline/test/utils/test_calculate_processes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using Test

@testset "calculate_processes" begin
using EpiAware, EpiAwarePipeline
using Random
rng = MersenneTwister(1234)
I0 = 10.0
rt = randn(rng, 20)
I_t = cumsum(rt) .+ log(I0) .|> exp
pmf = [1.0]
pipeline = EpiAwareExamplePipeline()

data = EpiData(pmf, exp)

result = calculate_processes(I_t, I0, data, pipeline)

# Check if the log of infections is calculated correctly
@testset "Log of infections" begin
expected_log_I_t = log.(I_t)
@test result.log_I_t expected_log_I_t
end

# Check if the exponential growth rate is calculated correctly
@testset "Exponential growth rate" begin
@test result.rt rt
end

# In this special case (pmf = [1.0]), the Rt = exp(rt)
@testset "Instantaneous reproduction number" begin
@test result.Rt exp.(rt)
end
end

0 comments on commit 3624ca0

Please sign in to comment.