Skip to content

Commit 2d0e7b7

Browse files
authored
Fix simplex_logpdf for values outside of the support (#173)
1 parent 68d6748 commit 2d0e7b7

File tree

2 files changed

+48
-6
lines changed

2 files changed

+48
-6
lines changed

src/multivariate.jl

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ TuringDirichlet(d::Dirichlet) = TuringDirichlet(d.alpha, d.alpha0, d.lmnB)
2828

2929
Base.length(d::TuringDirichlet) = length(d.alpha)
3030

31+
function Distributions.insupport(d::TuringDirichlet, x::AbstractVector{<:Real})
32+
return dirichlet_insupport(x, length(d))
33+
end
34+
function dirichlet_insupport(x::AbstractVector{<:Real}, d::Int)
35+
return d == length(x) && all(x -> zero(x) <= x <= one(x), x) && sum(x) 1
36+
end
37+
3138
# copied from Distributions
3239
# TODO: remove and use `Dirichlet`?
3340
function Distributions._rand!(
@@ -65,9 +72,23 @@ ZygoteRules.@adjoint function Distributions.Dirichlet(d, alpha)
6572
return ZygoteRules.pullback(TuringDirichlet, d, alpha)
6673
end
6774

68-
simplex_logpdf(alpha, lmnB, x::AbstractVector) = sum(xlogy.(alpha .- 1, x)) - lmnB
75+
function xlogy_or_neginf(x, y)
76+
z = zero(y)
77+
return y >= z ? xlogy(x, y) : xlogy(one(x), z)
78+
end
79+
function identity_or_neginf(x::Real, insupport::Bool)
80+
return insupport ? float(x) : log(zero(x))
81+
end
82+
83+
function simplex_logpdf(alpha, lmnB, x::AbstractVector)
84+
logp = sum(xlogy_or_neginf.(alpha .- 1, x)) - lmnB
85+
return identity_or_neginf(logp, dirichlet_insupport(x, length(alpha)))
86+
end
6987
function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
70-
return vec(sum(xlogy.(alpha .- 1, x); dims=1)) .- lmnB
88+
return identity_or_neginf.(
89+
vec(sum(xlogy_or_neginf.(alpha .- 1, x); dims=1)) .- lmnB,
90+
dirichlet_insupport.(eachcol(x), length(alpha)),
91+
)
7192
end
7293

7394
ZygoteRules.@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractVector)

test/others.jl

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@
130130

131131
# Use finite differencing to compute reverse-mode sensitivities.
132132
x̄s_fdm = FDM.j′vp(central_fdm(5, 1), f, ȳ, x...)
133-
133+
134134
if AD == "All" || AD == "Zygote"
135135
# Use Zygote to compute reverse-mode sensitivities.
136136
y_zygote, back_zygote = Zygote.pullback(f, x...)
@@ -180,14 +180,14 @@
180180

181181
# Check that Tracker forwards-pass produces the correct answer.
182182
@test isapprox(y, Tracker.data(y_tracker), atol=atol, rtol=rtol)
183-
183+
184184
# Check that Tracker reverse-mode sensitivities are correct.
185185
@test all(zip(x̄s_tracker, x̄s_fdm)) do (x̄_tracker, x̄_fdm)
186186
isapprox(Tracker.data(x̄_tracker), x̄_fdm; atol=atol, rtol=rtol)
187187
end
188188
end
189189
end
190-
_to_cov(B) = B + B' + 2 * size(B, 1) * Matrix(I, size(B)...)
190+
_to_cov(B) = B + B' + 2 * size(B, 1) * Matrix(I, size(B)...)
191191

192192
@testset "logsumexp" begin
193193
x, y = rand(3), rand()
@@ -257,7 +257,7 @@
257257
@testset "Params" begin
258258
m = rand(10)
259259
sigmas = randexp(10)
260-
260+
261261
d = TuringDiagMvNormal(m, sigmas)
262262
@test params(d) == (m, sigmas)
263263

@@ -335,5 +335,26 @@
335335
@test s2 isa Matrix{Float64}
336336
@test size(s2) == (dim, n)
337337
end
338+
339+
# https://github.com/TuringLang/DistributionsAD.jl/issues/158
340+
let
341+
d = TuringDirichlet(rand(2))
342+
z = rand(d)
343+
logpdf_z = logpdf(d, z)
344+
pdf_z = pdf(d, z)
345+
346+
for x in ([0.5, 0.8], [-0.5, 1.5])
347+
@test logpdf(d, x) == -Inf
348+
@test iszero(pdf(d, x))
349+
350+
xmat = hcat(x, x)
351+
@test all(==(-Inf), logpdf(d, xmat))
352+
@test all(iszero, pdf(d, xmat))
353+
354+
xzmat = hcat(x, z)
355+
@test logpdf(d, xzmat) == [-Inf, logpdf_z]
356+
@test pdf(d, xzmat) == [0, pdf_z]
357+
end
358+
end
338359
end
339360
end

0 commit comments

Comments
 (0)