Skip to content

Commit 0b6a162

Browse files
SamuelBrand1seabbs
andauthored
issue 492: Create distribution that returns <: Union{Int,BigInt} (#497)
* Make RealUnivariateDistribution type * test RealUnivariateDistribution * update method * reformat * modify procf * Explicit use of Distributions * fix unit test fails * Discrete valued dists expecting rand return of Union{int,BigInt} * Make RealUnivariateDistribution type * test RealUnivariateDistribution * update method * reformat * modify procf * Explicit use of Distributions * fix unit test fails * Discrete valued dists expecting rand return of Union{int,BigInt} * Typo * Revert CI test --------- Co-authored-by: Sam Abbott <[email protected]>
1 parent 119f244 commit 0b6a162

File tree

9 files changed

+39
-14
lines changed

9 files changed

+39
-14
lines changed

EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@ using DynamicPPL: Model, fix, condition, @submodel, @model
1010
using MCMCChains: Chains
1111
using Random: AbstractRNG, randexp
1212
using Tables: rowtable
13+
import Base: eltype
1314

1415
using Distributions, DocStringExtensions, QuadGK, Statistics, Turing
1516

1617
#Export Structures
17-
export HalfNormal, DirectSample, SafePoisson, SafeNegativeBinomial
18+
export HalfNormal, DirectSample, SafePoisson, SafeNegativeBinomial, SafeIntValued, SafeInt,
19+
SafeDiscreteUnivariateDistribution
1820

1921
#Export functions
2022
export scan, spread_draws, censored_cdf, censored_pmf, get_param_array, prefix_submodel, ∫F
@@ -32,6 +34,7 @@ include("turing-methods.jl")
3234
include("DirectSample.jl")
3335
include("post-inference.jl")
3436
include("get_param_array.jl")
37+
include("SafeInt.jl")
3538
include("SafePoisson.jl")
3639
include("SafeNegativeBinomial.jl")
3740

EpiAware/src/EpiAwareUtils/SafeInt.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
const SafeInt = Union{Int, BigInt}
2+
3+
"""
4+
A type to represent real-valued distributions, the purpose of this type is to avoid problems
5+
with the `eltype` function when having `rand` calls in the model.
6+
"""
7+
struct SafeIntValued <: Distributions.ValueSupport end
8+
function Base.eltype(::Type{<:Distributions.Sampleable{F, SafeIntValued}}) where {F}
9+
SafeInt
10+
end
11+
12+
"""
13+
A constant alias for `Distribution{Univariate, SafeIntValued}`. This type represents a univariate distribution with real-valued outcomes.
14+
"""
15+
const SafeDiscreteUnivariateDistribution = Distributions.Distribution{
16+
Distributions.Univariate, SafeIntValued}

EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ var(d)
6565
2.4617291430060293e40
6666
```
6767
"
68-
struct SafeNegativeBinomial{T <: Real} <: DiscreteUnivariateDistribution
68+
struct SafeNegativeBinomial{T <: Real} <: SafeDiscreteUnivariateDistribution
6969
r::T
7070
p::T
7171

EpiAware/src/EpiAwareUtils/SafePoisson.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ var(d)
4545
7.016735912097631e20
4646
```
4747
"
48-
struct SafePoisson{T <: Real} <: DiscreteUnivariateDistribution
48+
struct SafePoisson{T <: Real} <: SafeDiscreteUnivariateDistribution
4949
λ::T
5050

5151
SafePoisson{T}::Real) where {T <: Real} = new{T}(λ)
@@ -86,7 +86,7 @@ Distributions.rate(d::SafePoisson) = d.λ
8686
### Statistics
8787

8888
Distributions.mean(d::SafePoisson) = d.λ
89-
Distributions.mode(d::SafePoisson) = _safe_int_floor(d.λ)
89+
Distributions.mode(d::SafePoisson) = floor(d.λ)
9090
Distributions.var(d::SafePoisson) = d.λ
9191
Distributions.skewness(d::SafePoisson) = one(typeof(d.λ)) / sqrt(d.λ)
9292
Distributions.kurtosis(d::SafePoisson) = one(typeof(d.λ)) / d.λ
@@ -229,7 +229,7 @@ function log1pmx(x::Float64)
229229
end
230230

231231
# Procedure F
232-
function procf(λ, K::Int, s::Float64)
232+
function procf(λ, K::SafeInt, s::Float64)
233233
# can be pre-computed, but does not seem to affect performance
234234
ω = 0.3989422804014327 / s
235235
b1 = 0.041666666666666664 / λ
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
@testitem "SafeInt Type Tests" begin
2+
using Distributions
3+
struct DummySampleable <: Sampleable{Univariate, SafeIntValued} end
4+
5+
@test SafeIntValued <: Distributions.ValueSupport
6+
@test eltype(DummySampleable) <: Union{Int, BigInt}
7+
@test SafeDiscreteUnivariateDistribution == Distribution{Univariate, SafeIntValued}
8+
end

EpiAware/test/EpiAwareUtils/SafeNegativeBinomial.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ end
7070

7171
dist = SafeNegativeBinomial(r, p)
7272
@testset "Large value of mean samples a BigInt with SafePoisson" begin
73-
@test rand(dist) isa BigInt
73+
@test rand(dist) isa Union{Int, BigInt}
7474
end
7575
@testset "Large value of mean sample failure with Poisson" begin
7676
_dist = EpiAware.EpiAwareUtils._negbin(dist)

EpiAware/test/EpiAwareUtils/SafePoisson.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
λ = 10.0
33
dist = SafePoisson(λ)
44
@test typeof(dist) <: SafePoisson
5-
@test rand(dist) isa Int
6-
@test rand(dist, 10) isa Vector{Int}
7-
@test rand(dist, 10, 10) isa Array{Int}
5+
@test rand(dist) isa SafeInt
6+
@test rand(dist, 10) isa Vector{SafeInt}
7+
@test rand(dist, 10, 10) isa Array{SafeInt}
88
end
99

1010
@testitem "Check distribution properties of SafePoisson" begin
@@ -54,7 +54,7 @@ end
5454
bigλ = exp(48.0) #Large value of λ
5555
dist = SafePoisson(bigλ)
5656
@testset "Large value of mean samples a BigInt with SafePoisson" begin
57-
@test rand(dist) isa BigInt
57+
@test rand(dist) isa SafeInt
5858
end
5959
@testset "Large value of mean sample failure with Poisson" begin
6060
_dist = Poisson(dist.λ)

EpiAware/test/EpiObsModels/modifiers/LatentDelay.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,7 @@ end
175175
ExpGrowthRate,
176176
Renewal] .|>
177177
em_type -> em_type(
178-
data = EpiData([0.2, 0.5, 0.3],
179-
em_type == Renewal ? softplus : exp
180-
),
178+
data = EpiData([0.2, 0.5, 0.3], exp),
181179
initialisation_prior = Normal(log(100.0), 0.01)
182180
)
183181

EpiAware/test/EpiObsModels/modifiers/PrefixObservationModel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@ end
1010

1111
mdl = generate_observations(model, missing, 10)
1212
draw = rand(mdl)
13-
@test typeof(draw[:var"Test.y_t[1]"]) <: Int
13+
@test typeof(draw[:var"Test.y_t[1]"]) <: Real
1414
end

0 commit comments

Comments
 (0)