Skip to content

Commit 9f7e81a

Browse files
authored
Merge pull request #84 from devmotion/dw/enzyme
Add support for `Enzyme`
2 parents 4d618d5 + 2cf12df commit 9f7e81a

4 files changed

Lines changed: 109 additions & 2 deletions

File tree

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LogDensityProblems"
22
uuid = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
33
authors = ["Tamas K. Papp <tkpapp@gmail.com>"]
4-
version = "0.11.3"
4+
version = "0.11.4"
55

66
[deps]
77
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
@@ -27,6 +27,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
2727
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
2828
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
2929
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
30+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
3031
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
3132
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
3233
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
@@ -35,4 +36,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
3536
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3637

3738
[targets]
38-
test = ["BenchmarkTools", "Distributions", "Documenter", "ForwardDiff", "Test", "StatsBase", "StatsFuns", "Tracker", "Zygote", "ReverseDiff"]
39+
test = ["BenchmarkTools", "Distributions", "Documenter", "ForwardDiff", "Pkg", "ReverseDiff", "StatsBase", "StatsFuns", "Test", "Tracker", "Zygote"]

src/AD_Enzyme.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import .Enzyme
2+
3+
struct EnzymeGradientLogDensity{L,M<:Union{Enzyme.ForwardMode,Enzyme.ReverseMode},S} <: ADGradientWrapper
4+
::L
5+
mode::M
6+
shadow::S # only used in forward mode
7+
end
8+
9+
"""
10+
ADgradient(:Enzyme, ℓ; kwargs...)
11+
ADgradient(Val(:Enzyme), ℓ; kwargs...)
12+
13+
Gradient using algorithmic/automatic differentiation via Enzyme.
14+
15+
# Keyword arguments
16+
17+
- `mode::Enzyme.Mode`: Differentiation mode (default: `Enzyme.Reverse`).
18+
Currently only `Enzyme.Reverse` and `Enzyme.Forward` are supported.
19+
20+
- `shadow`: Collection of one-hot vectors for each entry of the inputs `x` to the log density
21+
`ℓ`, or `nothing` (default: `nothing`). This keyword argument is only used in forward
22+
mode. By default, it will be recomputed in every call of `logdensity_and_gradient(ℓ, x)`.
23+
For performance reasons it is recommended to compute it only once when calling `ADgradient`.
24+
The one-hot vectors can be constructed, e.g., with `Enzyme.onehot(x)`.
25+
"""
26+
function ADgradient(::Val{:Enzyme}, ℓ; mode::Enzyme.Mode = Enzyme.Reverse, shadow = nothing)
27+
mode isa Union{Enzyme.ForwardMode,Enzyme.ReverseMode} ||
28+
throw(ArgumentError("currently automatic differentiation via Enzyme only supports " *
29+
"`Enzyme.Forward` and `Enzyme.Reverse` modes"))
30+
if mode isa Enzyme.ReverseMode && shadow !== nothing
31+
@info "keyword argument `shadow` is ignored in reverse mode"
32+
shadow = nothing
33+
end
34+
return EnzymeGradientLogDensity(ℓ, mode, shadow)
35+
end
36+
37+
function Base.show(io::IO, ∇ℓ::EnzymeGradientLogDensity)
38+
print(io, "Enzyme AD wrapper for ", ∇ℓ.ℓ, " with ",
39+
∇ℓ.mode isa Enzyme.ForwardMode ? "forward" : "reverse", " mode")
40+
end
41+
42+
function logdensity_and_gradient(∇ℓ::EnzymeGradientLogDensity{<:Any,<:Enzyme.ForwardMode},
43+
x::AbstractVector)
44+
@unpack ℓ, mode, shadow = ∇ℓ
45+
_shadow = shadow === nothing ? Enzyme.onehot(x) : shadow
46+
y, ∂ℓ_∂x = Enzyme.autodiff(mode, Base.Fix1(logdensity, ℓ), Enzyme.BatchDuplicated,
47+
Enzyme.BatchDuplicated(x, _shadow))
48+
return y, collect(∂ℓ_∂x)
49+
end
50+
51+
function logdensity_and_gradient(∇ℓ::EnzymeGradientLogDensity{<:Any,<:Enzyme.ReverseMode},
52+
x::AbstractVector)
53+
@unpack ℓ, mode = ∇ℓ
54+
# Currently it is not possible to retrieve the primal together with the derivatives.
55+
# Ref: https://github.com/EnzymeAD/Enzyme.jl/issues/107
56+
y = logdensity(ℓ, x)
57+
∂ℓ_∂x = zero(x)
58+
Enzyme.autodiff(mode, Base.Fix1(logdensity, ℓ), Enzyme.Active,
59+
Enzyme.Duplicated(x, ∂ℓ_∂x))
60+
y, ∂ℓ_∂x
61+
end

src/LogDensityProblems.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ function __init__()
241241
@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("AD_Tracker.jl")
242242
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" include("AD_Zygote.jl")
243243
@require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" include("AD_ReverseDiff.jl")
244+
@require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" include("AD_Enzyme.jl")
244245
end
245246

246247
####

test/runtests.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
@static if VERSION >= v"1.6"
2+
# Enzyme only supports Julia >= 1.6
3+
# We add it without messing with the existing, possibly precompiled, dependencies
4+
using Pkg
5+
Pkg.add(Pkg.PackageSpec(; name="Enzyme", uuid="7da242da-08ed-463a-9acd-ee780be4f1d9");
6+
preserve=Pkg.PRESERVE_ALL)
7+
8+
import Enzyme
9+
struct EnzymeTestMode <: Enzyme.Mode end
10+
end
11+
112
using LogDensityProblems, Test, Distributions, TransformVariables, BenchmarkTools
213
import LogDensityProblems: capabilities, dimension, logdensity
314
using LogDensityProblems: logdensity_and_gradient, LogDensityOrder
@@ -149,6 +160,39 @@ end
149160
end
150161
end
151162

163+
@static if VERSION >= v"1.6"
164+
@testset "AD via Enzyme" begin
165+
= TestLogDensity(test_logdensity1)
166+
167+
∇ℓ_reverse = ADgradient(:Enzyme, ℓ)
168+
@test ∇ℓ_reverse === ADgradient(:Enzyme, ℓ; mode=Enzyme.Reverse)
169+
@test repr(∇ℓ_reverse) == "Enzyme AD wrapper for " * repr(ℓ) * " with reverse mode"
170+
171+
∇ℓ_forward = ADgradient(:Enzyme, ℓ; mode=Enzyme.Forward)
172+
∇ℓ_forward_shadow = ADgradient(:Enzyme, ℓ;
173+
mode=Enzyme.Forward,
174+
shadow=Enzyme.onehot(Vector{Float64}(undef, dimension(ℓ))))
175+
for ∇ℓ in (∇ℓ_forward, ∇ℓ_forward_shadow)
176+
@test repr(∇ℓ) == "Enzyme AD wrapper for " * repr(ℓ) * " with forward mode"
177+
end
178+
179+
for ∇ℓ in (∇ℓ_reverse, ∇ℓ_forward, ∇ℓ_forward_shadow)
180+
@test dimension(∇ℓ) == 3
181+
@test capabilities(∇ℓ) LogDensityOrder(1)
182+
for _ in 1:100
183+
x = randn(3)
184+
@test @inferred(logdensity(∇ℓ, x)) test_logdensity1(x)
185+
@test logdensity_and_gradient(∇ℓ, x) (test_logdensity1(x), test_gradient(x))
186+
end
187+
end
188+
189+
# Branches in `ADgradient`
190+
@test_throws ArgumentError ADgradient(:Enzyme, ℓ; mode=EnzymeTestMode())
191+
∇ℓ = @test_logs (:info, "keyword argument `shadow` is ignored in reverse mode") ADgradient(:Enzyme, ℓ; shadow = (1,))
192+
@test ∇ℓ.shadow === nothing
193+
end
194+
end
195+
152196
@testset "ADgradient missing method" begin
153197
msg = "Don't know how to AD with Foo, consider `import Foo` if there is such a package."
154198
P = TransformedLogDensity(as(Array, 1), x -> sum(abs2, x))

0 commit comments

Comments
 (0)