Skip to content

Commit 02147c9

Browse files
committed
fix bug, add funnel
1 parent c6d2f4a commit 02147c9

2 files changed

Lines changed: 8 additions & 3 deletions

File tree

test/sample-correctness_tests.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,19 @@ end
9999

100100
@testset "NUTS tests with heavier tails and skewness" begin
101101
K = 5
102+
𝒩 = StandardMultivariateNormal(K)
102103

103104
# somewhat nasty, relaxed requirements
104-
= elongate(1.1)(StandardMultivariateNormal(K))
105+
= elongate(1.1)(𝒩)
105106
NUTS_tests(RNG, ℓ, "elongate(1.1, 𝑁)",
106107
10000; p_alert = 0.05, EBFMI_alert = 0.2, R̂_fail = 1.05)
107108

108109
# this has very nasty tails so we relax requirements a bit
109-
= (elongate(1.1) shift(ones(K)))(StandardMultivariateNormal(K))
110+
= (elongate(1.1) shift(ones(K)))(𝒩)
110111
NUTS_tests(RNG, ℓ, "skew elongate(1.1, 𝑁)",
111112
10000; τ_alert = 0.1, EBFMI_alert = 0.2, R̂_fail = 1.05, p_fail = 0.001)
113+
114+
# funnel, mixed with a normal
115+
= mix(0.8, funnel()(𝒩), 𝒩)
116+
NUTS_tests(RNG, ℓ, "funnel", 10000; EBFMI_alert = 0.2, τ_alert = 0.1)
112117
end

test/test_hamiltonian.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ end
110110

111111
@testset "invalid values" begin
112112
n = 3
113-
= multivariate_normal(randn(RNG, n), I)
113+
= multivariate_normal(randn(RNG, n), I(n))
114114
@test_throws DynamicHMCError evaluate_ℓ(ℓ, fill(NaN, n))
115115
end
116116
end

0 commit comments

Comments
 (0)