1
1
"""
2
- HalleyDescent(; linsolve = nothing, precs = DEFAULT_PRECS )
2
+ HalleyDescent(; linsolve = nothing)
3
3
4
4
Improve the NewtonDescent with higher-order terms. First compute the descent direction as ``J a = -fu``.
5
5
Then compute the hessian-vector-vector product and solve for the second-order correction term as ``J b = H a a``.
6
6
Finally, compute the descent direction as ``δu = a * a / (b / 2 - a)``.
7
7
8
+ Note that `import TaylorDiff` is required to use this descent algorithm.
9
+
8
10
See also [`NewtonDescent`](@ref).
9
11
"""
10
- @kwdef @concrete struct HalleyDescent <: AbstractDescentAlgorithm
12
+ @kwdef @concrete struct HalleyDescent <: AbstractDescentDirection
11
13
linsolve = nothing
12
- precs = DEFAULT_PRECS
13
- end
14
-
15
- using TaylorDiff: derivative
16
-
17
- function Base. show (io:: IO , d:: HalleyDescent )
18
- modifiers = String[]
19
- d. linsolve != = nothing && push! (modifiers, " linsolve = $(d. linsolve) " )
20
- d. precs != = DEFAULT_PRECS && push! (modifiers, " precs = $(d. precs) " )
21
- print (io, " HalleyDescent($(join (modifiers, " , " )) )" )
22
14
end
23
15
24
16
supports_line_search (:: HalleyDescent ) = true
25
17
26
- @concrete mutable struct HalleyDescentCache{pre_inverted} <: AbstractDescentCache
18
+ @concrete mutable struct HalleyDescentCache <: AbstractDescentCache
27
19
f
28
20
p
29
21
δu
30
22
δus
31
23
b
32
24
fu
25
+ hvvp
33
26
lincache
34
27
timer
28
+ preinverted_jacobian <: Union{Val{false}, Val{true}}
35
29
end
36
30
37
31
@internal_caches HalleyDescentCache :lincache
38
32
39
- function __internal_init (prob:: NonlinearProblem , alg:: HalleyDescent , J, fu, u; stats,
40
- shared:: Val{N} = Val (1 ), pre_inverted:: Val{INV} = False,
33
+ function InternalAPI. init (
34
+ prob:: NonlinearProblem , alg:: HalleyDescent , J, fu, u; stats,
35
+ shared = Val (1 ), pre_inverted:: Val = Val (false ),
41
36
linsolve_kwargs = (;), abstol = nothing , reltol = nothing ,
42
- timer = get_timer_output (), kwargs... ) where {INV, N}
37
+ timer = get_timer_output (), kwargs... )
43
38
@bb δu = similar (u)
44
39
@bb b = similar (u)
45
40
@bb fu = similar (fu)
46
- δus = N ≤ 1 ? nothing : map (2 : N) do i
41
+ @bb hvvp = similar (fu)
42
+ δus = Utils. unwrap_val (shared) ≤ 1 ? nothing : map (2 : Utils. unwrap_val (shared)) do i
47
43
@bb δu_ = similar (u)
48
44
end
49
- INV && return HalleyDescentCache {true} (prob. f, prob. p, δu, δus, b, nothing , timer)
50
- lincache = LinearSolverCache (
51
- alg, alg. linsolve, J, _vec (fu), _vec (u); stats, abstol, reltol, linsolve_kwargs... )
52
- return HalleyDescentCache {false} (prob. f, prob. p, δu, δus, b, fu, lincache, timer)
45
+ lincache = Utils. unwrap_val (pre_inverted) ? nothing :
46
+ construct_linear_solver (
47
+ alg, alg. linsolve, J, Utils. safe_vec (fu), Utils. safe_vec (u);
48
+ stats, abstol, reltol, linsolve_kwargs...
49
+ )
50
+ return HalleyDescentCache (
51
+ prob. f, prob. p, δu, δus, b, fu, hvvp, lincache, timer, pre_inverted)
53
52
end
54
53
55
- function __internal_solve! (cache:: HalleyDescentCache{INV} , J, fu, u, idx:: Val = Val (1 );
56
- skip_solve:: Bool = false , new_jacobian:: Bool = true , kwargs... ) where {INV}
57
- δu = get_du (cache, idx)
54
+ function InternalAPI. solve! (
55
+ cache:: HalleyDescentCache , J, fu, u, idx:: Val = Val (1 );
56
+ skip_solve:: Bool = false , new_jacobian:: Bool = true , kwargs... )
57
+ δu = SciMLBase. get_du (cache, idx)
58
58
skip_solve && return DescentResult (; δu)
59
- if INV
59
+ if preinverted_jacobian (cache)
60
60
@assert J!= = nothing " `J` must be provided when `pre_inverted = Val(true)`."
61
61
@bb δu = J × vec (fu)
62
62
else
63
63
@static_timeit cache. timer " linear solve 1" begin
64
64
linres = cache. lincache (;
65
- A = J, b = _vec (fu), kwargs... , linu = _vec (δu), du = _vec (δu),
65
+ A = J, b = Utils. safe_vec (fu),
66
+ kwargs... , linu = Utils. safe_vec (δu),
66
67
reuse_A_if_factorization = ! new_jacobian || (idx != = Val (1 )))
67
- δu = _restructure ( get_du (cache, idx), linres. u)
68
+ δu = Utils . restructure (SciMLBase . get_du (cache, idx), linres. u)
68
69
if ! linres. success
69
70
set_du! (cache, δu, idx)
70
71
return DescentResult (; δu, success = false , linsolve_success = false )
@@ -73,15 +74,17 @@ function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val =
73
74
end
74
75
b = cache. b
75
76
# compute the hessian-vector-vector product
76
- hvvp = evaluate_hvvp (cache, cache. f, cache. p, u, δu)
77
+ hvvp = evaluate_hvvp (cache. hvvp, cache , cache. f, cache. p, u, δu)
77
78
# second linear solve, reuse factorization if possible
78
- if INV
79
+ if preinverted_jacobian (cache)
79
80
@bb b = J × vec (hvvp)
80
81
else
81
82
@static_timeit cache. timer " linear solve 2" begin
82
- linres = cache. lincache (; A = J, b = _vec (hvvp), kwargs... , linu = _vec (b),
83
- du = _vec (b), reuse_A_if_factorization = true )
84
- b = _restructure (cache. b, linres. u)
83
+ linres = cache. lincache (;
84
+ A = J, b = Utils. safe_vec (hvvp),
85
+ kwargs... , linu = Utils. safe_vec (b),
86
+ reuse_A_if_factorization = true )
87
+ b = Utils. restructure (cache. b, linres. u)
85
88
if ! linres. success
86
89
set_du! (cache, δu, idx)
87
90
return DescentResult (; δu, success = false , linsolve_success = false )
@@ -94,13 +97,4 @@ function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val =
94
97
return DescentResult (; δu)
95
98
end
96
99
97
- function evaluate_hvvp (
98
- cache:: HalleyDescentCache , f:: NonlinearFunction{iip} , p, u, δu) where {iip}
99
- if iip
100
- binary_f = @closure (y, x) -> f (y, x, p)
101
- derivative (binary_f, cache. fu, u, δu, Val {3} ())
102
- else
103
- unary_f = Base. Fix2 (f, p)
104
- derivative (unary_f, u, δu, Val {3} ())
105
- end
106
- end
100
+ evaluate_hvvp (hvvp, cache, f, p, u, δu) = error (" not implemented. please import TaylorDiff" )
0 commit comments