Skip to content

Commit e7b4f4b

Browse files
committed
More test coverage
1 parent 907080a commit e7b4f4b

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

test/Nonlinear/ReverseAD.jl

+45
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
module TestReverseAD
22

33
using Test
4+
import LinearAlgebra
45
import MathOptInterface
6+
import SparseArrays
57

68
const MOI = MathOptInterface
79
const Nonlinear = MOI.Nonlinear
@@ -798,6 +800,49 @@ function test_gradient_nested_subexpressions()
798800
return
799801
end
800802

803+
function _dense_hessian(hessian_sparsity, V, n)
804+
I = [i for (i, _) in hessian_sparsity]
805+
J = [j for (_, j) in hessian_sparsity]
806+
raw = SparseArrays.sparse(I, J, V, n, n)
807+
return Matrix(
808+
raw + raw' -
809+
SparseArrays.sparse(LinearAlgebra.diagm(0 => LinearAlgebra.diag(raw))),
810+
)
811+
end
812+
813+
# This covers the code that computes Hessians in odd chunks of Hess-vec
814+
# products.
815+
function test_odd_chunks_Hessian_products()
816+
for i in 1:18
817+
_test_odd_chunks_Hessian_products(i)
818+
end
819+
return
820+
end
821+
822+
function _test_odd_chunks_Hessian_products(N)
823+
data = Nonlinear.NonlinearData()
824+
x = MOI.VariableIndex.(1:N)
825+
Nonlinear.set_objective(data, Expr(:call, :*, x...))
826+
Nonlinear.set_differentiation_backend(
827+
data,
828+
Nonlinear.SparseReverseMode(),
829+
x,
830+
)
831+
MOI.initialize(data, [:Hess])
832+
hessian_sparsity = MOI.hessian_lagrangian_structure(data)
833+
V = zeros(length(hessian_sparsity))
834+
values = ones(N)
835+
MOI.eval_hessian_lagrangian(data, V, values, 1.0, Float64[])
836+
H = _dense_hessian(hessian_sparsity, V, N)
837+
@test H (ones(N, N) - LinearAlgebra.diagm(0 => ones(N)))
838+
values[1] = 0.5
839+
MOI.eval_hessian_lagrangian(data, V, values, 1.0, Float64[])
840+
H = _dense_hessian(hessian_sparsity, V, N)
841+
H_22 = (ones(N - 1, N - 1) - LinearAlgebra.diagm(0 => ones(N - 1))) / 2
842+
@test H [0 ones(N - 1)'; ones(N - 1) H_22]
843+
return
844+
end
845+
801846
end # module
802847

803848
TestReverseAD.runtests()

0 commit comments

Comments
 (0)