Skip to content

Commit 3cc12f2

Browse files
authored
Merge pull request #192 from tpapp/tp/update-test-framework
update test framework
2 parents 42bbb2c + c7185c1 commit 3cc12f2

4 files changed

Lines changed: 16 additions & 8 deletions

File tree

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1919
TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff"
2020

2121
[compat]
22-
LogDensityTestSuite = "0.6"
22+
LogDensityTestSuite = "0.7"

test/sample-correctness_tests.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,20 @@ end
9999

100100
@testset "NUTS tests with heavier tails and skewness" begin
101101
K = 5
102+
𝒩 = StandardMultivariateNormal(K)
102103

103104
# somewhat nasty, relaxed requirements
104-
= elongate(1.1, StandardMultivariateNormal(K))
105-
NUTS_tests(RNG, ℓ, "elongate(1.1, 𝑁)", 10000; p_alert = 1e-5, EBFMI_alert = 0.2, R̂_fail = 1.2)
105+
= elongate(1.1)(𝒩)
106+
NUTS_tests(RNG, ℓ, "elongate(1.1, 𝑁)",
107+
10000; p_alert = 0.05, EBFMI_alert = 0.2, R̂_fail = 1.05, τ_fail = 0.3)
106108

107109
# this has very nasty tails so we relax requirements a bit
108-
= elongate(1.1, shift(ones(K), StandardMultivariateNormal(K)))
109-
NUTS_tests(RNG, ℓ, "skew elongate(1.1, 𝑁)", 10000; τ_alert = 0.1, EBFMI_alert = 0.2)
110+
= (elongate(1.1) shift(ones(K)))(𝒩)
111+
NUTS_tests(RNG, ℓ, "skew elongate(1.1, 𝑁)",
112+
10000; τ_alert = 0.1, EBFMI_alert = 0.2, R̂_fail = 1.05, p_fail = 0.001)
113+
114+
# funnel, mixed with a normal
115+
= mix(0.8, funnel()(𝒩), 𝒩)
116+
NUTS_tests(RNG, ℓ, "funnel", 10000;
117+
EBFMI_alert = 0.2, τ_alert = 0.1, p_fail = 5e-3, R̂_fail = 1.05)
110118
end

test/test_hamiltonian.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ end
110110

111111
@testset "invalid values" begin
112112
n = 3
113-
= multivariate_normal(randn(RNG, n), I)
113+
= multivariate_normal(randn(RNG, n), I(n))
114114
@test_throws DynamicHMCError evaluate_ℓ(ℓ, fill(NaN, n))
115115
end
116116
end

test/utilities.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ thus decorrelates the density perfectly.
6161
find_stable_ϵ::GaussianKineticEnergy, Σ) = eigmin.W'*Σ*κ.W)
6262

6363
"Multivariate normal with `Σ = LL'`."
64-
multivariate_normal(μ, L) = shift, linear(L, StandardMultivariateNormal(length))))
64+
multivariate_normal(μ, L) = (shift) linear(L))(StandardMultivariateNormal(length(μ)))
6565

6666
"Multivariate normal with diagonal `Σ` (constant `v` variance)."
67-
multivariate_normal(μ, v::Real = 1) = multivariate_normal(μ, I * v)
67+
multivariate_normal(μ, v::Real = 1) = multivariate_normal(μ, I(length(μ)) * v)
6868

6969
"""
7070
$(SIGNATURES)

0 commit comments

Comments
 (0)