diff --git a/EpiAware/src/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl b/EpiAware/src/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl index f4eeaeb3a..26778ddad 100644 --- a/EpiAware/src/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl +++ b/EpiAware/src/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl @@ -3,8 +3,8 @@ The `NegativeBinomialError` struct represents an observation model for negative binomial errors. It is a subtype of `AbstractTuringObservationModel`. ## Constructors -- `NegativeBinomialError(; cluster_factor_prior::Distribution = HalfNormal(0.1), pos_shift::AbstractFloat = 1e-6)`: Constructs a `NegativeBinomialError` object with default values for the cluster factor prior and positive shift. -- `NegativeBinomialError(cluster_factor_prior::Distribution; pos_shift::AbstractFloat = 1e-6)`: Constructs a `NegativeBinomialError` object with a specified cluster factor prior and default value for the positive shift. +- `NegativeBinomialError(; cluster_factor_prior::Distribution = HalfNormal(0.1))`: Constructs a `NegativeBinomialError` object with default values for the cluster factor prior. +- `NegativeBinomialError(cluster_factor_prior::Distribution)`: Constructs a `NegativeBinomialError` object with a specified cluster factor prior. ## Examples ```julia @@ -14,25 +14,10 @@ nb_model = generate_observations(nb, missing, fill(10, 10)) rand(nb_model) ``` " -struct NegativeBinomialError{S <: Sampleable, T <: AbstractFloat} <: - AbstractTuringObservationErrorModel +@kwdef struct NegativeBinomialError{S <: Sampleable, T <: AbstractFloat} <: + AbstractTuringObservationErrorModel "The prior distribution for the cluster factor." - cluster_factor_prior::S - "The positive shift value." - pos_shift::T - - function NegativeBinomialError(; - cluster_factor_prior::Distribution = HalfNormal(0.01), - pos_shift::AbstractFloat = 1e-6) - new{typeof(cluster_factor_prior), typeof(pos_shift)}( - cluster_factor_prior, pos_shift) - end - - function NegativeBinomialError(cluster_factor_prior::Distribution; - pos_shift::AbstractFloat = 1e-6) - new{typeof(cluster_factor_prior), typeof(pos_shift)}( - cluster_factor_prior, pos_shift) - end + cluster_factor_prior::S = HalfNormal(0.01) end @doc raw" @@ -49,6 +34,6 @@ end This function generates the observation error model based on the negative binomial error model with a positive shift. It dispatches to the `NegativeBinomialMeanClust` distribution. " function observation_error(obs_model::NegativeBinomialError, Y_t, sq_cluster_factor) - return NegativeBinomialMeanClust(Y_t + obs_model.pos_shift, + return NegativeBinomialMeanClust(Y_t, sq_cluster_factor) end diff --git a/EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl b/EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl index cea48fffd..1424c7a60 100644 --- a/EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl +++ b/EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl @@ -3,8 +3,7 @@ The `PoissonError` struct represents an observation model for Poisson errors. It is a subtype of `AbstractTuringObservationErrorModel`. ## Constructors -- `PoissonError(; pos_shift::AbstractFloat = 0.)`: Constructs a `PoissonError` -object with default values for the cluster factor prior and positive shift. +- `PoissonError()`: Constructs a `PoissonError` object. ## Examples ```julia @@ -15,19 +14,12 @@ rand(poi_model) ``` " struct PoissonError{T <: AbstractFloat} <: AbstractTuringObservationErrorModel - "The positive shift value." - pos_shift::T - - function PoissonError(; pos_shift::AbstractFloat = 0.0) - @assert pos_shift>=0.0 "The positive shift value must be non-negative." - new{typeof(pos_shift)}(pos_shift) - end end @doc raw" The observation error model for Poisson errors. This function generates the -observation error model based on the Poisson error model with a positive shift. +observation error model based on the Poisson error model. " function observation_error(obs_model::PoissonError, Y_t) - return Poisson(Y_t + obs_model.pos_shift) + return Poisson(Y_t) end diff --git a/EpiAware/src/EpiObsModels/ObservationErrorModels/methods.jl b/EpiAware/src/EpiObsModels/ObservationErrorModels/methods.jl index b45cf0f00..2bbab14fa 100644 --- a/EpiAware/src/EpiObsModels/ObservationErrorModels/methods.jl +++ b/EpiAware/src/EpiObsModels/ObservationErrorModels/methods.jl @@ -1,5 +1,7 @@ @doc raw" -Generates observations from an observation error model. It provides support for missing values in observations (`y_t`), and missing values at the beginning of the expected observations (`Y_t`). It dispatches to the `observation_error` function to generate the observation error distribution which uses priors generated by `generate_observation_error_priors` submodel. For most observation error models specific implementations of `observation_error` and `generate_observation_error_priors` are required but a specific implementation of `generate_observations` is not required. +Generates observations from an observation error model. It provides support for missing values in observations (`y_t`), and missing values at the beginning of the expected observations (`Y_t`). It also pads the expected observations with a small value (1e-6) to mitigate potential numerical issues. + +It dispatches to the `observation_error` function to generate the observation error distribution which uses priors generated by `generate_observation_error_priors` submodel. For most observation error models specific implementations of `observation_error` and `generate_observation_error_priors` are required but a specific implementation of `generate_observations` is not required. " @model function EpiAwareBase.generate_observations( obs_model::AbstractTuringObservationErrorModel, @@ -13,8 +15,10 @@ Generates observations from an observation error model. It provides support for @assert length(y_t)==length(Y_t) "The observation vector and expected observation vector must have the same length." end + pad_Y_t = Y_t + 1e-6 + for i in findfirst(!ismissing, Y_t):length(Y_t) - y_t[i] ~ observation_error(obs_model, Y_t[i], priors...) + y_t[i] ~ observation_error(obs_model, pad_Y_t[i], priors...) end return y_t, priors diff --git a/EpiAware/test/EpiAwareUtils/generate_epiware.jl b/EpiAware/test/EpiAwareUtils/generate_epiware.jl index e73fd3878..0b8175fa8 100644 --- a/EpiAware/test/EpiAwareUtils/generate_epiware.jl +++ b/EpiAware/test/EpiAwareUtils/generate_epiware.jl @@ -1,10 +1,8 @@ - @testitem "`generate_epiaware` with direct infections and RW latent process runs" begin using Distributions, Turing, DynamicPPL # Define test inputs y_t = missing # Data will be generated from the model data = EpiData([0.2, 0.3, 0.5], exp) - pos_shift = 1e-6 time_horizon = 100 #Define the epi_model @@ -42,7 +40,6 @@ end # Define test inputs y_t = missing# rand(1:10, 365) # Data will be generated from the model data = EpiData([0.2, 0.3, 0.5], exp) - pos_shift = 1e-6 #Define the epi_model epi_model = ExpGrowthRate(data, Normal()) @@ -56,7 +53,7 @@ end #Define the observation model - no delay model time_horizon = 5 obs_model = NegativeBinomialError( - truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0); pos_shift + truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0) ) # Create full epi model and sample from it @@ -80,7 +77,6 @@ end # Define test inputs y_t = missing# rand(1:10, 365) # Data will be generated from the model data = EpiData([0.2, 0.3, 0.5], exp) - pos_shift = 1e-6 #Define the epi_model epi_model = Renewal(data, Normal()) @@ -94,8 +90,7 @@ end #Define the observation model - no delay model time_horizon = 5 obs_model = NegativeBinomialError( - truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0); - pos_shift + truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0) ) # Create full epi model and sample from it diff --git a/EpiAware/test/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl b/EpiAware/test/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl index 9e6d0c70b..2b1c8fd54 100644 --- a/EpiAware/test/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl +++ b/EpiAware/test/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl @@ -2,20 +2,19 @@ using Distributions # Test default constructor nb = NegativeBinomialError() + @test typeof(nb) <: NegativeBinomialError + @test typeof(nb) <: TuringObservationErrorModel @test all(rand(nb.cluster_factor_prior, 100) .>= 0.0) @test isapprox(mean(nb.cluster_factor_prior), 0.01) - @test nb.pos_shift ≈ 1e-6 # Test constructor with custom prior prior = Gamma(2.0, 1.0) nb = NegativeBinomialError(prior) @test nb.cluster_factor_prior == prior - @test nb.pos_shift ≈ 1e-6 - # Test constructor with custom prior and pos_shift - nb = NegativeBinomialError(prior; pos_shift = 1e-3) + # Test constructor with custom prior + nb = NegativeBinomialError(prior) @test nb.cluster_factor_prior == prior - @test nb.pos_shift ≈ 1e-3 end @testitem "Testing NegativeBinomialError against theoretical properties" begin @@ -27,7 +26,7 @@ end α = 0.2 # Cluster factor (dispersion parameter) # Define the observation model - nb_obs_model = NegativeBinomialError(pos_shift = 0.0) + nb_obs_model = NegativeBinomialError() # Generate observations from the model Y_t = fill(μ, n) # True values diff --git a/EpiAware/test/EpiObsModels/ObservationErrorModels/PoissonError.jl b/EpiAware/test/EpiObsModels/ObservationErrorModels/PoissonError.jl index 732709ce6..55bc6ff74 100644 --- a/EpiAware/test/EpiObsModels/ObservationErrorModels/PoissonError.jl +++ b/EpiAware/test/EpiObsModels/ObservationErrorModels/PoissonError.jl @@ -1,14 +1,9 @@ @testitem "PoissonErrorConstructor" begin using Distributions # Test default constructor - poi = PoissonError() - @test poi.pos_shift ≈ zero(Float64) - poi_float = PoissonError(; pos_shift = 0.0f0) - @test poi_float.pos_shift ≈ zero(Float32) - - # Test constructor with pos_shift - poi2 = PoissonError(; pos_shift = 1e-3) - @test poi2.pos_shift ≈ 1e-3 + poi = PoissonError + @test typeof(poi) <: PoissonError + @test typeof(poi) <: TuringObservationErrorModel end @testitem "Testing PoissonError against theoretical properties" begin @@ -19,7 +14,7 @@ end μ = 10.0 # Mean of the poisson distribution # Define the observation model - poi_obs_model = PoissonError(pos_shift = 0.0) + poi_obs_model = PoissonError() # Generate observations from the model Y_t = fill(μ, n) # True values