Skip to content

Commit d158696

Browse files
committed
Add support for in-place functions; add tests
1 parent c6eba21 commit d158696

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

src/descent/halley.jl

+15-2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ supports_line_search(::HalleyDescent) = true
2929
δu
3030
δus
3131
b
32+
fu
3233
lincache
3334
timer
3435
end
@@ -41,13 +42,14 @@ function __internal_init(
4142
reltol = nothing, timer = get_timer_output(), kwargs...) where {INV, N}
4243
@bb δu = similar(u)
4344
@bb b = similar(u)
45+
@bb fu = similar(fu)
4446
δus = N 1 ? nothing : map(2:N) do i
4547
@bb δu_ = similar(u)
4648
end
4749
INV && return HalleyDescentCache{true}(prob.f, prob.p, δu, δus, b, nothing, timer)
4850
lincache = LinearSolverCache(
4951
alg, alg.linsolve, J, _vec(fu), _vec(u); abstol, reltol, linsolve_kwargs...)
50-
return HalleyDescentCache{false}(prob.f, prob.p, δu, δus, b, lincache, timer)
52+
return HalleyDescentCache{false}(prob.f, prob.p, δu, δus, b, fu, lincache, timer)
5153
end
5254

5355
function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val = Val(1);
@@ -67,7 +69,7 @@ function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val =
6769
end
6870
b = cache.b
6971
# compute the hessian-vector-vector product
70-
hvvp = derivative(Base.Fix2(cache.f, cache.p), u, δu, 2)
72+
hvvp = evaluate_hvvp(cache, cache.f, cache.p, u, δu)
7173
# second linear solve, reuse factorization if possible
7274
if INV
7375
@bb b = J × vec(hvvp)
@@ -83,3 +85,14 @@ function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val =
8385
cache.b = b
8486
return δu, true, (;)
8587
end
88+
89+
function evaluate_hvvp(
90+
cache::HalleyDescentCache, f::NonlinearFunction{iip}, p, u, δu) where {iip}
91+
if iip
92+
binary_f = (y, x) -> f(y, x, p)
93+
derivative(binary_f, cache.fu, u, δu, Val{3}())
94+
else
95+
unary_f = Base.Fix2(f, p)
96+
derivative(unary_f, u, δu, Val{3}())
97+
end
98+
end

test/misc/halley_tests.jl

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
@testitem "Halley method" begin
2+
f(u, p) = u .* u .- p
3+
f!(fu, u, p) = fu .= u .* u .- p
4+
u0 = [1.0, 1.0]
5+
p = 2.0
6+
7+
# out-of-place
8+
prob1 = NonlinearProblem(f, u0, p)
9+
sol1 = solve(prob1, Halley())
10+
@test sol1.u [sqrt(2.0), sqrt(2.0)]
11+
12+
# in-place
13+
prob2 = NonlinearProblem(f!, u0, p)
14+
sol2 = solve(prob2, Halley())
15+
@test sol2.u [sqrt(2.0), sqrt(2.0)]
16+
end

0 commit comments

Comments
 (0)