|
1 |
| -# Multivariate continuous |
| 1 | +# Univariate |
2 | 2 |
|
3 |
| -struct ProductVectorContinuousMultivariate{ |
4 |
| - Tdists <: AbstractVector{<:ContinuousMultivariateDistribution}, |
5 |
| -} <: ContinuousMatrixDistribution |
6 |
| - dists::Tdists |
7 |
| -end |
8 |
| -Base.size(dist::ProductVectorContinuousMultivariate) = (length(dist.dists[1]), length(dist)) |
9 |
| -Base.length(dist::ProductVectorContinuousMultivariate) = length(dist.dists) |
10 |
| -function ArrayDist(dists::AbstractVector{<:ContinuousMultivariateDistribution}) |
11 |
| - return ProductVectorContinuousMultivariate(dists) |
12 |
| -end |
13 |
| -function Distributions.logpdf( |
14 |
| - dist::ProductVectorContinuousMultivariate, |
15 |
| - x::AbstractMatrix{<:Real}, |
16 |
| -) |
17 |
| - return sum(logpdf(dist.dists[i], x[:,i]) for i in 1:length(dist)) |
18 |
| -end |
19 |
| -function Distributions.logpdf( |
20 |
| - dist::ProductVectorContinuousMultivariate, |
21 |
| - x::AbstractVector{<:AbstractVector{<:Real}}, |
22 |
| -) |
23 |
| - return sum(logpdf(dist.dists[i], x[i]) for i in 1:length(dist)) |
24 |
| -end |
25 |
| -function Distributions.rand( |
26 |
| - rng::Random.AbstractRNG, |
27 |
| - dist::ProductVectorContinuousMultivariate, |
28 |
| -) |
29 |
| - return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 1:length(dist)) |
30 |
| -end |
| 3 | +const VectorOfUnivariate{ |
| 4 | + S <: ValueSupport, |
| 5 | + Tdist <: UnivariateDistribution{S}, |
| 6 | + Tdists <: AbstractVector{Tdist}, |
| 7 | +} = Distributions.Product{S, Tdist, Tdists} |
31 | 8 |
|
32 |
| -# Multivariate discrete |
| 9 | +function ArrayDist(dists::AbstractVector{<:Normal{T}}) where {T} |
| 10 | + if T <: TrackedReal |
| 11 | + init_m = dists[1].μ |
| 12 | + means = mapreduce(vcat, drop(dists, 1); init = init_m) do d |
| 13 | + d.μ |
| 14 | + end |
| 15 | + init_v = dists[1].σ^2 |
| 16 | + vars = mapreduce(vcat, drop(dists, 1); init = init_v) do d |
| 17 | + d.σ^2 |
| 18 | + end |
| 19 | + else |
| 20 | + means = [d.μ for d in dists] |
| 21 | + vars = [d.σ^2 for d in dists] |
| 22 | + end |
33 | 23 |
|
34 |
| -struct ProductVectorDiscreteMultivariate{ |
35 |
| - Tdists <: AbstractVector{<:DiscreteMultivariateDistribution}, |
36 |
| -} <: DiscreteMatrixDistribution |
37 |
| - dists::Tdists |
| 24 | + return MvNormal(means, vars) |
38 | 25 | end
|
39 |
| -Base.size(dist::ProductVectorDiscreteMultivariate) = (length(dist.dists[1]), length(dist)) |
40 |
| -Base.length(dist::ProductVectorDiscreteMultivariate) = length(dist.dists) |
41 |
| -function ArrayDist(dists::AbstractVector{<:DiscreteMultivariateDistribution}) |
42 |
| - return ProductVectorDiscreteMultivariate(dists) |
43 |
| -end |
44 |
| -function Distributions.logpdf( |
45 |
| - dist::ProductVectorDiscreteMultivariate, |
46 |
| - x::AbstractMatrix{<:Integer}, |
47 |
| -) |
48 |
| - return sum(logpdf(dist.dists[i], x[:,i]) for i in 1:length(dist)) |
| 26 | +function ArrayDist(dists::AbstractVector{<:UnivariateDistribution}) |
| 27 | + return Distributions.Product(dists) |
49 | 28 | end
|
50 |
| -function Distributions.logpdf( |
51 |
| - dist::ProductVectorDiscreteMultivariate, |
52 |
| - x::AbstractVector{<:AbstractVector{<:Integer}}, |
53 |
| -) |
54 |
| - return sum(logpdf(dist.dists[i], x[i]) for i in 1:length(dist)) |
| 29 | +function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractVector{<:Real}) |
| 30 | + return sum(logpdf.(dist.v, x)) |
55 | 31 | end
|
56 |
| -function Distributions.rand( |
57 |
| - rng::Random.AbstractRNG, |
58 |
| - dist::ProductVectorDiscreteMultivariate, |
59 |
| -) |
60 |
| - return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 1:length(dist)) |
61 |
| -end |
62 |
| - |
63 |
| -# Univariate continuous |
64 |
| - |
65 |
| -struct ProductVectorContinuousUnivariate{ |
66 |
| - Tdists <: AbstractVector{<:ContinuousUnivariateDistribution}, |
67 |
| -} <: ContinuousMultivariateDistribution |
68 |
| - dists::Tdists |
69 |
| -end |
70 |
| -Base.length(dist::ProductVectorContinuousUnivariate) = length(dist.dists) |
71 |
| -Base.size(dist::ProductVectorContinuousUnivariate) = (length(dist),) |
72 |
| -function ArrayDist(dists::AbstractVector{<:ContinuousUnivariateDistribution}) |
73 |
| - return ProductVectorContinuousUnivariate(dists) |
| 32 | +function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real}) |
| 33 | + # Any other more efficient implementation breaks Zygote |
| 34 | + return [logpdf(dist, x[:,i]) for i in 1:size(x, 2)] |
74 | 35 | end
|
75 | 36 | function Distributions.logpdf(
|
76 |
| - dist::ProductVectorContinuousUnivariate, |
77 |
| - x::AbstractVector{<:Real}, |
78 |
| -) |
79 |
| - return sum(logpdf.(dist.dists, x)) |
80 |
| -end |
81 |
| -function Distributions.rand( |
82 |
| - rng::Random.AbstractRNG, |
83 |
| - dist::ProductVectorContinuousUnivariate, |
| 37 | + dist::VectorOfUnivariate, |
| 38 | + x::AbstractVector{<:AbstractMatrix{<:Real}}, |
84 | 39 | )
|
85 |
| - return rand.(Ref(rng), dist.dists) |
| 40 | + return logpdf.(Ref(dist), x) |
86 | 41 | end
|
87 | 42 |
|
88 |
| -struct ProductMatrixContinuousUnivariate{ |
89 |
| - Tdists <: AbstractMatrix{<:ContinuousUnivariateDistribution}, |
90 |
| -} <: ContinuousMatrixDistribution |
| 43 | +struct MatrixOfUnivariate{ |
| 44 | + S <: ValueSupport, |
| 45 | + Tdist <: UnivariateDistribution{S}, |
| 46 | + Tdists <: AbstractMatrix{Tdist}, |
| 47 | +} <: MatrixDistribution{S} |
91 | 48 | dists::Tdists
|
92 | 49 | end
|
93 |
| -Base.size(dist::ProductMatrixContinuousUnivariate) = size(dist.dists) |
94 |
| -function ArrayDist(dists::AbstractMatrix{<:ContinuousUnivariateDistribution}) |
95 |
| - return ProductMatrixContinuousUnivariate(dists) |
| 50 | +Base.size(dist::MatrixOfUnivariate) = size(dist.dists) |
| 51 | +function ArrayDist(dists::AbstractMatrix{<:UnivariateDistribution}) |
| 52 | + return MatrixOfUnivariate(dists) |
96 | 53 | end
|
97 |
| -function Distributions.logpdf( |
98 |
| - dist::ProductMatrixContinuousUnivariate, |
99 |
| - x::AbstractMatrix{<:Real}, |
100 |
| -) |
101 |
| - return sum(logpdf.(dist.dists, x)) |
| 54 | +function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real}) |
| 55 | + # Broadcasting here breaks Tracker for some reason |
| 56 | + return sum(zip(dist.dists, x)) do (dist, x) |
| 57 | + logpdf(dist, x) |
| 58 | + end |
102 | 59 | end
|
103 |
| -function Distributions.rand( |
104 |
| - rng::Random.AbstractRNG, |
105 |
| - dist::ProductMatrixContinuousUnivariate, |
106 |
| -) |
| 60 | +function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate) |
107 | 61 | return rand.(Ref(rng), dist.dists)
|
108 | 62 | end
|
109 | 63 |
|
110 |
| -# Univariate discrete |
| 64 | +# Multivariate continuous |
111 | 65 |
|
112 |
| -struct ProductVectorDiscreteUnivariate{ |
113 |
| - Tdists <: AbstractVector{<:DiscreteUnivariateDistribution}, |
114 |
| -} <: ContinuousMultivariateDistribution |
| 66 | +struct VectorOfMultivariate{ |
| 67 | + S <: ValueSupport, |
| 68 | + Tdist <: MultivariateDistribution{S}, |
| 69 | + Tdists <: AbstractVector{Tdist}, |
| 70 | +} <: MatrixDistribution{S} |
115 | 71 | dists::Tdists
|
116 | 72 | end
|
117 |
| -Base.length(dist::ProductVectorDiscreteUnivariate) = length(dist.dists) |
118 |
| -Base.size(dist::ProductVectorDiscreteUnivariate) = (length(dist.dists[1]), length(dist)) |
119 |
| -function ArrayDist(dists::AbstractVector{<:DiscreteUnivariateDistribution}) |
120 |
| - ProductVectorDiscreteUnivariate(dists) |
121 |
| -end |
122 |
| -function Distributions.logpdf( |
123 |
| - dist::ProductVectorDiscreteUnivariate, |
124 |
| - x::AbstractVector{<:Integer}, |
125 |
| -) |
126 |
| - return sum(logpdf.(dist.dists, x)) |
127 |
| -end |
128 |
| -function Distributions.rand( |
129 |
| - rng::Random.AbstractRNG, |
130 |
| - dist::ProductVectorDiscreteUnivariate, |
131 |
| -) |
132 |
| - return rand.(Ref(rng), dist.dists) |
133 |
| -end |
134 |
| - |
135 |
| -struct ProductMatrixDiscreteUnivariate{ |
136 |
| - Tdists <: AbstractMatrix{<:DiscreteUnivariateDistribution}, |
137 |
| -} <: DiscreteMatrixDistribution |
138 |
| - dists::Tdists |
| 73 | +Base.size(dist::VectorOfMultivariate) = (length(dist.dists[1]), length(dist)) |
| 74 | +Base.length(dist::VectorOfMultivariate) = length(dist.dists) |
| 75 | +function ArrayDist(dists::AbstractVector{<:MultivariateDistribution}) |
| 76 | + return VectorOfMultivariate(dists) |
139 | 77 | end
|
140 |
| -Base.size(dists::ProductMatrixDiscreteUnivariate) = size(dist.dists) |
141 |
| -function ArrayDist(dists::AbstractMatrix{<:DiscreteUnivariateDistribution}) |
142 |
| - return ProductMatrixDiscreteUnivariate(dists) |
| 78 | +function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real}) |
| 79 | + return sum(logpdf(dist.dists[i], x[:,i]) for i in 1:length(dist)) |
143 | 80 | end
|
144 | 81 | function Distributions.logpdf(
|
145 |
| - dist::ProductMatrixDiscreteUnivariate, |
146 |
| - x::AbstractMatrix{<:Real}, |
| 82 | + dist::VectorOfMultivariate, |
| 83 | + x::AbstractVector{<:AbstractVector{<:Real}}, |
147 | 84 | )
|
148 |
| - return sum(logpdf.(dist.dists, x)) |
| 85 | + return sum(logpdf(dist.dists[i], x[i]) for i in 1:length(dist)) |
149 | 86 | end
|
150 |
| -function Distributions.rand(rng::Random.AbstractRNG, dist::ProductMatrixDiscreteUnivariate) |
151 |
| - return rand.(Ref(rng), dist.dists) |
| 87 | +function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate) |
| 88 | + init = reshape(rand(rng, dist.dists[1]), :, 1) |
| 89 | + return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 2:length(dist); init = init) |
152 | 90 | end
|
0 commit comments