|
| 1 | +import .Enzyme |
| 2 | + |
| 3 | +struct EnzymeGradientLogDensity{L,M<:Union{Enzyme.ForwardMode,Enzyme.ReverseMode},S} <: ADGradientWrapper |
| 4 | + ℓ::L |
| 5 | + mode::M |
| 6 | + shadow::S # only used in forward mode |
| 7 | +end |
| 8 | + |
| 9 | +""" |
| 10 | + ADgradient(:Enzyme, ℓ; kwargs...) |
| 11 | + ADgradient(Val(:Enzyme), ℓ; kwargs...) |
| 12 | +
|
| 13 | +Gradient using algorithmic/automatic differentiation via Enzyme. |
| 14 | +
|
| 15 | +# Keyword arguments |
| 16 | +
|
| 17 | +- `mode::Enzyme.Mode`: Differentiation mode (default: `Enzyme.Reverse`). |
| 18 | + Currently only `Enzyme.Reverse` and `Enzyme.Forward` are supported. |
| 19 | +
|
| 20 | +- `shadow`: Collection of one-hot vectors for each entry of the inputs `x` to the log density |
| 21 | + `ℓ`, or `nothing` (default: `nothing`). This keyword argument is only used in forward |
| 22 | + mode. By default, it will be recomputed in every call of `logdensity_and_gradient(ℓ, x)`. |
| 23 | + For performance reasons it is recommended to compute it only once when calling `ADgradient`. |
| 24 | + The one-hot vectors can be constructed, e.g., with `Enzyme.onehot(x)`. |
| 25 | +""" |
| 26 | +function ADgradient(::Val{:Enzyme}, ℓ; mode::Enzyme.Mode = Enzyme.Reverse, shadow = nothing) |
| 27 | + mode isa Union{Enzyme.ForwardMode,Enzyme.ReverseMode} || |
| 28 | + throw(ArgumentError("currently automatic differentiation via Enzyme only supports " * |
| 29 | + "`Enzyme.Forward` and `Enzyme.Reverse` modes")) |
| 30 | + if mode isa Enzyme.ReverseMode && shadow !== nothing |
| 31 | + @info "keyword argument `shadow` is ignored in reverse mode" |
| 32 | + shadow = nothing |
| 33 | + end |
| 34 | + return EnzymeGradientLogDensity(ℓ, mode, shadow) |
| 35 | +end |
| 36 | + |
| 37 | +function Base.show(io::IO, ∇ℓ::EnzymeGradientLogDensity) |
| 38 | + print(io, "Enzyme AD wrapper for ", ∇ℓ.ℓ, " with ", |
| 39 | + ∇ℓ.mode isa Enzyme.ForwardMode ? "forward" : "reverse", " mode") |
| 40 | +end |
| 41 | + |
| 42 | +function logdensity_and_gradient(∇ℓ::EnzymeGradientLogDensity{<:Any,<:Enzyme.ForwardMode}, |
| 43 | + x::AbstractVector) |
| 44 | + @unpack ℓ, mode, shadow = ∇ℓ |
| 45 | + _shadow = shadow === nothing ? Enzyme.onehot(x) : shadow |
| 46 | + y, ∂ℓ_∂x = Enzyme.autodiff(mode, Base.Fix1(logdensity, ℓ), Enzyme.BatchDuplicated, |
| 47 | + Enzyme.BatchDuplicated(x, _shadow)) |
| 48 | + return y, collect(∂ℓ_∂x) |
| 49 | +end |
| 50 | + |
| 51 | +function logdensity_and_gradient(∇ℓ::EnzymeGradientLogDensity{<:Any,<:Enzyme.ReverseMode}, |
| 52 | + x::AbstractVector) |
| 53 | + @unpack ℓ, mode = ∇ℓ |
| 54 | + # Currently it is not possible to retrieve the primal together with the derivatives. |
| 55 | + # Ref: https://github.com/EnzymeAD/Enzyme.jl/issues/107 |
| 56 | + y = logdensity(ℓ, x) |
| 57 | + ∂ℓ_∂x = zero(x) |
| 58 | + Enzyme.autodiff(mode, Base.Fix1(logdensity, ℓ), Enzyme.Active, |
| 59 | + Enzyme.Duplicated(x, ∂ℓ_∂x)) |
| 60 | + y, ∂ℓ_∂x |
| 61 | +end |
0 commit comments