Skip to content

Commit

Permalink
remove numerical pading as an option and hard code
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs committed Jun 11, 2024
1 parent 0ec11bd commit ea2b6df
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
The `NegativeBinomialError` struct represents an observation model for negative binomial errors. It is a subtype of `AbstractTuringObservationModel`.
## Constructors
- `NegativeBinomialError(; cluster_factor_prior::Distribution = HalfNormal(0.1), pos_shift::AbstractFloat = 1e-6)`: Constructs a `NegativeBinomialError` object with default values for the cluster factor prior and positive shift.
- `NegativeBinomialError(cluster_factor_prior::Distribution; pos_shift::AbstractFloat = 1e-6)`: Constructs a `NegativeBinomialError` object with a specified cluster factor prior and default value for the positive shift.
- `NegativeBinomialError(; cluster_factor_prior::Distribution = HalfNormal(0.1))`: Constructs a `NegativeBinomialError` object with default values for the cluster factor prior.
- `NegativeBinomialError(cluster_factor_prior::Distribution)`: Constructs a `NegativeBinomialError` object with a specified cluster factor prior.
## Examples
```julia
Expand All @@ -14,25 +14,10 @@ nb_model = generate_observations(nb, missing, fill(10, 10))
rand(nb_model)
```
"
struct NegativeBinomialError{S <: Sampleable, T <: AbstractFloat} <:
AbstractTuringObservationErrorModel
@kwdef struct NegativeBinomialError{S <: Sampleable, T <: AbstractFloat} <:
AbstractTuringObservationErrorModel
"The prior distribution for the cluster factor."
cluster_factor_prior::S
"The positive shift value."
pos_shift::T

function NegativeBinomialError(;
cluster_factor_prior::Distribution = HalfNormal(0.01),
pos_shift::AbstractFloat = 1e-6)
new{typeof(cluster_factor_prior), typeof(pos_shift)}(
cluster_factor_prior, pos_shift)
end

function NegativeBinomialError(cluster_factor_prior::Distribution;
pos_shift::AbstractFloat = 1e-6)
new{typeof(cluster_factor_prior), typeof(pos_shift)}(
cluster_factor_prior, pos_shift)
end
cluster_factor_prior::S = HalfNormal(0.01)
end

@doc raw"
Expand All @@ -49,6 +34,6 @@ end
This function generates the observation error model based on the negative binomial error model with a positive shift. It dispatches to the `NegativeBinomialMeanClust` distribution.
"
function observation_error(obs_model::NegativeBinomialError, Y_t, sq_cluster_factor)
return NegativeBinomialMeanClust(Y_t + obs_model.pos_shift,
return NegativeBinomialMeanClust(Y_t,
sq_cluster_factor)
end
14 changes: 3 additions & 11 deletions EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ The `PoissonError` struct represents an observation model for Poisson errors. It
is a subtype of `AbstractTuringObservationErrorModel`.
## Constructors
- `PoissonError(; pos_shift::AbstractFloat = 0.)`: Constructs a `PoissonError`
object with default values for the cluster factor prior and positive shift.
- `PoissonError()`: Constructs a `PoissonError` object.
## Examples
```julia
Expand All @@ -15,19 +14,12 @@ rand(poi_model)
```
"
struct PoissonError{T <: AbstractFloat} <: AbstractTuringObservationErrorModel
"The positive shift value."
pos_shift::T

function PoissonError(; pos_shift::AbstractFloat = 0.0)
@assert pos_shift>=0.0 "The positive shift value must be non-negative."
new{typeof(pos_shift)}(pos_shift)
end
end

@doc raw"
The observation error model for Poisson errors. This function generates the
observation error model based on the Poisson error model with a positive shift.
observation error model based on the Poisson error model.
"
function observation_error(obs_model::PoissonError, Y_t)
return Poisson(Y_t + obs_model.pos_shift)
return Poisson(Y_t)
end
8 changes: 6 additions & 2 deletions EpiAware/src/EpiObsModels/ObservationErrorModels/methods.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
@doc raw"
Generates observations from an observation error model. It provides support for missing values in observations (`y_t`), and missing values at the beginning of the expected observations (`Y_t`). It dispatches to the `observation_error` function to generate the observation error distribution which uses priors generated by `generate_observation_error_priors` submodel. For most observation error models specific implementations of `observation_error` and `generate_observation_error_priors` are required but a specific implementation of `generate_observations` is not required.
Generates observations from an observation error model. It provides support for missing values in observations (`y_t`), and missing values at the beginning of the expected observations (`Y_t`). It also pads the expected observations with a small value (1e-6) to mitigate potential numerical issues.
It dispatches to the `observation_error` function to generate the observation error distribution which uses priors generated by `generate_observation_error_priors` submodel. For most observation error models specific implementations of `observation_error` and `generate_observation_error_priors` are required but a specific implementation of `generate_observations` is not required.
"
@model function EpiAwareBase.generate_observations(
obs_model::AbstractTuringObservationErrorModel,
Expand All @@ -13,8 +15,10 @@ Generates observations from an observation error model. It provides support for
@assert length(y_t)==length(Y_t) "The observation vector and expected observation vector must have the same length."
end

pad_Y_t = Y_t + 1e-6

for i in findfirst(!ismissing, Y_t):length(Y_t)
y_t[i] ~ observation_error(obs_model, Y_t[i], priors...)
y_t[i] ~ observation_error(obs_model, pad_Y_t[i], priors...)
end

return y_t, priors
Expand Down
9 changes: 2 additions & 7 deletions EpiAware/test/EpiAwareUtils/generate_epiware.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@

@testitem "`generate_epiaware` with direct infections and RW latent process runs" begin
using Distributions, Turing, DynamicPPL
# Define test inputs
y_t = missing # Data will be generated from the model
data = EpiData([0.2, 0.3, 0.5], exp)
pos_shift = 1e-6
time_horizon = 100

#Define the epi_model
Expand Down Expand Up @@ -42,7 +40,6 @@ end
# Define test inputs
y_t = missing# rand(1:10, 365) # Data will be generated from the model
data = EpiData([0.2, 0.3, 0.5], exp)
pos_shift = 1e-6

#Define the epi_model
epi_model = ExpGrowthRate(data, Normal())
Expand All @@ -56,7 +53,7 @@ end
#Define the observation model - no delay model
time_horizon = 5
obs_model = NegativeBinomialError(
truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0); pos_shift
truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0)
)

# Create full epi model and sample from it
Expand All @@ -80,7 +77,6 @@ end
# Define test inputs
y_t = missing# rand(1:10, 365) # Data will be generated from the model
data = EpiData([0.2, 0.3, 0.5], exp)
pos_shift = 1e-6

#Define the epi_model
epi_model = Renewal(data, Normal())
Expand All @@ -94,8 +90,7 @@ end
#Define the observation model - no delay model
time_horizon = 5
obs_model = NegativeBinomialError(
truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0);
pos_shift
truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0)
)

# Create full epi model and sample from it
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,19 @@
using Distributions
# Test default constructor
nb = NegativeBinomialError()
@test typeof(nb) <: NegativeBinomialError
@test typeof(nb) <: TuringObservationErrorModel
@test all(rand(nb.cluster_factor_prior, 100) .>= 0.0)
@test isapprox(mean(nb.cluster_factor_prior), 0.01)
@test nb.pos_shift 1e-6

# Test constructor with custom prior
prior = Gamma(2.0, 1.0)
nb = NegativeBinomialError(prior)
@test nb.cluster_factor_prior == prior
@test nb.pos_shift 1e-6

# Test constructor with custom prior and pos_shift
nb = NegativeBinomialError(prior; pos_shift = 1e-3)
# Test constructor with custom prior
nb = NegativeBinomialError(prior)
@test nb.cluster_factor_prior == prior
@test nb.pos_shift 1e-3
end

@testitem "Testing NegativeBinomialError against theoretical properties" begin
Expand All @@ -27,7 +26,7 @@ end
α = 0.2 # Cluster factor (dispersion parameter)

# Define the observation model
nb_obs_model = NegativeBinomialError(pos_shift = 0.0)
nb_obs_model = NegativeBinomialError()

# Generate observations from the model
Y_t = fill(μ, n) # True values
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
@testitem "PoissonErrorConstructor" begin
using Distributions
# Test default constructor
poi = PoissonError()
@test poi.pos_shift zero(Float64)
poi_float = PoissonError(; pos_shift = 0.0f0)
@test poi_float.pos_shift zero(Float32)

# Test constructor with pos_shift
poi2 = PoissonError(; pos_shift = 1e-3)
@test poi2.pos_shift 1e-3
poi = PoissonError
@test typeof(poi) <: PoissonError
@test typeof(poi) <: TuringObservationErrorModel
end

@testitem "Testing PoissonError against theoretical properties" begin
Expand All @@ -19,7 +14,7 @@ end
μ = 10.0 # Mean of the poisson distribution

# Define the observation model
poi_obs_model = PoissonError(pos_shift = 0.0)
poi_obs_model = PoissonError()

# Generate observations from the model
Y_t = fill(μ, n) # True values
Expand Down

0 comments on commit ea2b6df

Please sign in to comment.