Skip to content

Commit

Permalink
Issue 266: Fix kernel bug (#267)
Browse files Browse the repository at this point in the history
* add partial option

* reverse partial if gate

* add a test for new partial functionality

* simplify LatentDelay

* simplify LatentDelay

---------

Co-authored-by: Samuel Brand <[email protected]>
  • Loading branch information
seabbs and SamuelBrand1 authored Jun 11, 2024
1 parent d498669 commit 07c781a
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 24 deletions.
8 changes: 3 additions & 5 deletions EpiAware/src/EpiObsModels/LatentDelay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,10 @@ Generates observations based on the `LatentDelay` observation model.
"
@model function EpiAwareBase.generate_observations(obs_model::LatentDelay, y_t, Y_t)
unobs_y_t = length(obs_model.pmf)
nobs_Y_t = length(Y_t) - unobs_y_t + 1
@assert unobs_y_t<=length(Y_t) "The delay PMF must be shorter than or equal to the observation vector"
@assert length(obs_model.pmf)<=length(Y_t) "The delay PMF must be shorter than or equal to the observation vector"

kernel = generate_observation_kernel(obs_model.pmf, nobs_Y_t)
expected_obs = kernel * Y_t[unobs_y_t:end]
kernel = generate_observation_kernel(obs_model.pmf, length(Y_t), partial = false)
expected_obs = kernel * Y_t

@submodel y_t, obs_aux = generate_observations(
obs_model.model, y_t, expected_obs)
Expand Down
2 changes: 1 addition & 1 deletion EpiAware/src/EpiObsModels/PoissonError.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Generate observations using the `PoissonError` observation model.
end
Y_y = length(y_t) - length(Y_t)

for i in eachindex(y_t)
for i in eachindex(Y_t)
y_t[Y_y + i] ~ Poisson(Y_t[i] + obs_model.pos_shift)
end

Expand Down
21 changes: 15 additions & 6 deletions EpiAware/src/EpiObsModels/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,25 @@ Generate an observation kernel matrix based on the given delay interval and time
# Arguments
- `delay_int::Vector{Float64}`: The delay PMF vector.
- `time_horizon::Int`: The number of time steps of the observation period.
- `partial::Bool`: Whether to generate a partial observation kernel matrix.
# Returns
- `K::SparseMatrixCSC{Float64, Int}`: The observation kernel matrix.
"""
function generate_observation_kernel(delay_int, time_horizon)
K = zeros(eltype(delay_int), time_horizon, time_horizon) |> SparseMatrixCSC
for i in 1:time_horizon, j in 1:time_horizon
m = i - j
if m >= 0 && m <= (length(delay_int) - 1)
K[i, j] = delay_int[m + 1]
function generate_observation_kernel(delay_int, time_horizon; partial::Bool = true)
if (partial)
K = zeros(eltype(delay_int), time_horizon, time_horizon) |> SparseMatrixCSC
for i in 1:time_horizon, j in 1:time_horizon
m = i - j
if m >= 0 && m <= (length(delay_int) - 1)
K[i, j] = delay_int[m + 1]
end
end
else
com_time = time_horizon - length(delay_int) + 1
K = zeros(eltype(delay_int), com_time, time_horizon) |> SparseMatrixCSC
for i in 1:com_time
K[i, i:(i + length(delay_int) - 1)] = delay_int
end
end
return K
Expand Down
32 changes: 32 additions & 0 deletions EpiAware/test/EpiObsModels/LatentDelay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,35 @@ end
end
end
end

@testitem "Test LatenDelay generate_observations function" begin
using DynamicPPL
struct TestObs <: AbstractTuringObservationModel end

@model function EpiAwareBase.generate_observations(obs_model::TestObs, y_t, Y_t)
return Y_t, (;)
end

delay_int = [0.2, 0.3, 0.5]
obs_model = LatentDelay(TestObs(), delay_int)

I_t = [10.0, 20.0, 30.0, 40.0, 50.0]
expected_obs = [23.0, 33.0, 43.0]

@testset "Test with entirely missing data" begin
mdl = generate_observations(obs_model, missing, I_t)
@test mdl()[1] == expected_obs
end

@testset "Test with missing data defined as a vector" begin
mdl = generate_observations(
obs_model, [missing, missing, missing, missing, missing], I_t)
@test mdl()[1] == expected_obs
end

@testset "Test with data" begin
pois_obs_model = LatentDelay(PoissonError(), delay_int)
mdl = generate_observations(pois_obs_model, [10.0, 20.0, 30.0, 40.0, 50.0], I_t)
@test mdl()[1] == [10.0, 20.0, 30.0, 40.0, 50]
end
end
34 changes: 22 additions & 12 deletions EpiAware/test/EpiObsModels/utils.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
@testitem "Testing generate_observation_kernel function" begin
@testitem "Testing generate_observation_kernel function defaults" begin
using SparseArrays
@testset "Test case 1" begin
delay_int = [0.2, 0.5, 0.3]
time_horizon = 5
expected_K = SparseMatrixCSC([0.2 0 0 0 0
0.5 0.2 0 0 0
0.3 0.5 0.2 0 0
0 0.3 0.5 0.2 0
0 0 0.3 0.5 0.2])
K = EpiAware.EpiObsModels.generate_observation_kernel(delay_int, time_horizon)
@test K == expected_K
end
delay_int = [0.2, 0.5, 0.3]
time_horizon = 5
expected_K = SparseMatrixCSC([0.2 0 0 0 0
0.5 0.2 0 0 0
0.3 0.5 0.2 0 0
0 0.3 0.5 0.2 0
0 0 0.3 0.5 0.2])
K = EpiAware.EpiObsModels.generate_observation_kernel(delay_int, time_horizon)
@test K == expected_K
end

@testitem "Test generate_observation_kernel partial=false setting" begin
using SparseArrays
delay_int = [0.2, 0.5, 0.3]
time_horizon = 5
expected_K = SparseMatrixCSC([0.2 0.5 0.3 0 0
0 0.2 0.5 0.3 0
0 0 0.2 0.5 0.3])
K = EpiAware.EpiObsModels.generate_observation_kernel(
delay_int, time_horizon; partial = false)
@test K == expected_K
end

@testitem "Check overflow safety of Negative Binomial sampling" begin
Expand Down

0 comments on commit 07c781a

Please sign in to comment.