-
Notifications
You must be signed in to change notification settings - Fork 44
Integrate DifferentiationInterface #416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @ErikQQY & thanks for the PR! Some comments below.
@@ -3,8 +3,10 @@ uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" | |||
version = "0.7.0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
It looks to me that the changes here (and in the previous ADTypes PR) are breaking, so this should be incremented to 0.8
-
Could you also update the calls to
Hamiltonian(..., ..., ForwardDiff)
in the README?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we need a minor release to reflect it’s a breaking change.
Could you also update the calls to Hamiltonian(..., ..., ForwardDiff) in the README?
I am planning to update the calls once docs are up in #414
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAICT #405 is non-breaking (it just adds a new currently not existing constructor), so it could be released in a non-breaking 0.7.x release.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, yes, that's true
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But with this PR, the previous Hamiltonian(metric, demo, ForwardDiff)
will no longer work. So I am wondering how about adding an extension for LogDensityProblemsAD to maintain the previous usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ADTypes interface is more expressive than specifying a symbol / module so I think it's a good thing to switch over and drop the old way in a breaking release. I think @devmotion's point was just that #405 in itself wasn't breaking (but this PR will be).
This reverts commit f2bf551.
_logdensity = Base.Fix1(LogDensityProblems.logdensity, ℓπ) | ||
_logdensity_and_gradient = if cap === LogDensityProblems.LogDensityOrder{0}() | ||
# In this case ℓπ does not support evaluation of the gradient of the log density function | ||
x -> DI.value_and_gradient(_logdensity, adtype, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's still unclear to me why it should be advantageous to avoid LogDensityProblemsAD... But whatever - if you'd like to remove it, you should copy all the things that LogDensityProblemsAD.ADgradient
does automatically to ensure good compatibility with all ADTypes. For instance, Enzyme does not like Base.FixXX
. So this should in principle be
x -> DI.value_and_gradient(LogDensityProblems.logdensity, adtype, DI.Constant(ℓπ), x)
But then DI doesn't allow active after non-active arguments so you have to use something like
x -> DI.value_and_gradient(_logdensity, adtype, x) | |
x -> DI.value_and_gradient(switched_logdensity, adtype, x, DI.Constant(ℓπ)) |
where
switched_logdensity(x, f::F) = LogDensityProblems.logdensity(f, x)
LogDensityProblemsAD does this already for you... And it even allows you to specify a typical x
which would vastly improve performance for some backends - which this approach doesn't. So I'm still not convinced 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gdalle, is there any plan to incorporate these optimisations in LogDensityProblemsAD
into DI
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's still unclear to me why it should be advantageous to avoid LogDensityProblemsAD...
The goal is probably to rely on all the AD bindings being in a single place, aka DI, for mutualized maintenance, testing and updates. LogDensityProblemsAD is a set of AD bindings that was specific to the probabilistic programming lanscape, while DI aims to support the whole Julia ecosystem. As a result, when someone needs a feature, everyone gets it.
And it even allows you to specify a typical x which would vastly improve performance for some backends - which this approach doesn't. So I'm still not convinced
Specifying a typical x
to speed things up is the whole point of DI's preparation mechanism (which LogDensityProblemsAD also relies on), so it is indeed supported. Whether downstream packages are able to provide such an x
is a different issue, but not an API limitation.
@gdalle, is there any plan to incorporate these optimisations in LogDensityProblemsAD into DI?
No, changing the DI API would be enormously breaking now that around 850 packages depend on it, directly or indirectly. For the foreseeable future it will keep taking the active argument first, and the context arguments next.
I'm not claiming this is the best design choice for each individual application. And indeed it seems you need to add something like these lines (or what @devmotion suggested) in your specific case:
struct SwitchArgs{F}
f::F
end
(s::SwitchArgs)(x, y) = s.f(y, x)
It's up to users to decide whether this (minor?) inconvenience is worth the switch to DI.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @gdalle. The SwitchArgs {F}
trick looks like a reasonable workaround.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although for Enzyme specifically David's solution is probably better, since it avoids creating a functor. Either way it's a one-liner
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The point about using a closure vs DI.Constant is something that we looked at in DynamicPPL as well (see the resulting code in https://github.com/TuringLang/DynamicPPL.jl/blob/eed80e5901cc641e286bd923a628e0bf488fc676/src/logdensityfunction.jl#L252-L280). Basically, I ran some benchmarks to determine whether it's better to use a closure vs a constant and for a few backends the Constant wins out, except for ReverseDiff where the closure wins. You could adopt a similar approach here if you wanted.
Actually, if you'd continue using LogDensityProblemsAD, you'd get these AD-specific optimizations automatically: For AD types that are supported via DI it will use DI.Constant
(and prepare the gradient cache if possible), and for others that are natively supported such as e.g. AutoReverseDiff
it will use an optimized approach such as a Base.Fix1
"closure": https://github.com/tpapp/LogDensityProblemsAD.jl/blob/ee6f9fef6ee550cb35a3bd3abdbd860e342e7b60/ext/LogDensityProblemsADReverseDiffExt.jl#L70 https://github.com/tpapp/LogDensityProblemsAD.jl/blob/ee6f9fef6ee550cb35a3bd3abdbd860e342e7b60/ext/LogDensityProblemsADReverseDiffExt.jl#L54
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since these kinds of tweaks are beneficial to several users, maybe it makes sense to integrate them to DI, i.e. with the ConstantConstant
/ VariableConstant
outlined above (the names are voluntarily tongue-in-cheek)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x = zeros(5)
Ah, I somehow missed the kwargs forwarding, my bad.
AD-specific optimizations
Yes totally agree on that :) I should've been clearer but I was mainly mentioning it so that @ErikQQY was aware of what we had done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with the ConstantConstant / VariableConstant outlined above (the names are voluntarily tongue-in-cheek)?
Isn't Constant
and ConstantConstant
better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, if we actually introduce this distinction (which I'm not yet 100% sure is a good idea) I'll have to pick different names, like FixedParameter
and VariableParameter
(and set Constant = VariableParameter
until the next breaking release)
I suggest that we close this PR for now and revisit it later:
|
Continue #405