Skip to content

Commit 46c939d

Browse files
authored
Merge pull request #85 from devmotion/dw/reversediff
ReverseDiff: Do not always compile tape
2 parents 9f7e81a + 77e05a9 commit 46c939d

3 files changed

Lines changed: 73 additions & 22 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LogDensityProblems"
22
uuid = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
33
authors = ["Tamas K. Papp <tkpapp@gmail.com>"]
4-
version = "0.11.4"
4+
version = "0.11.5"
55

66
[deps]
77
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"

src/AD_ReverseDiff.jl

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,49 @@ struct ReverseDiffLogDensity{L,C} <: ADGradientWrapper
1010
end
1111

1212
"""
13-
ADgradient(:ReverseDiff, ℓ)
14-
ADgradient(Val(:ReverseDiff), ℓ)
13+
ADgradient(:ReverseDiff, ℓ; compile=Val(false), x=nothing)
14+
ADgradient(Val(:ReverseDiff), ℓ; compile=Val(false), x=nothing)
1515
1616
Gradient using algorithmic/automatic differentiation via ReverseDiff.
17+
18+
If `compile isa Val{true}`, a tape of the log density computation is created upon construction of the gradient function and used in every evaluation of the gradient.
19+
One may provide an example input `x::AbstractVector` of the log density function.
20+
If `x` is `nothing` (the default), the tape is created with input `zeros(dimension(ℓ))`.
21+
22+
By default, no tape is created.
23+
24+
!!! note
25+
Using a compiled tape can lead to significant performance improvements when the gradient of the log density
26+
is evaluated multiple times (possibly for different inputs).
27+
However, if the log density contains branches, use of a compiled tape can lead to silently incorrect results.
1728
"""
18-
ADgradient(::Val{:ReverseDiff}, ℓ) = begin
19-
f = _logdensity_closure(ℓ)
20-
x = rand(dimension(ℓ)) #init random parameters
21-
tape = ReverseDiff.GradientTape(f, x)
22-
compiledtape = ReverseDiff.compile(tape)
23-
ReverseDiffLogDensity(ℓ, compiledtape)
29+
function ADgradient(::Val{:ReverseDiff}, ℓ;
30+
compile::Union{Val{true},Val{false}}=Val(false), x::Union{Nothing,AbstractVector}=nothing)
31+
ReverseDiffLogDensity(ℓ, _compiledtape(ℓ, compile, x))
2432
end
2533

26-
Base.show(io::IO, ∇ℓ::ReverseDiffLogDensity) = print(io, "ReverseDiff AD wrapper for ", ∇ℓ.ℓ)
34+
_compiledtape(ℓ, compile, x) = nothing
35+
_compiledtape(ℓ, ::Val{true}, ::Nothing) = _compiledtape(ℓ, Val(true), zeros(dimension(ℓ)))
36+
function _compiledtape(ℓ, ::Val{true}, x)
37+
tape = ReverseDiff.GradientTape(Base.Fix1(logdensity, ℓ), x)
38+
return ReverseDiff.compile(tape)
39+
end
40+
41+
function Base.show(io::IO, ∇ℓ::ReverseDiffLogDensity)
42+
print(io, "ReverseDiff AD wrapper for ", ∇ℓ.ℓ, " (")
43+
if ∇ℓ.compiledtape === nothing
44+
print(io, "no ")
45+
end
46+
print(io, "compiled tape)")
47+
end
2748

28-
function logdensity_and_gradient(∇ℓ::ReverseDiffLogDensity, x::AbstractVector{T}) where {T}
49+
function logdensity_and_gradient(∇ℓ::ReverseDiffLogDensity, x::AbstractVector)
2950
@unpack ℓ, compiledtape = ∇ℓ
3051
buffer = _diffresults_buffer(ℓ, x)
31-
result = ReverseDiff.gradient!(buffer, compiledtape, x)
52+
if compiledtape === nothing
53+
result = ReverseDiff.gradient!(buffer, Base.Fix1(logdensity, ℓ), x)
54+
else
55+
result = ReverseDiff.gradient!(buffer, compiledtape, x)
56+
end
3257
_diffresults_extract(result)
3358
end

test/runtests.jl

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,42 @@ end
9696
####
9797

9898
@testset "AD via ReverseDiff" begin
99-
= TestLogDensity(test_logdensity1)
100-
∇ℓ = ADgradient(:ReverseDiff, ℓ)
101-
@test repr(∇ℓ) == "ReverseDiff AD wrapper for " * repr(ℓ)
102-
@test dimension(∇ℓ) == 3
103-
@test capabilities(∇ℓ) LogDensityOrder(1)
104-
for _ in 1:100
105-
x = randn(3)
106-
@test @inferred(logdensity(∇ℓ, x)) test_logdensity1(x)
107-
@test @inferred(logdensity_and_gradient(∇ℓ, x))
108-
(test_logdensity1(x), test_gradient(x))
99+
= TestLogDensity()
100+
101+
∇ℓ_default = ADgradient(:ReverseDiff, ℓ)
102+
∇ℓ_nocompile = ADgradient(:ReverseDiff, ℓ; compile=Val(false))
103+
for ∇ℓ in (∇ℓ_default, ∇ℓ_nocompile)
104+
@test repr(∇ℓ) == "ReverseDiff AD wrapper for " * repr(ℓ) * " (no compiled tape)"
105+
end
106+
107+
∇ℓ_compile = ADgradient(:ReverseDiff, ℓ; compile=Val(true))
108+
∇ℓ_compile_x = ADgradient(:ReverseDiff, ℓ; compile=Val(true), x=rand(3))
109+
for ∇ℓ in (∇ℓ_compile, ∇ℓ_compile_x)
110+
@test repr(∇ℓ) == "ReverseDiff AD wrapper for " * repr(ℓ) * " (compiled tape)"
111+
end
112+
113+
for ∇ℓ in (∇ℓ_default, ∇ℓ_nocompile, ∇ℓ_compile, ∇ℓ_compile_x)
114+
@test dimension(∇ℓ) == 3
115+
@test capabilities(∇ℓ) LogDensityOrder(1)
116+
117+
for _ in 1:100
118+
x = rand(3)
119+
@test @inferred(logdensity(∇ℓ, x)) test_logdensity(x)
120+
@test @inferred(logdensity_and_gradient(∇ℓ, x))
121+
(test_logdensity(x), test_gradient(x))
122+
123+
x = -x
124+
@test @inferred(logdensity(∇ℓ, x)) test_logdensity(x)
125+
if ∇ℓ.compiledtape === nothing
126+
# Recompute tape => correct results
127+
@test @inferred(logdensity_and_gradient(∇ℓ, x))
128+
(test_logdensity(x), zero(x))
129+
else
130+
# Tape not recomputed => incorrect results, uses always the same branch
131+
@test @inferred(logdensity_and_gradient(∇ℓ, x))
132+
(test_logdensity1(x), test_gradient(x))
133+
end
134+
end
109135
end
110136
end
111137

0 commit comments

Comments
 (0)