Skip to content

Commit 915059b

Browse files
authored
Fix zero-type of logjac for ReshapeTransform (#851)
* Fix zero-type of logjac for ReshapeTransform * Bump patch version to 0.35.4 * Make logpdf of NoDist be of the eltype of the argument * Use float_type_with_fallback for logjacs and logpdfs * Make LogProbType be float(Real)
1 parent 9df42bf commit 915059b

File tree

4 files changed

+50
-20
lines changed

4 files changed

+50
-20
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.35.3"
3+
version = "0.35.4"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/distribution_wrappers.jl

+20-12
Original file line numberDiff line numberDiff line change
@@ -54,30 +54,38 @@ function Distributions.rand!(
5454
) where {N}
5555
return Distributions.rand!(rng, d.dist, x)
5656
end
57-
Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0
58-
Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0
59-
function Distributions.logpdf(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real})
60-
return zeros(Int, size(x, 2))
57+
function Distributions.logpdf(::NoDist{<:Univariate}, x::Real)
58+
return zero(LogProbType)
59+
end
60+
function Distributions.logpdf(::NoDist{<:Multivariate}, x::AbstractVector{<:Real})
61+
return zero(LogProbType)
62+
end
63+
function Distributions.logpdf(::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real})
64+
return zeros(LogProbType, size(x, 2))
65+
end
66+
function Distributions.logpdf(::NoDist{<:Matrixvariate}, x::AbstractMatrix{<:Real})
67+
return zero(LogProbType)
6168
end
62-
Distributions.logpdf(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0
6369
Distributions.minimum(d::NoDist) = minimum(d.dist)
6470
Distributions.maximum(d::NoDist) = maximum(d.dist)
6571

66-
Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real, ::Bool) = 0
72+
function Bijectors.logpdf_with_trans(::NoDist{<:Univariate}, x::Real, ::Bool)
73+
return zero(LogProbType)
74+
end
6775
function Bijectors.logpdf_with_trans(
68-
d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}, ::Bool
76+
::NoDist{<:Multivariate}, x::AbstractVector{<:Real}, ::Bool
6977
)
70-
return 0
78+
return zero(LogProbType)
7179
end
7280
function Bijectors.logpdf_with_trans(
73-
d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}, ::Bool
81+
::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}, ::Bool
7482
)
75-
return zeros(Int, size(x, 2))
83+
return zeros(LogProbType, size(x, 2))
7684
end
7785
function Bijectors.logpdf_with_trans(
78-
d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}, ::Bool
86+
::NoDist{<:Matrixvariate}, x::AbstractMatrix{<:Real}, ::Bool
7987
)
80-
return 0
88+
return zero(LogProbType)
8189
end
8290

8391
Bijectors.bijector(d::NoDist) = Bijectors.bijector(d.dist)

src/utils.jl

+28-6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,16 @@ const NO_DEFAULT = NoDefault()
55
# A short-hand for a type commonly used in type signatures for VarInfo methods.
66
VarNameTuple = NTuple{N,VarName} where {N}
77

8+
# TODO(mhauru) This is currently used in the transformation functions of NoDist,
9+
# ReshapeTransform, and UnwrapSingletonTransform, and in VarInfo. We should also use it in
10+
# SimpleVarInfo and maybe other places.
11+
"""
12+
The type for all log probability variables.
13+
14+
This is Float64 on 64-bit systems and Float32 on 32-bit systems.
15+
"""
16+
const LogProbType = float(Real)
17+
818
"""
919
@addlogprob!(ex)
1020
@@ -252,12 +262,16 @@ function (f::UnwrapSingletonTransform)(x)
252262
return only(x)
253263
end
254264

255-
Bijectors.with_logabsdet_jacobian(f::UnwrapSingletonTransform, x) = (f(x), 0)
265+
function Bijectors.with_logabsdet_jacobian(f::UnwrapSingletonTransform, x)
266+
return f(x), zero(LogProbType)
267+
end
268+
256269
function Bijectors.with_logabsdet_jacobian(
257270
inv_f::Bijectors.Inverse{<:UnwrapSingletonTransform}, x
258271
)
259272
f = inv_f.orig
260-
return (reshape([x], f.input_size), 0)
273+
result = reshape([x], f.input_size)
274+
return result, zero(LogProbType)
261275
end
262276

263277
"""
@@ -306,18 +320,26 @@ function (inv_f::Bijectors.Inverse{<:ReshapeTransform})(x)
306320
return inverse(x)
307321
end
308322

309-
Bijectors.with_logabsdet_jacobian(f::ReshapeTransform, x) = (f(x), 0)
323+
function Bijectors.with_logabsdet_jacobian(f::ReshapeTransform, x)
324+
return f(x), zero(LogProbType)
325+
end
310326

311327
function Bijectors.with_logabsdet_jacobian(inv_f::Bijectors.Inverse{<:ReshapeTransform}, x)
312-
return (inv_f(x), 0)
328+
return inv_f(x), zero(LogProbType)
313329
end
314330

315331
struct ToChol <: Bijectors.Bijector
316332
uplo::Char
317333
end
318334

319-
Bijectors.with_logabsdet_jacobian(f::ToChol, x) = (Cholesky(Matrix(x), f.uplo, 0), 0)
320-
Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y::Cholesky) = (y.UL, 0)
335+
function Bijectors.with_logabsdet_jacobian(f::ToChol, x)
336+
return Cholesky(Matrix(x), f.uplo, 0), zero(LogProbType)
337+
end
338+
339+
function Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y::Cholesky)
340+
return y.UL, zero(LogProbType)
341+
end
342+
321343
function Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y)
322344
return error(
323345
"Inverse{ToChol} is only defined for Cholesky factorizations. " *

src/varinfo.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -821,7 +821,7 @@ end
821821

822822
# VarInfo
823823

824-
VarInfo(meta=Metadata()) = VarInfo(meta, Ref{Float64}(0.0), Ref(0))
824+
VarInfo(meta=Metadata()) = VarInfo(meta, Ref{LogProbType}(0.0), Ref(0))
825825

826826
function TypedVarInfo(vi::VectorVarInfo)
827827
new_metas = group_by_symbol(vi.metadata)

0 commit comments

Comments
 (0)