diff --git a/research/src/riemannian_hmc_utility.jl b/research/src/riemannian_hmc_utility.jl index 4e82b3d5..8ceab303 100644 --- a/research/src/riemannian_hmc_utility.jl +++ b/research/src/riemannian_hmc_utility.jl @@ -1,8 +1,8 @@ -using Random, LinearAlgebra, ReverseDiff, ForwardDiff, VecTargets +using Random, LinearAlgebra, ReverseDiff, ForwardDiff, MCMCLogDensityProblems # Fisher information metric function gen_∂G∂θ_rev(Vfunc, x; f=identity) - _Hfunc = VecTargets.gen_hess(Vfunc, ReverseDiff.track.(x)) + _Hfunc = MCMCLogDensityProblems.gen_hess(Vfunc, ReverseDiff.track.(x)) Hfunc = x -> _Hfunc(x)[3] # QUES What's the best output format of this function? return x -> ReverseDiff.jacobian(x -> f(Hfunc(x)), x) # default output shape [∂H∂x₁; ∂H∂x₂; ...] @@ -37,7 +37,7 @@ end function prepare_sample_target(hps, θ₀, ℓπ) Vfunc = x -> -ℓπ(x) # potential energy is the negative log-probability - _Hfunc = VecTargets.gen_hess(Vfunc, θ₀) # x -> (value, gradient, hessian) + _Hfunc = MCMCLogDensityProblems.gen_hess(Vfunc, θ₀) # x -> (value, gradient, hessian) Hfunc = x -> copy.(_Hfunc(x)) # _Hfunc do in-place computation, copy to avoid bug fstabilize = H -> H + hps.λ * I @@ -70,8 +70,8 @@ function prepare_sample(hps; rng=MersenneTwister(1110)) θ₀ = rand(rng, dim(target)) - ℓπ = VecTargets.gen_logpdf(target) - ∂ℓπ∂θ = VecTargets.gen_logpdf_grad(target, θ₀) + ℓπ = MCMCLogDensityProblems.gen_logpdf(target) + ∂ℓπ∂θ = MCMCLogDensityProblems.gen_logpdf_grad(target, θ₀) _, _, Gfunc, ∂G∂θfunc = prepare_sample_target(hps, θ₀, ℓπ) diff --git a/research/tests/riemannian_hmc.jl b/research/tests/riemannian_hmc.jl index b0db74b9..67b1cad0 100644 --- a/research/tests/riemannian_hmc.jl +++ b/research/tests/riemannian_hmc.jl @@ -19,8 +19,8 @@ using AdvancedHMC: neg_energy, energy θ₀ = rand(rng, dim(target)) - ℓπ = VecTargets.gen_logpdf(target) - ∂ℓπ∂θ = VecTargets.gen_logpdf_grad(target, θ₀) + ℓπ = MCMCLogDensityProblems.gen_logpdf(target) + ∂ℓπ∂θ = MCMCLogDensityProblems.gen_logpdf_grad(target, θ₀) Vfunc, Hfunc, Gfunc, ∂G∂θfunc = prepare_sample_target(hps, θ₀, ℓπ) diff --git a/research/tests/runtests.jl b/research/tests/runtests.jl index d90c3599..da95548d 100644 --- a/research/tests/runtests.jl +++ b/research/tests/runtests.jl @@ -1,7 +1,7 @@ using Comonicon, ReTest using Pkg; -Pkg.add(; url="https://github.com/xukai92/VecTargets.jl.git"); +Pkg.add(; url="https://github.com/chalk-lab/MCMCLogDensityProblems.jl.git"); # include the source code for experimental HMC include("../src/relativistic_hmc.jl") diff --git a/test/CUDA/cuda.jl b/test/CUDA/cuda.jl index 4e376d47..4cce17ce 100644 --- a/test/CUDA/cuda.jl +++ b/test/CUDA/cuda.jl @@ -2,12 +2,13 @@ using Pkg Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..")) -include(joinpath(@__DIR__, "..", "common.jl")) - using Test using AdvancedHMC using AdvancedHMC: DualValue, PhasePoint using CUDA +using LogDensityProblems + +include(joinpath(@__DIR__, "..", "common.jl")) @testset "AdvancedHMC GPU" begin n_chains = 1000