-
Notifications
You must be signed in to change notification settings - Fork 44
Integrate DifferentiationInterface #416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,5 +1,6 @@ | ||||||
module AdvancedHMC | ||||||
|
||||||
using ADTypes: AbstractADType, AutoForwardDiff | ||||||
using Statistics: mean, var, middle | ||||||
using LinearAlgebra: | ||||||
Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, cholesky, UniformScaling | ||||||
|
@@ -21,7 +22,10 @@ using AbstractMCMC: AbstractMCMC, LogDensityModel | |||||
|
||||||
import StatsBase: sample | ||||||
|
||||||
import DifferentiationInterface: DifferentiationInterface as DI | ||||||
|
||||||
const DEFAULT_FLOAT_TYPE = typeof(float(0)) | ||||||
const DEFAULT_ADTYPE = AutoForwardDiff() | ||||||
|
||||||
include("utilities.jl") | ||||||
|
||||||
|
@@ -134,48 +138,32 @@ function Hamiltonian(metric::AbstractMetric, ℓ::LogDensityModel; kwargs...) | |||||
return Hamiltonian(metric, ℓ.logdensity; kwargs...) | ||||||
end | ||||||
function Hamiltonian(metric::AbstractMetric, ℓ; kwargs...) | ||||||
cap = LogDensityProblems.capabilities(ℓ) | ||||||
if cap === nothing | ||||||
throw( | ||||||
ArgumentError( | ||||||
"The log density function does not support the LogDensityProblems.jl interface", | ||||||
), | ||||||
) | ||||||
end | ||||||
# Check if we're capable of computing gradients. | ||||||
ℓπ = if cap === LogDensityProblems.LogDensityOrder{0}() | ||||||
# In this case ℓ does not support evaluation of the gradient of the log density function | ||||||
# We use ForwardDiff to compute the gradient | ||||||
LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; kwargs...) | ||||||
else | ||||||
# In this case ℓ already supports evaluation of the gradient of the log density function | ||||||
ℓ | ||||||
end | ||||||
return Hamiltonian( | ||||||
metric, | ||||||
Base.Fix1(LogDensityProblems.logdensity, ℓπ), | ||||||
Base.Fix1(LogDensityProblems.logdensity_and_gradient, ℓπ), | ||||||
) | ||||||
return Hamiltonian(metric, ℓ, DEFAULT_ADTYPE; kwargs...) | ||||||
end | ||||||
|
||||||
## With explicit AD specification | ||||||
function Hamiltonian( | ||||||
metric::AbstractMetric, ℓπ::LogDensityModel, kind::Union{Symbol,Val,Module}; kwargs... | ||||||
metric::AbstractMetric, ℓπ::LogDensityModel, adtype::AbstractADType; kwargs... | ||||||
) | ||||||
return Hamiltonian(metric, ℓπ.logdensity, kind; kwargs...) | ||||||
return Hamiltonian(metric, ℓπ.logdensity, adtype; kwargs...) | ||||||
end | ||||||
function Hamiltonian(metric::AbstractMetric, ℓπ, kind::Union{Symbol,Val,Module}; kwargs...) | ||||||
if LogDensityProblems.capabilities(ℓπ) === nothing | ||||||
function Hamiltonian(metric::AbstractMetric, ℓπ, adtype::AbstractADType; kwargs...) | ||||||
cap = LogDensityProblems.capabilities(ℓπ) | ||||||
if cap === nothing | ||||||
throw( | ||||||
ArgumentError( | ||||||
"The log density function does not support the LogDensityProblems.jl interface", | ||||||
), | ||||||
) | ||||||
end | ||||||
ℓ = LogDensityProblemsAD.ADgradient( | ||||||
kind isa Val ? kind : Val(Symbol(kind)), ℓπ; kwargs... | ||||||
) | ||||||
return Hamiltonian(metric, ℓ) | ||||||
_logdensity = Base.Fix1(LogDensityProblems.logdensity, ℓπ) | ||||||
_logdensity_and_gradient = if cap === LogDensityProblems.LogDensityOrder{0}() | ||||||
# In this case ℓπ does not support evaluation of the gradient of the log density function | ||||||
x -> DI.value_and_gradient(_logdensity, adtype, x) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's still unclear to me why it should be advantageous to avoid LogDensityProblemsAD... But whatever - if you'd like to remove it, you should copy all the things that x -> DI.value_and_gradient(LogDensityProblems.logdensity, adtype, DI.Constant(ℓπ), x) But then DI doesn't allow active after non-active arguments so you have to use something like
Suggested change
where switched_logdensity(x, f::F) = LogDensityProblems.logdensity(f, x) LogDensityProblemsAD does this already for you... And it even allows you to specify a typical There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @gdalle, is there any plan to incorporate these optimisations in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The goal is probably to rely on all the AD bindings being in a single place, aka DI, for mutualized maintenance, testing and updates. LogDensityProblemsAD is a set of AD bindings that was specific to the probabilistic programming lanscape, while DI aims to support the whole Julia ecosystem. As a result, when someone needs a feature, everyone gets it.
Specifying a typical
No, changing the DI API would be enormously breaking now that around 850 packages depend on it, directly or indirectly. For the foreseeable future it will keep taking the active argument first, and the context arguments next. I'm not claiming this is the best design choice for each individual application. And indeed it seems you need to add something like these lines (or what @devmotion suggested) in your specific case: struct SwitchArgs{F}
f::F
end
(s::SwitchArgs)(x, y) = s.f(y, x) It's up to users to decide whether this (minor?) inconvenience is worth the switch to DI. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, @gdalle. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although for Enzyme specifically David's solution is probably better, since it avoids creating a functor. Either way it's a one-liner There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Actually, if you'd continue using LogDensityProblemsAD, you'd get these AD-specific optimizations automatically: For AD types that are supported via DI it will use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since these kinds of tweaks are beneficial to several users, maybe it makes sense to integrate them to DI, i.e. with the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Ah, I somehow missed the kwargs forwarding, my bad.
Yes totally agree on that :) I should've been clearer but I was mainly mentioning it so that @ErikQQY was aware of what we had done. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Isn't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, if we actually introduce this distinction (which I'm not yet 100% sure is a good idea) I'll have to pick different names, like |
||||||
else | ||||||
Base.Fix1(LogDensityProblems.logdensity_and_gradient, ℓπ) | ||||||
end | ||||||
return Hamiltonian(metric, _logdensity, _logdensity_and_gradient) | ||||||
end | ||||||
|
||||||
### Init | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks to me that the changes here (and in the previous ADTypes PR) are breaking, so this should be incremented to 0.8
Could you also update the calls to
Hamiltonian(..., ..., ForwardDiff)
in the README?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we need a minor release to reflect it’s a breaking change.
I am planning to update the calls once docs are up in #414
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAICT #405 is non-breaking (it just adds a new currently not existing constructor), so it could be released in a non-breaking 0.7.x release.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, yes, that's true
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But with this PR, the previous
Hamiltonian(metric, demo, ForwardDiff)
will no longer work. So I am wondering how about adding an extension for LogDensityProblemsAD to maintain the previous usage.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ADTypes interface is more expressive than specifying a symbol / module so I think it's a good thing to switch over and drop the old way in a breaking release. I think @devmotion's point was just that #405 in itself wasn't breaking (but this PR will be).