Skip to content

Integrate DifferentiationInterface #416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
version = "0.7.0"
Copy link
Member

@penelopeysm penelopeysm Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. It looks to me that the changes here (and in the previous ADTypes PR) are breaking, so this should be incremented to 0.8

  2. Could you also update the calls to Hamiltonian(..., ..., ForwardDiff) in the README?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we need a minor release to reflect it’s a breaking change.

Could you also update the calls to Hamiltonian(..., ..., ForwardDiff) in the README?

I am planning to update the calls once docs are up in #414

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICT #405 is non-breaking (it just adds a new currently not existing constructor), so it could be released in a non-breaking 0.7.x release.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, yes, that's true

Copy link
Collaborator Author

@ErikQQY ErikQQY Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But with this PR, the previous Hamiltonian(metric, demo, ForwardDiff) will no longer work. So I am wondering how about adding an extension for LogDensityProblemsAD to maintain the previous usage.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ADTypes interface is more expressive than specifying a symbol / module so I think it's a good thing to switch over and drop the old way in a breaking release. I think @devmotion's point was just that #405 in itself wasn't breaking (but this PR will be).


[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Expand All @@ -17,13 +19,11 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[weakdeps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"

[extensions]
AdvancedHMCADTypesExt = "ADTypes"
AdvancedHMCCUDAExt = "CUDA"
AdvancedHMCMCMCChainsExt = "MCMCChains"
AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq"
Expand All @@ -33,6 +33,7 @@ ADTypes = "1"
AbstractMCMC = "5.6"
ArgCheck = "1, 2"
CUDA = "3, 4, 5"
DifferentiationInterface = "0.6.49"
DocStringExtensions = "0.8, 0.9"
LinearAlgebra = "<0.1, 1"
LogDensityProblems = "2"
Expand Down
24 changes: 0 additions & 24 deletions ext/AdvancedHMCADTypesExt.jl

This file was deleted.

48 changes: 18 additions & 30 deletions src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module AdvancedHMC

using ADTypes: AbstractADType, AutoForwardDiff
using Statistics: mean, var, middle
using LinearAlgebra:
Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, cholesky, UniformScaling
Expand All @@ -21,7 +22,10 @@ using AbstractMCMC: AbstractMCMC, LogDensityModel

import StatsBase: sample

import DifferentiationInterface: DifferentiationInterface as DI

const DEFAULT_FLOAT_TYPE = typeof(float(0))
const DEFAULT_ADTYPE = AutoForwardDiff()

include("utilities.jl")

Expand Down Expand Up @@ -134,48 +138,32 @@ function Hamiltonian(metric::AbstractMetric, ℓ::LogDensityModel; kwargs...)
return Hamiltonian(metric, ℓ.logdensity; kwargs...)
end
function Hamiltonian(metric::AbstractMetric, ℓ; kwargs...)
cap = LogDensityProblems.capabilities(ℓ)
if cap === nothing
throw(
ArgumentError(
"The log density function does not support the LogDensityProblems.jl interface",
),
)
end
# Check if we're capable of computing gradients.
ℓπ = if cap === LogDensityProblems.LogDensityOrder{0}()
# In this case ℓ does not support evaluation of the gradient of the log density function
# We use ForwardDiff to compute the gradient
LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; kwargs...)
else
# In this case ℓ already supports evaluation of the gradient of the log density function
end
return Hamiltonian(
metric,
Base.Fix1(LogDensityProblems.logdensity, ℓπ),
Base.Fix1(LogDensityProblems.logdensity_and_gradient, ℓπ),
)
return Hamiltonian(metric, ℓ, DEFAULT_ADTYPE; kwargs...)
end

## With explicit AD specification
function Hamiltonian(
metric::AbstractMetric, ℓπ::LogDensityModel, kind::Union{Symbol,Val,Module}; kwargs...
metric::AbstractMetric, ℓπ::LogDensityModel, adtype::AbstractADType; kwargs...
)
return Hamiltonian(metric, ℓπ.logdensity, kind; kwargs...)
return Hamiltonian(metric, ℓπ.logdensity, adtype; kwargs...)
end
function Hamiltonian(metric::AbstractMetric, ℓπ, kind::Union{Symbol,Val,Module}; kwargs...)
if LogDensityProblems.capabilities(ℓπ) === nothing
function Hamiltonian(metric::AbstractMetric, ℓπ, adtype::AbstractADType; kwargs...)
cap = LogDensityProblems.capabilities(ℓπ)
if cap === nothing
throw(
ArgumentError(
"The log density function does not support the LogDensityProblems.jl interface",
),
)
end
ℓ = LogDensityProblemsAD.ADgradient(
kind isa Val ? kind : Val(Symbol(kind)), ℓπ; kwargs...
)
return Hamiltonian(metric, ℓ)
_logdensity = Base.Fix1(LogDensityProblems.logdensity, ℓπ)
_logdensity_and_gradient = if cap === LogDensityProblems.LogDensityOrder{0}()
# In this case ℓπ does not support evaluation of the gradient of the log density function
x -> DI.value_and_gradient(_logdensity, adtype, x)
Copy link
Member

@devmotion devmotion Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still unclear to me why it should be advantageous to avoid LogDensityProblemsAD... But whatever - if you'd like to remove it, you should copy all the things that LogDensityProblemsAD.ADgradient does automatically to ensure good compatibility with all ADTypes. For instance, Enzyme does not like Base.FixXX. So this should in principle be

        x -> DI.value_and_gradient(LogDensityProblems.logdensity, adtype, DI.Constant(ℓπ), x)

But then DI doesn't allow active after non-active arguments so you have to use something like

Suggested change
x -> DI.value_and_gradient(_logdensity, adtype, x)
x -> DI.value_and_gradient(switched_logdensity, adtype, x, DI.Constant(ℓπ))

where

switched_logdensity(x, f::F) = LogDensityProblems.logdensity(f, x)

LogDensityProblemsAD does this already for you... And it even allows you to specify a typical x which would vastly improve performance for some backends - which this approach doesn't. So I'm still not convinced 😄

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gdalle, is there any plan to incorporate these optimisations in LogDensityProblemsAD into DI?

Copy link

@gdalle gdalle Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still unclear to me why it should be advantageous to avoid LogDensityProblemsAD...

The goal is probably to rely on all the AD bindings being in a single place, aka DI, for mutualized maintenance, testing and updates. LogDensityProblemsAD is a set of AD bindings that was specific to the probabilistic programming lanscape, while DI aims to support the whole Julia ecosystem. As a result, when someone needs a feature, everyone gets it.

And it even allows you to specify a typical x which would vastly improve performance for some backends - which this approach doesn't. So I'm still not convinced

Specifying a typical x to speed things up is the whole point of DI's preparation mechanism (which LogDensityProblemsAD also relies on), so it is indeed supported. Whether downstream packages are able to provide such an x is a different issue, but not an API limitation.

@gdalle, is there any plan to incorporate these optimisations in LogDensityProblemsAD into DI?

No, changing the DI API would be enormously breaking now that around 850 packages depend on it, directly or indirectly. For the foreseeable future it will keep taking the active argument first, and the context arguments next.

I'm not claiming this is the best design choice for each individual application. And indeed it seems you need to add something like these lines (or what @devmotion suggested) in your specific case:

struct SwitchArgs{F}
    f::F
end

(s::SwitchArgs)(x, y) = s.f(y, x)

It's up to users to decide whether this (minor?) inconvenience is worth the switch to DI.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @gdalle. The SwitchArgs {F} trick looks like a reasonable workaround.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although for Enzyme specifically David's solution is probably better, since it avoids creating a functor. Either way it's a one-liner

Copy link
Member

@devmotion devmotion Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point about using a closure vs DI.Constant is something that we looked at in DynamicPPL as well (see the resulting code in https://github.com/TuringLang/DynamicPPL.jl/blob/eed80e5901cc641e286bd923a628e0bf488fc676/src/logdensityfunction.jl#L252-L280). Basically, I ran some benchmarks to determine whether it's better to use a closure vs a constant and for a few backends the Constant wins out, except for ReverseDiff where the closure wins. You could adopt a similar approach here if you wanted.

Actually, if you'd continue using LogDensityProblemsAD, you'd get these AD-specific optimizations automatically: For AD types that are supported via DI it will use DI.Constant (and prepare the gradient cache if possible), and for others that are natively supported such as e.g. AutoReverseDiff it will use an optimized approach such as a Base.Fix1 "closure": https://github.com/tpapp/LogDensityProblemsAD.jl/blob/ee6f9fef6ee550cb35a3bd3abdbd860e342e7b60/ext/LogDensityProblemsADReverseDiffExt.jl#L70 https://github.com/tpapp/LogDensityProblemsAD.jl/blob/ee6f9fef6ee550cb35a3bd3abdbd860e342e7b60/ext/LogDensityProblemsADReverseDiffExt.jl#L54

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since these kinds of tweaks are beneficial to several users, maybe it makes sense to integrate them to DI, i.e. with the ConstantConstant / VariableConstant outlined above (the names are voluntarily tongue-in-cheek)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x = zeros(5)

Ah, I somehow missed the kwargs forwarding, my bad.

AD-specific optimizations

Yes totally agree on that :) I should've been clearer but I was mainly mentioning it so that @ErikQQY was aware of what we had done.

Copy link
Member

@yebai yebai Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with the ConstantConstant / VariableConstant outlined above (the names are voluntarily tongue-in-cheek)?

Isn't Constant and ConstantConstant better?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, if we actually introduce this distinction (which I'm not yet 100% sure is a good idea) I'll have to pick different names, like FixedParameter and VariableParameter (and set Constant = VariableParameter until the next breaking release)

else
Base.Fix1(LogDensityProblems.logdensity_and_gradient, ℓπ)
end
return Hamiltonian(metric, _logdensity, _logdensity_and_gradient)
end

### Init
Expand Down
4 changes: 2 additions & 2 deletions test/abstractmcmc.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ReTest, Random, AdvancedHMC, ForwardDiff, AbstractMCMC
using ReTest, Random, AdvancedHMC, ForwardDiff, AbstractMCMC, ADTypes
using Statistics: mean

@testset "AbstractMCMC w/ gdemo" begin
Expand All @@ -18,7 +18,7 @@ using Statistics: mean
custom = HMCSampler(κ, metric, adaptor)

model = AdvancedHMC.LogDensityModel(
LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓπ_gdemo)
LogDensityProblemsAD.ADgradient(AutoForwardDiff(), ℓπ_gdemo)
)

@testset "getparams and setparams!!" begin
Expand Down
4 changes: 2 additions & 2 deletions test/adaptation.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ReTest, LinearAlgebra, Distributions, AdvancedHMC, Random, ForwardDiff
using ReTest, LinearAlgebra, Distributions, AdvancedHMC, Random, ForwardDiff, ADTypes
using AdvancedHMC.Adaptation:
WelfordVar, NaiveVar, WelfordCov, NaiveCov, get_estimation, get_estimation, reset!

Expand All @@ -9,7 +9,7 @@ function runnuts(ℓπ, metric; n_samples=10_000)
rng = MersenneTwister(0)

nuts = NUTS(0.8)
h = Hamiltonian(metric, ℓπ, ForwardDiff)
h = Hamiltonian(metric, ℓπ, AutoForwardDiff())
step_size = AdvancedHMC.make_step_size(rng, nuts, h, θ_init)
integrator = AdvancedHMC.make_integrator(nuts, step_size)
κ = AdvancedHMC.make_kernel(nuts, integrator)
Expand Down
6 changes: 3 additions & 3 deletions test/contrib.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
using ReTest, AdvancedHMC, ForwardDiff, Zygote
using ReTest, AdvancedHMC, ForwardDiff, Zygote, ADTypes

@testset "contrib" begin
@testset "ad" begin
metric = UnitEuclideanMetric(D)
h_hand = Hamiltonian(metric, ℓπ, ∂ℓπ∂θ)
h_forwarddiff = Hamiltonian(metric, ℓπ, ForwardDiff)
h_zygote = Hamiltonian(metric, ℓπ, Zygote)
h_forwarddiff = Hamiltonian(metric, ℓπ, AutoForwardDiff())
h_zygote = Hamiltonian(metric, ℓπ, AutoZygote())
for x in [rand(D), rand(D, 10)]
v_hand, g_hand = h_hand.∂ℓπ∂θ(x)
v_forwarddiff, g_forwarddiff = h_forwarddiff.∂ℓπ∂θ(x)
Expand Down
6 changes: 3 additions & 3 deletions test/demo.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using ReTest
using AdvancedHMC, Distributions, ForwardDiff, ComponentArrays, AbstractMCMC
using AdvancedHMC, Distributions, ForwardDiff, ComponentArrays, AbstractMCMC, ADTypes
using LinearAlgebra, ADTypes

@testset "Demo" begin
Expand All @@ -23,7 +23,7 @@ using LinearAlgebra, ADTypes

# Define a Hamiltonian system
metric = DiagEuclideanMetric(D)
hamiltonian = Hamiltonian(metric, ℓπ, ForwardDiff)
hamiltonian = Hamiltonian(metric, ℓπ, AutoForwardDiff())

# Define a leapfrog solver, with initial step size chosen heuristically
initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
Expand Down Expand Up @@ -76,7 +76,7 @@ end
metric = DiagEuclideanMetric(D)

# choose AD framework or provide a function manually
hamiltonian = Hamiltonian(metric, ℓπ, Val(:ForwardDiff); x=p1)
hamiltonian = Hamiltonian(metric, ℓπ, AutoForwardDiff(); x=p1)

# Define a leapfrog solver, with initial step size chosen heuristically
initial_ϵ = find_good_stepsize(hamiltonian, p1)
Expand Down
4 changes: 2 additions & 2 deletions test/integrator.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ReTest, Random, AdvancedHMC, ForwardDiff
using ReTest, Random, AdvancedHMC, ForwardDiff, ADTypes

using OrdinaryDiffEq
using LinearAlgebra: dot
Expand Down Expand Up @@ -120,7 +120,7 @@ using Statistics: mean
ϵ = 0.01
for lf in [Leapfrog(ϵ), DiffEqIntegrator(ϵ, VerletLeapfrog())]
q_init = randn(1)
h = Hamiltonian(UnitEuclideanMetric(1), negU, ForwardDiff)
h = Hamiltonian(UnitEuclideanMetric(1), negU, AutoForwardDiff())
p_init = AdvancedHMC.rand_momentum(
Random.default_rng(), h.metric, h.kinetic, q_init
)
Expand Down
4 changes: 2 additions & 2 deletions test/mcmcchains.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ReTest, Random, AdvancedHMC, ForwardDiff, AbstractMCMC, MCMCChains
using ReTest, Random, AdvancedHMC, ForwardDiff, AbstractMCMC, MCMCChains, ADTypes
using Statistics: mean

@testset "MCMCChains w/ gdemo" begin
Expand All @@ -10,7 +10,7 @@ using Statistics: mean
θ_init = randn(rng, 2)

model = AdvancedHMC.LogDensityModel(
LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓπ_gdemo)
LogDensityProblemsAD.ADgradient(AutoForwardDiff(), ℓπ_gdemo)
)
integrator = Leapfrog(1e-3)
kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()))
Expand Down
4 changes: 2 additions & 2 deletions test/models.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ReTest, Random, AdvancedHMC, ForwardDiff
using ReTest, Random, AdvancedHMC, ForwardDiff, ADTypes
using Statistics: mean

@testset "Models" begin
Expand All @@ -11,7 +11,7 @@ using Statistics: mean
θ_init = randn(rng, 2)

metric = DiagEuclideanMetric(2)
h = Hamiltonian(metric, ℓπ_gdemo, ForwardDiff)
h = Hamiltonian(metric, ℓπ_gdemo, AutoForwardDiff())
integrator = Leapfrog(0.1)
κ = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()))
adaptor = StanHMCAdaptor(
Expand Down
Loading