Skip to content

Commit 1b434a7

Browse files
committed
Specialize logdensityof for DensityMeasure
Ensures proper type propagation (until future refactor of density calculation engine).
1 parent c159659 commit 1b434a7

File tree

5 files changed

+61
-7
lines changed

5 files changed

+61
-7
lines changed

ext/MeasureBaseChainRulesCoreExt.jl

+18
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@ using MeasureBase
66
using ChainRulesCore: NoTangent, ZeroTangent
77
import ChainRulesCore
88

9+
# = utils ====================================================================
10+
11+
using MeasureBase: isneginf, isposinf
12+
13+
_isneginf_pullback(::Any) = (NoTangent(), ZeroTangent())
14+
ChainRulesCore.rrule(::typeof(isneginf), x) = isneginf(x), _logdensityof_rt_pullback
15+
16+
_isposinf_pullback(::Any) = (NoTangent(), ZeroTangent())
17+
ChainRulesCore.rrule(::typeof(isposinf), x) = isposinf(x), _isposinf_pullback
918

1019
# = insupport & friends ======================================================
1120

@@ -33,4 +42,13 @@ ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = check_dof(ν, μ), _check_do
3342
_checked_arg_pullback(ΔΩ) = NoTangent(), NoTangent(), ΔΩ
3443
ChainRulesCore.rrule(::typeof(checked_arg), ν, x) = checked_arg(ν, x), _checked_arg_pullback
3544

45+
# = return type inference ====================================================
46+
47+
using MeasureBase: logdensityof_rt
48+
49+
_logdensityof_rt_pullback(::Any) = (NoTangent(), NoTangent(), ZeroTangent())
50+
function ChainRulesCore.rrule(::typeof(logdensityof_rt), target, v)
51+
logdensityof_rt(target, v), _logdensityof_rt_pullback
52+
end
53+
3654
end # module MeasureBaseChainRulesCoreExt

src/density-core.jl

+4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ To compute a log-density relative to a specific base-measure, see
3333
_checksupport(insupport(μ, x), result)
3434
end
3535

36+
@inline function logdensityof_rt(::T, ::U) where {T,U}
37+
Core.Compiler.return_type(logdensityof, Tuple{T,U})
38+
end
39+
3640
_checksupport(cond, result) = ifelse(cond == true, result, oftype(result, -Inf))
3741

3842
export unsafe_logdensityof

src/density.jl

+18
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,24 @@ logdensity_def(μ::DensityMeasure, x) = logdensityof(μ.f, x)
163163

164164
density_def::DensityMeasure, x) = densityof.f, x)
165165

166+
function logdensityof::DensityMeasure, x::Any)
167+
integrand, μ_base = μ.f, μ.base
168+
169+
base_logval = logdensityof(μ_base, x)
170+
171+
T = typeof(base_logval)
172+
U = logdensityof_rt(integrand, x)
173+
R = promote_type(T, U)
174+
175+
# Don't evaluate base measure if integrand is zero or NaN
176+
if isneginf(base_logval)
177+
R(-Inf)
178+
else
179+
integrand_logval = logdensityof(integrand, x)
180+
convert(R, integrand_logval + base_logval)::R
181+
end
182+
end
183+
166184
"""
167185
rebase(μ, ν)
168186

src/utils.jl

+5-7
Original file line numberDiff line numberDiff line change
@@ -165,20 +165,18 @@ using InverseFunctions: FunctionWithInverse
165165
unwrap(f) = f
166166
unwrap(f::FunctionWithInverse) = f.f
167167

168-
169168
fcomp(f, g) = fchain(g, f)
170169
fcomp(::typeof(identity), g) = g
171170
fcomp(f, ::typeof(identity)) = f
172171
fcomp(::typeof(identity), ::typeof(identity)) = identity
173172

173+
near_neg_inf(::Type{T}) where {T<:Real} = T(-1E38) # Still fits into Float32
174174

175-
near_neg_inf(::Type{T}) where T<:Real = T(-1E38) # Still fits into Float32
176-
177-
isneginf(x) = isinf(x) && x < 0
178-
isposinf(x) = isinf(x) && x > 0
175+
isneginf(x) = isinf(x) && x < zero(x)
176+
isposinf(x) = isinf(x) && x > zero(x)
179177

180-
isapproxzero(x::T) where T<:Real = x zero(T)
178+
isapproxzero(x::T) where {T<:Real} = x zero(T)
181179
isapproxzero(A::AbstractArray) = all(isapproxzero, A)
182180

183-
isapproxone(x::T) where T<:Real = x one(T)
181+
isapproxone(x::T) where {T<:Real} = x one(T)
184182
isapproxone(A::AbstractArray) = all(isapproxone, A)

test/test_basics.jl

+16
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,22 @@ end
189189
end
190190
end
191191

192+
@testset "logdensityof" begin
193+
f1 = let A = randn(Float32, 3, 3)
194+
x -> sum(A * x)
195+
end
196+
f2 = x -> sqrt(abs(sum(x)))
197+
f3 = x -> 2 * sum(x)
198+
f4 = x -> sum(sqrt.(abs.(x)))
199+
m = @inferred ∫exp(f1, ∫exp(f2, ∫exp(f3, ∫exp(f4, StdUniform()^3))))
200+
201+
for x in [Float32[0.7, 0.2, 0.5], Float32[-0.7, 0.2, 0.5]]
202+
@test @inferred(logdensityof(m, x)) isa Float32
203+
@test logdensityof(m, x)
204+
f1(x) + f2(x) + f3(x) + f4(x) + logdensityof(StdUniform()^3, x)
205+
end
206+
end
207+
192208
@testset "logdensity_rel" begin
193209
@test logdensity_rel(Dirac(0.0) + Lebesgue(), Dirac(1.0), 0.0) == Inf
194210
@test logdensity_rel(Dirac(0.0) + Lebesgue(), Dirac(1.0), 1.0) == -Inf

0 commit comments

Comments
 (0)