36
36
37
37
@internal_caches HalleyDescentCache :lincache
38
38
39
- function __internal_init (
40
- prob :: NonlinearProblem , alg :: HalleyDescent , J, fu, u; shared:: Val{N} = Val (1 ),
41
- pre_inverted :: Val{INV} = False, linsolve_kwargs = (;), abstol = nothing ,
42
- reltol = nothing , timer = get_timer_output (), kwargs... ) where {INV, N}
39
+ function __internal_init (prob :: NonlinearProblem , alg :: HalleyDescent , J, fu, u; stats,
40
+ shared:: Val{N} = Val (1 ), pre_inverted :: Val{INV} = False ,
41
+ linsolve_kwargs = (;), abstol = nothing , reltol = nothing ,
42
+ timer = get_timer_output (), kwargs... ) where {INV, N}
43
43
@bb δu = similar (u)
44
44
@bb b = similar (u)
45
45
@bb fu = similar (fu)
@@ -48,23 +48,27 @@ function __internal_init(
48
48
end
49
49
INV && return HalleyDescentCache {true} (prob. f, prob. p, δu, δus, b, nothing , timer)
50
50
lincache = LinearSolverCache (
51
- alg, alg. linsolve, J, _vec (fu), _vec (u); abstol, reltol, linsolve_kwargs... )
51
+ alg, alg. linsolve, J, _vec (fu), _vec (u); stats, abstol, reltol, linsolve_kwargs... )
52
52
return HalleyDescentCache {false} (prob. f, prob. p, δu, δus, b, fu, lincache, timer)
53
53
end
54
54
55
55
function __internal_solve! (cache:: HalleyDescentCache{INV} , J, fu, u, idx:: Val = Val (1 );
56
56
skip_solve:: Bool = false , new_jacobian:: Bool = true , kwargs... ) where {INV}
57
57
δu = get_du (cache, idx)
58
- skip_solve && return δu, true , (; )
58
+ skip_solve && return DescentResult (; δu )
59
59
if INV
60
60
@assert J!= = nothing " `J` must be provided when `pre_inverted = Val(true)`."
61
61
@bb δu = J × vec (fu)
62
62
else
63
63
@static_timeit cache. timer " linear solve 1" begin
64
- δu = cache. lincache (;
64
+ linres = cache. lincache (;
65
65
A = J, b = _vec (fu), kwargs... , linu = _vec (δu), du = _vec (δu),
66
66
reuse_A_if_factorization = ! new_jacobian || (idx != = Val (1 )))
67
- δu = _restructure (get_du (cache, idx), δu)
67
+ δu = _restructure (get_du (cache, idx), linres. u)
68
+ if ! linres. success
69
+ set_du! (cache, δu, idx)
70
+ return DescentResult (; δu, success = false , linsolve_success = false )
71
+ end
68
72
end
69
73
end
70
74
b = cache. b
@@ -75,15 +79,19 @@ function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val =
75
79
@bb b = J × vec (hvvp)
76
80
else
77
81
@static_timeit cache. timer " linear solve 2" begin
78
- b = cache. lincache (; A = J, b = _vec (hvvp), kwargs... , linu = _vec (b),
82
+ linres = cache. lincache (; A = J, b = _vec (hvvp), kwargs... , linu = _vec (b),
79
83
du = _vec (b), reuse_A_if_factorization = true )
80
- b = _restructure (cache. b, b)
84
+ b = _restructure (cache. b, linres. u)
85
+ if ! linres. success
86
+ set_du! (cache, δu, idx)
87
+ return DescentResult (; δu, success = false , linsolve_success = false )
88
+ end
81
89
end
82
90
end
83
91
@bb @. δu = δu * δu / (b / 2 - δu)
84
92
set_du! (cache, δu, idx)
85
93
cache. b = b
86
- return δu, true , (; )
94
+ return DescentResult (; δu )
87
95
end
88
96
89
97
function evaluate_hvvp (
0 commit comments