Skip to content

Commit 4e84b48

Browse files
authored
Fix CUDA tests (#433)
* Fix CUDA tests * Fix renaming errors
1 parent a8c0ee7 commit 4e84b48

File tree

4 files changed

+11
-10
lines changed

4 files changed

+11
-10
lines changed

research/src/riemannian_hmc_utility.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
using Random, LinearAlgebra, ReverseDiff, ForwardDiff, VecTargets
1+
using Random, LinearAlgebra, ReverseDiff, ForwardDiff, MCMCLogDensityProblems
22

33
# Fisher information metric
44
function gen_∂G∂θ_rev(Vfunc, x; f=identity)
5-
_Hfunc = VecTargets.gen_hess(Vfunc, ReverseDiff.track.(x))
5+
_Hfunc = MCMCLogDensityProblems.gen_hess(Vfunc, ReverseDiff.track.(x))
66
Hfunc = x -> _Hfunc(x)[3]
77
# QUES What's the best output format of this function?
88
return x -> ReverseDiff.jacobian(x -> f(Hfunc(x)), x) # default output shape [∂H∂x₁; ∂H∂x₂; ...]
@@ -37,7 +37,7 @@ end
3737

3838
function prepare_sample_target(hps, θ₀, ℓπ)
3939
Vfunc = x -> -ℓπ(x) # potential energy is the negative log-probability
40-
_Hfunc = VecTargets.gen_hess(Vfunc, θ₀) # x -> (value, gradient, hessian)
40+
_Hfunc = MCMCLogDensityProblems.gen_hess(Vfunc, θ₀) # x -> (value, gradient, hessian)
4141
Hfunc = x -> copy.(_Hfunc(x)) # _Hfunc do in-place computation, copy to avoid bug
4242

4343
fstabilize = H -> H + hps.λ * I
@@ -70,8 +70,8 @@ function prepare_sample(hps; rng=MersenneTwister(1110))
7070

7171
θ₀ = rand(rng, dim(target))
7272

73-
ℓπ = VecTargets.gen_logpdf(target)
74-
∂ℓπ∂θ = VecTargets.gen_logpdf_grad(target, θ₀)
73+
ℓπ = MCMCLogDensityProblems.gen_logpdf(target)
74+
∂ℓπ∂θ = MCMCLogDensityProblems.gen_logpdf_grad(target, θ₀)
7575

7676
_, _, Gfunc, ∂G∂θfunc = prepare_sample_target(hps, θ₀, ℓπ)
7777

research/tests/riemannian_hmc.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ using AdvancedHMC: neg_energy, energy
1919

2020
θ₀ = rand(rng, dim(target))
2121

22-
ℓπ = VecTargets.gen_logpdf(target)
23-
∂ℓπ∂θ = VecTargets.gen_logpdf_grad(target, θ₀)
22+
ℓπ = MCMCLogDensityProblems.gen_logpdf(target)
23+
∂ℓπ∂θ = MCMCLogDensityProblems.gen_logpdf_grad(target, θ₀)
2424

2525
Vfunc, Hfunc, Gfunc, ∂G∂θfunc = prepare_sample_target(hps, θ₀, ℓπ)
2626

research/tests/runtests.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Comonicon, ReTest
22

33
using Pkg;
4-
Pkg.add(; url="https://github.com/xukai92/VecTargets.jl.git");
4+
Pkg.add(; url="https://github.com/chalk-lab/MCMCLogDensityProblems.jl.git");
55

66
# include the source code for experimental HMC
77
include("../src/relativistic_hmc.jl")

test/CUDA/cuda.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@ using Pkg
22
Pkg.activate(@__DIR__)
33
Pkg.develop(; path=joinpath(@__DIR__, "..", ".."))
44

5-
include(joinpath(@__DIR__, "..", "common.jl"))
6-
75
using Test
86
using AdvancedHMC
97
using AdvancedHMC: DualValue, PhasePoint
108
using CUDA
9+
using LogDensityProblems
10+
11+
include(joinpath(@__DIR__, "..", "common.jl"))
1112

1213
@testset "AdvancedHMC GPU" begin
1314
n_chains = 1000

0 commit comments

Comments
 (0)