From 0b6a162071d94bde1916914cfa6b1d87d452455a Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Thu, 24 Oct 2024 18:54:48 +0100 Subject: [PATCH] issue 492: Create distribution that returns `<: Union{Int,BigInt}` (#497) * Make RealUnivariateDistribution type * test RealUnivariateDistribution * update method * reformat * modify procf * Explicit use of Distributions * fix unit test fails * Discrete valued dists expecting rand return of Union{int,BigInt} * Make RealUnivariateDistribution type * test RealUnivariateDistribution * update method * reformat * modify procf * Explicit use of Distributions * fix unit test fails * Discrete valued dists expecting rand return of Union{int,BigInt} * Typo * Revert CI test --------- Co-authored-by: Sam Abbott --- EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl | 5 ++++- EpiAware/src/EpiAwareUtils/SafeInt.jl | 16 ++++++++++++++++ .../src/EpiAwareUtils/SafeNegativeBinomial.jl | 2 +- EpiAware/src/EpiAwareUtils/SafePoisson.jl | 6 +++--- EpiAware/test/EpiAwareUtils/SafeInt.jl | 8 ++++++++ .../test/EpiAwareUtils/SafeNegativeBinomial.jl | 2 +- EpiAware/test/EpiAwareUtils/SafePoisson.jl | 8 ++++---- .../test/EpiObsModels/modifiers/LatentDelay.jl | 4 +--- .../modifiers/PrefixObservationModel.jl | 2 +- 9 files changed, 39 insertions(+), 14 deletions(-) create mode 100644 EpiAware/src/EpiAwareUtils/SafeInt.jl create mode 100644 EpiAware/test/EpiAwareUtils/SafeInt.jl diff --git a/EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl b/EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl index bcabad2d0..22aceaa91 100644 --- a/EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl +++ b/EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl @@ -10,11 +10,13 @@ using DynamicPPL: Model, fix, condition, @submodel, @model using MCMCChains: Chains using Random: AbstractRNG, randexp using Tables: rowtable +import Base: eltype using Distributions, DocStringExtensions, QuadGK, Statistics, Turing #Export Structures -export HalfNormal, DirectSample, SafePoisson, SafeNegativeBinomial +export HalfNormal, DirectSample, SafePoisson, SafeNegativeBinomial, SafeIntValued, SafeInt, + SafeDiscreteUnivariateDistribution #Export functions export scan, spread_draws, censored_cdf, censored_pmf, get_param_array, prefix_submodel, ∫F @@ -32,6 +34,7 @@ include("turing-methods.jl") include("DirectSample.jl") include("post-inference.jl") include("get_param_array.jl") +include("SafeInt.jl") include("SafePoisson.jl") include("SafeNegativeBinomial.jl") diff --git a/EpiAware/src/EpiAwareUtils/SafeInt.jl b/EpiAware/src/EpiAwareUtils/SafeInt.jl new file mode 100644 index 000000000..565327eed --- /dev/null +++ b/EpiAware/src/EpiAwareUtils/SafeInt.jl @@ -0,0 +1,16 @@ +const SafeInt = Union{Int, BigInt} + +""" +A type to represent real-valued distributions, the purpose of this type is to avoid problems +with the `eltype` function when having `rand` calls in the model. +""" +struct SafeIntValued <: Distributions.ValueSupport end +function Base.eltype(::Type{<:Distributions.Sampleable{F, SafeIntValued}}) where {F} + SafeInt +end + +""" +A constant alias for `Distribution{Univariate, SafeIntValued}`. This type represents a univariate distribution with real-valued outcomes. +""" +const SafeDiscreteUnivariateDistribution = Distributions.Distribution{ + Distributions.Univariate, SafeIntValued} diff --git a/EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl b/EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl index 4341bcc28..e57714b80 100644 --- a/EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl +++ b/EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl @@ -65,7 +65,7 @@ var(d) 2.4617291430060293e40 ``` " -struct SafeNegativeBinomial{T <: Real} <: DiscreteUnivariateDistribution +struct SafeNegativeBinomial{T <: Real} <: SafeDiscreteUnivariateDistribution r::T p::T diff --git a/EpiAware/src/EpiAwareUtils/SafePoisson.jl b/EpiAware/src/EpiAwareUtils/SafePoisson.jl index 46e5201ad..c85243ebc 100644 --- a/EpiAware/src/EpiAwareUtils/SafePoisson.jl +++ b/EpiAware/src/EpiAwareUtils/SafePoisson.jl @@ -45,7 +45,7 @@ var(d) 7.016735912097631e20 ``` " -struct SafePoisson{T <: Real} <: DiscreteUnivariateDistribution +struct SafePoisson{T <: Real} <: SafeDiscreteUnivariateDistribution λ::T SafePoisson{T}(λ::Real) where {T <: Real} = new{T}(λ) @@ -86,7 +86,7 @@ Distributions.rate(d::SafePoisson) = d.λ ### Statistics Distributions.mean(d::SafePoisson) = d.λ -Distributions.mode(d::SafePoisson) = _safe_int_floor(d.λ) +Distributions.mode(d::SafePoisson) = floor(d.λ) Distributions.var(d::SafePoisson) = d.λ Distributions.skewness(d::SafePoisson) = one(typeof(d.λ)) / sqrt(d.λ) Distributions.kurtosis(d::SafePoisson) = one(typeof(d.λ)) / d.λ @@ -229,7 +229,7 @@ function log1pmx(x::Float64) end # Procedure F -function procf(λ, K::Int, s::Float64) +function procf(λ, K::SafeInt, s::Float64) # can be pre-computed, but does not seem to affect performance ω = 0.3989422804014327 / s b1 = 0.041666666666666664 / λ diff --git a/EpiAware/test/EpiAwareUtils/SafeInt.jl b/EpiAware/test/EpiAwareUtils/SafeInt.jl new file mode 100644 index 000000000..d1a59dfab --- /dev/null +++ b/EpiAware/test/EpiAwareUtils/SafeInt.jl @@ -0,0 +1,8 @@ +@testitem "SafeInt Type Tests" begin + using Distributions + struct DummySampleable <: Sampleable{Univariate, SafeIntValued} end + + @test SafeIntValued <: Distributions.ValueSupport + @test eltype(DummySampleable) <: Union{Int, BigInt} + @test SafeDiscreteUnivariateDistribution == Distribution{Univariate, SafeIntValued} +end diff --git a/EpiAware/test/EpiAwareUtils/SafeNegativeBinomial.jl b/EpiAware/test/EpiAwareUtils/SafeNegativeBinomial.jl index 8c828e49d..4e1ee09db 100644 --- a/EpiAware/test/EpiAwareUtils/SafeNegativeBinomial.jl +++ b/EpiAware/test/EpiAwareUtils/SafeNegativeBinomial.jl @@ -70,7 +70,7 @@ end dist = SafeNegativeBinomial(r, p) @testset "Large value of mean samples a BigInt with SafePoisson" begin - @test rand(dist) isa BigInt + @test rand(dist) isa Union{Int, BigInt} end @testset "Large value of mean sample failure with Poisson" begin _dist = EpiAware.EpiAwareUtils._negbin(dist) diff --git a/EpiAware/test/EpiAwareUtils/SafePoisson.jl b/EpiAware/test/EpiAwareUtils/SafePoisson.jl index 6630e066b..874cf3423 100644 --- a/EpiAware/test/EpiAwareUtils/SafePoisson.jl +++ b/EpiAware/test/EpiAwareUtils/SafePoisson.jl @@ -2,9 +2,9 @@ λ = 10.0 dist = SafePoisson(λ) @test typeof(dist) <: SafePoisson - @test rand(dist) isa Int - @test rand(dist, 10) isa Vector{Int} - @test rand(dist, 10, 10) isa Array{Int} + @test rand(dist) isa SafeInt + @test rand(dist, 10) isa Vector{SafeInt} + @test rand(dist, 10, 10) isa Array{SafeInt} end @testitem "Check distribution properties of SafePoisson" begin @@ -54,7 +54,7 @@ end bigλ = exp(48.0) #Large value of λ dist = SafePoisson(bigλ) @testset "Large value of mean samples a BigInt with SafePoisson" begin - @test rand(dist) isa BigInt + @test rand(dist) isa SafeInt end @testset "Large value of mean sample failure with Poisson" begin _dist = Poisson(dist.λ) diff --git a/EpiAware/test/EpiObsModels/modifiers/LatentDelay.jl b/EpiAware/test/EpiObsModels/modifiers/LatentDelay.jl index f1bddcb22..34ccc1721 100644 --- a/EpiAware/test/EpiObsModels/modifiers/LatentDelay.jl +++ b/EpiAware/test/EpiObsModels/modifiers/LatentDelay.jl @@ -175,9 +175,7 @@ end ExpGrowthRate, Renewal] .|> em_type -> em_type( - data = EpiData([0.2, 0.5, 0.3], - em_type == Renewal ? softplus : exp - ), + data = EpiData([0.2, 0.5, 0.3], exp), initialisation_prior = Normal(log(100.0), 0.01) ) diff --git a/EpiAware/test/EpiObsModels/modifiers/PrefixObservationModel.jl b/EpiAware/test/EpiObsModels/modifiers/PrefixObservationModel.jl index b46ba54c4..346d2b9ba 100644 --- a/EpiAware/test/EpiObsModels/modifiers/PrefixObservationModel.jl +++ b/EpiAware/test/EpiObsModels/modifiers/PrefixObservationModel.jl @@ -10,5 +10,5 @@ end mdl = generate_observations(model, missing, 10) draw = rand(mdl) - @test typeof(draw[:var"Test.y_t[1]"]) <: Int + @test typeof(draw[:var"Test.y_t[1]"]) <: Real end