Skip to content

Commit 73f1f04

Browse files
committed
[ReverseAD] Add support for user-defined hessians
1 parent 1ab5078 commit 73f1f04

File tree

4 files changed

+114
-20
lines changed

4 files changed

+114
-20
lines changed

src/Nonlinear/ReverseAD/forward_over_reverse.jl

+32-3
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ function _hessian_slice_inner(d, ex, input_ϵ, output_ϵ, ::Type{T}) where {T}
126126
for i in ex.dependent_subexpressions
127127
subexpr = d.subexpressions[i]
128128
subexpr_forward_values_ϵ[i] = _forward_eval_ϵ(
129+
d,
129130
subexpr,
130131
_reinterpret_unsafe(T, subexpr.forward_storage_ϵ),
131132
_reinterpret_unsafe(T, subexpr.partials_storage_ϵ),
@@ -135,6 +136,7 @@ function _hessian_slice_inner(d, ex, input_ϵ, output_ϵ, ::Type{T}) where {T}
135136
)
136137
end
137138
_forward_eval_ϵ(
139+
d,
138140
ex,
139141
_reinterpret_unsafe(T, d.forward_storage_ϵ),
140142
_reinterpret_unsafe(T, d.partials_storage_ϵ),
@@ -178,6 +180,7 @@ end
178180

179181
"""
180182
_forward_eval_ϵ(
183+
d::NLPEvaluator,
181184
ex::Union{_FunctionStorage,_SubexpressionStorage},
182185
storage_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
183186
partials_storage_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
@@ -195,6 +198,7 @@ components separate so that we don't need to recompute the real components.
195198
This assumes that `_reverse_model(d, x)` has already been called.
196199
"""
197200
function _forward_eval_ϵ(
201+
d::NLPEvaluator,
198202
ex::Union{_FunctionStorage,_SubexpressionStorage},
199203
storage_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
200204
partials_storage_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
@@ -328,10 +332,35 @@ function _forward_eval_ϵ(
328332
recip_denominator,
329333
)
330334
elseif op > 6
331-
error(
332-
"User-defined operators not supported for hessian " *
333-
"computations",
335+
f_input = _UnsafeVectorView(d.jac_storage, n_children)
336+
for (i, c) in enumerate(children_idx)
337+
f_input[i] = ex.forward_storage[children_arr[c]]
338+
end
339+
H = _UnsafeHessianView(d.user_output_buffer, n_children)
340+
has_hessian = Nonlinear.eval_multivariate_hessian(
341+
user_operators,
342+
user_operators.multivariate_operators[node.index],
343+
H,
344+
f_input,
334345
)
346+
if !has_hessian
347+
continue
348+
end
349+
for col in 1:n_children
350+
dual = zero(ForwardDiff.Partials{N,T})
351+
for row in 1:n_children
352+
# Make sure we get the lower-triangular component.
353+
h = row >= col ? H[row, col] : H[col, row]
354+
# Performance optimization: hessians can be quite
355+
# sparse
356+
if !iszero(h)
357+
i = children_arr[children_idx[row]]
358+
dual += h * storage_ϵ[i]
359+
end
360+
end
361+
i = children_arr[children_idx[col]]
362+
partials_storage_ϵ[i] = dual
363+
end
335364
end
336365
elseif node.type == Nonlinear.NODE_CALL_UNIVARIATE
337366
@inbounds child_idx = children_arr[ex.adj.colptr[k]]

src/Nonlinear/ReverseAD/mathoptinterface_api.jl

+7-2
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
function MOI.features_available(d::NLPEvaluator)
88
# Check if we have any user-defined multivariate operators, in which case we
99
# need to disable hessians. The result of features_available depends on this.
10-
d.disable_2ndorder =
11-
length(d.data.operators.registered_multivariate_operators) > 0
10+
d.disable_2ndorder = any(
11+
op -> op.∇²f === nothing,
12+
d.data.operators.registered_multivariate_operators,
13+
)
1214
if d.disable_2ndorder
1315
return [:Grad, :Jac, :JacVec]
1416
end
@@ -286,6 +288,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
286288
for i in d.subexpression_order
287289
subexpr = d.subexpressions[i]
288290
subexpr_forward_values_ϵ[i] = _forward_eval_ϵ(
291+
d,
289292
subexpr,
290293
reinterpret(T, subexpr.forward_storage_ϵ),
291294
reinterpret(T, subexpr.partials_storage_ϵ),
@@ -302,6 +305,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
302305
fill!(output_ϵ, zero(T))
303306
if d.objective !== nothing
304307
_forward_eval_ϵ(
308+
d,
305309
d.objective,
306310
reinterpret(T, d.forward_storage_ϵ),
307311
reinterpret(T, d.partials_storage_ϵ),
@@ -322,6 +326,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
322326
end
323327
for (i, con) in enumerate(d.constraints)
324328
_forward_eval_ϵ(
329+
d,
325330
con,
326331
reinterpret(T, d.forward_storage_ϵ),
327332
reinterpret(T, d.partials_storage_ϵ),

src/Nonlinear/ReverseAD/utils.jl

+62
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,68 @@ function _UnsafeVectorView(x::Vector, N::Int)
6060
return _UnsafeVectorView(0, N, pointer(x))
6161
end
6262

63+
"""
64+
_UnsafeHessianView(x, N)
65+
66+
Lightweight unsafe view that converts a vector `x` into the lower-triangular
67+
component of a symmetric `N`-by-`N` matrix.
68+
69+
## Motivation
70+
71+
`_UnsafeHessianView` is needed as an allocation-free equivalent of `view`. Other
72+
alternatives, like `reshape(view(x, 1:N^2), N, N)` or a struct like
73+
```julia
74+
struct _SafeView{T}
75+
x::Vector{T}
76+
len::Int
77+
end
78+
```
79+
will allocate so that `x` can be tracked by Julia's GC.
80+
`_UnsafeHessianView` relies on the fact that the use-cases of
81+
`_UnsafeHessianView` only temporarily wrap a long-lived vector like
82+
`d.jac_storage` so that we don't have to worry about the GC removing
83+
`d.jac_storage` while `_UnsafeHessianView` exists. This lets us use a `Ptr{T}`
84+
and create a struct that is `isbitstype` and therefore does not allocate.
85+
86+
## Unsafe behavior
87+
88+
`_UnsafeHessianView` is unsafe because it assumes that the vector `x` remains
89+
valid during the usage of `_UnsafeHessianView`.
90+
"""
91+
struct _UnsafeHessianView <: AbstractMatrix{Float64}
92+
N::Int
93+
ptr::Ptr{Float64}
94+
end
95+
96+
Base.size(x::_UnsafeHessianView) = (x.N, x.N)
97+
98+
function _linear_index(row, col)
99+
if row < col
100+
error("Unable to access upper-triangular component: ($row, $col)")
101+
end
102+
return div((row - 1) * row, 2) + col
103+
end
104+
105+
function Base.getindex(x::_UnsafeHessianView, i, j)
106+
return unsafe_load(x.ptr, _linear_index(i, j))
107+
end
108+
109+
function Base.setindex!(x::_UnsafeHessianView, value, i, j)
110+
unsafe_store!(x.ptr, value, _linear_index(i, j))
111+
return value
112+
end
113+
114+
function _UnsafeHessianView(x::Vector, N::Int)
115+
z = div(N * (N + 1), 2)
116+
if length(x) < z
117+
resize!(x, z)
118+
end
119+
for i in 1:z
120+
x[i] = 0.0
121+
end
122+
return _UnsafeHessianView(N, pointer(x))
123+
end
124+
63125
function _reinterpret_unsafe(::Type{T}, x::Vector{R}) where {T,R}
64126
# how many T's fit into x?
65127
@assert isbitstype(T) && isbitstype(R)

test/Nonlinear/ReverseAD.jl

+13-15
Original file line numberDiff line numberDiff line change
@@ -287,14 +287,13 @@ function test_hessian_sparsity_registered_function()
287287
Nonlinear.set_objective(model, :(f($x, $z) + $y^2))
288288
evaluator =
289289
Nonlinear.Evaluator(model, Nonlinear.SparseReverseMode(), [x, y, z])
290-
@test_broken :Hess in MOI.features_available(evaluator)
291-
# TODO(odow): re-enable these tests when user-defined hessians are supported
292-
# MOI.initialize(evaluator, [:Grad, :Jac, :Hess])
293-
# @test MOI.hessian_lagrangian_structure(evaluator) ==
294-
# [(1, 1), (2, 2), (3, 3), (3, 1)]
295-
# H = fill(NaN, 4)
296-
# MOI.eval_hessian_lagrangian(evaluator, H, rand(3), 1.5, Float64[])
297-
# @test H == 1.5 .* [2.0, 2.0, 2.0, 0.0]
290+
@test :Hess in MOI.features_available(evaluator)
291+
MOI.initialize(evaluator, [:Grad, :Jac, :Hess])
292+
@test MOI.hessian_lagrangian_structure(evaluator) ==
293+
[(1, 1), (2, 2), (3, 3), (3, 1)]
294+
H = fill(NaN, 4)
295+
MOI.eval_hessian_lagrangian(evaluator, H, rand(3), 1.5, Float64[])
296+
@test H == 1.5 .* [2.0, 2.0, 2.0, 0.0]
298297
return
299298
end
300299

@@ -318,13 +317,12 @@ function test_hessian_sparsity_registered_rosenbrock()
318317
Nonlinear.set_objective(model, :(rosenbrock($x, $y)))
319318
evaluator =
320319
Nonlinear.Evaluator(model, Nonlinear.SparseReverseMode(), [x, y])
321-
@test_broken :Hess in MOI.features_available(evaluator)
322-
# TODO(odow): re-enable these tests when user-defined hessians are supported
323-
# MOI.initialize(evaluator, [:Grad, :Jac, :Hess])
324-
# @test MOI.hessian_lagrangian_structure(evaluator) == [(1, 1), (2, 2), (2, 1)]
325-
# H = fill(NaN, 3)
326-
# MOI.eval_hessian_lagrangian(evaluator, H, [1.0, 1.0], 1.5, Float64[])
327-
# @test H == 1.5 .* [802, 200, -400]
320+
@test :Hess in MOI.features_available(evaluator)
321+
MOI.initialize(evaluator, [:Grad, :Jac, :Hess])
322+
@test MOI.hessian_lagrangian_structure(evaluator) == [(1, 1), (2, 2), (2, 1)]
323+
H = fill(NaN, 3)
324+
MOI.eval_hessian_lagrangian(evaluator, H, [1.0, 1.0], 1.5, Float64[])
325+
@test H == 1.5 .* [802, 200, -400]
328326
return
329327
end
330328

0 commit comments

Comments
 (0)