Skip to content

Commit 567b996

Browse files
authored
Reduce allocation in mass matrix adaptor (#427)
* Reduce allocation in mass matrix adaptor * More simplification * Update docstrings * Use simpler resize * Use better resize * Use resize_adaptor! for resizing
1 parent 888fa4d commit 567b996

File tree

2 files changed

+31
-19
lines changed

2 files changed

+31
-19
lines changed

src/adaptation/massmatrix.jl

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function adapt!(
1313
α::AbstractScalarOrVec{<:AbstractFloat},
1414
is_update::Bool=true,
1515
)
16-
resize!(adaptor, θ)
16+
resize_adaptor!(adaptor, size(θ))
1717
push!(adaptor, θ)
1818
is_update && update!(adaptor)
1919
return nothing
@@ -29,7 +29,7 @@ UnitMassMatrix() = UnitMassMatrix{Float64}()
2929

3030
Base.string(::UnitMassMatrix) = "I"
3131

32-
Base.resize!(pc::UnitMassMatrix, θ::AbstractVecOrMat) = nothing
32+
resize_adaptor!(pc::UnitMassMatrix, size_θ::Tuple) = nothing
3333

3434
reset!(::UnitMassMatrix) = nothing
3535

@@ -102,20 +102,31 @@ function WelfordVar(sz::Union{Tuple{Int},Tuple{Int,Int}}; kwargs...)
102102
return WelfordVar{Float64}(sz; kwargs...)
103103
end
104104

105-
function Base.resize!(wv::WelfordVar, θ::AbstractVecOrMat{T}) where {T<:AbstractFloat}
106-
if size(θ) != size(wv.var)
105+
function resize_adaptor!(wv::WelfordVar{T}, size_θ::Tuple{Int,Int}) where {T<:AbstractFloat}
106+
if size_θ != size(wv.var)
107107
@assert wv.n == 0 "Cannot resize a var estimator when it contains samples."
108-
wv.μ = zeros(T, size(θ))
109-
wv.M = zeros(T, size(θ))
110-
wv.δ = zeros(T, size(θ))
111-
wv.var = ones(T, size(θ))
108+
wv.μ = zeros(T, size_θ)
109+
wv.M = zeros(T, size_θ)
110+
wv.δ = zeros(T, size_θ)
111+
wv.var = ones(T, size_θ)
112+
end
113+
end
114+
115+
function resize_adaptor!(wv::WelfordVar{T}, size_θ::Tuple{Int}) where {T<:AbstractFloat}
116+
length_θ = first(size_θ)
117+
if length_θ != size(wv.var, 1)
118+
@assert wv.n == 0 "Cannot resize a var estimator when it contains samples."
119+
fill!(resize!(wv.μ, length_θ), T(0))
120+
fill!(resize!(wv.M, length_θ), T(0))
121+
fill!(resize!(wv.δ, length_θ), T(0))
122+
fill!(resize!(wv.var, length_θ), T(1))
112123
end
113124
end
114125

115126
function reset!(wv::WelfordVar{T}) where {T<:AbstractFloat}
116127
wv.n = 0
117-
wv.μ .= zero(T)
118-
wv.M .= zero(T)
128+
fill!(wv.μ, zero(T))
129+
fill!(wv.M, zero(T))
119130
return nothing
120131
end
121132

@@ -191,20 +202,21 @@ end
191202

192203
WelfordCov(sz::Tuple{Int}; kwargs...) = WelfordCov{Float64}(sz; kwargs...)
193204

194-
function Base.resize!(wc::WelfordCov, θ::AbstractVector{T}) where {T<:AbstractFloat}
195-
if length(θ) != size(wc.cov, 1)
205+
function resize_adaptor!(wc::WelfordCov{T}, size_θ::Tuple{Int}) where {T<:AbstractFloat}
206+
length_θ = first(size_θ)
207+
if length_θ != size(wc.cov, 1)
196208
@assert wc.n == 0 "Cannot resize a var estimator when it contains samples."
197-
wc.μ = zeros(T, length))
198-
wc.δ = zeros(T, length))
199-
wc.M = zeros(T, length(θ), length(θ))
200-
wc.cov = LinearAlgebra.diagm(0 => ones(T, length(θ)))
209+
fill!(resize!(wc.μ, length_θ), T(0))
210+
fill!(resize!(wc.δ, length_θ), T(0))
211+
wc.M = zeros(T, length_θ, length_θ)
212+
wc.cov = LinearAlgebra.diagm(0 => ones(T, length_θ))
201213
end
202214
end
203215

204216
function reset!(wc::WelfordCov{T}) where {T<:AbstractFloat}
205217
wc.n = 0
206-
wc.μ .= zero(T)
207-
wc.M .= zero(T)
218+
fill!(wc.μ, zero(T))
219+
fill!(wc.M, zero(T))
208220
return nothing
209221
end
210222

src/adaptation/stan_adaptor.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ function adapt!(
117117

118118
adapt!(tp.ssa, θ, α)
119119

120-
resize!(tp.pc, θ) # Resize pre-conditioner if necessary.
120+
resize_adaptor!(tp.pc, size(θ)) # Resize pre-conditioner if necessary.
121121

122122
# Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp
123123
if is_in_window(tp)

0 commit comments

Comments
 (0)