Skip to content

Commit 0329564

Browse files
committed
apply David's comments and add tests
1 parent 7f7651b commit 0329564

9 files changed

+202
-286
lines changed

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.3.2"
66
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
77
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
88
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
9+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
910
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1112
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
@@ -21,6 +22,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2122
Combinatorics = "0.7"
2223
Distributions = "0.22"
2324
DiffRules = "0.1, 1.0"
25+
FillArrays = "0.8"
2426
ForwardDiff = "0.10.6"
2527
PDMats = "0.9"
2628
SpecialFunctions = "0.8, 0.9, 0.10"

src/DistributionsAD.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@ using PDMats,
1111
StatsFuns
1212

1313
using Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix, TrackedArray,
14-
TrackedVecOrMat, track, data
15-
using ZygoteRules: ZygoteRules, pullback
14+
TrackedVecOrMat, track, @grad, data
15+
using ZygoteRules: ZygoteRules, @adjoint, pullback
1616
using LinearAlgebra: copytri!
1717
using Distributions: AbstractMvLogNormal,
1818
ContinuousMultivariateDistribution
19-
using DiffRules, SpecialFunctions
19+
using DiffRules, SpecialFunctions, FillArrays
2020
using ForwardDiff: @define_binary_dual_op # Needed for `eval`ing diffrules here
2121
using Base.Iterators: drop
2222

@@ -39,15 +39,15 @@ export TuringScalMvNormal,
3939
TuringPoissonBinomial,
4040
TuringWishart,
4141
TuringInverseWishart,
42-
Multi,
43-
ArrayDist
42+
ArrayDist,
43+
FillDist
4444

4545
include("common.jl")
4646
include("univariate.jl")
4747
include("multivariate.jl")
4848
include("matrixvariate.jl")
49-
include("multi.jl")
5049
include("flatten.jl")
5150
include("array_dist.jl")
51+
include("multi.jl")
5252

5353
end

src/array_dist.jl

+63-125
Original file line numberDiff line numberDiff line change
@@ -1,152 +1,90 @@
1-
# Multivariate continuous
1+
# Univariate
22

3-
struct ProductVectorContinuousMultivariate{
4-
Tdists <: AbstractVector{<:ContinuousMultivariateDistribution},
5-
} <: ContinuousMatrixDistribution
6-
dists::Tdists
7-
end
8-
Base.size(dist::ProductVectorContinuousMultivariate) = (length(dist.dists[1]), length(dist))
9-
Base.length(dist::ProductVectorContinuousMultivariate) = length(dist.dists)
10-
function ArrayDist(dists::AbstractVector{<:ContinuousMultivariateDistribution})
11-
return ProductVectorContinuousMultivariate(dists)
12-
end
13-
function Distributions.logpdf(
14-
dist::ProductVectorContinuousMultivariate,
15-
x::AbstractMatrix{<:Real},
16-
)
17-
return sum(logpdf(dist.dists[i], x[:,i]) for i in 1:length(dist))
18-
end
19-
function Distributions.logpdf(
20-
dist::ProductVectorContinuousMultivariate,
21-
x::AbstractVector{<:AbstractVector{<:Real}},
22-
)
23-
return sum(logpdf(dist.dists[i], x[i]) for i in 1:length(dist))
24-
end
25-
function Distributions.rand(
26-
rng::Random.AbstractRNG,
27-
dist::ProductVectorContinuousMultivariate,
28-
)
29-
return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 1:length(dist))
30-
end
3+
const VectorOfUnivariate{
4+
S <: ValueSupport,
5+
Tdist <: UnivariateDistribution{S},
6+
Tdists <: AbstractVector{Tdist},
7+
} = Distributions.Product{S, Tdist, Tdists}
318

32-
# Multivariate discrete
9+
function ArrayDist(dists::AbstractVector{<:Normal{T}}) where {T}
10+
if T <: TrackedReal
11+
init_m = dists[1].μ
12+
means = mapreduce(vcat, drop(dists, 1); init = init_m) do d
13+
d.μ
14+
end
15+
init_v = dists[1].σ^2
16+
vars = mapreduce(vcat, drop(dists, 1); init = init_v) do d
17+
d.σ^2
18+
end
19+
else
20+
means = [d.μ for d in dists]
21+
vars = [d.σ^2 for d in dists]
22+
end
3323

34-
struct ProductVectorDiscreteMultivariate{
35-
Tdists <: AbstractVector{<:DiscreteMultivariateDistribution},
36-
} <: DiscreteMatrixDistribution
37-
dists::Tdists
24+
return MvNormal(means, vars)
3825
end
39-
Base.size(dist::ProductVectorDiscreteMultivariate) = (length(dist.dists[1]), length(dist))
40-
Base.length(dist::ProductVectorDiscreteMultivariate) = length(dist.dists)
41-
function ArrayDist(dists::AbstractVector{<:DiscreteMultivariateDistribution})
42-
return ProductVectorDiscreteMultivariate(dists)
43-
end
44-
function Distributions.logpdf(
45-
dist::ProductVectorDiscreteMultivariate,
46-
x::AbstractMatrix{<:Integer},
47-
)
48-
return sum(logpdf(dist.dists[i], x[:,i]) for i in 1:length(dist))
26+
function ArrayDist(dists::AbstractVector{<:UnivariateDistribution})
27+
return Distributions.Product(dists)
4928
end
50-
function Distributions.logpdf(
51-
dist::ProductVectorDiscreteMultivariate,
52-
x::AbstractVector{<:AbstractVector{<:Integer}},
53-
)
54-
return sum(logpdf(dist.dists[i], x[i]) for i in 1:length(dist))
29+
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractVector{<:Real})
30+
return sum(logpdf.(dist.v, x))
5531
end
56-
function Distributions.rand(
57-
rng::Random.AbstractRNG,
58-
dist::ProductVectorDiscreteMultivariate,
59-
)
60-
return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 1:length(dist))
61-
end
62-
63-
# Univariate continuous
64-
65-
struct ProductVectorContinuousUnivariate{
66-
Tdists <: AbstractVector{<:ContinuousUnivariateDistribution},
67-
} <: ContinuousMultivariateDistribution
68-
dists::Tdists
69-
end
70-
Base.length(dist::ProductVectorContinuousUnivariate) = length(dist.dists)
71-
Base.size(dist::ProductVectorContinuousUnivariate) = (length(dist),)
72-
function ArrayDist(dists::AbstractVector{<:ContinuousUnivariateDistribution})
73-
return ProductVectorContinuousUnivariate(dists)
32+
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
33+
# Any other more efficient implementation breaks Zygote
34+
return [logpdf(dist, x[:,i]) for i in 1:size(x, 2)]
7435
end
7536
function Distributions.logpdf(
76-
dist::ProductVectorContinuousUnivariate,
77-
x::AbstractVector{<:Real},
78-
)
79-
return sum(logpdf.(dist.dists, x))
80-
end
81-
function Distributions.rand(
82-
rng::Random.AbstractRNG,
83-
dist::ProductVectorContinuousUnivariate,
37+
dist::VectorOfUnivariate,
38+
x::AbstractVector{<:AbstractMatrix{<:Real}},
8439
)
85-
return rand.(Ref(rng), dist.dists)
40+
return logpdf.(Ref(dist), x)
8641
end
8742

88-
struct ProductMatrixContinuousUnivariate{
89-
Tdists <: AbstractMatrix{<:ContinuousUnivariateDistribution},
90-
} <: ContinuousMatrixDistribution
43+
struct MatrixOfUnivariate{
44+
S <: ValueSupport,
45+
Tdist <: UnivariateDistribution{S},
46+
Tdists <: AbstractMatrix{Tdist},
47+
} <: MatrixDistribution{S}
9148
dists::Tdists
9249
end
93-
Base.size(dist::ProductMatrixContinuousUnivariate) = size(dist.dists)
94-
function ArrayDist(dists::AbstractMatrix{<:ContinuousUnivariateDistribution})
95-
return ProductMatrixContinuousUnivariate(dists)
50+
Base.size(dist::MatrixOfUnivariate) = size(dist.dists)
51+
function ArrayDist(dists::AbstractMatrix{<:UnivariateDistribution})
52+
return MatrixOfUnivariate(dists)
9653
end
97-
function Distributions.logpdf(
98-
dist::ProductMatrixContinuousUnivariate,
99-
x::AbstractMatrix{<:Real},
100-
)
101-
return sum(logpdf.(dist.dists, x))
54+
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real})
55+
# Broadcasting here breaks Tracker for some reason
56+
return sum(zip(dist.dists, x)) do (dist, x)
57+
logpdf(dist, x)
58+
end
10259
end
103-
function Distributions.rand(
104-
rng::Random.AbstractRNG,
105-
dist::ProductMatrixContinuousUnivariate,
106-
)
60+
function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate)
10761
return rand.(Ref(rng), dist.dists)
10862
end
10963

110-
# Univariate discrete
64+
# Multivariate continuous
11165

112-
struct ProductVectorDiscreteUnivariate{
113-
Tdists <: AbstractVector{<:DiscreteUnivariateDistribution},
114-
} <: ContinuousMultivariateDistribution
66+
struct VectorOfMultivariate{
67+
S <: ValueSupport,
68+
Tdist <: MultivariateDistribution{S},
69+
Tdists <: AbstractVector{Tdist},
70+
} <: MatrixDistribution{S}
11571
dists::Tdists
11672
end
117-
Base.length(dist::ProductVectorDiscreteUnivariate) = length(dist.dists)
118-
Base.size(dist::ProductVectorDiscreteUnivariate) = (length(dist.dists[1]), length(dist))
119-
function ArrayDist(dists::AbstractVector{<:DiscreteUnivariateDistribution})
120-
ProductVectorDiscreteUnivariate(dists)
121-
end
122-
function Distributions.logpdf(
123-
dist::ProductVectorDiscreteUnivariate,
124-
x::AbstractVector{<:Integer},
125-
)
126-
return sum(logpdf.(dist.dists, x))
127-
end
128-
function Distributions.rand(
129-
rng::Random.AbstractRNG,
130-
dist::ProductVectorDiscreteUnivariate,
131-
)
132-
return rand.(Ref(rng), dist.dists)
133-
end
134-
135-
struct ProductMatrixDiscreteUnivariate{
136-
Tdists <: AbstractMatrix{<:DiscreteUnivariateDistribution},
137-
} <: DiscreteMatrixDistribution
138-
dists::Tdists
73+
Base.size(dist::VectorOfMultivariate) = (length(dist.dists[1]), length(dist))
74+
Base.length(dist::VectorOfMultivariate) = length(dist.dists)
75+
function ArrayDist(dists::AbstractVector{<:MultivariateDistribution})
76+
return VectorOfMultivariate(dists)
13977
end
140-
Base.size(dists::ProductMatrixDiscreteUnivariate) = size(dist.dists)
141-
function ArrayDist(dists::AbstractMatrix{<:DiscreteUnivariateDistribution})
142-
return ProductMatrixDiscreteUnivariate(dists)
78+
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
79+
return sum(logpdf(dist.dists[i], x[:,i]) for i in 1:length(dist))
14380
end
14481
function Distributions.logpdf(
145-
dist::ProductMatrixDiscreteUnivariate,
146-
x::AbstractMatrix{<:Real},
82+
dist::VectorOfMultivariate,
83+
x::AbstractVector{<:AbstractVector{<:Real}},
14784
)
148-
return sum(logpdf.(dist.dists, x))
85+
return sum(logpdf(dist.dists[i], x[i]) for i in 1:length(dist))
14986
end
150-
function Distributions.rand(rng::Random.AbstractRNG, dist::ProductMatrixDiscreteUnivariate)
151-
return rand.(Ref(rng), dist.dists)
87+
function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate)
88+
init = reshape(rand(rng, dist.dists[1]), :, 1)
89+
return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 2:length(dist); init = init)
15290
end

src/common.jl

+37-9
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ function Base.fill(
66
)
77
return track(fill, value, dims...)
88
end
9-
Tracker.@grad function Base.fill(value::Real, dims...)
9+
@grad function Base.fill(value::Real, dims...)
1010
return fill(data(value), dims...), function(Δ)
1111
size(Δ) dims && error("Dimension mismatch")
1212
return (sum(Δ), map(_->nothing, dims)...)
@@ -16,15 +16,15 @@ end
1616
## StatsFuns ##
1717

1818
logsumexp(x::TrackedArray) = track(logsumexp, x)
19-
Tracker.@grad function logsumexp(x::TrackedArray)
19+
@grad function logsumexp(x::TrackedArray)
2020
lse = logsumexp(data(x))
2121
return lse, Δ ->.* exp.(x .- lse),)
2222
end
2323

2424
## Linear algebra ##
2525

2626
LinearAlgebra.UpperTriangular(A::TrackedMatrix) = track(UpperTriangular, A)
27-
Tracker.@grad function LinearAlgebra.UpperTriangular(A::AbstractMatrix)
27+
@grad function LinearAlgebra.UpperTriangular(A::AbstractMatrix)
2828
return UpperTriangular(data(A)), Δ->(UpperTriangular(Δ),)
2929
end
3030

@@ -39,27 +39,27 @@ function turing_chol(A::AbstractMatrix, check)
3939
(chol.factors, chol.info)
4040
end
4141
turing_chol(A::TrackedMatrix, check) = track(turing_chol, A, check)
42-
Tracker.@grad function turing_chol(A::AbstractMatrix, check)
42+
@grad function turing_chol(A::AbstractMatrix, check)
4343
C, back = pullback(unsafe_cholesky, data(A), data(check))
4444
return (C.factors, C.info), Δ->back((factors=data(Δ[1]),))
4545
end
4646

4747
unsafe_cholesky(x, check) = cholesky(x, check=check)
48-
ZygoteRules.@adjoint function unsafe_cholesky::Real, check)
48+
@adjoint function unsafe_cholesky::Real, check)
4949
C = cholesky(Σ; check=check)
5050
return C, function::NamedTuple)
5151
issuccess(C) || return (zero(Σ), nothing)
5252
.factors[1, 1] / (2 * C.U[1, 1]), nothing)
5353
end
5454
end
55-
ZygoteRules.@adjoint function unsafe_cholesky::Diagonal, check)
55+
@adjoint function unsafe_cholesky::Diagonal, check)
5656
C = cholesky(Σ; check=check)
5757
return C, function::NamedTuple)
5858
issuccess(C) || (Diagonal(zero(diag.factors))), nothing)
5959
(Diagonal(diag.factors) .* inv.(2 .* C.factors.diag)), nothing)
6060
end
6161
end
62-
ZygoteRules.@adjoint function unsafe_cholesky::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}, check)
62+
@adjoint function unsafe_cholesky::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}, check)
6363
C = cholesky(Σ; check=check)
6464
return C, function::NamedTuple)
6565
issuccess(C) || return (zero.factors), nothing)
@@ -78,7 +78,7 @@ end
7878
# Specialised logdet for cholesky to target the triangle directly.
7979
logdet_chol_tri(U::AbstractMatrix) = 2 * sum(log, U[diagind(U)])
8080
logdet_chol_tri(U::TrackedMatrix) = track(logdet_chol_tri, U)
81-
Tracker.@grad function logdet_chol_tri(U::AbstractMatrix)
81+
@grad function logdet_chol_tri(U::AbstractMatrix)
8282
U_data = data(U)
8383
return logdet_chol_tri(U_data), Δ->(Matrix(Diagonal(2 .* Δ ./ diag(U_data))),)
8484
end
@@ -97,7 +97,7 @@ function zygote_ldiv(A::TrackedMatrix, B::AbstractVecOrMat)
9797
return track(zygote_ldiv, A, B)
9898
end
9999
zygote_ldiv(A::AbstractMatrix, B::TrackedVecOrMat) = track(zygote_ldiv, A, B)
100-
Tracker.@grad function zygote_ldiv(A, B)
100+
@grad function zygote_ldiv(A, B)
101101
Y, back = pullback(\, data(A), data(B))
102102
return Y, Δ->back(data(Δ))
103103
end
@@ -112,3 +112,31 @@ function SpecialFunctions.logabsgamma(x::TrackedReal)
112112
v = loggamma(x)
113113
return v, sign(data(v))
114114
end
115+
116+
# Some Tracker fixes
117+
118+
for i = 0:2, c = Tracker.combinations([:AbstractArray, :TrackedArray, :TrackedReal, :Number], i), f = [:hcat, :vcat]
119+
if :TrackedReal in c
120+
cnames = map(_ -> gensym(), c)
121+
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal}, xs::Union{AbstractArray,Number}...) =
122+
track($f, $(cnames...), x, xs...)
123+
end
124+
end
125+
@grad function vcat(x::Real)
126+
vcat(data(x)), (Δ) -> (Δ[1],)
127+
end
128+
@grad function vcat(x1::Real, x2::Real)
129+
vcat(data(x1), data(x2)), (Δ) -> (Δ[1], Δ[2])
130+
end
131+
@grad function vcat(x1::AbstractVector, x2::Real)
132+
vcat(data(x1), data(x2)), (Δ) -> (Δ[1:length(x1)], Δ[length(x1)+1])
133+
end
134+
135+
# Zygote fill has issues with non-numbers
136+
137+
@adjoint function fill(x::T, dims...) where {T}
138+
function zfill(x, dims...,)
139+
return reshape([x for i in 1:prod(dims)], dims)
140+
end
141+
pullback(zfill, x, dims...)
142+
end

src/matrixvariate.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,10 @@ end
201201

202202
## Adjoints
203203

204-
ZygoteRules.@adjoint function Distributions.Wishart(df::Real, S::AbstractMatrix{<:Real})
204+
@adjoint function Distributions.Wishart(df::Real, S::AbstractMatrix{<:Real})
205205
return pullback(TuringWishart, df, S)
206206
end
207-
ZygoteRules.@adjoint function Distributions.InverseWishart(df::Real, S::AbstractMatrix{<:Real})
207+
@adjoint function Distributions.InverseWishart(df::Real, S::AbstractMatrix{<:Real})
208208
return pullback(TuringInverseWishart, df, S)
209209
end
210210

0 commit comments

Comments
 (0)