diff --git a/docs/src/api.md b/docs/src/api.md index 54b5939d..0812007e 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -8,6 +8,7 @@ This modularity means that different HMC variants can be easily constructed by c - Unit metric: `UnitEuclideanMetric(dim)` - Diagonal metric: `DiagEuclideanMetric(dim)` - Dense metric: `DenseEuclideanMetric(dim)` + - Rank update metric: `RankUpdateEuclideanMetric(dim)` where `dim` is the dimensionality of the sampling space. diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index b25710d5..090553a8 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -2,7 +2,18 @@ module AdvancedHMC using Statistics: mean, var, middle using LinearAlgebra: - Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, cholesky, UniformScaling + Symmetric, + UpperTriangular, + mul!, + ldiv!, + dot, + I, + diag, + cholesky, + UniformScaling, + Diagonal, + qr, + lmul! using StatsFuns: logaddexp, logsumexp, loghalf using Random: Random, AbstractRNG using ProgressMeter: ProgressMeter @@ -40,7 +51,8 @@ struct GaussianKinetic <: AbstractKinetic end export GaussianKinetic include("metric.jl") -export UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric +export UnitEuclideanMetric, + DiagEuclideanMetric, DenseEuclideanMetric, RankUpdateEuclideanMetric include("hamiltonian.jl") export Hamiltonian diff --git a/src/hamiltonian.jl b/src/hamiltonian.jl index ffcaee89..2ca100fb 100644 --- a/src/hamiltonian.jl +++ b/src/hamiltonian.jl @@ -61,6 +61,18 @@ function ∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric,<:GaussianKinetic}, r::A return M⁻¹ * r end +function ∂H∂r( + h::Hamiltonian{<:RankUpdateEuclideanMetric,<:GaussianKinetic}, r::AbstractVecOrMat +) + (; M⁻¹) = h.metric + axes_M⁻¹ = __axes(M⁻¹) + axes_r = __axes(r) + (first(axes_M⁻¹) !== first(axes_r)) && throw( + ArgumentError("AxesMismatch: M⁻¹ has axes $(axes_M⁻¹) but r has axes $(axes_r)") + ) + return M⁻¹ * r +end + # TODO (kai) make the order of θ and r consistent with neg_energy # TODO (kai) add stricter types to block hamiltonian.jl#L37 from working on unknown metric/kinetic # The gradient of a position-dependent Hamiltonian system depends on both θ and r. @@ -165,6 +177,13 @@ function neg_energy( return -dot(r, h.metric._temp) / 2 end +function neg_energy( + h::Hamiltonian{<:RankUpdateEuclideanMetric,<:GaussianKinetic}, r::T, θ::T +) where {T<:AbstractVecOrMat} + M⁻¹ = h.metric.M⁻¹ + return -r' * M⁻¹ * r / 2 +end + energy(args...) = -neg_energy(args...) #### diff --git a/src/metric.jl b/src/metric.jl index 9b3bf9c5..45312f1e 100644 --- a/src/metric.jl +++ b/src/metric.jl @@ -98,6 +98,84 @@ function Base.show(io::IO, dem::DenseEuclideanMetric) return print(io, "DenseEuclideanMetric(diag=$(_string_M⁻¹(dem.M⁻¹)))") end +""" + RankUpdateEuclideanMetric{T,AM,AB,AD,F} <: AbstractMetric + +A Gaussian Euclidean metric whose inverse is constructed by rank-updates. + +# Fields + +$(TYPEDFIELDS) + +# Constructors + + RankUpdateEuclideanMetric(n::Int) + RankUpdateEuclideanMetric(M⁻¹, B, D) + + - Construct a Gaussian Euclidean metric of size `(n, n)` with `M⁻¹` being diagonal matrix. + - Construct a Gaussian Euclidean metric of `M⁻¹`, where `M⁻¹` should be a full rank positive definite matrix, + and `B` `D` must be chose so that the Woodbury matrix `W = M⁻¹ + B D B^\\mathrm{T}` is positive definite. + +# Example + +```julia +julia> RankUpdateEuclideanMetric(3) +RankUpdateEuclideanMetric(diag=[1.0, 1.0, 1.0]) +``` + +# References + + - Ben Bales, Arya Pourzanjani, Aki Vehtari, Linda Petzold, Selecting the Metric in Hamiltonian Monte Carlo, 2019 +""" +struct RankUpdateEuclideanMetric{T,AM<:AbstractVecOrMat{T},AB,AD,F} <: AbstractMetric + "Diagnal of the inverse of the mass matrix" + M⁻¹::AM + B::AB + D::AD + factorization::F +end + +function woodbury_factorize(A, B, D) + cholA = cholesky(A isa Diagonal ? A : Symmetric(A)) + U = cholA.U + Q, R = qr(U' \ B) + V = cholesky(Symmetric(muladd(R, D * R', I))).U + return (U=U, Q=Q, V=V) +end + +function RankUpdateEuclideanMetric(n::Int) + M⁻¹ = Diagonal(ones(n)) + B = zeros(n, 0) + D = zeros(0, 0) + factorization = woodbury_factorize(M⁻¹, B, D) + return RankUpdateEuclideanMetric(M⁻¹, B, D, factorization) +end +function RankUpdateEuclideanMetric(::Type{T}, n::Int) where {T} + M⁻¹ = Diagonal(ones(T, n)) + B = Matrix{T}(undef, n, 0) + D = Matrix{T}(undef, 0, 0) + factorization = woodbury_factorize(M⁻¹, B, D) + return RankUpdateEuclideanMetric(M⁻¹, B, D, factorization) +end + +function RankUpdateEuclideanMetric(M⁻¹, B, D) + factorization = woodbury_factorize(M⁻¹, B, D) + return RankUpdateEuclideanMetric(M⁻¹, B, D, factorization) +end + +function RankUpdateEuclideanMetric(::Type{T}, sz::Tuple{Int}) where {T} + return RankUpdateEuclideanMetric(T, first(sz)) +end +RankUpdateEuclideanMetric(sz::Tuple{Int}) = RankUpdateEuclideanMetric(Float64, sz) + +renew(::RankUpdateEuclideanMetric, (M⁻¹, B, D)) = RankUpdateEuclideanMetric(M⁻¹, B, D) + +Base.size(metric::RankUpdateEuclideanMetric, dim...) = size(metric.M⁻¹.diag, dim...) + +function Base.show(io::IO, ::MIME"text/plain", metric::RankUpdateEuclideanMetric) + return print(io, "RankUpdateEuclideanMetric(diag=$(diag(metric.M⁻¹)))") +end + # `rand` functions for `metric` types. function rand_momentum( @@ -131,3 +209,19 @@ function rand_momentum( ldiv!(metric.cholM⁻¹, r) return r end + +function rand_momentum( + rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, + metric::RankUpdateEuclideanMetric{T}, + kinetic::GaussianKinetic, + ::AbstractVecOrMat, +) where {T} + M⁻¹ = metric.M⁻¹ + r = _randn(rng, T, size(M⁻¹.diag)...) + F = metric.factorization + k = min(size(F.U, 1), size(F.V, 1)) + @views ldiv!(F.V, r isa AbstractVector ? r[1:k] : r[1:k, :]) + lmul!(F.Q, r) + ldiv!(F.U, r) + return r +end diff --git a/test/metric.jl b/test/metric.jl index c05078c6..a359e33c 100644 --- a/test/metric.jl +++ b/test/metric.jl @@ -10,6 +10,7 @@ using ReTest, Random, AdvancedHMC UnitEuclideanMetric((D, n_chains)), DiagEuclideanMetric((D, n_chains)), # DenseEuclideanMetric((D, n_chains)) # not supported ATM + # RankUpdateEuclideanMetric((D, n_chains)) # not supported ATM ] r = AdvancedHMC.rand_momentum(rng, metric, GaussianKinetic(), θ) all_same = true @@ -25,8 +26,12 @@ using ReTest, Random, AdvancedHMC rng = MersenneTwister(1) θ = randn(rng, D) ℓπ(θ) = 1 - for metric in - [UnitEuclideanMetric(1), DiagEuclideanMetric(1), DenseEuclideanMetric(1)] + for metric in [ + UnitEuclideanMetric(1), + DiagEuclideanMetric(1), + DenseEuclideanMetric(1), + RankUpdateEuclideanMetric(1), + ] h = Hamiltonian(metric, ℓπ, ℓπ) h = AdvancedHMC.resize(h, θ) @test size(h.metric) == size(θ) diff --git a/test/sampler.jl b/test/sampler.jl index 919edd9c..84c487bb 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -62,6 +62,7 @@ end :UnitEuclideanMetric => UnitEuclideanMetric(D), :DiagEuclideanMetric => DiagEuclideanMetric(D), :DenseEuclideanMetric => DenseEuclideanMetric(D), + :RankUpdateEuclideanMetric => RankUpdateEuclideanMetric(D), ) h = Hamiltonian(metric, ℓπ, ∂ℓπ∂θ) @testset "$lfsym" for (lfsym, lf) in Dict( @@ -104,6 +105,11 @@ end @test mean(samples) ≈ zeros(D) atol = RNDATOL end + if metricsym == :RankUpdateEuclideanMetric + # Skip tests with `RankUpdateEuclideanMetric` for `MassMatrixAdaptor` + continue + end + @testset "$adaptorsym" for (adaptorsym, adaptor) in Dict( :MassMatrixAdaptorOnly => MassMatrixAdaptor(metric), :StepSizeAdaptorOnly => StepSizeAdaptor(0.8, τ.integrator),