Skip to content

Commit 44a57e9

Browse files
authored
Fix test errors and update to ChainRules 1 and Distributions 0.25.15 (#198)
* Fix warning * Mark BetaPrime as broken * Test Julia 1.3 only with Ubuntu * Bump version * Not broken anymore?! * Fix numerical issues of BetaPrime for α=1 * Update to CR1 * Remove NNlib * Fix error * Add rrule for to_simplex * Prepare for addition of ChainRules definitions to StatsFuns * Prepare for use of CR in Distributions * Add Zygote workaround * Handle Zygote's `nothing` in tests * Remove ChainRules definitions that are moved to StatsFuns * Remove ChainRules definition that was moved to Distributions * Fix deprecation * Fix Tracker gradient of `poissonbinomial_pdf` * Add Zygote workarounds * Update list of broken distributions * Remove type piracy (default definition is sufficient) * Use `eachcol`
1 parent 0ee831b commit 44a57e9

19 files changed

+124
-339
lines changed

.github/workflows/AD.yml

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,30 @@ jobs:
1717
- '1'
1818
os:
1919
- ubuntu-latest
20-
- macOS-latest
2120
arch:
2221
- x64
2322
AD:
2423
- ForwardDiff
2524
- Tracker
2625
- ReverseDiff
2726
- Zygote
27+
include:
28+
- version: '1'
29+
os: macOS-latest
30+
arch: x64
31+
AD: ForwardDiff
32+
- version: '1'
33+
os: macOS-latest
34+
arch: x64
35+
AD: Tracker
36+
- version: '1'
37+
os: macOS-latest
38+
arch: x64
39+
AD: ReverseDiff
40+
- version: '1'
41+
os: macOS-latest
42+
arch: x64
43+
AD: Zygote
2844
steps:
2945
- uses: actions/checkout@v2
3046
- uses: julia-actions/setup-julia@v1

.github/workflows/Others.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@ jobs:
1616
- '1'
1717
os:
1818
- ubuntu-latest
19-
- macOS-latest
2019
arch:
2120
- x64
21+
include:
22+
- version: '1'
23+
os: macOS-latest
24+
arch: x64
2225
steps:
2326
- uses: actions/checkout@v2
2427
- uses: julia-actions/setup-julia@v1

Project.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DistributionsAD"
22
uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
3-
version = "0.6.29"
3+
version = "0.6.30"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -23,18 +23,18 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2323

2424
[compat]
2525
Adapt = "2, 3"
26-
ChainRules = "0.7, 0.8"
27-
ChainRulesCore = "0.9.44, 0.10"
26+
ChainRules = "1"
27+
ChainRulesCore = "1"
2828
Compat = "3.6"
2929
DiffRules = "0.1, 1.0"
30-
Distributions = "0.23.3, 0.24, 0.25"
30+
Distributions = "0.25.15"
3131
FillArrays = "0.8, 0.9, 0.10, 0.11"
3232
NaNMath = "0.3"
3333
PDMats = "0.9, 0.10, 0.11"
3434
Requires = "1"
3535
SpecialFunctions = "0.8, 0.9, 0.10, 1"
3636
StaticArrays = "0.12, 1.0"
3737
StatsBase = "0.32, 0.33"
38-
StatsFuns = "0.8, 0.9"
38+
StatsFuns = "0.9.10"
3939
ZygoteRules = "0.2"
4040
julia = "1.3"

src/DistributionsAD.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,6 @@ export TuringScalMvNormal,
4545
arraydist,
4646
filldist
4747

48-
# check if Distributions >= 0.24 by checking if a generic implementation of `pdf` is defined
49-
const DISTRIBUTIONS_HAS_GENERIC_UNIVARIATE_PDF = hasmethod(pdf, Tuple{UnivariateDistribution,Real})
50-
5148
include("common.jl")
5249
include("arraydist.jl")
5350
include("filldist.jl")

src/arraydist.jl

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,6 @@ function arraydist(dists::AbstractVector{<:UnivariateDistribution})
66
return Product(dists)
77
end
88

9-
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
10-
size(x, 1) == length(dist) ||
11-
throw(DimensionMismatch("Inconsistent array dimensions."))
12-
# `eachcol` breaks Zygote, so we use `view` directly
13-
return map(i -> sum(map(logpdf, dist.v, view(x, :, i))), axes(x, 2))
14-
end
15-
169
struct MatrixOfUnivariate{
1710
S <: ValueSupport,
1811
Tdist <: UnivariateDistribution{S},
@@ -56,8 +49,7 @@ function arraydist(dists::AbstractVector{<:MultivariateDistribution})
5649
end
5750

5851
function Distributions._logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
59-
# `eachcol` breaks Zygote, so we use `view` directly
60-
return sum(i -> logpdf(dist.dists[i], view(x, :, i)), axes(x, 2))
52+
return sum(((di, xi),) -> logpdf(di, xi), zip(dist.dists, eachcol(x)))
6153
end
6254
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
6355
return map(x -> logpdf(dist, x), x)

src/chainrules.jl

Lines changed: 1 addition & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -6,136 +6,6 @@
66
insupport = a <= x <= b,
77
diff = b - a,
88
c = insupport ? inv(diff) : inv(one(diff)),
9-
z = insupport ? zero(x) : oftype(x, NaN),
109
),
11-
(c, -c, z),
10+
(c, -c, ZeroTangent()),
1211
)
13-
14-
# StatsFuns: https://github.com/JuliaStats/StatsFuns.jl/pull/106
15-
16-
## Beta ##
17-
18-
@scalar_rule(
19-
betalogpdf::Real, β::Real, x::Number),
20-
@setup(z = digamma+ β)),
21-
(
22-
log(x) + z - digamma(α),
23-
log1p(-x) + z - digamma(β),
24-
- 1) / x + (1 - β) / (1 - x),
25-
),
26-
)
27-
28-
## Gamma ##
29-
30-
@scalar_rule(
31-
gammalogpdf(k::Real, θ::Real, x::Number),
32-
@setup(
33-
invθ = inv(θ),
34-
xoθ = invθ * x,
35-
z = xoθ - k,
36-
),
37-
(
38-
log(xoθ) - digamma(k),
39-
invθ * z,
40-
- (1 + z) / x,
41-
),
42-
)
43-
44-
## Chisq ##
45-
46-
@scalar_rule(
47-
chisqlogpdf(k::Real, x::Number),
48-
@setup(hk = k / 2),
49-
(
50-
(log(x) - logtwo - digamma(hk)) / 2,
51-
(hk - 1) / x - one(hk) / 2,
52-
),
53-
)
54-
55-
## FDist ##
56-
57-
@scalar_rule(
58-
fdistlogpdf(ν1::Real, ν2::Real, x::Number),
59-
@setup(
60-
xν1 = x * ν1,
61-
temp1 = xν1 + ν2,
62-
a = (x - 1) / temp1,
63-
νsum = ν1 + ν2,
64-
di = digamma(νsum / 2),
65-
),
66-
(
67-
(-log1p(ν2 / xν1) - ν2 * a + di - digamma(ν1 / 2)) / 2,
68-
(-log1p(xν1 / ν2) + ν1 * a + di - digamma(ν2 / 2)) / 2,
69-
((ν1 - 2) / x - ν1 * νsum / temp1) / 2,
70-
),
71-
)
72-
73-
## TDist ##
74-
75-
@scalar_rule(
76-
tdistlogpdf::Real, x::Number),
77-
@setup(
78-
νp1 = ν + 1,
79-
xsq = x^2,
80-
invν = inv(ν),
81-
a = xsq * invν,
82-
b = νp1 /+ xsq),
83-
),
84-
(
85-
(digamma(νp1 / 2) - digamma/ 2) + a * b - log1p(a) - invν) / 2,
86-
- x * b,
87-
),
88-
)
89-
90-
## Binomial ##
91-
92-
@scalar_rule(
93-
binomlogpdf(n::Real, p::Real, k::Real),
94-
@setup(z = digamma(n - k + 1)),
95-
(
96-
digamma(n + 2) - z + log1p(-p) - 1 / (1 + n),
97-
(k / p - n) / (1 - p),
98-
z - digamma(k + 1) + logit(p),
99-
),
100-
)
101-
102-
## Poisson ##
103-
104-
@scalar_rule(
105-
poislogpdf::Number, x::Number),
106-
((iszero(x) && iszero(λ) ? zero(x / λ) : x / λ) - 1, log(λ) - digamma(x + 1)),
107-
)
108-
109-
## PoissonBinomial
110-
111-
function ChainRulesCore.rrule(
112-
::typeof(Distributions.poissonbinomial_pdf_fft), p::AbstractVector{<:Real}
113-
)
114-
y = Distributions.poissonbinomial_pdf_fft(p)
115-
A = poissonbinomial_partialderivatives(p)
116-
function poissonbinomial_pdf_fft_pullback(Δy)
117-
= InplaceableThunk(
118-
@thunk(A * Δy),
119-
Δ -> LinearAlgebra.mul!(Δ, A, Δy, true, true),
120-
)
121-
return (NO_FIELDS, p̄)
122-
end
123-
return y, poissonbinomial_pdf_fft_pullback
124-
end
125-
126-
if isdefined(Distributions, :poissonbinomial_pdf)
127-
function ChainRulesCore.rrule(
128-
::typeof(Distributions.poissonbinomial_pdf), p::AbstractVector{<:Real}
129-
)
130-
y = Distributions.poissonbinomial_pdf(p)
131-
A = poissonbinomial_partialderivatives(p)
132-
function poissonbinomial_pdf_pullback(Δy)
133-
= InplaceableThunk(
134-
@thunk(A * Δy),
135-
Δ -> LinearAlgebra.mul!(Δ, A, Δy, true, true),
136-
)
137-
return (NO_FIELDS, p̄)
138-
end
139-
return y, poissonbinomial_pdf_pullback
140-
end
141-
end

src/common.jl

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -46,45 +46,3 @@ parameterless_type(x) = parameterless_type(typeof(x))
4646
parameterless_type(x::Type) = __parameterless_type(x)
4747

4848
@non_differentiable adapt_randn(::Any...)
49-
50-
# PoissonBinomial
51-
52-
# compute matrix of partial derivatives [∂P(X=j-1)/∂pᵢ]_{i=1,…,n; j=1,…,n+1}
53-
#
54-
# This uses the same dynamic programming "trick" as for the computation of the primals
55-
# in Distributions
56-
#
57-
# Reference (for the primal):
58-
#
59-
# Marlin A. Thomas & Audrey E. Taub (1982)
60-
# Calculating binomial probabilities when the trial probabilities are unequal,
61-
# Journal of Statistical Computation and Simulation, 14:2, 125-131, DOI: 10.1080/00949658208810534
62-
function poissonbinomial_partialderivatives(p)
63-
n = length(p)
64-
A = zeros(eltype(p), n, n + 1)
65-
@inbounds for j in 1:n
66-
A[j, end] = 1
67-
end
68-
@inbounds for (i, pi) in enumerate(p)
69-
qi = 1 - pi
70-
for k in (n - i + 1):n
71-
kp1 = k + 1
72-
for j in 1:(i - 1)
73-
A[j, k] = pi * A[j, k] + qi * A[j, kp1]
74-
end
75-
for j in (i+1):n
76-
A[j, k] = pi * A[j, k] + qi * A[j, kp1]
77-
end
78-
end
79-
for j in 1:(i-1)
80-
A[j, end] *= pi
81-
end
82-
for j in (i+1):n
83-
A[j, end] *= pi
84-
end
85-
end
86-
@inbounds for j in 1:n, i in 1:n
87-
A[i, j] -= A[i, j+1]
88-
end
89-
return A
90-
end

src/forwarddiff.jl

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,3 @@ function nbinomlogpdf(r::ForwardDiff.Dual{T}, p::Real, k::Int) where {T}
4848
Δ_r = ForwardDiff.partials(r) * _nbinomlogpdf_grad_1(val_r, p, k)
4949
return FD(nbinomlogpdf(val_r, p, k), Δ_r)
5050
end
51-
52-
## ForwardDiff broadcasting support ##
53-
# If we use Distributions >= 0.24, then `DISTRIBUTIONS_HAS_GENERIC_UNIVARIATE_PDF` is `true`.
54-
# In Distributions 0.24 `logpdf` is defined for inputs of type `Real` which are then
55-
# converted to the support of the distributions (such as integers) in their concrete implementations.
56-
# Thus it is no needed to have a special function for dual numbers that performs the conversion
57-
# (and actually this method leads to method ambiguity errors since even discrete distributions now
58-
# define logpdf(::MyDistribution, ::Real), see, e.g.,
59-
# JuliaStats/Distributions.jl@ae2d6c5/src/univariate/discrete/binomial.jl#L119).
60-
if !DISTRIBUTIONS_HAS_GENERIC_UNIVARIATE_PDF
61-
@eval begin
62-
function Distributions.logpdf(d::DiscreteUnivariateDistribution, k::ForwardDiff.Dual)
63-
return logpdf(d, convert(Integer, ForwardDiff.value(k)))
64-
end
65-
end
66-
end

src/matrixvariate.jl

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -214,24 +214,3 @@ function Distributions._rand!(rng::AbstractRNG, d::TuringInverseWishart, A::Abst
214214
X = Distributions._rand!(rng, TuringWishart(d.df, inv(cholesky(d.S))), A)
215215
A .= inv(cholesky!(X))
216216
end
217-
218-
# Only needed in Distributions < 0.24
219-
if !DISTRIBUTIONS_HAS_GENERIC_UNIVARIATE_PDF
220-
for T in (:MatrixBeta, :MatrixNormal, :Wishart, :InverseWishart,
221-
:TuringWishart, :TuringInverseWishart,
222-
:VectorOfMultivariate, :MatrixOfUnivariate)
223-
@eval begin
224-
Distributions.loglikelihood(d::$T, X::AbstractMatrix{<:Real}) = logpdf(d, X)
225-
function Distributions.loglikelihood(d::$T, X::AbstractArray{<:Real,3})
226-
(size(X, 1), size(X, 2)) == size(d) || throw(DimensionMismatch("Inconsistent array dimensions."))
227-
return sum(i -> _logpdf(d, view(X, :, :, i)), axes(X, 3))
228-
end
229-
function Distributions.loglikelihood(
230-
d::$T,
231-
X::AbstractArray{<:AbstractMatrix{<:Real}},
232-
)
233-
return sum(x -> logpdf(d, x), X)
234-
end
235-
end
236-
end
237-
end

src/reversediff.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ using ..DistributionsAD: DistributionsAD
1818

1919

2020
import SpecialFunctions, NaNMath
21-
import ..DistributionsAD: turing_chol, symm_turing_chol, _mv_categorical_logpdf, adapt_randn,
21+
import ..DistributionsAD: turing_chol, symm_turing_chol, adapt_randn,
2222
simplex_logpdf
2323
import Base.Broadcast: materialize
2424
import StatsFuns: logsumexp

src/tracker.jl

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -261,23 +261,17 @@ end
261261
PoissonBinomial(p::TrackedArray{<:Real}; check_args=true) =
262262
TuringPoissonBinomial(p; check_args = check_args)
263263

264-
poissonbinomial_pdf_fft(x::TrackedArray) = track(poissonbinomial_pdf_fft, x)
265-
@grad function poissonbinomial_pdf_fft(x::TrackedArray)
266-
x_data = data(x)
267-
value = poissonbinomial_pdf_fft(x_data)
268-
A = poissonbinomial_partialderivatives(x_data)
269-
poissonbinomial_pdf_fft_pullback(Δ) = (A * Δ,)
270-
return value, poissonbinomial_pdf_fft_pullback
271-
end
272-
273-
if isdefined(Distributions, :poissonbinomial_pdf)
274-
Distributions.poissonbinomial_pdf(x::TrackedArray) = track(Distributions.poissonbinomial_pdf, x)
275-
@grad function Distributions.poissonbinomial_pdf(x::TrackedArray)
276-
x_data = data(x)
277-
value = Distributions.poissonbinomial_pdf(x_data)
278-
A = poissonbinomial_partialderivatives(x_data)
279-
poissonbinomial_pdf_pullback(Δ) = (A * Δ,)
280-
return value, poissonbinomial_pdf_pullback
264+
for f in (:poissonbinomial_pdf, :poissonbinomial_pdf_fft)
265+
pullback = Symbol(f, :_pullback)
266+
@eval begin
267+
Distributions.$f(x::TrackedArray) = track(Distributions.$f, x)
268+
@grad function Distributions.$f(x::TrackedArray)
269+
x_data = data(x)
270+
value = Distributions.$f(x_data)
271+
A = Distributions.poissonbinomial_pdf_partialderivatives(x_data)
272+
$pullback(Δ) = (A * Δ,)
273+
return value, $pullback
274+
end
281275
end
282276
end
283277

0 commit comments

Comments
 (0)