Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

issue 492: Create distribution that returns <: Real #494

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ using DynamicPPL: Model, fix, condition, @submodel, @model
using MCMCChains: Chains
using Random: AbstractRNG, randexp
using Tables: rowtable
import Base: eltype

using Distributions, DocStringExtensions, QuadGK, Statistics, Turing

#Export Structures
export HalfNormal, DirectSample, SafePoisson, SafeNegativeBinomial
export HalfNormal, DirectSample, SafePoisson, SafeNegativeBinomial, RealValued,
RealUnivariateDistribution

#Export functions
export scan, spread_draws, censored_cdf, censored_pmf, get_param_array, prefix_submodel, ∫F
Expand All @@ -32,6 +34,7 @@ include("turing-methods.jl")
include("DirectSample.jl")
include("post-inference.jl")
include("get_param_array.jl")
include("RealValued.jl")
include("SafePoisson.jl")
include("SafeNegativeBinomial.jl")

Expand Down
12 changes: 12 additions & 0 deletions EpiAware/src/EpiAwareUtils/RealValued.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
A type to represent real-valued distributions, the purpose of this type is to avoid problems
with the `eltype` function when having `rand` calls in the model.
"""
struct RealValued <: Distributions.ValueSupport end
Base.eltype(::Type{<:Distributions.Sampleable{F, RealValued}}) where {F} = Real

"""
A constant alias for `Distribution{Univariate, RealValued}`. This type represents a univariate distribution with real-valued outcomes.
"""
const RealUnivariateDistribution = Distributions.Distribution{
Distributions.Univariate, RealValued}
2 changes: 1 addition & 1 deletion EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ var(d)
2.4617291430060293e40
```
"
struct SafeNegativeBinomial{T <: Real} <: DiscreteUnivariateDistribution
struct SafeNegativeBinomial{T <: Real} <: RealUnivariateDistribution
r::T
p::T

Expand Down
14 changes: 7 additions & 7 deletions EpiAware/src/EpiAwareUtils/SafePoisson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ var(d)
7.016735912097631e20
```
"
struct SafePoisson{T <: Real} <: DiscreteUnivariateDistribution
struct SafePoisson{T <: Real} <: RealUnivariateDistribution
λ::T

SafePoisson{T}(λ::Real) where {T <: Real} = new{T}(λ)
Expand Down Expand Up @@ -86,7 +86,7 @@ Distributions.rate(d::SafePoisson) = d.λ
### Statistics

Distributions.mean(d::SafePoisson) = d.λ
Distributions.mode(d::SafePoisson) = _safe_int_floor(d.λ)
Distributions.mode(d::SafePoisson) = floor(d.λ)
Distributions.var(d::SafePoisson) = d.λ
Distributions.skewness(d::SafePoisson) = one(typeof(d.λ)) / sqrt(d.λ)
Distributions.kurtosis(d::SafePoisson) = one(typeof(d.λ)) / d.λ
Expand Down Expand Up @@ -142,12 +142,12 @@ ad_rand(λ) = ad_rand(Random.GLOBAL_RNG, λ)
function ad_rand(rng::AbstractRNG, λ)
s = sqrt(λ)
d = 6.0 * λ^2
L = _safe_int_floor(λ - 1.1484)
L = floor(λ - 1.1484)
# Step N
G = λ + s * randn(rng)

if G >= 0.0
K = _safe_int_floor(G)
K = floor(G)
# Step I
if K >= L
return K
Expand Down Expand Up @@ -177,7 +177,7 @@ function ad_rand(rng::AbstractRNG, λ)
continue
end

K = _safe_int_floor(λ + s * T)
K = floor(λ + s * T)
px, py, fx, fy = procf(λ, K, s)
c = 0.1069 / λ

Expand Down Expand Up @@ -229,7 +229,7 @@ function log1pmx(x::Float64)
end

# Procedure F
function procf(λ, K::Int, s::Float64)
function procf(λ, K, s::Float64)
# can be pre-computed, but does not seem to affect performance
ω = 0.3989422804014327 / s
b1 = 0.041666666666666664 / λ
Expand All @@ -241,7 +241,7 @@ function procf(λ, K::Int, s::Float64)

if K < 10
px = -float(λ)
py = λ^K / factorial(K)
py = λ^K / factorial(floor(Int, K))
else
δ = 0.08333333333333333 / K
δ -= 4.8 * δ^3
Expand Down
8 changes: 8 additions & 0 deletions EpiAware/test/EpiAwareUtils/RealValued.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
@testitem "RealValued Type Tests" begin
using Distributions
struct DummySampleable <: Sampleable{Univariate, RealValued} end

@test RealValued <: Distributions.ValueSupport
@test eltype(DummySampleable) == Real
@test RealUnivariateDistribution == Distribution{Univariate, RealValued}
end
2 changes: 1 addition & 1 deletion EpiAware/test/EpiAwareUtils/SafeNegativeBinomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ end

dist = SafeNegativeBinomial(r, p)
@testset "Large value of mean samples a BigInt with SafePoisson" begin
@test rand(dist) isa BigInt
@test rand(dist) isa Real
end
@testset "Large value of mean sample failure with Poisson" begin
_dist = EpiAware.EpiAwareUtils._negbin(dist)
Expand Down
8 changes: 4 additions & 4 deletions EpiAware/test/EpiAwareUtils/SafePoisson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
λ = 10.0
dist = SafePoisson(λ)
@test typeof(dist) <: SafePoisson
@test rand(dist) isa Int
@test rand(dist, 10) isa Vector{Int}
@test rand(dist, 10, 10) isa Array{Int}
@test rand(dist) isa Real
@test rand(dist, 10) isa Vector{Real}
@test rand(dist, 10, 10) isa Array{Real}
end

@testitem "Check distribution properties of SafePoisson" begin
Expand Down Expand Up @@ -54,7 +54,7 @@ end
bigλ = exp(48.0) #Large value of λ
dist = SafePoisson(bigλ)
@testset "Large value of mean samples a BigInt with SafePoisson" begin
@test rand(dist) isa BigInt
@test rand(dist) isa Real
end
@testset "Large value of mean sample failure with Poisson" begin
_dist = Poisson(dist.λ)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ end

mdl = generate_observations(model, missing, 10)
draw = rand(mdl)
@test typeof(draw[:var"Test.y_t[1]"]) <: Int
@test typeof(draw[:var"Test.y_t[1]"]) <: Real
end
Loading