Skip to content

Commit a8c0ee7

Browse files
authored
Add quality tests (#431)
1 parent e3a56ed commit a8c0ee7

File tree

9 files changed

+31
-18
lines changed

9 files changed

+31
-18
lines changed

src/AdvancedHMC.jl

+6-3
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,12 @@ function Hamiltonian(metric::AbstractMetric, ℓπ, kind::Union{Symbol,Val,Modul
172172
),
173173
)
174174
end
175-
= LogDensityProblemsAD.ADgradient(
176-
kind isa Val ? kind : Val(Symbol(kind)), ℓπ; kwargs...
177-
)
175+
_kind = if kind isa Val || kind isa Symbol
176+
kind
177+
else
178+
Symbol(kind)
179+
end
180+
= LogDensityProblemsAD.ADgradient(_kind, ℓπ; kwargs...)
178181
return Hamiltonian(metric, ℓ)
179182
end
180183

src/adaptation/massmatrix.jl

-2
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,6 @@ end
171171

172172
NaiveCov{T}(sz::Tuple{Int}) where {T<:AbstractFloat} = NaiveCov(Vector{Vector{T}}())
173173

174-
NaiveCov(sz::Union{Tuple{Int},Tuple{Int,Int}}; kwargs...) = NaiveCov{Float64}(sz; kwargs...)
175-
176174
Base.push!(nc::NaiveCov, s::AbstractVector) = push!(nc.S, s)
177175

178176
reset!(nc::NaiveCov{T}) where {T} = resize!(nc.S, 0)

src/integrator.jl

+7-3
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ function step(
218218
res = if FullTraj
219219
Vector{P}(undef, n_steps)
220220
else
221-
z
221+
Vector{P}(undef, 1)
222222
end
223223

224224
(; θ, r) = z
@@ -242,7 +242,7 @@ function step(
242242
if FullTraj
243243
res[i] = z
244244
else
245-
res = z
245+
res[1] = z
246246
end
247247
if !isfinite(z)
248248
# Remove undef
@@ -252,5 +252,9 @@ function step(
252252
break
253253
end
254254
end
255-
return res
255+
return if FullTraj
256+
res
257+
else
258+
first(res)
259+
end
256260
end

src/metric.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ end
8080
function DenseEuclideanMetric(
8181
M⁻¹::Union{AbstractMatrix{T},AbstractArray{T,3}}
8282
) where {T<:AbstractFloat}
83-
_temp = Vector{T}(undef, Base.front(size(M⁻¹)))
83+
_temp = Vector{T}(undef, first(size(M⁻¹)))
8484
return DenseEuclideanMetric(M⁻¹, cholesky(Symmetric(M⁻¹)).U, _temp)
8585
end
8686
DenseEuclideanMetric(::Type{T}, D::Int) where {T} = DenseEuclideanMetric(Matrix{T}(I, D, D))

src/trajectory.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,7 @@ end
444444
TurnStatistic() = TurnStatistic(undef)
445445

446446
TurnStatistic(::ClassicNoUTurn, ::PhasePoint) = TurnStatistic()
447+
TurnStatistic(::ClassicNoUTurn, ::Vector{<:PhasePoint}) = TurnStatistic()
447448
function TurnStatistic(::Union{GeneralisedNoUTurn,StrictGeneralisedNoUTurn}, z::PhasePoint)
448449
return TurnStatistic(z.r)
449450
end
@@ -776,7 +777,7 @@ function find_good_stepsize(
776777
ϵ′ = ratio_too_high ? d * ϵ : invd * ϵ
777778
_, H′ = A(h, z, ϵ)
778779
ΔH = H - H′
779-
@debug "Crossing step" direction H′ ϵ α = min(1, exp(ΔH))
780+
@debug "Crossing step" H′ ϵ α = min(1, exp(ΔH))
780781
# stop if there is no crossing; otherwise, continue to half or double stepsize.
781782
if xor(ratio_too_high, ΔH > log_a_cross)
782783
break

test/Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
88
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
99
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1010
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
11+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
1112
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1213
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1314
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"

test/aqua.jl

-7
This file was deleted.

test/quality.jl

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using AdvancedHMC
2+
using ReTest
3+
using Aqua: Aqua
4+
using JET
5+
using ForwardDiff
6+
7+
@testset "Aqua" begin
8+
Aqua.test_all(AdvancedHMC)
9+
end
10+
11+
@testset "JET" begin
12+
JET.test_package(AdvancedHMC; target_defined_modules=true)
13+
end

test/runtests.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ include("common.jl")
1919
if GROUP == "All" || GROUP == "AdvancedHMC"
2020
using ReTest
2121

22-
include("aqua.jl")
22+
include("quality.jl")
2323
include("metric.jl")
2424
include("hamiltonian.jl")
2525
include("integrator.jl")

0 commit comments

Comments
 (0)