Skip to content

Commit da39435

Browse files
authored
Add ReverseDiff support. (#23)
Add ReverseDiff support,fix #22, improve missing AD method message
1 parent b38e0ba commit da39435

5 files changed

Lines changed: 52 additions & 8 deletions

File tree

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
2121
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
2222
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
2323
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
24+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2425
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2526
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2627

2728
[targets]
28-
test = ["Distributions", "Documenter", "Flux", "ForwardDiff", "Test", "StatsBase"]
29+
test = ["Distributions", "Documenter", "Flux", "ForwardDiff", "ReverseDiff", "Test", "StatsBase"]

src/AD_ForwardDiff.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,10 @@ function show(io::IO, ℓ::ForwardDiffLogDensity)
1313
", w/ chunk size ", length(ℓ.gradientconfig.seeds))
1414
end
1515

16-
_anyargument(ℓ) = zeros(dimension(ℓ))
17-
1816
_default_chunk(ℓ) = ForwardDiff.Chunk(dimension(ℓ))
1917

2018
_default_gradientconfig(ℓ, chunk::ForwardDiff.Chunk) =
21-
ForwardDiff.GradientConfig(_value_closure(ℓ), _anyargument(ℓ), chunk)
19+
ForwardDiff.GradientConfig(_value_closure(ℓ), _vectorargument(ℓ), chunk)
2220

2321
_default_gradientconfig(ℓ, chunk::Integer) =
2422
_default_gradientconfig(ℓ, ForwardDiff.Chunk(chunk))
@@ -42,7 +40,7 @@ end
4240

4341
function logdensity(::Type{ValueGradient}, fℓ::ForwardDiffLogDensity, x::RealVector)
4442
@unpack ℓ, gradientconfig = fℓ
45-
result = DiffResults.GradientResult(_anyargument(ℓ)) # allocate a new result
43+
result = DiffResults.GradientResult(_vectorargument(ℓ)) # allocate a new result
4644
result = ForwardDiff.gradient!(result, _value_closure(ℓ), x, gradientconfig)
4745
ValueGradient(DiffResults.value(result), DiffResults.gradient(result))
4846
end

src/AD_ReverseDiff.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import .ReverseDiff
2+
3+
struct ReverseDiffLogDensity{L, C} <: ADGradientWrapper
4+
::L
5+
gradientconfig::C
6+
end
7+
8+
function ADgradient(::Val{:ReverseDiff}, ℓ::AbstractLogDensityProblem)
9+
cfg = ReverseDiff.GradientConfig(zeros(dimension(ℓ)))
10+
ReverseDiffLogDensity(ℓ, cfg)
11+
end
12+
13+
function logdensity(::Type{ValueGradient}, fℓ::ReverseDiffLogDensity, x::RealVector)
14+
@unpack ℓ, gradientconfig = fℓ
15+
result = DiffResults.GradientResult(_vectorargument(ℓ)) # allocate a new result
16+
result = ReverseDiff.gradient!(result, _value_closure(ℓ), x, gradientconfig)
17+
ValueGradient(DiffResults.value(result), DiffResults.gradient(result))
18+
end

src/LogDensityProblems.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,19 +224,32 @@ package being loaded.
224224
ADgradient(kind::Symbol, P::AbstractLogDensityProblem; kwargs...) =
225225
ADgradient(Val{kind}(), P; kwargs...)
226226

227+
function ADgradient(v::Val{kind}, P::AbstractLogDensityProblem; kwargs...) where kind
228+
@info "Don't know how to AD with $(kind), consider `import $(kind)` if there is such a package."
229+
throw(MethodError(ADgradient, (v, P)))
230+
end
231+
227232
"""
228233
$(SIGNATURES)
229234
230-
235+
A closure for the value of the log density.
231236
"""
232237
@inline _value_closure(ℓ) = x -> logdensity(Value, ℓ, x).value
233238

239+
"""
240+
$(SIGNATURES)
241+
242+
Make a vector argument for transformation `ℓ` using a Float64 vector.
243+
"""
244+
@inline _vectorargument(ℓ) = zeros(dimension(ℓ))
245+
234246

235247
# wrappers - specific
236248

237249
function __init__()
238250
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("AD_ForwardDiff.jl")
239251
@require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" include("AD_Flux.jl")
252+
@require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" include("AD_ReverseDiff.jl")
240253
end
241254

242255

test/runtests.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ using LogDensityProblems: Value, ValueGradient
33
using Test
44

55
using Distributions
6-
import ForwardDiff
7-
import Flux
6+
import ForwardDiff, Flux, ReverseDiff
87
using Parameters: @unpack
98
using DocStringExtensions: SIGNATURES
109
using TransformVariables
@@ -131,6 +130,15 @@ end
131130
@test logdensity(ValueGradient, ∇ℓ, x) ValueGradient(f(x), -6 .* x)
132131
end
133132

133+
@testset "AD via ReverseDiff" begin
134+
f(x) = -3*abs2(x[1])
135+
= TransformedLogDensity(as(Array, asℝ, 1), f)
136+
∇ℓ = ADgradient(:ReverseDiff, ℓ)
137+
x = randn(1)
138+
@test logdensity(Value, ℓ, x) logdensity(Value, ∇ℓ, x)
139+
@test logdensity(ValueGradient, ∇ℓ, x) ValueGradient(f(x), -6 .* x)
140+
end
141+
134142
@testset "@iffinite" begin
135143
flag = [0]
136144
f(x) = (y = LogDensityProblems.@iffinite x; flag[1] += 1; y)
@@ -151,3 +159,9 @@ end
151159
@test logdensity(ValueGradient, ∇P, [1.0]) ValueGradient(-1.0, [-2.0])
152160
@test logdensity(ValueGradient, ∇P, [-1.0]) ValueGradient(-Inf, randn(1))
153161
end
162+
163+
@testset "ADgradient missing method" begin
164+
msg = "Don't know how to AD with Foo, consider `import Foo` if there is such a package."
165+
P = TransformedLogDensity(as(Array, 1), x -> sum(abs2, x))
166+
@test_logs((:info, msg), @test_throws(MethodError, ADgradient(:Foo, P)))
167+
end

0 commit comments

Comments
 (0)