@@ -29,6 +29,7 @@ supports_line_search(::HalleyDescent) = true
29
29
δu
30
30
δus
31
31
b
32
+ fu
32
33
lincache
33
34
timer
34
35
end
@@ -41,13 +42,14 @@ function __internal_init(
41
42
reltol = nothing , timer = get_timer_output (), kwargs... ) where {INV, N}
42
43
@bb δu = similar (u)
43
44
@bb b = similar (u)
45
+ @bb fu = similar (fu)
44
46
δus = N ≤ 1 ? nothing : map (2 : N) do i
45
47
@bb δu_ = similar (u)
46
48
end
47
49
INV && return HalleyDescentCache {true} (prob. f, prob. p, δu, δus, b, nothing , timer)
48
50
lincache = LinearSolverCache (
49
51
alg, alg. linsolve, J, _vec (fu), _vec (u); abstol, reltol, linsolve_kwargs... )
50
- return HalleyDescentCache {false} (prob. f, prob. p, δu, δus, b, lincache, timer)
52
+ return HalleyDescentCache {false} (prob. f, prob. p, δu, δus, b, fu, lincache, timer)
51
53
end
52
54
53
55
function __internal_solve! (cache:: HalleyDescentCache{INV} , J, fu, u, idx:: Val = Val (1 );
@@ -67,7 +69,7 @@ function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val =
67
69
end
68
70
b = cache. b
69
71
# compute the hessian-vector-vector product
70
- hvvp = derivative (Base . Fix2 ( cache. f, cache. p) , u, δu, 2 )
72
+ hvvp = evaluate_hvvp ( cache, cache . f, cache. p, u, δu)
71
73
# second linear solve, reuse factorization if possible
72
74
if INV
73
75
@bb b = J × vec (hvvp)
@@ -83,3 +85,14 @@ function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val =
83
85
cache. b = b
84
86
return δu, true , (;)
85
87
end
88
+
89
+ function evaluate_hvvp (
90
+ cache:: HalleyDescentCache , f:: NonlinearFunction{iip} , p, u, δu) where {iip}
91
+ if iip
92
+ binary_f = (y, x) -> f (y, x, p)
93
+ derivative (binary_f, cache. fu, u, δu, Val {3} ())
94
+ else
95
+ unary_f = Base. Fix2 (f, p)
96
+ derivative (unary_f, u, δu, Val {3} ())
97
+ end
98
+ end
0 commit comments