Skip to content

Commit 888fa4d

Browse files
authored
Fix gradient length mismatch in PhasePoint (#426)
1 parent 2c7eab6 commit 888fa4d

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

src/hamiltonian.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ struct PhasePoint{T<:AbstractVecOrMat{<:AbstractFloat},V<:DualValue}
5959
θ::T # Position variables / model parameters.
6060
r::T # Momentum variables
6161
ℓπ::V # Cached neg potential energy for the current θ.
62-
ℓκ::V # Cached neg kinect energy for the current r.
62+
ℓκ::V # Cached neg kinetic energy for the current r.
6363
function PhasePoint::T, r::T, ℓπ::V, ℓκ::V) where {T,V}
64-
@argcheck length(θ) == length(r) == length(ℓπ.gradient) == length(ℓπ.gradient)
64+
@argcheck length(θ) == length(r) == length(ℓπ.gradient) == length(ℓκ.gradient)
6565
if !isfinite(ℓπ)
6666
ℓπ = DualValue(
6767
map(v -> isfinite(v) ? v : oftype(v, -Inf), ℓπ.value), ℓπ.gradient

test/hamiltonian.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,14 @@ end
3939

4040
@test z1.ℓπ.value == z2.ℓπ.value
4141
@test z1.ℓκ.value == z2.ℓκ.value
42+
43+
# Test gradient length mismatch of neg potential and kinetic energy in PhasePoint
44+
@test_throws ArgumentError PhasePoint(
45+
[T(Inf)],
46+
[T(Inf)],
47+
DualValue(zero(T), [zero(T)]),
48+
DualValue(zero(T), zeros(T, 2)),
49+
)
4250
end
4351
end
4452

0 commit comments

Comments
 (0)