diff --git a/Project.toml b/Project.toml index e49d11908..8cf8a9a51 100644 --- a/Project.toml +++ b/Project.toml @@ -26,6 +26,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" @@ -35,6 +36,7 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [extensions] DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] +DynamicPPLDifferentiationInterfaceTestExt = ["DifferentiationInterfaceTest"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLJETExt = ["JET"] @@ -52,6 +54,7 @@ ChainRulesCore = "1" Compat = "4" ConstructionBase = "1.5.4" DifferentiationInterface = "0.6.41" +DifferentiationInterfaceTest = "0.9.6" Distributions = "0.25" DocStringExtensions = "0.9" EnzymeCore = "0.6 - 0.8" diff --git a/ext/DynamicPPLDifferentiationInterfaceTestExt.jl b/ext/DynamicPPLDifferentiationInterfaceTestExt.jl new file mode 100644 index 000000000..431ea3fb8 --- /dev/null +++ b/ext/DynamicPPLDifferentiationInterfaceTestExt.jl @@ -0,0 +1,87 @@ +module DynamicPPLDifferentiationInterfaceTestExt + +using DynamicPPL: + DynamicPPL, + ADTypes, + LogDensityProblems, + Model, + DI, # DifferentiationInterface + AbstractVarInfo, + VarInfo, + LogDensityFunction +import DifferentiationInterfaceTest as DIT + +""" + REFERENCE_ADTYPE + +Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since +it's the default AD backend used in Turing.jl. +""" +const REFERENCE_ADTYPE = ADTypes.AutoForwardDiff() + +""" + make_scenario( + model::Model, + adtype::ADTypes.AbstractADType, + varinfo::AbstractVarInfo=VarInfo(model), + params::Vector{<:Real}=varinfo[:], + reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, + expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing, + ) + +Construct a DifferentiationInterfaceTest.Scenario for the given `model` and `adtype`. + +More docs to follow. +""" +function make_scenario( + model::Model, + adtype::ADTypes.AbstractADType; + varinfo::AbstractVarInfo=VarInfo(model), + params::Vector{<:Real}=varinfo[:], + reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, + expected_grad::Union{Nothing,Vector{<:Real}}=nothing, +) + params = map(identity, params) + context = DynamicPPL.DefaultContext() + adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo, context) + # Below is a performance optimisation, see: https://github.com/TuringLang/DynamicPPL.jl/pull/806#issuecomment-2658049143 + if DynamicPPL.use_closure(adtype) + f = x -> DynamicPPL.logdensity_at(x, model, varinfo, context) + di_contexts = () + else + f = DynamicPPL.logdensity_at + di_contexts = (DI.Constant(model), DI.Constant(varinfo), DI.Constant(context)) + end + + # Calculate ground truth to compare against + grad_true = if expected_grad === nothing + ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype) + LogDensityProblems.logdensity_and_gradient(ldf_reference, params)[2] + else + expected_grad + end + + return DIT.Scenario{:gradient,:out}( + f, params; contexts=di_contexts, res1=grad_true, name="$(model.f)" + ) +end + +function DynamicPPL.TestUtils.AD.run_ad( + model::Model, + adtype::ADTypes.AbstractADType; + varinfo::AbstractVarInfo=VarInfo(model), + params::Vector{<:Real}=varinfo[:], + reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, + expected_grad::Union{Nothing,Vector{<:Real}}=nothing, + kwargs..., +) + scen = make_scenario(model, adtype; varinfo=varinfo, expected_grad=expected_grad) + tweaked_adtype = DynamicPPL.tweak_adtype( + adtype, model, varinfo, DynamicPPL.DefaultContext() + ) + return DIT.test_differentiation( + tweaked_adtype, [scen]; scenario_intact=false, kwargs... + ) +end + +end diff --git a/src/test_utils.jl b/src/test_utils.jl index c7d12c927..db480eadc 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -19,4 +19,8 @@ include("test_utils/contexts.jl") include("test_utils/varinfo.jl") include("test_utils/sampler.jl") +module AD + function run_ad end +end + end diff --git a/test/Project.toml b/test/Project.toml index 9fa3fd872..fa382b08a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,6 +8,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" diff --git a/test/ad.jl b/test/ad.jl index a4f3dbfa7..b5a5a1384 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,4 +1,5 @@ using DynamicPPL: LogDensityFunction +import DifferentiationInterfaceTest as DIT @testset "Automatic differentiation" begin # Used as the ground truth that others are compared against. @@ -27,7 +28,7 @@ using DynamicPPL: LogDensityFunction x = DynamicPPL.getparams(f) # Calculate reference logp + gradient of logp using ForwardDiff ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype) - ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x) + ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)[2] @testset "$adtype" for adtype in test_adtypes @info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype" @@ -56,10 +57,9 @@ using DynamicPPL: LogDensityFunction ref_ldf, adtype ) else - ldf = DynamicPPL.LogDensityFunction(ref_ldf, adtype) - logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x) - @test grad ≈ ref_grad - @test logp ≈ ref_logp + DynamicPPL.TestUtils.AD.run_ad( + m, adtype; varinfo=varinfo, expected_grad=ref_grad + ) end end end