Skip to content

Commit 5802209

Browse files
authored
[ReverseAD] Add support for user-defined hessians (#1819)
1 parent e102d84 commit 5802209

File tree

7 files changed

+236
-24
lines changed

7 files changed

+236
-24
lines changed

docs/src/submodules/Nonlinear/reference.md

+1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ Nonlinear.eval_univariate_gradient
6060
Nonlinear.eval_univariate_hessian
6161
Nonlinear.eval_multivariate_function
6262
Nonlinear.eval_multivariate_gradient
63+
Nonlinear.eval_multivariate_hessian
6364
Nonlinear.eval_logic_function
6465
Nonlinear.eval_comparison_function
6566
```

src/Nonlinear/ReverseAD/forward_over_reverse.jl

+35-3
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ function _hessian_slice_inner(d, ex, input_ϵ, output_ϵ, ::Type{T}) where {T}
126126
for i in ex.dependent_subexpressions
127127
subexpr = d.subexpressions[i]
128128
subexpr_forward_values_ϵ[i] = _forward_eval_ϵ(
129+
d,
129130
subexpr,
130131
_reinterpret_unsafe(T, subexpr.forward_storage_ϵ),
131132
_reinterpret_unsafe(T, subexpr.partials_storage_ϵ),
@@ -135,6 +136,7 @@ function _hessian_slice_inner(d, ex, input_ϵ, output_ϵ, ::Type{T}) where {T}
135136
)
136137
end
137138
_forward_eval_ϵ(
139+
d,
138140
ex,
139141
_reinterpret_unsafe(T, d.forward_storage_ϵ),
140142
_reinterpret_unsafe(T, d.partials_storage_ϵ),
@@ -178,6 +180,7 @@ end
178180

179181
"""
180182
_forward_eval_ϵ(
183+
d::NLPEvaluator,
181184
ex::Union{_FunctionStorage,_SubexpressionStorage},
182185
storage_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
183186
partials_storage_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
@@ -195,6 +198,7 @@ components separate so that we don't need to recompute the real components.
195198
This assumes that `_reverse_model(d, x)` has already been called.
196199
"""
197200
function _forward_eval_ϵ(
201+
d::NLPEvaluator,
198202
ex::Union{_FunctionStorage,_SubexpressionStorage},
199203
storage_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
200204
partials_storage_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
@@ -328,10 +332,38 @@ function _forward_eval_ϵ(
328332
recip_denominator,
329333
)
330334
elseif op > 6
331-
error(
332-
"User-defined operators not supported for hessian " *
333-
"computations",
335+
f_input = _UnsafeVectorView(d.jac_storage, n_children)
336+
for (i, c) in enumerate(children_idx)
337+
f_input[i] = ex.forward_storage[children_arr[c]]
338+
end
339+
H = _UnsafeLowerTriangularMatrixView(
340+
d.user_output_buffer,
341+
n_children,
342+
)
343+
has_hessian = Nonlinear.eval_multivariate_hessian(
344+
user_operators,
345+
user_operators.multivariate_operators[node.index],
346+
H,
347+
f_input,
334348
)
349+
# This might be `false` if we extend this code to all
350+
# multivariate functions.
351+
@assert has_hessian
352+
for col in 1:n_children
353+
dual = zero(ForwardDiff.Partials{N,T})
354+
for row in 1:n_children
355+
# Make sure we get the lower-triangular component.
356+
h = row >= col ? H[row, col] : H[col, row]
357+
# Performance optimization: hessians can be quite
358+
# sparse
359+
if !iszero(h)
360+
i = children_arr[children_idx[row]]
361+
dual += h * storage_ϵ[i]
362+
end
363+
end
364+
i = children_arr[children_idx[col]]
365+
partials_storage_ϵ[i] = dual
366+
end
335367
end
336368
elseif node.type == Nonlinear.NODE_CALL_UNIVARIATE
337369
@inbounds child_idx = children_arr[ex.adj.colptr[k]]

src/Nonlinear/ReverseAD/mathoptinterface_api.jl

+9-4
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
# in the LICENSE.md file or at https://opensource.org/licenses/MIT.
66

77
function MOI.features_available(d::NLPEvaluator)
8-
# Check if we have any user-defined multivariate operators, in which case we
9-
# need to disable hessians. The result of features_available depends on this.
10-
d.disable_2ndorder =
11-
length(d.data.operators.registered_multivariate_operators) > 0
8+
# Check if we are missing any hessians for user-defined multivariate
9+
# operators, in which case we need to disable :Hess and :HessVec.
10+
d.disable_2ndorder = any(
11+
op -> op.∇²f === nothing,
12+
d.data.operators.registered_multivariate_operators,
13+
)
1214
if d.disable_2ndorder
1315
return [:Grad, :Jac, :JacVec]
1416
end
@@ -286,6 +288,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
286288
for i in d.subexpression_order
287289
subexpr = d.subexpressions[i]
288290
subexpr_forward_values_ϵ[i] = _forward_eval_ϵ(
291+
d,
289292
subexpr,
290293
reinterpret(T, subexpr.forward_storage_ϵ),
291294
reinterpret(T, subexpr.partials_storage_ϵ),
@@ -302,6 +305,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
302305
fill!(output_ϵ, zero(T))
303306
if d.objective !== nothing
304307
_forward_eval_ϵ(
308+
d,
305309
d.objective,
306310
reinterpret(T, d.forward_storage_ϵ),
307311
reinterpret(T, d.partials_storage_ϵ),
@@ -322,6 +326,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
322326
end
323327
for (i, con) in enumerate(d.constraints)
324328
_forward_eval_ϵ(
329+
d,
325330
con,
326331
reinterpret(T, d.forward_storage_ϵ),
327332
reinterpret(T, d.partials_storage_ϵ),

src/Nonlinear/ReverseAD/utils.jl

+92
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,105 @@ Base.length(v::_UnsafeVectorView) = v.len
5353

5454
Base.size(v::_UnsafeVectorView) = (v.len,)
5555

56+
"""
57+
_UnsafeVectorView(x::Vector, N::Int)
58+
59+
Create a new [`_UnsafeVectorView`](@ref) from `x`, and resize `x` if needed to
60+
ensure it has a length of at least `N`.
61+
62+
## Unsafe behavior
63+
64+
In addition to the usafe behavior of `_UnsafeVectorView`, this constructor is
65+
additionally unsafe because it may resize `x`. Only call it if you are sure that
66+
the usage of `_UnsafeVectorView(x, N)` is short-lived, and that there are no
67+
other views to `x` while the returned value is within scope.
68+
"""
5669
function _UnsafeVectorView(x::Vector, N::Int)
5770
if length(x) < N
5871
resize!(x, N)
5972
end
6073
return _UnsafeVectorView(0, N, pointer(x))
6174
end
6275

76+
"""
77+
_UnsafeLowerTriangularMatrixView(x, N)
78+
79+
Lightweight unsafe view that converts a vector `x` into the lower-triangular
80+
component of a symmetric `N`-by-`N` matrix.
81+
82+
## Motivation
83+
84+
`_UnsafeLowerTriangularMatrixView` is needed as an allocation-free equivalent of
85+
`view`. Other alternatives, like `reshape(view(x, 1:N^2), N, N)` or a struct
86+
like
87+
```julia
88+
struct _SafeView{T}
89+
x::Vector{T}
90+
len::Int
91+
end
92+
```
93+
will allocate so that `x` can be tracked by Julia's GC.
94+
`_UnsafeLowerTriangularMatrixView` relies on the fact that the use-cases of
95+
`_UnsafeLowerTriangularMatrixView` only temporarily wrap a long-lived vector
96+
like `d.jac_storage` so that we don't have to worry about the GC removing
97+
`d.jac_storage` while `_UnsafeLowerTriangularMatrixView` exists. This lets us
98+
use a `Ptr{T}` and create a struct that is `isbitstype` and therefore does not
99+
allocate.
100+
101+
## Unsafe behavior
102+
103+
`_UnsafeLowerTriangularMatrixView` is unsafe because it assumes that the vector
104+
`x` remains valid during the usage of `_UnsafeLowerTriangularMatrixView`.
105+
"""
106+
struct _UnsafeLowerTriangularMatrixView <: AbstractMatrix{Float64}
107+
N::Int
108+
ptr::Ptr{Float64}
109+
end
110+
111+
Base.size(x::_UnsafeLowerTriangularMatrixView) = (x.N, x.N)
112+
113+
function _linear_index(row, col)
114+
if row < col
115+
error("Unable to access upper-triangular component: ($row, $col)")
116+
end
117+
return div((row - 1) * row, 2) + col
118+
end
119+
120+
function Base.getindex(x::_UnsafeLowerTriangularMatrixView, i, j)
121+
return unsafe_load(x.ptr, _linear_index(i, j))
122+
end
123+
124+
function Base.setindex!(x::_UnsafeLowerTriangularMatrixView, value, i, j)
125+
unsafe_store!(x.ptr, value, _linear_index(i, j))
126+
return value
127+
end
128+
129+
"""
130+
_UnsafeLowerTriangularMatrixView(x::Vector{Float64}, N::Int)
131+
132+
Create a new [`_UnsafeLowerTriangularMatrixView`](@ref) from `x`, zero the
133+
elements in `x`, and resize `x` if needed to ensure it has a length of at least
134+
`N * (N + 1) / 2`.
135+
136+
## Unsafe behavior
137+
138+
In addition to the usafe behavior of `_UnsafeLowerTriangularMatrixView`, this
139+
constructor is additionally unsafe because it may resize `x`. Only call it if
140+
you are sure that the usage of `_UnsafeLowerTriangularMatrixView(x, N)` is
141+
short-lived, and that there are no other views to `x` while the returned value
142+
is within scope.
143+
"""
144+
function _UnsafeLowerTriangularMatrixView(x::Vector{Float64}, N::Int)
145+
z = div(N * (N + 1), 2)
146+
if length(x) < z
147+
resize!(x, z)
148+
end
149+
for i in 1:z
150+
x[i] = 0.0
151+
end
152+
return _UnsafeLowerTriangularMatrixView(N, pointer(x))
153+
end
154+
63155
function _reinterpret_unsafe(::Type{T}, x::Vector{R}) where {T,R}
64156
# how many T's fit into x?
65157
@assert isbitstype(T) && isbitstype(R)

src/Nonlinear/model.jl

+8-1
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,15 @@ Register the user-defined operator `op` with `nargs` input arguments in `model`.
203203
* `∇f(g::AbstractVector{T}, x::T...)::T` is a function that takes a cache vector
204204
`g` of length `length(x)`, and fills each element `g[i]` with the partial
205205
derivative of `f` with respect to `x[i]`.
206+
* `∇²f(H::AbstractMatrix, x::T...)::T` is a function that takes a matrix `H` and
207+
fills the lower-triangular components `H[i, j]` with the Hessian of `f` with
208+
respect to `x[i]` and `x[j]` for `i >= j`.
206209
207-
Hessian are not supported for multivariate functions.
210+
### Notes for multivariate Hessians
211+
212+
* `H` has `size(H) == (length(x), length(x))`, but you must not access
213+
elements `H[i, j]` for `i > j`.
214+
* `H` is dense, but you do not need to fill structural zeros.
208215
"""
209216
function register_operator(model::Model, op::Symbol, nargs::Int, f::Function...)
210217
return register_operator(model.operators, op, nargs, f...)

src/Nonlinear/operators.jl

+18-1
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,24 @@ end
631631

632632
_nan_to_zero(x) = isnan(x) ? 0.0 : x
633633

634-
# No docstring because this function is still a WIP.
634+
"""
635+
eval_multivariate_hessian(
636+
registry::OperatorRegistry,
637+
op::Symbol,
638+
H::AbstractMatrix,
639+
x::AbstractVector{T},
640+
) where {T}
641+
642+
Evaluate the Hessian of operator `∇²op(x)`, where `op` is a multivariate
643+
function in `registry`.
644+
645+
The Hessian is stored in the lower-triangular part of the matrix `H`.
646+
647+
!!! note
648+
Implementations of the Hessian operators will not fill structural zeros.
649+
Therefore, before calling this function you should pre-populate the matrix
650+
`H` with `0`.
651+
"""
635652
function eval_multivariate_hessian(
636653
registry::OperatorRegistry,
637654
op::Symbol,

test/Nonlinear/ReverseAD.jl

+73-15
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,34 @@ function test_objective_quadratic_univariate()
5353
return
5454
end
5555

56+
function test_objective_and_constraints_quadratic_univariate()
57+
x = MOI.VariableIndex(1)
58+
model = Nonlinear.Model()
59+
Nonlinear.set_objective(model, :($x^2 + 1))
60+
Nonlinear.add_constraint(model, :($x^2), MOI.LessThan(2.0))
61+
evaluator = Nonlinear.Evaluator(model, Nonlinear.SparseReverseMode(), [x])
62+
MOI.initialize(evaluator, [:Grad, :Jac, :Hess])
63+
@test MOI.eval_objective(evaluator, [1.2]) == 1.2^2 + 1
64+
g = [NaN]
65+
MOI.eval_objective_gradient(evaluator, g, [1.2])
66+
@test g == [2.4]
67+
@test MOI.hessian_lagrangian_structure(evaluator) == [(1, 1), (1, 1)]
68+
H = [NaN, NaN]
69+
MOI.eval_hessian_lagrangian(evaluator, H, [1.2], 1.5, Float64[1.3])
70+
@test H == [1.5, 1.3] .* [2.0, 2.0]
71+
Hp = [NaN]
72+
MOI.eval_hessian_lagrangian_product(
73+
evaluator,
74+
Hp,
75+
[1.2],
76+
[1.2],
77+
1.5,
78+
Float64[1.3],
79+
)
80+
@test Hp == [1.5 * 2.0 * 1.2 + 1.3 * 2.0 * 1.2]
81+
return
82+
end
83+
5684
function test_objective_quadratic_multivariate()
5785
x = MOI.VariableIndex(1)
5886
y = MOI.VariableIndex(2)
@@ -287,14 +315,13 @@ function test_hessian_sparsity_registered_function()
287315
Nonlinear.set_objective(model, :(f($x, $z) + $y^2))
288316
evaluator =
289317
Nonlinear.Evaluator(model, Nonlinear.SparseReverseMode(), [x, y, z])
290-
@test_broken :Hess in MOI.features_available(evaluator)
291-
# TODO(odow): re-enable these tests when user-defined hessians are supported
292-
# MOI.initialize(evaluator, [:Grad, :Jac, :Hess])
293-
# @test MOI.hessian_lagrangian_structure(evaluator) ==
294-
# [(1, 1), (2, 2), (3, 3), (3, 1)]
295-
# H = fill(NaN, 4)
296-
# MOI.eval_hessian_lagrangian(evaluator, H, rand(3), 1.5, Float64[])
297-
# @test H == 1.5 .* [2.0, 2.0, 2.0, 0.0]
318+
@test :Hess in MOI.features_available(evaluator)
319+
MOI.initialize(evaluator, [:Grad, :Jac, :Hess])
320+
@test MOI.hessian_lagrangian_structure(evaluator) ==
321+
[(1, 1), (2, 2), (3, 3), (3, 1)]
322+
H = fill(NaN, 4)
323+
MOI.eval_hessian_lagrangian(evaluator, H, rand(3), 1.5, Float64[])
324+
@test H == 1.5 .* [2.0, 2.0, 2.0, 0.0]
298325
return
299326
end
300327

@@ -308,6 +335,7 @@ function test_hessian_sparsity_registered_rosenbrock()
308335
return
309336
end
310337
function ∇²f(H, x...)
338+
@assert size(H) == (2, 2)
311339
H[1, 1] = 1200 * x[1]^2 - 400 * x[2] + 2
312340
H[2, 1] = -400 * x[1]
313341
H[2, 2] = 200.0
@@ -318,13 +346,43 @@ function test_hessian_sparsity_registered_rosenbrock()
318346
Nonlinear.set_objective(model, :(rosenbrock($x, $y)))
319347
evaluator =
320348
Nonlinear.Evaluator(model, Nonlinear.SparseReverseMode(), [x, y])
321-
@test_broken :Hess in MOI.features_available(evaluator)
322-
# TODO(odow): re-enable these tests when user-defined hessians are supported
323-
# MOI.initialize(evaluator, [:Grad, :Jac, :Hess])
324-
# @test MOI.hessian_lagrangian_structure(evaluator) == [(1, 1), (2, 2), (2, 1)]
325-
# H = fill(NaN, 3)
326-
# MOI.eval_hessian_lagrangian(evaluator, H, [1.0, 1.0], 1.5, Float64[])
327-
# @test H == 1.5 .* [802, 200, -400]
349+
@test :Hess in MOI.features_available(evaluator)
350+
MOI.initialize(evaluator, [:Grad, :Jac, :Hess])
351+
@test MOI.hessian_lagrangian_structure(evaluator) ==
352+
[(1, 1), (2, 2), (2, 1)]
353+
H = fill(NaN, 3)
354+
MOI.eval_hessian_lagrangian(evaluator, H, [1.0, 1.0], 1.5, Float64[])
355+
@test H == 1.5 .* [802.0, 200.0, -400.0]
356+
return
357+
end
358+
359+
function test_hessian_registered_error()
360+
x = MOI.VariableIndex(1)
361+
y = MOI.VariableIndex(2)
362+
f(x...) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
363+
function ∇f(g, x...)
364+
g[1] = 400 * x[1]^3 - 400 * x[1] * x[2] + 2 * x[1] - 2
365+
g[2] = 200 * (x[2] - x[1]^2)
366+
return
367+
end
368+
function ∇²f(H, x...)
369+
H[1, 1] = 1200 * x[1]^2 - 400 * x[2] + 2
370+
# Wrong index! Should be [2, 1]
371+
H[1, 2] = -400 * x[1]
372+
H[2, 2] = 200.0
373+
return
374+
end
375+
model = Nonlinear.Model()
376+
Nonlinear.register_operator(model, :rosenbrock, 2, f, ∇f, ∇²f)
377+
Nonlinear.set_objective(model, :(rosenbrock($x, $y)))
378+
evaluator =
379+
Nonlinear.Evaluator(model, Nonlinear.SparseReverseMode(), [x, y])
380+
MOI.initialize(evaluator, [:Grad, :Jac, :Hess])
381+
H = fill(NaN, 3)
382+
@test_throws(
383+
ErrorException("Unable to access upper-triangular component: (1, 2)"),
384+
MOI.eval_hessian_lagrangian(evaluator, H, [1.0, 1.0], 1.5, Float64[]),
385+
)
328386
return
329387
end
330388

0 commit comments

Comments
 (0)