Skip to content

Commit 7067a90

Browse files
ErikQQYyebai
andauthored
Add ComponentArrays extension (#407)
* Add ComponentArrays extension * Change to more general arrays * Fix stale compat * More general similar * Add proper errors and correct test cases * Fix errors in mass matrix adaption * Remove PhasePoint fix * Correct use ComponentMatrix * Dont use rand matrix in dense metric * Apply reviews * Relax type strict * Ensure axes matching of general arrays * Only focus on axes * Clean up * Clean extension * More informative error message --------- Co-authored-by: Hong Ge <[email protected]>
1 parent 567b996 commit 7067a90

File tree

7 files changed

+145
-36
lines changed

7 files changed

+145
-36
lines changed

Project.toml

+3
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1818

1919
[weakdeps]
2020
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
21+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
2122
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2223
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
2324
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
2425

2526
[extensions]
2627
AdvancedHMCADTypesExt = "ADTypes"
28+
AdvancedHMCComponentArraysExt = "ComponentArrays"
2729
AdvancedHMCCUDAExt = "CUDA"
2830
AdvancedHMCMCMCChainsExt = "MCMCChains"
2931
AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq"
@@ -32,6 +34,7 @@ AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq"
3234
ADTypes = "1"
3335
AbstractMCMC = "5.6"
3436
ArgCheck = "1, 2"
37+
ComponentArrays = "0.15"
3538
CUDA = "3, 4, 5"
3639
DocStringExtensions = "0.8, 0.9"
3740
LinearAlgebra = "<0.1, 1"

ext/AdvancedHMCComponentArraysExt.jl

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module AdvancedHMCComponentArraysExt
2+
3+
using AdvancedHMC: AdvancedHMC, __axes
4+
using ComponentArrays: ComponentVecOrMat, getaxes
5+
6+
AdvancedHMC.__axes(r::ComponentVecOrMat) = getaxes(r)
7+
8+
end # module

src/adaptation/massmatrix.jl

+7-6
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,16 @@ function get_estimation(nv::NaiveVar)
7878
end
7979

8080
# Ref: https://github.com/stan-dev/math/blob/develop/stan/math/prim/mat/fun/welford_var_estimator.hpp
81-
mutable struct WelfordVar{T<:AbstractFloat,E<:AbstractVecOrMat{T}} <: DiagMatrixEstimator{T}
81+
mutable struct WelfordVar{T<:AbstractFloat,E<:AbstractVecOrMat{T},V<:AbstractVecOrMat{T}} <:
82+
DiagMatrixEstimator{T}
8283
n::Int
8384
n_min::Int
8485
μ::E
8586
M::E
8687
δ::E # cache for diff
87-
var::E # cache for variance
88-
function WelfordVar(n::Int, n_min::Int, μ::E, M::E, δ::E, var::E) where {E}
89-
return new{eltype(E),E}(n, n_min, μ, M, δ, var)
88+
var::V # cache for variance
89+
function WelfordVar(n::Int, n_min::Int, μ::E, M::E, δ::E, var::V) where {E,V}
90+
return new{eltype(E),E,V}(n, n_min, μ, M, δ, var)
9091
end
9192
end
9293

@@ -182,13 +183,13 @@ function get_estimation(nc::NaiveCov)
182183
end
183184

184185
# Ref: https://github.com/stan-dev/math/blob/develop/stan/math/prim/mat/fun/welford_covar_estimator.hpp
185-
mutable struct WelfordCov{F<:AbstractFloat} <: DenseMatrixEstimator{F}
186+
mutable struct WelfordCov{F<:AbstractFloat,C<:AbstractMatrix{F}} <: DenseMatrixEstimator{F}
186187
n::Int
187188
n_min::Int
188189
μ::Vector{F}
189190
M::Matrix{F}
190191
δ::Vector{F} # cache for diff
191-
cov::Matrix{F}
192+
cov::C
192193
end
193194

194195
Base.show(io::IO, ::WelfordCov) = print(io, "WelfordCov")

src/hamiltonian.jl

+14-2
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,22 @@ end
4343

4444
∂H∂r(h::Hamiltonian{<:UnitEuclideanMetric,<:GaussianKinetic}, r::AbstractVecOrMat) = copy(r)
4545
function ∂H∂r(h::Hamiltonian{<:DiagEuclideanMetric,<:GaussianKinetic}, r::AbstractVecOrMat)
46-
return h.metric.M⁻¹ .* r
46+
(; M⁻¹) = h.metric
47+
axes_M⁻¹ = __axes(M⁻¹)
48+
axes_r = __axes(r)
49+
(first(axes_M⁻¹) !== first(axes_r)) && throw(
50+
ArgumentError("AxesMismatch: M⁻¹ has axes $(axes_M⁻¹) but r has axes $(axes_r)")
51+
)
52+
return M⁻¹ .* r
4753
end
4854
function ∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric,<:GaussianKinetic}, r::AbstractVecOrMat)
49-
return h.metric.M⁻¹ * r
55+
(; M⁻¹) = h.metric
56+
axes_M⁻¹ = __axes(M⁻¹)
57+
axes_r = __axes(r)
58+
(last(axes_M⁻¹) !== first(axes_r)) && throw(
59+
ArgumentError("AxesMismatch: M⁻¹ has axes $(axes_M⁻¹) but r has axes $(axes_r)")
60+
)
61+
return M⁻¹ * r
5062
end
5163

5264
# TODO (kai) make the order of θ and r consistent with neg_energy

src/utilities.jl

+7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ function _randn(
2222
return out
2323
end
2424

25+
"""
26+
__axes(r::AbstractVecOrMat)
27+
28+
Return the axes of input `r`, where `r` can be `AbstractArrays`, `ComponentArrays` or other custom arrays.
29+
"""
30+
@inline __axes(r::AbstractVecOrMat) = axes(r)
31+
2532
"""
2633
`rand_coupled` produces coupled randomness given a vector of RNGs. For example,
2734
when a vector of RNGs is provided, `rand_coupled` peforms a single `rand` call

test/demo.jl

+73-28
Original file line numberDiff line numberDiff line change
@@ -71,39 +71,84 @@ end
7171

7272
ℓπ = DemoProblemComponentArrays()
7373

74-
# Define a Hamiltonian system
75-
D = length(p1) # number of parameters
76-
metric = DiagEuclideanMetric(D)
74+
@testset "Test Diagonal ComponentArray metric" begin
7775

78-
# choose AD framework or provide a function manually
79-
hamiltonian = Hamiltonian(metric, ℓπ, Val(:ForwardDiff); x=p1)
76+
# Define a Hamiltonian system
77+
M⁻¹ = ComponentArray(; μ=1.0, σ=1.0)
78+
metric = DiagEuclideanMetric(M⁻¹)
8079

81-
# Define a leapfrog solver, with initial step size chosen heuristically
82-
initial_ϵ = find_good_stepsize(hamiltonian, p1)
83-
integrator = Leapfrog(initial_ϵ)
80+
# choose AD framework or provide a function manually
81+
hamiltonian = Hamiltonian(metric, ℓπ, Val(:ForwardDiff))
8482

85-
# Define an HMC sampler, with the following components
86-
proposal = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()))
87-
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator))
83+
# Define a leapfrog solver, with initial step size chosen heuristically
84+
initial_ϵ = find_good_stepsize(hamiltonian, p1)
85+
integrator = Leapfrog(initial_ϵ)
8886

89-
# -- run sampler
90-
n_samples, n_adapts = 100, 50
91-
samples, stats = sample(
92-
hamiltonian,
93-
proposal,
94-
p1,
95-
n_samples,
96-
adaptor,
97-
n_adapts;
98-
progress=false,
99-
verbose=false,
100-
)
87+
# Define an HMC sampler, with the following components
88+
proposal = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()))
89+
adaptor = StanHMCAdaptor(
90+
MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator)
91+
)
10192

102-
@test length(samples) == n_samples
103-
@test length(stats) == n_samples
104-
labels = ComponentArrays.labels(samples[1])
105-
@test "μ" labels
106-
@test "σ" labels
93+
# -- run sampler
94+
n_samples, n_adapts = 100, 50
95+
samples, stats = sample(
96+
hamiltonian,
97+
proposal,
98+
p1,
99+
n_samples,
100+
adaptor,
101+
n_adapts;
102+
progress=false,
103+
verbose=false,
104+
)
105+
106+
@test length(samples) == n_samples
107+
@test length(stats) == n_samples
108+
lab = ComponentArrays.labels(samples[1])
109+
@test "μ" lab
110+
@test "σ" lab
111+
end
112+
113+
@testset "Test Dense ComponentArray metric" begin
114+
115+
# Define a Hamiltonian system
116+
ax = getaxes(p1)[1]
117+
M⁻¹ = ComponentArray([2.0 1.0; 1.0 2.0], ax, ax)
118+
metric = DenseEuclideanMetric(M⁻¹)
119+
120+
# choose AD framework or provide a function manually
121+
hamiltonian = Hamiltonian(metric, ℓπ, Val(:ForwardDiff))
122+
123+
# Define a leapfrog solver, with initial step size chosen heuristically
124+
initial_ϵ = find_good_stepsize(hamiltonian, p1)
125+
integrator = Leapfrog(initial_ϵ)
126+
127+
# Define an HMC sampler, with the following components
128+
proposal = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()))
129+
adaptor = StanHMCAdaptor(
130+
MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator)
131+
)
132+
133+
# -- run sampler
134+
n_samples, n_adapts = 100, 50
135+
samples, stats = sample(
136+
hamiltonian,
137+
proposal,
138+
p1,
139+
n_samples,
140+
adaptor,
141+
n_adapts;
142+
progress=false,
143+
verbose=false,
144+
)
145+
146+
@test length(samples) == n_samples
147+
@test length(stats) == n_samples
148+
lab = ComponentArrays.labels(samples[1])
149+
@test "μ" lab
150+
@test "σ" lab
151+
end
107152
end
108153

109154
@testset "ADTypes" begin

test/hamiltonian.jl

+33
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using ReTest, AdvancedHMC
22
using AdvancedHMC: GaussianKinetic, DualValue, PhasePoint
33
using LinearAlgebra: dot, diagm
4+
using ComponentArrays
45

56
@testset "Hamiltonian" begin
67
f = x -> dot(x, x)
@@ -76,3 +77,35 @@ end
7677
end
7778
end
7879
end
80+
81+
@testset "Energy with ComponentArrays" begin
82+
n_tests = 10
83+
for T in [Float32, Float64]
84+
for _ in 1:n_tests
85+
θ_init = ComponentArray(; a=randn(T, D), b=randn(T, D))
86+
r_init = ComponentArray(; a=randn(T, D), b=randn(T, D))
87+
88+
h = Hamiltonian(UnitEuclideanMetric(T, 2 * D), ℓπ, ∂ℓπ∂θ)
89+
@test -AdvancedHMC.neg_energy(h, r_init, θ_init) == sum(abs2, r_init) / 2
90+
@test AdvancedHMC.∂H∂r(h, r_init) == r_init
91+
@test typeof(AdvancedHMC.∂H∂r(h, r_init)) == typeof(r_init)
92+
93+
M⁻¹ = ComponentArray(;
94+
a=ones(T, D) + abs.(randn(T, D)), b=ones(T, D) + abs.(randn(T, D))
95+
)
96+
h = Hamiltonian(DiagEuclideanMetric(M⁻¹), ℓπ, ∂ℓπ∂θ)
97+
@test -AdvancedHMC.neg_energy(h, r_init, θ_init)
98+
r_init' * diagm(0 => M⁻¹) * r_init / 2
99+
@test AdvancedHMC.∂H∂r(h, r_init) == M⁻¹ .* r_init
100+
@test typeof(AdvancedHMC.∂H∂r(h, r_init)) == typeof(r_init)
101+
102+
m = randn(T, 2 * D, 2 * D)
103+
ax = getaxes(r_init)[1]
104+
M⁻¹ = ComponentArray(m' * m, ax, ax)
105+
h = Hamiltonian(DenseEuclideanMetric(M⁻¹), ℓπ, ∂ℓπ∂θ)
106+
@test -AdvancedHMC.neg_energy(h, r_init, θ_init) r_init' * M⁻¹ * r_init / 2
107+
@test all(AdvancedHMC.∂H∂r(h, r_init) .== M⁻¹ * r_init)
108+
@test typeof(AdvancedHMC.∂H∂r(h, r_init)) == typeof(r_init)
109+
end
110+
end
111+
end

0 commit comments

Comments
 (0)