From 0d0d87b0af856a4bd09aba67ebc5fff163fac1dc Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 19 Mar 2025 23:11:59 +0000 Subject: [PATCH 1/3] Test Enzyme --- test/ad.jl | 234 ++++++++++++++++++++++++----------------------- test/runtests.jl | 98 ++++++++++---------- 2 files changed, 168 insertions(+), 164 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index a4f3dbfa7..6e6db594f 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,26 +1,30 @@ using DynamicPPL: LogDensityFunction +using EnzymeCore: set_runtime_activity, Forward, Reverse @testset "Automatic differentiation" begin # Used as the ground truth that others are compared against. ref_adtype = AutoForwardDiff() test_adtypes = [ - AutoReverseDiff(; compile=false), - AutoReverseDiff(; compile=true), - AutoMooncake(; config=nothing), + # AutoReverseDiff(; compile=false), + # AutoReverseDiff(; compile=true), + # AutoMooncake(; config=nothing), + AutoEnzyme(; mode=set_runtime_activity(Forward, true)), + AutoEnzyme(; mode=set_runtime_activity(Reverse, true)), ] - @testset "Unsupported backends" begin - @model demo() = x ~ Normal() - @test_logs (:warn, r"not officially supported") LogDensityFunction( - demo(); adtype=AutoZygote() - ) - end + # @testset "Unsupported backends" begin + # @model demo() = x ~ Normal() + # @test_logs (:warn, r"not officially supported") LogDensityFunction( + # demo(); adtype=AutoZygote() + # ) + # end - @testset "Correctness: ForwardDiff, ReverseDiff, and Mooncake" begin + @testset "Correctness on supported AD backends" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) - vns = DynamicPPL.TestUtils.varnames(m) - varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) + # rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) + # vns = DynamicPPL.TestUtils.varnames(m) + # varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) + varinfos = [VarInfo(m)] @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos f = LogDensityFunction(m, varinfo) @@ -66,106 +70,106 @@ using DynamicPPL: LogDensityFunction end end - @testset "Turing#2151: ReverseDiff compilation & eltype(vi, spl)" begin - # Failing model - t = 1:0.05:8 - σ = 0.3 - y = @. rand(sin(t) + Normal(0, σ)) - @model function state_space(y, TT, ::Type{T}=Float64) where {T} - # Priors - α ~ Normal(y[1], 0.001) - τ ~ Exponential(1) - η ~ filldist(Normal(0, 1), TT - 1) - σ ~ Exponential(1) - # create latent variable - x = Vector{T}(undef, TT) - x[1] = α - for t in 2:TT - x[t] = x[t - 1] + η[t - 1] * τ - end - # measurement model - y ~ MvNormal(x, σ^2 * I) - return x - end - model = state_space(y, length(t)) - - # Dummy sampling algorithm for testing. The test case can only be replicated - # with a custom sampler, it doesn't work with SampleFromPrior(). We need to - # overload assume so that model evaluation doesn't fail due to a lack - # of implementation - struct MyEmptyAlg end - DynamicPPL.assume( - ::Random.AbstractRNG, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi - ) = DynamicPPL.assume(dist, vn, vi) - - # Compiling the ReverseDiff tape used to fail here - spl = Sampler(MyEmptyAlg()) - vi = VarInfo(model) - ldf = LogDensityFunction( - model, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true) - ) - @test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any - end - - # Test that various different ways of specifying array types as arguments work with all - # ADTypes. - @testset "Array argument types" begin - test_m = randn(2, 3) - - function eval_logp_and_grad(model, m, adtype) - ldf = LogDensityFunction(model(); adtype=adtype) - return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) - end - - @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} - m = Matrix{T}(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - scalar_matrix_model_reference = eval_logp_and_grad( - scalar_matrix_model, test_m, ref_adtype - ) - - @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} - m = T(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) - - @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} - m = Array{T}(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - scalar_array_model_reference = eval_logp_and_grad( - scalar_array_model, test_m, ref_adtype - ) - - @model function array_model(::Type{T}=Array{Float64}) where {T} - m = T(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype) - - @testset "$adtype" for adtype in test_adtypes - scalar_matrix_model_logp_and_grad = eval_logp_and_grad( - scalar_matrix_model, test_m, adtype - ) - @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] - @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] - matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) - @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] - @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] - scalar_array_model_logp_and_grad = eval_logp_and_grad( - scalar_array_model, test_m, adtype - ) - @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] - @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] - array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) - @test array_model_logp_and_grad[1] ≈ array_model_reference[1] - @test array_model_logp_and_grad[2] ≈ array_model_reference[2] - end - end + # @testset "Turing#2151: ReverseDiff compilation & eltype(vi, spl)" begin + # # Failing model + # t = 1:0.05:8 + # σ = 0.3 + # y = @. rand(sin(t) + Normal(0, σ)) + # @model function state_space(y, TT, ::Type{T}=Float64) where {T} + # # Priors + # α ~ Normal(y[1], 0.001) + # τ ~ Exponential(1) + # η ~ filldist(Normal(0, 1), TT - 1) + # σ ~ Exponential(1) + # # create latent variable + # x = Vector{T}(undef, TT) + # x[1] = α + # for t in 2:TT + # x[t] = x[t - 1] + η[t - 1] * τ + # end + # # measurement model + # y ~ MvNormal(x, σ^2 * I) + # return x + # end + # model = state_space(y, length(t)) + # + # # Dummy sampling algorithm for testing. The test case can only be replicated + # # with a custom sampler, it doesn't work with SampleFromPrior(). We need to + # # overload assume so that model evaluation doesn't fail due to a lack + # # of implementation + # struct MyEmptyAlg end + # DynamicPPL.assume( + # ::Random.AbstractRNG, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi + # ) = DynamicPPL.assume(dist, vn, vi) + # + # # Compiling the ReverseDiff tape used to fail here + # spl = Sampler(MyEmptyAlg()) + # vi = VarInfo(model) + # ldf = LogDensityFunction( + # model, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true) + # ) + # @test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any + # end + # + # # Test that various different ways of specifying array types as arguments work with all + # # ADTypes. + # @testset "Array argument types" begin + # test_m = randn(2, 3) + # + # function eval_logp_and_grad(model, m, adtype) + # ldf = LogDensityFunction(model(); adtype=adtype) + # return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) + # end + # + # @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} + # m = Matrix{T}(undef, 2, 3) + # return m ~ filldist(MvNormal(zeros(2), I), 3) + # end + # + # scalar_matrix_model_reference = eval_logp_and_grad( + # scalar_matrix_model, test_m, ref_adtype + # ) + # + # @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} + # m = T(undef, 2, 3) + # return m ~ filldist(MvNormal(zeros(2), I), 3) + # end + # + # matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) + # + # @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} + # m = Array{T}(undef, 2, 3) + # return m ~ filldist(MvNormal(zeros(2), I), 3) + # end + # + # scalar_array_model_reference = eval_logp_and_grad( + # scalar_array_model, test_m, ref_adtype + # ) + # + # @model function array_model(::Type{T}=Array{Float64}) where {T} + # m = T(undef, 2, 3) + # return m ~ filldist(MvNormal(zeros(2), I), 3) + # end + # + # array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype) + # + # @testset "$adtype" for adtype in test_adtypes + # scalar_matrix_model_logp_and_grad = eval_logp_and_grad( + # scalar_matrix_model, test_m, adtype + # ) + # @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] + # @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] + # matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) + # @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] + # @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] + # scalar_array_model_logp_and_grad = eval_logp_and_grad( + # scalar_array_model, test_m, adtype + # ) + # @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] + # @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] + # array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) + # @test array_model_logp_and_grad[1] ≈ array_model_reference[1] + # @test array_model_logp_and_grad[2] ≈ array_model_reference[2] + # end + # end end diff --git a/test/runtests.jl b/test/runtests.jl index 3473d5594..0676c037b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,60 +45,60 @@ include("test_util.jl") # groups are chosen to make both groups take roughly the same amount of # time, but beyond that there is no particular reason for the split. if GROUP == "All" || GROUP == "Group1" - if AQUA - include("Aqua.jl") - end - include("utils.jl") - include("compiler.jl") - include("varnamedvector.jl") - include("varinfo.jl") - include("simple_varinfo.jl") - include("model.jl") - include("sampler.jl") - include("independence.jl") - include("distribution_wrappers.jl") - include("logdensityfunction.jl") - include("linking.jl") - include("serialization.jl") - include("pointwise_logdensities.jl") - include("lkj.jl") - include("contexts.jl") - include("context_implementations.jl") - include("threadsafe.jl") - include("debug_utils.jl") - include("deprecated.jl") + # if AQUA + # include("Aqua.jl") + # end + # include("utils.jl") + # include("compiler.jl") + # include("varnamedvector.jl") + # include("varinfo.jl") + # include("simple_varinfo.jl") + # include("model.jl") + # include("sampler.jl") + # include("independence.jl") + # include("distribution_wrappers.jl") + # include("logdensityfunction.jl") + # include("linking.jl") + # include("serialization.jl") + # include("pointwise_logdensities.jl") + # include("lkj.jl") + # include("contexts.jl") + # include("context_implementations.jl") + # include("threadsafe.jl") + # include("debug_utils.jl") + # include("deprecated.jl") end if GROUP == "All" || GROUP == "Group2" - @testset "compat" begin - include(joinpath("compat", "ad.jl")) - end - @testset "extensions" begin - include("ext/DynamicPPLMCMCChainsExt.jl") - include("ext/DynamicPPLJETExt.jl") - end + # @testset "compat" begin + # include(joinpath("compat", "ad.jl")) + # end + # @testset "extensions" begin + # include("ext/DynamicPPLMCMCChainsExt.jl") + # include("ext/DynamicPPLJETExt.jl") + # end @testset "ad" begin - include("ext/DynamicPPLForwardDiffExt.jl") - include("ext/DynamicPPLMooncakeExt.jl") + # include("ext/DynamicPPLForwardDiffExt.jl") + # include("ext/DynamicPPLMooncakeExt.jl") include("ad.jl") end - @testset "prob and logprob macro" begin - @test_throws ErrorException prob"..." - @test_throws ErrorException logprob"..." - end - @testset "doctests" begin - DocMeta.setdocmeta!( - DynamicPPL, - :DocTestSetup, - :(using DynamicPPL, Distributions); - recursive=true, - ) - doctestfilters = [ - # Ignore the source of a warning in the doctest output, since this is dependent on host. - # This is a line that starts with "└ @ " and ends with the line number. - r"└ @ .+:[0-9]+", - ] - doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters) - end + # @testset "prob and logprob macro" begin + # @test_throws ErrorException prob"..." + # @test_throws ErrorException logprob"..." + # end + # @testset "doctests" begin + # DocMeta.setdocmeta!( + # DynamicPPL, + # :DocTestSetup, + # :(using DynamicPPL, Distributions); + # recursive=true, + # ) + # doctestfilters = [ + # # Ignore the source of a warning in the doctest output, since this is dependent on host. + # # This is a line that starts with "└ @ " and ends with the line number. + # r"└ @ .+:[0-9]+", + # ] + # doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters) + # end end end From bb6c91177a326d472afd98c57e894406bfa5ce5a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 19 Mar 2025 23:20:05 +0000 Subject: [PATCH 2/3] import enzyme --- test/Project.toml | 1 + test/runtests.jl | 1 + 2 files changed, 2 insertions(+) diff --git a/test/Project.toml b/test/Project.toml index 9fa3fd872..4c8866e31 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -12,6 +12,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" diff --git a/test/runtests.jl b/test/runtests.jl index 0676c037b..d88aaa20e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,7 @@ using DifferentiationInterface using Distributions using DistributionsAD using Documenter +using Enzyme: Enzyme using ForwardDiff using LogDensityProblems using MacroTools From d6bf74a7846be978e13625907d2e19db86e713f3 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 19 Mar 2025 23:41:07 +0000 Subject: [PATCH 3/3] test on all varinfos --- test/ad.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index 6e6db594f..6091be834 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -21,10 +21,9 @@ using EnzymeCore: set_runtime_activity, Forward, Reverse @testset "Correctness on supported AD backends" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - # rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) - # vns = DynamicPPL.TestUtils.varnames(m) - # varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) - varinfos = [VarInfo(m)] + rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) + vns = DynamicPPL.TestUtils.varnames(m) + varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos f = LogDensityFunction(m, varinfo)