Skip to content

Commit 00bded1

Browse files
committed
Make it a ext
1 parent 5dbd10d commit 00bded1

File tree

8 files changed

+74
-50
lines changed

8 files changed

+74
-50
lines changed

Project.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2525
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
2626
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2727
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
28+
SciMLJacobianOperators = "19f34311-ddf3-4b8b-af20-060888a46c0e"
2829
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
2930
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3031
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
@@ -113,7 +114,6 @@ StaticArrays = "1.9"
113114
StaticArraysCore = "1.4"
114115
Sundials = "4.23.1"
115116
SymbolicIndexingInterface = "0.3.31"
116-
Symbolics = "6"
117117
TaylorDiff = "0.3"
118118
Test = "1.10"
119119
Zygote = "0.6.69"
@@ -148,8 +148,9 @@ SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
148148
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
149149
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
150150
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
151+
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
151152
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
152153
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
153154

154155
[targets]
155-
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote"]
156+
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "TaylorDiff", "Test", "Zygote"]

lib/NonlinearSolveBase/Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b"
3535
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
3636
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3737
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
38+
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
3839

3940
[extensions]
4041
NonlinearSolveBaseBandedMatricesExt = "BandedMatrices"
@@ -44,6 +45,7 @@ NonlinearSolveBaseLineSearchExt = "LineSearch"
4445
NonlinearSolveBaseLinearSolveExt = "LinearSolve"
4546
NonlinearSolveBaseSparseArraysExt = "SparseArrays"
4647
NonlinearSolveBaseSparseMatrixColoringsExt = "SparseMatrixColorings"
48+
NonlinearSolveBaseTaylorDiffExt = "TaylorDiff"
4749

4850
[compat]
4951
ADTypes = "1.9"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
module NonlinearSolveBaseTaylorDiffExt
2+
using SciMLBase: NonlinearFunction
3+
using NonlinearSolveBase: HalleyDescentCache
4+
import NonlinearSolveBase: evaluate_hvvp
5+
using TaylorDiff: derivative, derivative!
6+
using FastClosures: @closure
7+
8+
function evaluate_hvvp(
9+
hvvp, cache::HalleyDescentCache, f::NonlinearFunction{iip}, p, u, δu) where {iip}
10+
if iip
11+
binary_f = @closure (y, x) -> f(y, x, p)
12+
derivative!(hvvp, binary_f, cache.fu, u, δu, Val(2))
13+
else
14+
unary_f = Base.Fix2(f, p)
15+
hvvp = derivative(unary_f, u, δu, Val(2))
16+
end
17+
hvvp
18+
end
19+
20+
end

lib/NonlinearSolveBase/src/NonlinearSolveBase.jl

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ include("wrappers.jl")
5050

5151
include("descent/common.jl")
5252
include("descent/newton.jl")
53+
include("descent/halley.jl")
5354
include("descent/steepest.jl")
5455
include("descent/damped_newton.jl")
5556
include("descent/dogleg.jl")
+36-42
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,71 @@
11
"""
2-
HalleyDescent(; linsolve = nothing, precs = DEFAULT_PRECS)
2+
HalleyDescent(; linsolve = nothing)
33
44
Improve the NewtonDescent with higher-order terms. First compute the descent direction as ``J a = -fu``.
55
Then compute the hessian-vector-vector product and solve for the second-order correction term as ``J b = H a a``.
66
Finally, compute the descent direction as ``δu = a * a / (b / 2 - a)``.
77
8+
Note that `import TaylorDiff` is required to use this descent algorithm.
9+
810
See also [`NewtonDescent`](@ref).
911
"""
10-
@kwdef @concrete struct HalleyDescent <: AbstractDescentAlgorithm
12+
@kwdef @concrete struct HalleyDescent <: AbstractDescentDirection
1113
linsolve = nothing
12-
precs = DEFAULT_PRECS
13-
end
14-
15-
using TaylorDiff: derivative
16-
17-
function Base.show(io::IO, d::HalleyDescent)
18-
modifiers = String[]
19-
d.linsolve !== nothing && push!(modifiers, "linsolve = $(d.linsolve)")
20-
d.precs !== DEFAULT_PRECS && push!(modifiers, "precs = $(d.precs)")
21-
print(io, "HalleyDescent($(join(modifiers, ", ")))")
2214
end
2315

2416
supports_line_search(::HalleyDescent) = true
2517

26-
@concrete mutable struct HalleyDescentCache{pre_inverted} <: AbstractDescentCache
18+
@concrete mutable struct HalleyDescentCache <: AbstractDescentCache
2719
f
2820
p
2921
δu
3022
δus
3123
b
3224
fu
25+
hvvp
3326
lincache
3427
timer
28+
preinverted_jacobian <: Union{Val{false}, Val{true}}
3529
end
3630

3731
@internal_caches HalleyDescentCache :lincache
3832

39-
function __internal_init(prob::NonlinearProblem, alg::HalleyDescent, J, fu, u; stats,
40-
shared::Val{N} = Val(1), pre_inverted::Val{INV} = False,
33+
function InternalAPI.init(
34+
prob::NonlinearProblem, alg::HalleyDescent, J, fu, u; stats,
35+
shared = Val(1), pre_inverted::Val = Val(false),
4136
linsolve_kwargs = (;), abstol = nothing, reltol = nothing,
42-
timer = get_timer_output(), kwargs...) where {INV, N}
37+
timer = get_timer_output(), kwargs...)
4338
@bb δu = similar(u)
4439
@bb b = similar(u)
4540
@bb fu = similar(fu)
46-
δus = N 1 ? nothing : map(2:N) do i
41+
@bb hvvp = similar(fu)
42+
δus = Utils.unwrap_val(shared) 1 ? nothing : map(2:Utils.unwrap_val(shared)) do i
4743
@bb δu_ = similar(u)
4844
end
49-
INV && return HalleyDescentCache{true}(prob.f, prob.p, δu, δus, b, nothing, timer)
50-
lincache = LinearSolverCache(
51-
alg, alg.linsolve, J, _vec(fu), _vec(u); stats, abstol, reltol, linsolve_kwargs...)
52-
return HalleyDescentCache{false}(prob.f, prob.p, δu, δus, b, fu, lincache, timer)
45+
lincache = Utils.unwrap_val(pre_inverted) ? nothing :
46+
construct_linear_solver(
47+
alg, alg.linsolve, J, Utils.safe_vec(fu), Utils.safe_vec(u);
48+
stats, abstol, reltol, linsolve_kwargs...
49+
)
50+
return HalleyDescentCache(
51+
prob.f, prob.p, δu, δus, b, fu, hvvp, lincache, timer, pre_inverted)
5352
end
5453

55-
function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val = Val(1);
56-
skip_solve::Bool = false, new_jacobian::Bool = true, kwargs...) where {INV}
57-
δu = get_du(cache, idx)
54+
function InternalAPI.solve!(
55+
cache::HalleyDescentCache, J, fu, u, idx::Val = Val(1);
56+
skip_solve::Bool = false, new_jacobian::Bool = true, kwargs...)
57+
δu = SciMLBase.get_du(cache, idx)
5858
skip_solve && return DescentResult(; δu)
59-
if INV
59+
if preinverted_jacobian(cache)
6060
@assert J!==nothing "`J` must be provided when `pre_inverted = Val(true)`."
6161
@bb δu = J × vec(fu)
6262
else
6363
@static_timeit cache.timer "linear solve 1" begin
6464
linres = cache.lincache(;
65-
A = J, b = _vec(fu), kwargs..., linu = _vec(δu), du = _vec(δu),
65+
A = J, b = Utils.safe_vec(fu),
66+
kwargs..., linu = Utils.safe_vec(δu),
6667
reuse_A_if_factorization = !new_jacobian || (idx !== Val(1)))
67-
δu = _restructure(get_du(cache, idx), linres.u)
68+
δu = Utils.restructure(SciMLBase.get_du(cache, idx), linres.u)
6869
if !linres.success
6970
set_du!(cache, δu, idx)
7071
return DescentResult(; δu, success = false, linsolve_success = false)
@@ -73,15 +74,17 @@ function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val =
7374
end
7475
b = cache.b
7576
# compute the hessian-vector-vector product
76-
hvvp = evaluate_hvvp(cache, cache.f, cache.p, u, δu)
77+
hvvp = evaluate_hvvp(cache.hvvp, cache, cache.f, cache.p, u, δu)
7778
# second linear solve, reuse factorization if possible
78-
if INV
79+
if preinverted_jacobian(cache)
7980
@bb b = J × vec(hvvp)
8081
else
8182
@static_timeit cache.timer "linear solve 2" begin
82-
linres = cache.lincache(; A = J, b = _vec(hvvp), kwargs..., linu = _vec(b),
83-
du = _vec(b), reuse_A_if_factorization = true)
84-
b = _restructure(cache.b, linres.u)
83+
linres = cache.lincache(;
84+
A = J, b = Utils.safe_vec(hvvp),
85+
kwargs..., linu = Utils.safe_vec(b),
86+
reuse_A_if_factorization = true)
87+
b = Utils.restructure(cache.b, linres.u)
8588
if !linres.success
8689
set_du!(cache, δu, idx)
8790
return DescentResult(; δu, success = false, linsolve_success = false)
@@ -94,13 +97,4 @@ function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val =
9497
return DescentResult(; δu)
9598
end
9699

97-
function evaluate_hvvp(
98-
cache::HalleyDescentCache, f::NonlinearFunction{iip}, p, u, δu) where {iip}
99-
if iip
100-
binary_f = @closure (y, x) -> f(y, x, p)
101-
derivative(binary_f, cache.fu, u, δu, Val{3}())
102-
else
103-
unary_f = Base.Fix2(f, p)
104-
derivative(unary_f, u, δu, Val{3}())
105-
end
106-
end
100+
evaluate_hvvp(hvvp, cache, f, p, u, δu) = error("not implemented. please import TaylorDiff")

lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm,
2020
AbstractTrustRegionMethodCache,
2121
Utils, InternalAPI, get_timer_output, @static_timeit,
2222
update_trace!, L2_NORM,
23-
NewtonDescent, DampedNewtonDescent, GeodesicAcceleration,
23+
NewtonDescent, DampedNewtonDescent, HalleyDescent, GeodesicAcceleration,
2424
Dogleg
2525
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NLStats, ReturnCode,
2626
NonlinearFunction,
@@ -31,6 +31,7 @@ using FiniteDiff: FiniteDiff # Default Finite Difference Method
3131
using ForwardDiff: ForwardDiff # Default Forward Mode AD
3232

3333
include("raphson.jl")
34+
include("halley.jl")
3435
include("gauss_newton.jl")
3536
include("levenberg_marquardt.jl")
3637
include("trust_region.jl")
@@ -93,7 +94,7 @@ end
9394

9495
@reexport using SciMLBase, NonlinearSolveBase
9596

96-
export NewtonRaphson, PseudoTransient
97+
export NewtonRaphson, Halley, PseudoTransient
9798
export GaussNewton, LevenbergMarquardt, TrustRegion
9899

99100
export RadiusUpdateSchemes

lib/NonlinearSolveFirstOrder/src/halley.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
2-
Halley(; concrete_jac = nothing, linsolve = nothing, linesearch = NoLineSearch(),
3-
precs = DEFAULT_PRECS, autodiff = nothing)
2+
Halley(; concrete_jac = nothing, linsolve = nothing, linesearch = missing,
3+
autodiff = nothing)
44
55
An experimental Halley's method implementation. Improves the convergence rate of Newton's method by using second-order derivative information to correct the descent direction.
66

test/23_test_problems_tests.jl

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
@testsetup module RobustnessTesting
22
using NonlinearSolve, LinearAlgebra, LinearSolve, NonlinearProblemLibrary, Test
3+
import TaylorDiff
34

45
problems = NonlinearProblemLibrary.problems
56
dicts = NonlinearProblemLibrary.dicts
@@ -61,10 +62,14 @@ end
6162
end
6263

6364
@testitem "23 Test Problems: Halley" setup=[RobustnessTesting] tags=[:core] begin
64-
alg_ops = (SimpleHalley(; autodiff = AutoForwardDiff()),)
65+
alg_ops = (
66+
Halley(),
67+
SimpleHalley(; autodiff = AutoForwardDiff())
68+
)
6569

6670
broken_tests = Dict(alg => Int[] for alg in alg_ops)
67-
broken_tests[alg_ops[1]] = [1, 5, 15, 16, 18]
71+
broken_tests[alg_ops[1]] = [1, 5, 15, 16]
72+
broken_tests[alg_ops[2]] = [1, 5, 15, 16, 18]
6873

6974
test_on_library(problems, dicts, alg_ops, broken_tests)
7075
end

0 commit comments

Comments
 (0)