From 0004dd088f4c45c7b2b77fe2eefdf6352b60795e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 10 Feb 2025 13:18:33 +0000 Subject: [PATCH 01/11] Remove LogDensityProblemsAD --- Project.toml | 6 ++-- ext/DynamicPPLForwardDiffExt.jl | 31 ++---------------- src/DynamicPPL.jl | 1 - src/contexts.jl | 1 + src/logdensityfunction.jl | 56 ++++++++++++--------------------- src/sampler.jl | 3 ++ test/Project.toml | 2 -- test/ad.jl | 20 ++++++------ test/logdensityfunction.jl | 13 +------- test/runtests.jl | 2 +- 10 files changed, 41 insertions(+), 94 deletions(-) diff --git a/Project.toml b/Project.toml index 7cd47fdbb..ed5a4b5dd 100644 --- a/Project.toml +++ b/Project.toml @@ -12,13 +12,13 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" -LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -51,15 +51,15 @@ Bijectors = "0.13.18, 0.14, 0.15" ChainRulesCore = "1" Compat = "4" ConstructionBase = "1.5.4" +DifferentiationInterface = "0.6.39" Distributions = "0.25" DocStringExtensions = "0.9" -KernelAbstractions = "0.9.33" EnzymeCore = "0.6 - 0.8" ForwardDiff = "0.10" JET = "0.9" +KernelAbstractions = "0.9.33" LinearAlgebra = "1.6" LogDensityProblems = "2" -LogDensityProblemsAD = "1.7.0" MCMCChains = "6" MacroTools = "0.5.6" Mooncake = "0.4.95" diff --git a/ext/DynamicPPLForwardDiffExt.jl b/ext/DynamicPPLForwardDiffExt.jl index 4bc33e217..a10e9eae6 100644 --- a/ext/DynamicPPLForwardDiffExt.jl +++ b/ext/DynamicPPLForwardDiffExt.jl @@ -1,40 +1,13 @@ module DynamicPPLForwardDiffExt -if isdefined(Base, :get_extension) - using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD - using ForwardDiff -else - using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD - using ..ForwardDiff -end +using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems +using ForwardDiff getchunksize(::ADTypes.AutoForwardDiff{chunk}) where {chunk} = chunk standardtag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true standardtag(::ADTypes.AutoForwardDiff) = false -function LogDensityProblemsAD.ADgradient( - ad::ADTypes.AutoForwardDiff, ℓ::DynamicPPL.LogDensityFunction -) - θ = DynamicPPL.getparams(ℓ) - f = Base.Fix1(LogDensityProblems.logdensity, ℓ) - - # Define configuration for ForwardDiff. - tag = if standardtag(ad) - ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(θ)) - else - ForwardDiff.Tag(f, eltype(θ)) - end - chunk_size = getchunksize(ad) - chunk = if chunk_size == 0 || chunk_size === nothing - ForwardDiff.Chunk(θ) - else - ForwardDiff.Chunk(length(θ), chunk_size) - end - - return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk, tag, x=θ) -end - # Allow Turing tag in gradient etc. calls of the log density function function ForwardDiff.checktag( ::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}}, diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index f0d42f98c..b413017cf 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -14,7 +14,6 @@ using MacroTools: MacroTools using ConstructionBase: ConstructionBase using Accessors: Accessors using LogDensityProblems: LogDensityProblems -using LogDensityProblemsAD: LogDensityProblemsAD using LinearAlgebra: LinearAlgebra, Cholesky diff --git a/src/contexts.jl b/src/contexts.jl index 0b4633283..87ad8df0b 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -184,6 +184,7 @@ at which point it will return the sampler of that context. getsampler(context::SamplingContext) = context.sampler getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context) getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context)) +getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context") """ struct DefaultContext <: AbstractContext end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 29f591cc3..426312c88 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -1,3 +1,5 @@ +import DifferentiationInterface as DI + """ LogDensityFunction @@ -81,37 +83,13 @@ end Return the `DynamicPPL.Model` wrapped in the given log-density function `f`. """ -getmodel(f::LogDensityProblemsAD.ADGradientWrapper) = - getmodel(LogDensityProblemsAD.parent(f)) getmodel(f::DynamicPPL.LogDensityFunction) = f.model """ setmodel(f, model[, adtype]) Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. - -!!! warning - Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a - `DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f` - might require recompilation of the gradient tape, depending on the AD backend. """ -function setmodel( - f::LogDensityProblemsAD.ADGradientWrapper, - model::DynamicPPL.Model, - adtype::ADTypes.AbstractADType, -) - # TODO: Should we handle `SciMLBase.NoAD`? - # For an `ADGradientWrapper` we do the following: - # 1. Update the `Model` in the underlying `LogDensityFunction`. - # 2. Re-construct the `ADGradientWrapper` using `ADgradient` using the provided `adtype` - # to ensure that the recompilation of gradient tapes, etc. also occur. For example, - # ReverseDiff.jl in compiled mode will cache the compiled tape, which means that just - # replacing the corresponding field with the new model won't be sufficient to obtain - # the correct gradients. - return LogDensityProblemsAD.ADgradient( - adtype, setmodel(LogDensityProblemsAD.parent(f), model) - ) -end function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) return Accessors.@set f.model = model end @@ -140,18 +118,24 @@ end # TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)? LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) -# This is important for performance -- one needs to provide `ADGradient` with a vector of -# parameters, or DifferentiationInterface will not have sufficient information to e.g. -# compile a rule for Mooncake (because it won't know the type of the input), or pre-allocate -# a tape when using ReverseDiff.jl. -function _make_ad_gradient(ad::ADTypes.AbstractADType, ℓ::LogDensityFunction) - x = map(identity, getparams(ℓ)) # ensure we concretise the elements of the params - return LogDensityProblemsAD.ADgradient(ad, ℓ; x) -end +_flipped_logdensity(θ, f) = LogDensityProblems.logdensity(f, θ) -function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoMooncake, f::LogDensityFunction) - return _make_ad_gradient(ad, f) +# By default, the AD backend to use is inferred from the context, which would +# typically be a SamplingContext which contains a sampler. +function LogDensityProblems.logdensity_and_gradient( + f::LogDensityFunction, θ::AbstractVector +) + adtype = getadtype(getsampler(getcontext(f))) + return LogDensityProblems.logdensity_and_gradient(f, θ, adtype) end -function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff, f::LogDensityFunction) - return _make_ad_gradient(ad, f) + +# Extra method allowing one to manually specify the AD backend to use, thus +# overriding the default AD backend inferred from the sampler. +function LogDensityProblems.logdensity_and_gradient( + f::LogDensityFunction, θ::AbstractVector, adtype::ADTypes.AbstractADType +) + # Ensure we concretise the elements of the params. + θ = map(identity, θ) # TODO: Is this needed? + prep = DI.prepare_gradient(_flipped_logdensity, adtype, θ, DI.Constant(f)) + return DI.value_and_gradient(_flipped_logdensity, prep, adtype, θ, DI.Constant(f)) end diff --git a/src/sampler.jl b/src/sampler.jl index aa3a637ee..fe49ee45f 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -49,6 +49,9 @@ struct Sampler{T} <: AbstractSampler alg::T end +# Extract the AD type from the underlying algorithm +getadtype(s::Sampler) = getadtype(s.alg) + # AbstractMCMC interface for SampleFromUniform and SampleFromPrior function AbstractMCMC.step( rng::Random.AbstractRNG, diff --git a/test/Project.toml b/test/Project.toml index c7583c672..420edba94 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -16,7 +16,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" -LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" @@ -46,7 +45,6 @@ EnzymeCore = "0.6 - 0.8" ForwardDiff = "0.10.12" JET = "0.9" LogDensityProblems = "2" -LogDensityProblemsAD = "1.7.0" MCMCChains = "6.0.4" MacroTools = "0.5.6" Mooncake = "0.4.59" diff --git a/test/ad.jl b/test/ad.jl index 87c7f22c3..354a8bd0c 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,21 +1,19 @@ @testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - f = DynamicPPL.LogDensityFunction(m) rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) vns = DynamicPPL.TestUtils.varnames(m) varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos f = DynamicPPL.LogDensityFunction(m, varinfo) - - # use ForwardDiff result as reference - ad_forwarddiff_f = LogDensityProblemsAD.ADgradient( - ADTypes.AutoForwardDiff(; chunksize=0), f - ) # convert to `Vector{Float64}` to avoid `ReverseDiff` initializing the gradients to Integer 0 # reference: https://github.com/TuringLang/DynamicPPL.jl/pull/571#issuecomment-1924304489 θ = convert(Vector{Float64}, varinfo[:]) - logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ) + # Calculate reference logp + gradient of logp using ForwardDiff + default_adtype = ADTypes.AutoForwardDiff() + ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient( + f, θ, default_adtype + ) @testset "$adtype" for adtype in [ ADTypes.AutoReverseDiff(; compile=false), @@ -27,9 +25,9 @@ if adtype isa ADTypes.AutoMooncake && varinfo isa DynamicPPL.SimpleVarInfo @test_broken 1 == 0 else - ad_f = LogDensityProblemsAD.ADgradient(adtype, f) - _, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ) + logp, grad = LogDensityProblems.logdensity_and_gradient(f, θ, adtype) @test grad ≈ ref_grad + @test logp ≈ ref_logp end end end @@ -70,6 +68,8 @@ spl = Sampler(MyEmptyAlg()) vi = VarInfo(model) ldf = DynamicPPL.LogDensityFunction(vi, model, SamplingContext(spl)) - @test LogDensityProblemsAD.ADgradient(AutoReverseDiff(; compile=true), ldf) isa Any + @test LogDensityProblems.logdensity_and_gradient( + ldf, vi[:], AutoReverseDiff(; compile=true) + ) isa Any end end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index beda767e6..319371609 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -1,4 +1,4 @@ -using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, ReverseDiff +using Test, DynamicPPL, ADTypes, LogDensityProblems, ReverseDiff @testset "`getmodel` and `setmodel`" begin @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS @@ -6,17 +6,6 @@ using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, Rever ℓ = DynamicPPL.LogDensityFunction(model) @test DynamicPPL.getmodel(ℓ) == model @test DynamicPPL.setmodel(ℓ, model).model == model - - # ReverseDiff related - ∇ℓ = LogDensityProblemsAD.ADgradient(:ReverseDiff, ℓ; compile=Val(false)) - @test DynamicPPL.getmodel(∇ℓ) == model - @test DynamicPPL.getmodel(DynamicPPL.setmodel(∇ℓ, model, AutoReverseDiff())) == - model - ∇ℓ = LogDensityProblemsAD.ADgradient(:ReverseDiff, ℓ; compile=Val(true)) - new_∇ℓ = DynamicPPL.setmodel(∇ℓ, model, AutoReverseDiff()) - @test DynamicPPL.getmodel(new_∇ℓ) == model - # HACK(sunxd): rely on internal implementation detail, i.e., naming of `compiledtape` - @test new_∇ℓ.compiledtape != ∇ℓ.compiledtape end end diff --git a/test/runtests.jl b/test/runtests.jl index 25cd2fb40..caddef5f9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,7 +9,7 @@ using Distributions using DistributionsAD using Documenter using ForwardDiff -using LogDensityProblems, LogDensityProblemsAD +using LogDensityProblems using MacroTools using MCMCChains using Mooncake: Mooncake From e8979b09498d30d1f2c3c615434f586775f47d08 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 13 Feb 2025 20:33:35 +0000 Subject: [PATCH 02/11] Implement LogDensityFunctionWithGrad in place of ADgradient --- HISTORY.md | 9 +++ Project.toml | 3 - docs/src/api.md | 3 +- ext/DynamicPPLForwardDiffExt.jl | 27 --------- src/logdensityfunction.jl | 88 +++++++++++++++++----------- src/sampler.jl | 3 - test/ad.jl | 59 +++++++++++++------ test/ext/DynamicPPLForwardDiffExt.jl | 14 ----- test/runtests.jl | 1 - test/test_util.jl | 9 +++ 10 files changed, 115 insertions(+), 101 deletions(-) delete mode 100644 ext/DynamicPPLForwardDiffExt.jl delete mode 100644 test/ext/DynamicPPLForwardDiffExt.jl diff --git a/HISTORY.md b/HISTORY.md index fa9e58e99..36d827e92 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -119,6 +119,15 @@ This release removes the feature of `VarInfo` where it kept track of which varia This change also affects sampling in Turing.jl. +**Other changes** + +LogDensityProblemsAD is now removed as a dependency. +Instead of constructing a `LogDensityProblemAD.ADgradient` object, we now directly use `DifferentiationInterface` to calculate the gradient of the log density with respect to model parameters. + +In practice, this means that if you want to calculate the gradient for a model, you can do: + +TODO(penelopeysm): Finish this + ## 0.34.2 - Fixed bugs in ValuesAsInModelContext as well as DebugContext where underlying PrefixContexts were not being applied. diff --git a/Project.toml b/Project.toml index ed5a4b5dd..b6c36a05c 100644 --- a/Project.toml +++ b/Project.toml @@ -28,7 +28,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" @@ -36,7 +35,6 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [extensions] DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] -DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLJETExt = ["JET"] DynamicPPLMCMCChainsExt = ["MCMCChains"] DynamicPPLMooncakeExt = ["Mooncake"] @@ -55,7 +53,6 @@ DifferentiationInterface = "0.6.39" Distributions = "0.25" DocStringExtensions = "0.9" EnzymeCore = "0.6 - 0.8" -ForwardDiff = "0.10" JET = "0.9" KernelAbstractions = "0.9.33" LinearAlgebra = "1.6" diff --git a/docs/src/api.md b/docs/src/api.md index 4d3c6bc97..2ff170c8e 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -54,10 +54,11 @@ logjoint ### LogDensityProblems.jl interface -The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by simply wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction`: +The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction` or `DynamicPPL.LogDensityFunctionWithGrad`. ```@docs DynamicPPL.LogDensityFunction +DynamicPPL.LogDensityFunctionWithGrad ``` ## Condition and decondition diff --git a/ext/DynamicPPLForwardDiffExt.jl b/ext/DynamicPPLForwardDiffExt.jl deleted file mode 100644 index a10e9eae6..000000000 --- a/ext/DynamicPPLForwardDiffExt.jl +++ /dev/null @@ -1,27 +0,0 @@ -module DynamicPPLForwardDiffExt - -using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems -using ForwardDiff - -getchunksize(::ADTypes.AutoForwardDiff{chunk}) where {chunk} = chunk - -standardtag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true -standardtag(::ADTypes.AutoForwardDiff) = false - -# Allow Turing tag in gradient etc. calls of the log density function -function ForwardDiff.checktag( - ::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}}, - ::DynamicPPL.LogDensityFunction, - ::AbstractArray{W}, -) where {V,W} - return true -end -function ForwardDiff.checktag( - ::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}}, - ::Base.Fix1{typeof(LogDensityProblems.logdensity),<:DynamicPPL.LogDensityFunction}, - ::AbstractArray{W}, -) where {V,W} - return true -end - -end # module diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 426312c88..ecb78218b 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -4,6 +4,10 @@ import DifferentiationInterface as DI LogDensityFunction A callable representing a log density function of a `model`. +`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface, +but only to 0th-order, i.e. it is only possible to calculate the log density, +and not its gradient. If you need to calculate the gradient as well, you have +to construct a [`DynamicPPL.LogDensityFunctionWithGrad`](@ref) object. # Fields $(FIELDS) @@ -55,16 +59,6 @@ struct LogDensityFunction{V,M,C} context::C end -# TODO: Deprecate. -function LogDensityFunction( - varinfo::AbstractVarInfo, - model::Model, - sampler::AbstractSampler, - context::AbstractContext, -) - return LogDensityFunction(varinfo, model, SamplingContext(sampler, context)) -end - function LogDensityFunction( model::Model, varinfo::AbstractVarInfo=VarInfo(model), @@ -94,11 +88,6 @@ function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) return Accessors.@set f.model = model end -# HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time -# we need to define these annoying methods to ensure that we stay compatible with everything. -getsampler(f::LogDensityFunction) = getsampler(getcontext(f)) -hassampler(f::LogDensityFunction) = hassampler(getcontext(f)) - """ getparams(f::LogDensityFunction) @@ -106,36 +95,69 @@ Return the parameters of the wrapped varinfo as a vector. """ getparams(f::LogDensityFunction) = f.varinfo[:] -# LogDensityProblems interface -function LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector) +# LogDensityProblems interface: logp (0th order) +function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector) context = getcontext(f) - vi_new = unflatten(f.varinfo, θ) + vi_new = unflatten(f.varinfo, x) return getlogp(last(evaluate!!(f.model, vi_new, context))) end +function _flipped_logdensity(x::AbstractVector, f::LogDensityFunction) + return LogDensityProblems.logdensity(f, x) +end function LogDensityProblems.capabilities(::Type{<:LogDensityFunction}) return LogDensityProblems.LogDensityOrder{0}() end # TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)? LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) -_flipped_logdensity(θ, f) = LogDensityProblems.logdensity(f, θ) +# LogDensityProblems interface: gradient (1st order) +""" + LogDensityFunctionWithGrad(ldf::DynamicPPL.LogDensityFunction, adtype::ADTypes.AbstractADType) + +A callable representing a log density function of a `model`. +`DynamicPPL.LogDensityFunctionWithGrad` implements the LogDensityProblems.jl +interface to 1st-order, meaning that you can both calculate the log density +using + + LogDensityProblems.logdensity(f, x) + +and its gradient using + LogDensityProblems.logdensity_and_gradient(f, x) + +where `f` is a `LogDensityFunctionWithGrad` object and `x` is a vector of parameters. + +# Fields +$(FIELDS) +""" +struct LogDensityFunctionWithGrad{V,M,C,TAD<:ADTypes.AbstractADType} + ldf::LogDensityFunction{V,M,C} + adtype::TAD + prep::DI.GradientPrep + + function LogDensityFunctionWithGrad( + ldf::LogDensityFunction{V,M,C}, adtype::TAD + ) where {V,M,C,TAD} + # Get a set of dummy params to use for prep and concretise type + x = map(identity, getparams(ldf)) + prep = DI.prepare_gradient(_flipped_logdensity, adtype, x, DI.Constant(ldf)) + # Store the prep with the struct + return new{V,M,C,TAD}(ldf, adtype, prep) + end +end +function LogDensityProblems.logdensity(f::LogDensityFunctionWithGrad) + return LogDensityProblems.logdensity(f.ldf) +end +function LogDensityProblems.capabilities(::Type{<:LogDensityFunctionWithGrad}) + return LogDensityProblems.LogDensityOrder{1}() +end # By default, the AD backend to use is inferred from the context, which would # typically be a SamplingContext which contains a sampler. function LogDensityProblems.logdensity_and_gradient( - f::LogDensityFunction, θ::AbstractVector -) - adtype = getadtype(getsampler(getcontext(f))) - return LogDensityProblems.logdensity_and_gradient(f, θ, adtype) -end - -# Extra method allowing one to manually specify the AD backend to use, thus -# overriding the default AD backend inferred from the sampler. -function LogDensityProblems.logdensity_and_gradient( - f::LogDensityFunction, θ::AbstractVector, adtype::ADTypes.AbstractADType + f::LogDensityFunctionWithGrad, x::AbstractVector ) - # Ensure we concretise the elements of the params. - θ = map(identity, θ) # TODO: Is this needed? - prep = DI.prepare_gradient(_flipped_logdensity, adtype, θ, DI.Constant(f)) - return DI.value_and_gradient(_flipped_logdensity, prep, adtype, θ, DI.Constant(f)) + x = map(identity, x) # Concretise type + return DI.value_and_gradient( + _flipped_logdensity, f.prep, f.adtype, x, DI.Constant(f.ldf) + ) end diff --git a/src/sampler.jl b/src/sampler.jl index fe49ee45f..aa3a637ee 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -49,9 +49,6 @@ struct Sampler{T} <: AbstractSampler alg::T end -# Extract the AD type from the underlying algorithm -getadtype(s::Sampler) = getadtype(s.alg) - # AbstractMCMC interface for SampleFromUniform and SampleFromPrior function AbstractMCMC.step( rng::Random.AbstractRNG, diff --git a/test/ad.jl b/test/ad.jl index 354a8bd0c..fc5e32109 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,3 +1,5 @@ +using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad + @testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) @@ -5,27 +7,46 @@ varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - f = DynamicPPL.LogDensityFunction(m, varinfo) - # convert to `Vector{Float64}` to avoid `ReverseDiff` initializing the gradients to Integer 0 - # reference: https://github.com/TuringLang/DynamicPPL.jl/pull/571#issuecomment-1924304489 - θ = convert(Vector{Float64}, varinfo[:]) + f = LogDensityFunction(m, varinfo) + x = DynamicPPL.getparams(f) # Calculate reference logp + gradient of logp using ForwardDiff default_adtype = ADTypes.AutoForwardDiff() + ldf_with_grad = LogDensityFunctionWithGrad(f, default_adtype) ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient( - f, θ, default_adtype + ldf_with_grad, x ) @testset "$adtype" for adtype in [ - ADTypes.AutoReverseDiff(; compile=false), - ADTypes.AutoReverseDiff(; compile=true), - ADTypes.AutoMooncake(; config=nothing), + AutoReverseDiff(; compile=false), + AutoReverseDiff(; compile=true), + AutoMooncake(; config=nothing), ] - # Mooncake can't currently handle something that is going on in - # SimpleVarInfo{<:VarNamedVector}. Disable all SimpleVarInfo tests for now. - if adtype isa ADTypes.AutoMooncake && varinfo isa DynamicPPL.SimpleVarInfo - @test_broken 1 == 0 + @info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype" + + # Put predicates here to avoid long lines + is_mooncake = adtype isa AutoMooncake + is_1_10 = v"1.10" <= VERSION < v"1.11" + is_1_11 = v"1.11" <= VERSION < v"1.12" + is_svi_vnv = varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} + is_svi_od = varinfo isa SimpleVarInfo{<:OrderedDict} + + # Mooncake doesn't work with several combinations of SimpleVarInfo. + if is_mooncake && is_1_11 && is_svi_vnv + # https://github.com/compintell/Mooncake.jl/issues/470 + @test_throws ArgumentError LogDensityFunctionWithGrad(f, adtype) + elseif is_mooncake && is_1_10 && is_svi_vnv + # TODO: report upstream + @test_throws UndefRefError LogDensityFunctionWithGrad(f, adtype) + elseif is_mooncake && is_1_10 && is_svi_od + # TODO: report upstream + @test_throws Mooncake.MooncakeRuleCompilationError LogDensityFunctionWithGrad( + f, adtype + ) else - logp, grad = LogDensityProblems.logdensity_and_gradient(f, θ, adtype) + ldf_with_grad = LogDensityFunctionWithGrad(f, adtype) + logp, grad = LogDensityProblems.logdensity_and_gradient( + ldf_with_grad, x + ) @test grad ≈ ref_grad @test logp ≈ ref_logp end @@ -61,15 +82,15 @@ # overload assume so that model evaluation doesn't fail due to a lack # of implementation struct MyEmptyAlg end - DynamicPPL.assume(rng, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi) = - DynamicPPL.assume(dist, vn, vi) + DynamicPPL.assume( + ::Random.AbstractRNG, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi + ) = DynamicPPL.assume(dist, vn, vi) # Compiling the ReverseDiff tape used to fail here spl = Sampler(MyEmptyAlg()) vi = VarInfo(model) - ldf = DynamicPPL.LogDensityFunction(vi, model, SamplingContext(spl)) - @test LogDensityProblems.logdensity_and_gradient( - ldf, vi[:], AutoReverseDiff(; compile=true) - ) isa Any + ldf = LogDensityFunction(vi, model, SamplingContext(spl)) + ldf_grad = LogDensityFunctionWithGrad(ldf, AutoReverseDiff(; compile=true)) + @test LogDensityProblems.logdensity_and_gradient(ldf_grad, vi[:]) isa Any end end diff --git a/test/ext/DynamicPPLForwardDiffExt.jl b/test/ext/DynamicPPLForwardDiffExt.jl deleted file mode 100644 index 8de28046b..000000000 --- a/test/ext/DynamicPPLForwardDiffExt.jl +++ /dev/null @@ -1,14 +0,0 @@ -@testset "tag" begin - for chunksize in (nothing, 0, 1, 10) - ad = ADTypes.AutoForwardDiff(; chunksize=chunksize) - standardtag = if !isdefined(Base, :get_extension) - DynamicPPL.DynamicPPLForwardDiffExt.standardtag - else - Base.get_extension(DynamicPPL, :DynamicPPLForwardDiffExt).standardtag - end - @test standardtag(ad) - for tag in (false, 0, 1) - @test !standardtag(AutoForwardDiff(; chunksize=chunksize, tag=tag)) - end - end -end diff --git a/test/runtests.jl b/test/runtests.jl index caddef5f9..3f065b731 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -75,7 +75,6 @@ include("test_util.jl") include("ext/DynamicPPLJETExt.jl") end @testset "ad" begin - include("ext/DynamicPPLForwardDiffExt.jl") include("ext/DynamicPPLMooncakeExt.jl") include("ad.jl") end diff --git a/test/test_util.jl b/test/test_util.jl index 27a68456c..d831a5ea6 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -56,6 +56,15 @@ function short_varinfo_name(vi::TypedVarInfo) end short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" short_varinfo_name(::DynamicPPL.VectorVarInfo) = "VectorVarInfo" +function short_varinfo_name(::SimpleVarInfo{<:NamedTuple,<:Ref}) + return "SimpleVarInfo{<:NamedTuple,<:Ref}" +end +function short_varinfo_name(::SimpleVarInfo{<:OrderedDict,<:Ref}) + return "SimpleVarInfo{<:OrderedDict,<:Ref}" +end +function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector,<:Ref}) + return "SimpleVarInfo{<:VarNamedVector,<:Ref}" +end short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector}) From bb832ab19a9280e21ec1135f5bb2a45856aa0280 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 14 Feb 2025 11:31:20 +0000 Subject: [PATCH 03/11] Dynamically decide whether to use closure vs constant --- src/logdensityfunction.jl | 74 ++++++++++++++++++++++++++++++++------- 1 file changed, 62 insertions(+), 12 deletions(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index ecb78218b..867e8ed80 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -101,9 +101,6 @@ function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector) vi_new = unflatten(f.varinfo, x) return getlogp(last(evaluate!!(f.model, vi_new, context))) end -function _flipped_logdensity(x::AbstractVector, f::LogDensityFunction) - return LogDensityProblems.logdensity(f, x) -end function LogDensityProblems.capabilities(::Type{<:LogDensityFunction}) return LogDensityProblems.LogDensityOrder{0}() end @@ -111,6 +108,47 @@ end LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) # LogDensityProblems interface: gradient (1st order) +""" + use_closure(adtype::ADTypes.AbstractADType) + +In LogDensityProblems, we want to calculate the derivative of logdensity(f, x) +with respect to x, where f is the model (in our case LogDensityFunction) and is +a constant. However, DifferentiationInterface generally expects a +single-argument function g(x) to differentiate. + +There are two ways of dealing with this: + +1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f) + +2. Use a constant context. This lets us pass a two-argument function to + DifferentiationInterface, as long as we also give it the 'inactive argument' + (i.e. the model) wrapped in `DI.Constant`. + +The relative performance of the two approaches, however, depends on the AD +backend used. Some benchmarks are provided here: +https://github.com/TuringLang/DynamicPPL.jl/pull/806#issuecomment-2658061480 + +This function is used to determine whether a given AD backend should use a +closure or a constant. If `use_closure(adtype)` returns `true`, then the +closure approach will be used. By default, this function returns `false`, i.e. +the constant approach will be used. +""" +use_closure(::ADTypes.AbstractADType) = false +use_closure(::ADTypes.AutoForwardDiff) = false +use_closure(::ADTypes.AutoMooncake) = false +use_closure(::ADTypes.AutoReverseDiff) = true + +""" + _flipped_logdensity(f::LogDensityFunction, x::AbstractVector) + +This function is the same as `LogDensityProblems.logdensity(f, x)` but with the +arguments flipped. It is used in the 'constant' approach to DifferentiationInterface +(see `use_closure` for more information). +""" +function _flipped_logdensity(x::AbstractVector, f::LogDensityFunction) + return LogDensityProblems.logdensity(f, x) +end + """ LogDensityFunctionWithGrad(ldf::DynamicPPL.LogDensityFunction, adtype::ADTypes.AbstractADType) @@ -134,15 +172,25 @@ struct LogDensityFunctionWithGrad{V,M,C,TAD<:ADTypes.AbstractADType} ldf::LogDensityFunction{V,M,C} adtype::TAD prep::DI.GradientPrep + with_closure::Bool function LogDensityFunctionWithGrad( ldf::LogDensityFunction{V,M,C}, adtype::TAD ) where {V,M,C,TAD} - # Get a set of dummy params to use for prep and concretise type + # Get a set of dummy params to use for prep x = map(identity, getparams(ldf)) - prep = DI.prepare_gradient(_flipped_logdensity, adtype, x, DI.Constant(ldf)) - # Store the prep with the struct - return new{V,M,C,TAD}(ldf, adtype, prep) + with_closure = use_closure(adtype) + if with_closure + prep = DI.prepare_gradient( + Base.Fix1(LogDensityProblems.logdensity, ldf), adtype, x + ) + else + prep = DI.prepare_gradient(_flipped_logdensity, adtype, x, DI.Constant(ldf)) + end + # Store the prep with the struct. We also store whether a closure was used because + # we need to know this when calling `DI.value_and_gradient`. In practice we could + # recalculate it, but this runs the risk of introducing inconsistencies. + return new{V,M,C,TAD}(ldf, adtype, prep, with_closure) end end function LogDensityProblems.logdensity(f::LogDensityFunctionWithGrad) @@ -151,13 +199,15 @@ end function LogDensityProblems.capabilities(::Type{<:LogDensityFunctionWithGrad}) return LogDensityProblems.LogDensityOrder{1}() end -# By default, the AD backend to use is inferred from the context, which would -# typically be a SamplingContext which contains a sampler. function LogDensityProblems.logdensity_and_gradient( f::LogDensityFunctionWithGrad, x::AbstractVector ) x = map(identity, x) # Concretise type - return DI.value_and_gradient( - _flipped_logdensity, f.prep, f.adtype, x, DI.Constant(f.ldf) - ) + return if f.with_closure + DI.value_and_gradient( + Base.Fix1(LogDensityProblems.logdensity, f.ldf), f.prep, f.adtype, x + ) + else + DI.value_and_gradient(_flipped_logdensity, f.prep, f.adtype, x, DI.Constant(f.ldf)) + end end From 74c9f12e10c5f4e0001659370d04ca25fa6837ec Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 15 Feb 2025 16:57:33 +0000 Subject: [PATCH 04/11] Combine LogDensityFunction{,WithGrad} into one (#811) --- docs/src/api.md | 3 +- src/logdensityfunction.jl | 266 +++++++++++++++++++++++--------------- test/ad.jl | 31 ++--- 3 files changed, 175 insertions(+), 125 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 2ff170c8e..60bdc05d9 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -54,11 +54,10 @@ logjoint ### LogDensityProblems.jl interface -The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction` or `DynamicPPL.LogDensityFunctionWithGrad`. +The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction`. ```@docs DynamicPPL.LogDensityFunction -DynamicPPL.LogDensityFunctionWithGrad ``` ## Condition and decondition diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 867e8ed80..dd048aad5 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -1,18 +1,39 @@ import DifferentiationInterface as DI """ - LogDensityFunction - -A callable representing a log density function of a `model`. -`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface, -but only to 0th-order, i.e. it is only possible to calculate the log density, -and not its gradient. If you need to calculate the gradient as well, you have -to construct a [`DynamicPPL.LogDensityFunctionWithGrad`](@ref) object. + LogDensityFunction( + model::Model, + varinfo::AbstractVarInfo=VarInfo(model), + context::AbstractContext=DefaultContext(); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing + ) + +A struct which contains a model, along with all the information necessary to: + + - calculate its log density at a given point; + - and if `adtype` is provided, calculate the gradient of the log density at + that point. + +At its most basic level, a LogDensityFunction wraps the model together with its +the type of varinfo to be used, as well as the evaluation context. These must +be known in order to calculate the log density (using +[`DynamicPPL.evaluate!!`](@ref)). + +If the `adtype` keyword argument is provided, then this struct will also store +the adtype along with other information for efficient calculation of the +gradient of the log density. Note that preparing a `LogDensityFunction` with an +AD type `AutoBackend()` requires the AD backend itself to have been loaded +(e.g. with `import Backend`). + +`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface. +If `adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a +concrete AD backend type, then `logdensity_and_gradient` is also implemented. # Fields $(FIELDS) # Examples + ```jldoctest julia> using Distributions @@ -48,66 +69,150 @@ julia> # This also respects the context in `model`. julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) true + +julia> # If we also need to calculate the gradient, we can specify an AD backend. + import ForwardDiff, ADTypes + +julia> f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff()); + +julia> LogDensityProblems.logdensity_and_gradient(f, [0.0]) +(-2.3378770664093453, [1.0]) ``` """ -struct LogDensityFunction{V,M,C} - "varinfo used for evaluation" - varinfo::V +struct LogDensityFunction{ + M<:Model,V<:AbstractVarInfo,C<:AbstractContext,AD<:Union{Nothing,ADTypes.AbstractADType} +} "model used for evaluation" model::M + "varinfo used for evaluation" + varinfo::V "context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable" context::C -end - -function LogDensityFunction( - model::Model, - varinfo::AbstractVarInfo=VarInfo(model), - context::Union{Nothing,AbstractContext}=nothing, -) - return LogDensityFunction(varinfo, model, context) -end + "AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated" + adtype::AD + "(internal use only) gradient preparation object for the model" + prep::Union{Nothing,DI.GradientPrep} + "(internal use only) whether a closure was used for the gradient preparation" + with_closure::Bool -# If a `context` has been specified, we use that. Otherwise we just use the leaf context of `model`. -function getcontext(f::LogDensityFunction) - return f.context === nothing ? leafcontext(f.model.context) : f.context + function LogDensityFunction( + model::Model, + varinfo::AbstractVarInfo=VarInfo(model), + context::AbstractContext=leafcontext(model.context); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + ) + if adtype === nothing + prep = nothing + with_closure = false + else + # Get a set of dummy params to use for prep + x = map(identity, varinfo[:]) + with_closure = use_closure(adtype) + if with_closure + prep = DI.prepare_gradient( + x -> logdensity_at(x, model, varinfo, context), adtype, x + ) + else + prep = DI.prepare_gradient( + logdensity_at, + adtype, + x, + DI.Constant(model), + DI.Constant(varinfo), + DI.Constant(context), + ) + end + with_closure = with_closure + end + return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}( + model, varinfo, context, adtype, prep, with_closure + ) + end end """ - getmodel(f) + setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType}) -Return the `DynamicPPL.Model` wrapped in the given log-density function `f`. -""" -getmodel(f::DynamicPPL.LogDensityFunction) = f.model +Set the AD type used for evaluation of log density gradient in the given LogDensityFunction. +This function also performs preparation of the gradient, and sets the `prep` +and `with_closure` fields of the LogDensityFunction. -""" - setmodel(f, model[, adtype]) +If `adtype` is `nothing`, the `prep` field will be set to `nothing` as well. -Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. +This function returns a new LogDensityFunction with the updated AD type, i.e. it does +not mutate the input LogDensityFunction. """ -function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) - return Accessors.@set f.model = model +function setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType}) + return if adtype === f.adtype + f # Avoid recomputing prep if not needed + else + LogDensityFunction(f.model, f.varinfo, f.context; adtype=adtype) + end end """ - getparams(f::LogDensityFunction) - -Return the parameters of the wrapped varinfo as a vector. + logdensity_at( + x::AbstractVector, + model::Model, + varinfo::AbstractVarInfo, + context::AbstractContext + ) + +Evaluate the log density of the given `model` at the given parameter values `x`, +using the given `varinfo` and `context`. Note that the `varinfo` argument is provided +only for its structure, in the sense that the parameters from the vector `x` are inserted into +it, and its own parameters are discarded. """ -getparams(f::LogDensityFunction) = f.varinfo[:] +function logdensity_at( + x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext +) + varinfo_new = unflatten(varinfo, x) + return getlogp(last(evaluate!!(model, varinfo_new, context))) +end + +### LogDensityProblems interface -# LogDensityProblems interface: logp (0th order) +function LogDensityProblems.capabilities( + ::Type{<:LogDensityFunction{M,V,C,Nothing}} +) where {M,V,C} + return LogDensityProblems.LogDensityOrder{0}() +end +function LogDensityProblems.capabilities( + ::Type{<:LogDensityFunction{M,V,C,AD}} +) where {M,V,C,AD<:ADTypes.AbstractADType} + return LogDensityProblems.LogDensityOrder{1}() +end function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector) - context = getcontext(f) - vi_new = unflatten(f.varinfo, x) - return getlogp(last(evaluate!!(f.model, vi_new, context))) + return logdensity_at(x, f.model, f.varinfo, f.context) end -function LogDensityProblems.capabilities(::Type{<:LogDensityFunction}) - return LogDensityProblems.LogDensityOrder{0}() +function LogDensityProblems.logdensity_and_gradient( + f::LogDensityFunction{M,V,C,AD}, x::AbstractVector +) where {M,V,C,AD<:ADTypes.AbstractADType} + f.prep === nothing && + error("Gradient preparation not available; this should not happen") + x = map(identity, x) # Concretise type + return if f.with_closure + DI.value_and_gradient( + x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x + ) + else + DI.value_and_gradient( + logdensity_at, + f.prep, + f.adtype, + x, + DI.Constant(f.model), + DI.Constant(f.varinfo), + DI.Constant(f.context), + ) + end end + # TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)? LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) -# LogDensityProblems interface: gradient (1st order) +### Utils + """ use_closure(adtype::ADTypes.AbstractADType) @@ -139,75 +244,24 @@ use_closure(::ADTypes.AutoMooncake) = false use_closure(::ADTypes.AutoReverseDiff) = true """ - _flipped_logdensity(f::LogDensityFunction, x::AbstractVector) + getmodel(f) -This function is the same as `LogDensityProblems.logdensity(f, x)` but with the -arguments flipped. It is used in the 'constant' approach to DifferentiationInterface -(see `use_closure` for more information). +Return the `DynamicPPL.Model` wrapped in the given log-density function `f`. """ -function _flipped_logdensity(x::AbstractVector, f::LogDensityFunction) - return LogDensityProblems.logdensity(f, x) -end +getmodel(f::DynamicPPL.LogDensityFunction) = f.model """ - LogDensityFunctionWithGrad(ldf::DynamicPPL.LogDensityFunction, adtype::ADTypes.AbstractADType) - -A callable representing a log density function of a `model`. -`DynamicPPL.LogDensityFunctionWithGrad` implements the LogDensityProblems.jl -interface to 1st-order, meaning that you can both calculate the log density -using - - LogDensityProblems.logdensity(f, x) - -and its gradient using - - LogDensityProblems.logdensity_and_gradient(f, x) + setmodel(f, model[, adtype]) -where `f` is a `LogDensityFunctionWithGrad` object and `x` is a vector of parameters. +Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. +""" +function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) + return LogDensityFunction(model, f.varinfo, f.context; adtype=f.adtype) +end -# Fields -$(FIELDS) """ -struct LogDensityFunctionWithGrad{V,M,C,TAD<:ADTypes.AbstractADType} - ldf::LogDensityFunction{V,M,C} - adtype::TAD - prep::DI.GradientPrep - with_closure::Bool + getparams(f::LogDensityFunction) - function LogDensityFunctionWithGrad( - ldf::LogDensityFunction{V,M,C}, adtype::TAD - ) where {V,M,C,TAD} - # Get a set of dummy params to use for prep - x = map(identity, getparams(ldf)) - with_closure = use_closure(adtype) - if with_closure - prep = DI.prepare_gradient( - Base.Fix1(LogDensityProblems.logdensity, ldf), adtype, x - ) - else - prep = DI.prepare_gradient(_flipped_logdensity, adtype, x, DI.Constant(ldf)) - end - # Store the prep with the struct. We also store whether a closure was used because - # we need to know this when calling `DI.value_and_gradient`. In practice we could - # recalculate it, but this runs the risk of introducing inconsistencies. - return new{V,M,C,TAD}(ldf, adtype, prep, with_closure) - end -end -function LogDensityProblems.logdensity(f::LogDensityFunctionWithGrad) - return LogDensityProblems.logdensity(f.ldf) -end -function LogDensityProblems.capabilities(::Type{<:LogDensityFunctionWithGrad}) - return LogDensityProblems.LogDensityOrder{1}() -end -function LogDensityProblems.logdensity_and_gradient( - f::LogDensityFunctionWithGrad, x::AbstractVector -) - x = map(identity, x) # Concretise type - return if f.with_closure - DI.value_and_gradient( - Base.Fix1(LogDensityProblems.logdensity, f.ldf), f.prep, f.adtype, x - ) - else - DI.value_and_gradient(_flipped_logdensity, f.prep, f.adtype, x, DI.Constant(f.ldf)) - end -end +Return the parameters of the wrapped varinfo as a vector. +""" +getparams(f::LogDensityFunction) = f.varinfo[:] diff --git a/test/ad.jl b/test/ad.jl index fc5e32109..d0d2a7b8a 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,4 +1,4 @@ -using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad +using DynamicPPL: LogDensityFunction @testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS @@ -10,11 +10,9 @@ using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad f = LogDensityFunction(m, varinfo) x = DynamicPPL.getparams(f) # Calculate reference logp + gradient of logp using ForwardDiff - default_adtype = ADTypes.AutoForwardDiff() - ldf_with_grad = LogDensityFunctionWithGrad(f, default_adtype) - ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient( - ldf_with_grad, x - ) + ref_adtype = ADTypes.AutoForwardDiff() + ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype) + ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x) @testset "$adtype" for adtype in [ AutoReverseDiff(; compile=false), @@ -33,20 +31,18 @@ using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad # Mooncake doesn't work with several combinations of SimpleVarInfo. if is_mooncake && is_1_11 && is_svi_vnv # https://github.com/compintell/Mooncake.jl/issues/470 - @test_throws ArgumentError LogDensityFunctionWithGrad(f, adtype) + @test_throws ArgumentError DynamicPPL.setadtype(ref_ldf, adtype) elseif is_mooncake && is_1_10 && is_svi_vnv # TODO: report upstream - @test_throws UndefRefError LogDensityFunctionWithGrad(f, adtype) + @test_throws UndefRefError DynamicPPL.setadtype(ref_ldf, adtype) elseif is_mooncake && is_1_10 && is_svi_od # TODO: report upstream - @test_throws Mooncake.MooncakeRuleCompilationError LogDensityFunctionWithGrad( - f, adtype + @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.setadtype( + ref_ldf, adtype ) else - ldf_with_grad = LogDensityFunctionWithGrad(f, adtype) - logp, grad = LogDensityProblems.logdensity_and_gradient( - ldf_with_grad, x - ) + ldf = DynamicPPL.setadtype(ref_ldf, adtype) + logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x) @test grad ≈ ref_grad @test logp ≈ ref_logp end @@ -89,8 +85,9 @@ using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad # Compiling the ReverseDiff tape used to fail here spl = Sampler(MyEmptyAlg()) vi = VarInfo(model) - ldf = LogDensityFunction(vi, model, SamplingContext(spl)) - ldf_grad = LogDensityFunctionWithGrad(ldf, AutoReverseDiff(; compile=true)) - @test LogDensityProblems.logdensity_and_gradient(ldf_grad, vi[:]) isa Any + ldf = LogDensityFunction( + model, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true) + ) + @test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any end end From 8d10fc1833c8a41ded6875b0a0c5d8d02b0737ec Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 15 Feb 2025 20:18:17 +0000 Subject: [PATCH 05/11] Warn if unsupported AD type is used --- src/logdensityfunction.jl | 17 ++++++++ test/ad.jl | 91 +++++++++++++++++++++------------------ 2 files changed, 67 insertions(+), 41 deletions(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index dd048aad5..77cc21475 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -1,5 +1,19 @@ import DifferentiationInterface as DI +""" + is_supported(adtype::AbstractADType) + +Check if the given AD type is formally supported by DynamicPPL. + +AD backends that are not formally supported can still be used for gradient +calculation; it is just that the DynamicPPL developers do not commit to +maintaining compatibility with them. +""" +is_supported(::ADTypes.AbstractADType) = false +is_supported(::ADTypes.AutoForwardDiff) = true +is_supported(::ADTypes.AutoMooncake) = true +is_supported(::ADTypes.AutoReverseDiff) = true + """ LogDensityFunction( model::Model, @@ -105,6 +119,9 @@ struct LogDensityFunction{ prep = nothing with_closure = false else + # Check support + is_supported(adtype) || + @warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed." # Get a set of dummy params to use for prep x = map(identity, varinfo[:]) with_closure = use_closure(adtype) diff --git a/test/ad.jl b/test/ad.jl index d0d2a7b8a..9dfef55c9 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,50 +1,59 @@ using DynamicPPL: LogDensityFunction -@testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) - vns = DynamicPPL.TestUtils.varnames(m) - varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) +@testset "Automatic differentiation" begin + @testset "Unsupported backends" begin + @model demo() = x ~ Normal() + @test_logs (:warn, r"not officially supported") LogDensityFunction( + demo(); adtype=AutoZygote() + ) + end + + @testset "Correctness: ForwardDiff, ReverseDiff, and Mooncake" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) + vns = DynamicPPL.TestUtils.varnames(m) + varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - f = LogDensityFunction(m, varinfo) - x = DynamicPPL.getparams(f) - # Calculate reference logp + gradient of logp using ForwardDiff - ref_adtype = ADTypes.AutoForwardDiff() - ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype) - ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x) + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + f = LogDensityFunction(m, varinfo) + x = DynamicPPL.getparams(f) + # Calculate reference logp + gradient of logp using ForwardDiff + ref_adtype = ADTypes.AutoForwardDiff() + ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype) + ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x) - @testset "$adtype" for adtype in [ - AutoReverseDiff(; compile=false), - AutoReverseDiff(; compile=true), - AutoMooncake(; config=nothing), - ] - @info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype" + @testset "$adtype" for adtype in [ + AutoReverseDiff(; compile=false), + AutoReverseDiff(; compile=true), + AutoMooncake(; config=nothing), + ] + @info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype" - # Put predicates here to avoid long lines - is_mooncake = adtype isa AutoMooncake - is_1_10 = v"1.10" <= VERSION < v"1.11" - is_1_11 = v"1.11" <= VERSION < v"1.12" - is_svi_vnv = varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} - is_svi_od = varinfo isa SimpleVarInfo{<:OrderedDict} + # Put predicates here to avoid long lines + is_mooncake = adtype isa AutoMooncake + is_1_10 = v"1.10" <= VERSION < v"1.11" + is_1_11 = v"1.11" <= VERSION < v"1.12" + is_svi_vnv = varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} + is_svi_od = varinfo isa SimpleVarInfo{<:OrderedDict} - # Mooncake doesn't work with several combinations of SimpleVarInfo. - if is_mooncake && is_1_11 && is_svi_vnv - # https://github.com/compintell/Mooncake.jl/issues/470 - @test_throws ArgumentError DynamicPPL.setadtype(ref_ldf, adtype) - elseif is_mooncake && is_1_10 && is_svi_vnv - # TODO: report upstream - @test_throws UndefRefError DynamicPPL.setadtype(ref_ldf, adtype) - elseif is_mooncake && is_1_10 && is_svi_od - # TODO: report upstream - @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.setadtype( - ref_ldf, adtype - ) - else - ldf = DynamicPPL.setadtype(ref_ldf, adtype) - logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x) - @test grad ≈ ref_grad - @test logp ≈ ref_logp + # Mooncake doesn't work with several combinations of SimpleVarInfo. + if is_mooncake && is_1_11 && is_svi_vnv + # https://github.com/compintell/Mooncake.jl/issues/470 + @test_throws ArgumentError DynamicPPL.setadtype(ref_ldf, adtype) + elseif is_mooncake && is_1_10 && is_svi_vnv + # TODO: report upstream + @test_throws UndefRefError DynamicPPL.setadtype(ref_ldf, adtype) + elseif is_mooncake && is_1_10 && is_svi_od + # TODO: report upstream + @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.setadtype( + ref_ldf, adtype + ) + else + ldf = DynamicPPL.setadtype(ref_ldf, adtype) + logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x) + @test grad ≈ ref_grad + @test logp ≈ ref_logp + end end end end From e6513a0d8cd47f0f4ddbad1c3fc49b1bf40a3de5 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 15 Feb 2025 19:41:19 +0000 Subject: [PATCH 06/11] Update changelog --- HISTORY.md | 42 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 36d827e92..abb64d13f 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -2,7 +2,7 @@ ## 0.35.0 -**Breaking** +**Breaking changes** ### `.~` right hand side must be a univariate distribution @@ -119,14 +119,50 @@ This release removes the feature of `VarInfo` where it kept track of which varia This change also affects sampling in Turing.jl. +### `LogDensityFunction` argument order + + - The method `LogDensityFunction(varinfo, model, context)` has been removed. + The only accepted order is `LogDensityFunction(model, varinfo, context; adtype)`. + (For an explanation of `adtype`, see below.) + The varinfo and context arguments are both still optional. + **Other changes** +### `LogDensityProblems` interface + LogDensityProblemsAD is now removed as a dependency. Instead of constructing a `LogDensityProblemAD.ADgradient` object, we now directly use `DifferentiationInterface` to calculate the gradient of the log density with respect to model parameters. -In practice, this means that if you want to calculate the gradient for a model, you can do: +Note that if you wish, you can still construct an `ADgradient` out of a `LogDensityFunction` object (there is nothing preventing this). + +However, in this version, `LogDensityFunction` now takes an extra AD type argument. +If this argument is not provided, the behaviour is exactly the same as before, i.e. you can calculate `logdensity` but not its gradient. +However, if you do pass an AD type, that will allow you to calculate the gradient as well. +You may thus find that it is easier to instead do this: + +```julia +@model f() = ... + +ldf = LogDensityFunction(f(); adtype=AutoForwardDiff()) +``` + +This will return an object which satisfies the `LogDensityProblems` interface to first-order, i.e. you can now directly call both + +``` +LogDensityProblems.logdensity(ldf, params) +LogDensityProblems.logdensity_and_gradient(ldf, params) +``` + +without having to construct a separate `ADgradient` object. -TODO(penelopeysm): Finish this +If you prefer, you can also use `setadtype` to tack on the AD type afterwards: + +```julia +@model f() = ... + +ldf = LogDensityFunction(f()) # by default, no adtype set +ldf_with_ad = setadtype(ldf, AutoForwardDiff()) +``` ## 0.34.2 From 74fbad24aee81ed51e6a9db6faa17854e84905e4 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 17 Feb 2025 14:16:40 +0000 Subject: [PATCH 07/11] Update DI compat bound Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b6c36a05c..ab8c6b5ec 100644 --- a/Project.toml +++ b/Project.toml @@ -49,7 +49,7 @@ Bijectors = "0.13.18, 0.14, 0.15" ChainRulesCore = "1" Compat = "4" ConstructionBase = "1.5.4" -DifferentiationInterface = "0.6.39" +DifferentiationInterface = "0.6.41" Distributions = "0.25" DocStringExtensions = "0.9" EnzymeCore = "0.6 - 0.8" From f76bb3d4f1589a40ddd4955dec15855eacc0f509 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 18 Feb 2025 21:13:00 +0000 Subject: [PATCH 08/11] Don't store with_closure inside LogDensityFunction Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> --- ext/DynamicPPLForwardDiffExt.jl | 54 +++++++++++++++++++++++++++++++++ src/logdensityfunction.jl | 19 +++++------- 2 files changed, 62 insertions(+), 11 deletions(-) create mode 100644 ext/DynamicPPLForwardDiffExt.jl diff --git a/ext/DynamicPPLForwardDiffExt.jl b/ext/DynamicPPLForwardDiffExt.jl new file mode 100644 index 000000000..4bc33e217 --- /dev/null +++ b/ext/DynamicPPLForwardDiffExt.jl @@ -0,0 +1,54 @@ +module DynamicPPLForwardDiffExt + +if isdefined(Base, :get_extension) + using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD + using ForwardDiff +else + using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD + using ..ForwardDiff +end + +getchunksize(::ADTypes.AutoForwardDiff{chunk}) where {chunk} = chunk + +standardtag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true +standardtag(::ADTypes.AutoForwardDiff) = false + +function LogDensityProblemsAD.ADgradient( + ad::ADTypes.AutoForwardDiff, ℓ::DynamicPPL.LogDensityFunction +) + θ = DynamicPPL.getparams(ℓ) + f = Base.Fix1(LogDensityProblems.logdensity, ℓ) + + # Define configuration for ForwardDiff. + tag = if standardtag(ad) + ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(θ)) + else + ForwardDiff.Tag(f, eltype(θ)) + end + chunk_size = getchunksize(ad) + chunk = if chunk_size == 0 || chunk_size === nothing + ForwardDiff.Chunk(θ) + else + ForwardDiff.Chunk(length(θ), chunk_size) + end + + return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk, tag, x=θ) +end + +# Allow Turing tag in gradient etc. calls of the log density function +function ForwardDiff.checktag( + ::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}}, + ::DynamicPPL.LogDensityFunction, + ::AbstractArray{W}, +) where {V,W} + return true +end +function ForwardDiff.checktag( + ::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}}, + ::Base.Fix1{typeof(LogDensityProblems.logdensity),<:DynamicPPL.LogDensityFunction}, + ::AbstractArray{W}, +) where {V,W} + return true +end + +end # module diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 77cc21475..9b6f93070 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -106,8 +106,6 @@ struct LogDensityFunction{ adtype::AD "(internal use only) gradient preparation object for the model" prep::Union{Nothing,DI.GradientPrep} - "(internal use only) whether a closure was used for the gradient preparation" - with_closure::Bool function LogDensityFunction( model::Model, @@ -117,15 +115,13 @@ struct LogDensityFunction{ ) if adtype === nothing prep = nothing - with_closure = false else # Check support is_supported(adtype) || @warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed." # Get a set of dummy params to use for prep x = map(identity, varinfo[:]) - with_closure = use_closure(adtype) - if with_closure + if use_closure(adtype) prep = DI.prepare_gradient( x -> logdensity_at(x, model, varinfo, context), adtype, x ) @@ -139,10 +135,9 @@ struct LogDensityFunction{ DI.Constant(context), ) end - with_closure = with_closure end return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}( - model, varinfo, context, adtype, prep, with_closure + model, varinfo, context, adtype, prep ) end end @@ -150,9 +145,9 @@ end """ setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType}) -Set the AD type used for evaluation of log density gradient in the given LogDensityFunction. -This function also performs preparation of the gradient, and sets the `prep` -and `with_closure` fields of the LogDensityFunction. +Set the AD type used for evaluation of log density gradient in the given +LogDensityFunction. This function also performs preparation of the gradient, +and sets the `prep` field of the LogDensityFunction. If `adtype` is `nothing`, the `prep` field will be set to `nothing` as well. @@ -208,7 +203,9 @@ function LogDensityProblems.logdensity_and_gradient( f.prep === nothing && error("Gradient preparation not available; this should not happen") x = map(identity, x) # Concretise type - return if f.with_closure + # Make branching statically inferrable, i.e. type-stable (even if the two + # branches happen to return different types) + return if use_closure(f.adtype) DI.value_and_gradient( x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x ) From 05f1bce85ad79a1863d557957a92eb575cd07384 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 18 Feb 2025 21:19:03 +0000 Subject: [PATCH 09/11] setadtype --> LogDensityFunction --- HISTORY.md | 5 +++-- src/logdensityfunction.jl | 20 ++++++++++---------- test/ad.jl | 12 ++++++++---- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index abb64d13f..3f999ccab 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -155,13 +155,14 @@ LogDensityProblems.logdensity_and_gradient(ldf, params) without having to construct a separate `ADgradient` object. -If you prefer, you can also use `setadtype` to tack on the AD type afterwards: +If you prefer, you can also construct a new `LogDensityFunction` with a new AD type afterwards. +The model, varinfo, and context will be taken from the original `LogDensityFunction`: ```julia @model f() = ... ldf = LogDensityFunction(f()) # by default, no adtype set -ldf_with_ad = setadtype(ldf, AutoForwardDiff()) +ldf_with_ad = LogDensityFunction(ldf, AutoForwardDiff()) ``` ## 0.34.2 diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 9b6f93070..57a5a4d91 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -143,18 +143,18 @@ struct LogDensityFunction{ end """ - setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType}) - -Set the AD type used for evaluation of log density gradient in the given -LogDensityFunction. This function also performs preparation of the gradient, -and sets the `prep` field of the LogDensityFunction. - -If `adtype` is `nothing`, the `prep` field will be set to `nothing` as well. + LogDensityFunction( + ldf::LogDensityFunction, + adtype::Union{Nothing,ADTypes.AbstractADType} + ) -This function returns a new LogDensityFunction with the updated AD type, i.e. it does -not mutate the input LogDensityFunction. +Create a new LogDensityFunction using the model, varinfo, and context from the given +`ldf` argument, but with the AD type set to `adtype`. To remove the AD type, pass +`nothing` as the second argument. """ -function setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType}) +function LogDensityFunction( + f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType} +) return if adtype === f.adtype f # Avoid recomputing prep if not needed else diff --git a/test/ad.jl b/test/ad.jl index 9dfef55c9..73519c3f5 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -39,17 +39,21 @@ using DynamicPPL: LogDensityFunction # Mooncake doesn't work with several combinations of SimpleVarInfo. if is_mooncake && is_1_11 && is_svi_vnv # https://github.com/compintell/Mooncake.jl/issues/470 - @test_throws ArgumentError DynamicPPL.setadtype(ref_ldf, adtype) + @test_throws ArgumentError DynamicPPL.LogDensityFunction( + ref_ldf, adtype + ) elseif is_mooncake && is_1_10 && is_svi_vnv # TODO: report upstream - @test_throws UndefRefError DynamicPPL.setadtype(ref_ldf, adtype) + @test_throws UndefRefError DynamicPPL.LogDensityFunction( + ref_ldf, adtype + ) elseif is_mooncake && is_1_10 && is_svi_od # TODO: report upstream - @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.setadtype( + @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction( ref_ldf, adtype ) else - ldf = DynamicPPL.setadtype(ref_ldf, adtype) + ldf = DynamicPPL.LogDensityFunction(ref_ldf, adtype) logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x) @test grad ≈ ref_grad @test logp ≈ ref_logp From 566257e84fcb39f5b9d3a3251082e27742671b37 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 18 Feb 2025 22:05:32 +0000 Subject: [PATCH 10/11] Re-add ForwardDiffExt (including tests) --- Project.toml | 2 + ext/DynamicPPLForwardDiffExt.jl | 72 +++++++++++----------------- src/logdensityfunction.jl | 24 +++++++++- test/ext/DynamicPPLForwardDiffExt.jl | 32 +++++++++++++ test/runtests.jl | 1 + 5 files changed, 87 insertions(+), 44 deletions(-) create mode 100644 test/ext/DynamicPPLForwardDiffExt.jl diff --git a/Project.toml b/Project.toml index ab8c6b5ec..26ab45425 100644 --- a/Project.toml +++ b/Project.toml @@ -28,6 +28,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" @@ -35,6 +36,7 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [extensions] DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] +DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLJETExt = ["JET"] DynamicPPLMCMCChainsExt = ["MCMCChains"] DynamicPPLMooncakeExt = ["Mooncake"] diff --git a/ext/DynamicPPLForwardDiffExt.jl b/ext/DynamicPPLForwardDiffExt.jl index 4bc33e217..6bd7a5d94 100644 --- a/ext/DynamicPPLForwardDiffExt.jl +++ b/ext/DynamicPPLForwardDiffExt.jl @@ -1,54 +1,40 @@ module DynamicPPLForwardDiffExt -if isdefined(Base, :get_extension) - using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD - using ForwardDiff -else - using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD - using ..ForwardDiff -end - -getchunksize(::ADTypes.AutoForwardDiff{chunk}) where {chunk} = chunk - -standardtag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true -standardtag(::ADTypes.AutoForwardDiff) = false - -function LogDensityProblemsAD.ADgradient( - ad::ADTypes.AutoForwardDiff, ℓ::DynamicPPL.LogDensityFunction -) - θ = DynamicPPL.getparams(ℓ) - f = Base.Fix1(LogDensityProblems.logdensity, ℓ) - - # Define configuration for ForwardDiff. - tag = if standardtag(ad) - ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(θ)) +using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems +using ForwardDiff + +# check if the AD type already has a tag +use_dynamicppl_tag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true +use_dynamicppl_tag(::ADTypes.AutoForwardDiff) = false + +function DynamicPPL.tweak_adtype( + ad::ADTypes.AutoForwardDiff{chunk_size}, + ::DynamicPPL.Model, + vi::DynamicPPL.AbstractVarInfo, + ::DynamicPPL.AbstractContext, +) where {chunk_size} + params = vi[:] + + # Use DynamicPPL tag to improve stack traces + # https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/ + # NOTE: DifferentiationInterface disables tag checking if the + # tag inside the AutoForwardDiff type is not nothing. See + # https://github.com/JuliaDiff/DifferentiationInterface.jl/blob/1df562180bdcc3e91c885aa5f4162a0be2ced850/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl#L338-L350. + # So we don't currently need to override ForwardDiff.checktag as well. + tag = if use_dynamicppl_tag(ad) + ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(params)) else - ForwardDiff.Tag(f, eltype(θ)) + ad.tag end - chunk_size = getchunksize(ad) + + # Optimise chunk size according to size of model chunk = if chunk_size == 0 || chunk_size === nothing - ForwardDiff.Chunk(θ) + ForwardDiff.Chunk(params) else - ForwardDiff.Chunk(length(θ), chunk_size) + ForwardDiff.Chunk(length(params), chunk_size) end - return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk, tag, x=θ) -end - -# Allow Turing tag in gradient etc. calls of the log density function -function ForwardDiff.checktag( - ::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}}, - ::DynamicPPL.LogDensityFunction, - ::AbstractArray{W}, -) where {V,W} - return true -end -function ForwardDiff.checktag( - ::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}}, - ::Base.Fix1{typeof(LogDensityProblems.logdensity),<:DynamicPPL.LogDensityFunction}, - ::AbstractArray{W}, -) where {V,W} - return true + return ADTypes.AutoForwardDiff(; chunksize=ForwardDiff.chunksize(chunk), tag=tag) end end # module diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 57a5a4d91..a42855f05 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -116,7 +116,9 @@ struct LogDensityFunction{ if adtype === nothing prep = nothing else - # Check support + # Make backend-specific tweaks to the adtype + adtype = tweak_adtype(adtype, model, varinfo, context) + # Check whether it is supported is_supported(adtype) || @warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed." # Get a set of dummy params to use for prep @@ -227,6 +229,26 @@ LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) ### Utils +""" + tweak_adtype( + adtype::ADTypes.AbstractADType, + model::Model, + varinfo::AbstractVarInfo, + context::AbstractContext + ) + +Return an 'optimised' form of the adtype. This is useful for doing +backend-specific optimisation of the adtype (e.g., for ForwardDiff, calculating +the chunk size: see the method override in `ext/DynamicPPLForwardDiffExt.jl`). +The model is passed as a parameter in case the optimisation depends on the +model. + +By default, this just returns the input unchanged. +""" +tweak_adtype( + adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo, ::AbstractContext +) = adtype + """ use_closure(adtype::ADTypes.AbstractADType) diff --git a/test/ext/DynamicPPLForwardDiffExt.jl b/test/ext/DynamicPPLForwardDiffExt.jl new file mode 100644 index 000000000..73a0510e9 --- /dev/null +++ b/test/ext/DynamicPPLForwardDiffExt.jl @@ -0,0 +1,32 @@ +module DynamicPPLForwardDiffExtTests + +using DynamicPPL +using ADTypes: AutoForwardDiff +using ForwardDiff: ForwardDiff +using Distributions: MvNormal +using LinearAlgebra: I +using Test: @test, @testset + +# get_chunksize(ad::AutoForwardDiff{chunk}) where {chunk} = chunk + +@testset "ForwardDiff tweak_adtype" begin + MODEL_SIZE = 10 + @model f() = x ~ MvNormal(zeros(MODEL_SIZE), I) + model = f() + varinfo = VarInfo(model) + context = DefaultContext() + + @testset "Chunk size setting" for chunksize in (nothing, 0) + base_adtype = AutoForwardDiff(; chunksize=chunksize) + new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo, context) + @test new_adtype isa AutoForwardDiff{MODEL_SIZE} + end + + @testset "Tag setting" begin + base_adtype = AutoForwardDiff() + new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo, context) + @test new_adtype.tag isa ForwardDiff.Tag{DynamicPPL.DynamicPPLTag} + end +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 3f065b731..caddef5f9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -75,6 +75,7 @@ include("test_util.jl") include("ext/DynamicPPLJETExt.jl") end @testset "ad" begin + include("ext/DynamicPPLForwardDiffExt.jl") include("ext/DynamicPPLMooncakeExt.jl") include("ad.jl") end From 04f640d02e6c93de3051fbf0979b6eb115efd0e2 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 19 Feb 2025 16:33:43 +0000 Subject: [PATCH 11/11] Add more tests for coverage --- test/logdensityfunction.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 319371609..d6e66ec59 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -1,4 +1,4 @@ -using Test, DynamicPPL, ADTypes, LogDensityProblems, ReverseDiff +using Test, DynamicPPL, ADTypes, LogDensityProblems, ForwardDiff @testset "`getmodel` and `setmodel`" begin @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS @@ -22,4 +22,15 @@ end @test LogDensityProblems.dimension(logdensity) == length(θ) end end + + @testset "capabilities" begin + model = DynamicPPL.TestUtils.DEMO_MODELS[1] + ldf = DynamicPPL.LogDensityFunction(model) + @test LogDensityProblems.capabilities(typeof(ldf)) == + LogDensityProblems.LogDensityOrder{0}() + + ldf_with_ad = DynamicPPL.LogDensityFunction(model; adtype=AutoForwardDiff()) + @test LogDensityProblems.capabilities(typeof(ldf_with_ad)) == + LogDensityProblems.LogDensityOrder{1}() + end end