|
1 | 1 | module TestReverseAD
|
2 | 2 |
|
3 | 3 | using Test
|
| 4 | +import LinearAlgebra |
4 | 5 | import MathOptInterface
|
| 6 | +import SparseArrays |
5 | 7 |
|
6 | 8 | const MOI = MathOptInterface
|
7 | 9 | const Nonlinear = MOI.Nonlinear
|
@@ -798,6 +800,49 @@ function test_gradient_nested_subexpressions()
|
798 | 800 | return
|
799 | 801 | end
|
800 | 802 |
|
| 803 | +function _dense_hessian(hessian_sparsity, V, n) |
| 804 | + I = [i for (i, _) in hessian_sparsity] |
| 805 | + J = [j for (_, j) in hessian_sparsity] |
| 806 | + raw = SparseArrays.sparse(I, J, V, n, n) |
| 807 | + return Matrix( |
| 808 | + raw + raw' - |
| 809 | + SparseArrays.sparse(LinearAlgebra.diagm(0 => LinearAlgebra.diag(raw))), |
| 810 | + ) |
| 811 | +end |
| 812 | + |
| 813 | +# This covers the code that computes Hessians in odd chunks of Hess-vec |
| 814 | +# products. |
| 815 | +function test_odd_chunks_Hessian_products() |
| 816 | + for i in 1:18 |
| 817 | + _test_odd_chunks_Hessian_products(i) |
| 818 | + end |
| 819 | + return |
| 820 | +end |
| 821 | + |
| 822 | +function _test_odd_chunks_Hessian_products(N) |
| 823 | + data = Nonlinear.NonlinearData() |
| 824 | + x = MOI.VariableIndex.(1:N) |
| 825 | + Nonlinear.set_objective(data, Expr(:call, :*, x...)) |
| 826 | + Nonlinear.set_differentiation_backend( |
| 827 | + data, |
| 828 | + Nonlinear.SparseReverseMode(), |
| 829 | + x, |
| 830 | + ) |
| 831 | + MOI.initialize(data, [:Hess]) |
| 832 | + hessian_sparsity = MOI.hessian_lagrangian_structure(data) |
| 833 | + V = zeros(length(hessian_sparsity)) |
| 834 | + values = ones(N) |
| 835 | + MOI.eval_hessian_lagrangian(data, V, values, 1.0, Float64[]) |
| 836 | + H = _dense_hessian(hessian_sparsity, V, N) |
| 837 | + @test H ≈ (ones(N, N) - LinearAlgebra.diagm(0 => ones(N))) |
| 838 | + values[1] = 0.5 |
| 839 | + MOI.eval_hessian_lagrangian(data, V, values, 1.0, Float64[]) |
| 840 | + H = _dense_hessian(hessian_sparsity, V, N) |
| 841 | + H_22 = (ones(N - 1, N - 1) - LinearAlgebra.diagm(0 => ones(N - 1))) / 2 |
| 842 | + @test H ≈ [0 ones(N - 1)'; ones(N - 1) H_22] |
| 843 | + return |
| 844 | +end |
| 845 | + |
801 | 846 | end # module
|
802 | 847 |
|
803 | 848 | TestReverseAD.runtests()
|
0 commit comments