Skip to content

Commit fc9f01e

Browse files
committed
Only focus on axes
1 parent d68d549 commit fc9f01e

File tree

3 files changed

+12
-34
lines changed

3 files changed

+12
-34
lines changed

ext/AdvancedHMCComponentArraysExt.jl

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,38 +2,9 @@ module AdvancedHMCComponentArraysExt
22

33
using LinearAlgebra
44

5-
using AdvancedHMC:
6-
AdvancedHMC,
7-
Hamiltonian,
8-
UnitEuclideanMetric,
9-
DiagEuclideanMetric,
10-
DenseEuclideanMetric,
11-
GaussianKinetic
12-
using ComponentArrays: ComponentVecOrMat, ComponentVector, ComponentMatrix, getaxes
5+
using AdvancedHMC: AdvancedHMC, __axes
6+
using ComponentArrays: ComponentVecOrMat, getaxes
137

14-
function AdvancedHMC.∂H∂r(
15-
h::Hamiltonian{<:DiagEuclideanMetric,<:GaussianKinetic}, r::ComponentVecOrMat
16-
)
17-
(; M⁻¹) = h.metric
18-
(getaxes(M⁻¹) !== getaxes(r)) &&
19-
throw(ArgumentError("Axes of mass matrix and momentum must match"))
20-
return h.metric.M⁻¹ .* r
21-
end
22-
function AdvancedHMC.∂H∂r(
23-
h::Hamiltonian{<:DenseEuclideanMetric,<:GaussianKinetic}, r::ComponentVector
24-
)
25-
(; M⁻¹) = h.metric
26-
(last(getaxes(M⁻¹)) !== first(getaxes(r))) &&
27-
throw(ArgumentError("Axes of mass matrix and momentum must match"))
28-
return h.metric.M⁻¹ * r
29-
end
30-
function AdvancedHMC.∂H∂r(
31-
h::Hamiltonian{<:DenseEuclideanMetric,<:GaussianKinetic}, r::ComponentMatrix
32-
)
33-
(; M⁻¹) = h.metric
34-
getaxes(M⁻¹) !== getaxes(r) &&
35-
throw(ArgumentError("Axes of mass matrix and momentum must match"))
36-
return h.metric.M⁻¹ * r
37-
end
8+
AdvancedHMC.__axes(r::ComponentVecOrMat) = getaxes(r)
389

3910
end # module

src/hamiltonian.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@ end
4444
∂H∂r(h::Hamiltonian{<:UnitEuclideanMetric,<:GaussianKinetic}, r::AbstractVecOrMat) = copy(r)
4545
function ∂H∂r(h::Hamiltonian{<:DiagEuclideanMetric,<:GaussianKinetic}, r::AbstractVecOrMat)
4646
(; M⁻¹) = h.metric
47-
(first(axes(M⁻¹)) !== first(axes(r))) &&
47+
(first(__axes(M⁻¹)) !== first(__axes(r))) &&
4848
throw(ArgumentError("Axes of mass matrix and momentum must match"))
4949
return h.metric.M⁻¹ .* r
5050
end
5151
function ∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric,<:GaussianKinetic}, r::AbstractVecOrMat)
5252
(; M⁻¹) = h.metric
53-
(last(axes(M⁻¹)) !== first(axes(r))) &&
53+
(last(__axes(M⁻¹)) !== first(__axes(r))) &&
5454
throw(ArgumentError("Axes of mass matrix and momentum must match"))
5555
return h.metric.M⁻¹ * r
5656
end

src/utilities.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ function _randn(
2222
return out
2323
end
2424

25+
"""
26+
__axes(r::AbstractVecOrMat)
27+
28+
Return the axes of input `r` where `r` can be generic arrays or custom arrays.
29+
"""
30+
@inline __axes(r::AbstractVecOrMat) = axes(r)
31+
2532
"""
2633
`rand_coupled` produces coupled randomness given a vector of RNGs. For example,
2734
when a vector of RNGs is provided, `rand_coupled` peforms a single `rand` call

0 commit comments

Comments
 (0)