Skip to content

Commit 11c53df

Browse files
tansongchenavik-pal
authored andcommitted
Rebase
1 parent c2edf17 commit 11c53df

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

src/descent/halley.jl

+19-11
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ end
3636

3737
@internal_caches HalleyDescentCache :lincache
3838

39-
function __internal_init(
40-
prob::NonlinearProblem, alg::HalleyDescent, J, fu, u; shared::Val{N} = Val(1),
41-
pre_inverted::Val{INV} = False, linsolve_kwargs = (;), abstol = nothing,
42-
reltol = nothing, timer = get_timer_output(), kwargs...) where {INV, N}
39+
function __internal_init(prob::NonlinearProblem, alg::HalleyDescent, J, fu, u; stats,
40+
shared::Val{N} = Val(1), pre_inverted::Val{INV} = False,
41+
linsolve_kwargs = (;), abstol = nothing, reltol = nothing,
42+
timer = get_timer_output(), kwargs...) where {INV, N}
4343
@bb δu = similar(u)
4444
@bb b = similar(u)
4545
@bb fu = similar(fu)
@@ -48,23 +48,27 @@ function __internal_init(
4848
end
4949
INV && return HalleyDescentCache{true}(prob.f, prob.p, δu, δus, b, nothing, timer)
5050
lincache = LinearSolverCache(
51-
alg, alg.linsolve, J, _vec(fu), _vec(u); abstol, reltol, linsolve_kwargs...)
51+
alg, alg.linsolve, J, _vec(fu), _vec(u); stats, abstol, reltol, linsolve_kwargs...)
5252
return HalleyDescentCache{false}(prob.f, prob.p, δu, δus, b, fu, lincache, timer)
5353
end
5454

5555
function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val = Val(1);
5656
skip_solve::Bool = false, new_jacobian::Bool = true, kwargs...) where {INV}
5757
δu = get_du(cache, idx)
58-
skip_solve && return δu, true, (;)
58+
skip_solve && return DescentResult(; δu)
5959
if INV
6060
@assert J!==nothing "`J` must be provided when `pre_inverted = Val(true)`."
6161
@bb δu = J × vec(fu)
6262
else
6363
@static_timeit cache.timer "linear solve 1" begin
64-
δu = cache.lincache(;
64+
linres = cache.lincache(;
6565
A = J, b = _vec(fu), kwargs..., linu = _vec(δu), du = _vec(δu),
6666
reuse_A_if_factorization = !new_jacobian || (idx !== Val(1)))
67-
δu = _restructure(get_du(cache, idx), δu)
67+
δu = _restructure(get_du(cache, idx), linres.u)
68+
if !linres.success
69+
set_du!(cache, δu, idx)
70+
return DescentResult(; δu, success = false, linsolve_success = false)
71+
end
6872
end
6973
end
7074
b = cache.b
@@ -75,15 +79,19 @@ function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val =
7579
@bb b = J × vec(hvvp)
7680
else
7781
@static_timeit cache.timer "linear solve 2" begin
78-
b = cache.lincache(; A = J, b = _vec(hvvp), kwargs..., linu = _vec(b),
82+
linres = cache.lincache(; A = J, b = _vec(hvvp), kwargs..., linu = _vec(b),
7983
du = _vec(b), reuse_A_if_factorization = true)
80-
b = _restructure(cache.b, b)
84+
b = _restructure(cache.b, linres.u)
85+
if !linres.success
86+
set_du!(cache, δu, idx)
87+
return DescentResult(; δu, success = false, linsolve_success = false)
88+
end
8189
end
8290
end
8391
@bb @. δu = δu * δu / (b / 2 - δu)
8492
set_du!(cache, δu, idx)
8593
cache.b = b
86-
return δu, true, (;)
94+
return DescentResult(; δu)
8795
end
8896

8997
function evaluate_hvvp(

test/core/23_test_problems_tests.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ end
5959
test_on_library(problems, dicts, alg_ops, broken_tests)
6060
end
6161

62-
@testitem "Halley" setup=[RobustnessTesting] begin
62+
@testitem "Halley" setup=[RobustnessTesting] tags=[:core] begin
6363
alg_ops = (Halley(),)
6464

6565
broken_tests = Dict(alg => Int[] for alg in alg_ops)

0 commit comments

Comments
 (0)