Skip to content

Integrate DifferentiationInterface #416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed

Conversation

ErikQQY
Copy link
Collaborator

@ErikQQY ErikQQY commented Apr 1, 2025

Continue #405

@yebai yebai requested review from penelopeysm and devmotion April 1, 2025 11:45
Copy link
Member

@penelopeysm penelopeysm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @ErikQQY & thanks for the PR! Some comments below.

@@ -3,8 +3,10 @@ uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
version = "0.7.0"
Copy link
Member

@penelopeysm penelopeysm Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. It looks to me that the changes here (and in the previous ADTypes PR) are breaking, so this should be incremented to 0.8

  2. Could you also update the calls to Hamiltonian(..., ..., ForwardDiff) in the README?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we need a minor release to reflect it’s a breaking change.

Could you also update the calls to Hamiltonian(..., ..., ForwardDiff) in the README?

I am planning to update the calls once docs are up in #414

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICT #405 is non-breaking (it just adds a new currently not existing constructor), so it could be released in a non-breaking 0.7.x release.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, yes, that's true

Copy link
Collaborator Author

@ErikQQY ErikQQY Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But with this PR, the previous Hamiltonian(metric, demo, ForwardDiff) will no longer work. So I am wondering how about adding an extension for LogDensityProblemsAD to maintain the previous usage.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ADTypes interface is more expressive than specifying a symbol / module so I think it's a good thing to switch over and drop the old way in a breaking release. I think @devmotion's point was just that #405 in itself wasn't breaking (but this PR will be).

_logdensity = Base.Fix1(LogDensityProblems.logdensity, ℓπ)
_logdensity_and_gradient = if cap === LogDensityProblems.LogDensityOrder{0}()
# In this case ℓπ does not support evaluation of the gradient of the log density function
x -> DI.value_and_gradient(_logdensity, adtype, x)
Copy link
Member

@devmotion devmotion Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still unclear to me why it should be advantageous to avoid LogDensityProblemsAD... But whatever - if you'd like to remove it, you should copy all the things that LogDensityProblemsAD.ADgradient does automatically to ensure good compatibility with all ADTypes. For instance, Enzyme does not like Base.FixXX. So this should in principle be

        x -> DI.value_and_gradient(LogDensityProblems.logdensity, adtype, DI.Constant(ℓπ), x)

But then DI doesn't allow active after non-active arguments so you have to use something like

Suggested change
x -> DI.value_and_gradient(_logdensity, adtype, x)
x -> DI.value_and_gradient(switched_logdensity, adtype, x, DI.Constant(ℓπ))

where

switched_logdensity(x, f::F) = LogDensityProblems.logdensity(f, x)

LogDensityProblemsAD does this already for you... And it even allows you to specify a typical x which would vastly improve performance for some backends - which this approach doesn't. So I'm still not convinced 😄

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gdalle, is there any plan to incorporate these optimisations in LogDensityProblemsAD into DI?

Copy link

@gdalle gdalle Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still unclear to me why it should be advantageous to avoid LogDensityProblemsAD...

The goal is probably to rely on all the AD bindings being in a single place, aka DI, for mutualized maintenance, testing and updates. LogDensityProblemsAD is a set of AD bindings that was specific to the probabilistic programming lanscape, while DI aims to support the whole Julia ecosystem. As a result, when someone needs a feature, everyone gets it.

And it even allows you to specify a typical x which would vastly improve performance for some backends - which this approach doesn't. So I'm still not convinced

Specifying a typical x to speed things up is the whole point of DI's preparation mechanism (which LogDensityProblemsAD also relies on), so it is indeed supported. Whether downstream packages are able to provide such an x is a different issue, but not an API limitation.

@gdalle, is there any plan to incorporate these optimisations in LogDensityProblemsAD into DI?

No, changing the DI API would be enormously breaking now that around 850 packages depend on it, directly or indirectly. For the foreseeable future it will keep taking the active argument first, and the context arguments next.

I'm not claiming this is the best design choice for each individual application. And indeed it seems you need to add something like these lines (or what @devmotion suggested) in your specific case:

struct SwitchArgs{F}
    f::F
end

(s::SwitchArgs)(x, y) = s.f(y, x)

It's up to users to decide whether this (minor?) inconvenience is worth the switch to DI.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @gdalle. The SwitchArgs {F} trick looks like a reasonable workaround.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although for Enzyme specifically David's solution is probably better, since it avoids creating a functor. Either way it's a one-liner

Copy link
Member

@devmotion devmotion Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point about using a closure vs DI.Constant is something that we looked at in DynamicPPL as well (see the resulting code in https://github.com/TuringLang/DynamicPPL.jl/blob/eed80e5901cc641e286bd923a628e0bf488fc676/src/logdensityfunction.jl#L252-L280). Basically, I ran some benchmarks to determine whether it's better to use a closure vs a constant and for a few backends the Constant wins out, except for ReverseDiff where the closure wins. You could adopt a similar approach here if you wanted.

Actually, if you'd continue using LogDensityProblemsAD, you'd get these AD-specific optimizations automatically: For AD types that are supported via DI it will use DI.Constant (and prepare the gradient cache if possible), and for others that are natively supported such as e.g. AutoReverseDiff it will use an optimized approach such as a Base.Fix1 "closure": https://github.com/tpapp/LogDensityProblemsAD.jl/blob/ee6f9fef6ee550cb35a3bd3abdbd860e342e7b60/ext/LogDensityProblemsADReverseDiffExt.jl#L70 https://github.com/tpapp/LogDensityProblemsAD.jl/blob/ee6f9fef6ee550cb35a3bd3abdbd860e342e7b60/ext/LogDensityProblemsADReverseDiffExt.jl#L54

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since these kinds of tweaks are beneficial to several users, maybe it makes sense to integrate them to DI, i.e. with the ConstantConstant / VariableConstant outlined above (the names are voluntarily tongue-in-cheek)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x = zeros(5)

Ah, I somehow missed the kwargs forwarding, my bad.

AD-specific optimizations

Yes totally agree on that :) I should've been clearer but I was mainly mentioning it so that @ErikQQY was aware of what we had done.

Copy link
Member

@yebai yebai Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with the ConstantConstant / VariableConstant outlined above (the names are voluntarily tongue-in-cheek)?

Isn't Constant and ConstantConstant better?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, if we actually introduce this distinction (which I'm not yet 100% sure is a good idea) I'll have to pick different names, like FixedParameter and VariableParameter (and set Constant = VariableParameter until the next breaking release)

@yebai
Copy link
Member

yebai commented Apr 8, 2025

I suggest that we close this PR for now and revisit it later:

  • LogDensityProblemsAD already uses DI under the hood, whereas DI is better.
  • LogDensityProblemsAD contains further optimisations that are currently missing from DI.

@yebai yebai closed this Apr 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants