Skip to content

Commit c6eba21

Browse files
committed
Improve docs and format
1 parent a64cd35 commit c6eba21

File tree

3 files changed

+17
-14
lines changed

3 files changed

+17
-14
lines changed

src/NonlinearSolve.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ include("default.jl")
146146
end
147147

148148
# Core Algorithms
149-
export NewtonRaphson, PseudoTransient, Klement, Broyden, LimitedMemoryBroyden, DFSane, Halley
149+
export NewtonRaphson, PseudoTransient, Klement, Broyden, LimitedMemoryBroyden, DFSane,
150+
Halley
150151
export GaussNewton, LevenbergMarquardt, TrustRegion
151152
export NonlinearSolvePolyAlgorithm, RobustMultiNewton, FastShortcutNonlinearPolyalg,
152153
FastShortcutNLLSPolyalg
@@ -159,7 +160,8 @@ export LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, NLSol
159160
export GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm, GeneralizedDFSane
160161

161162
# Descent Algorithms
162-
export NewtonDescent, SteepestDescent, Dogleg, DampedNewtonDescent, GeodesicAcceleration, HalleyDescent
163+
export NewtonDescent, SteepestDescent, Dogleg, DampedNewtonDescent, GeodesicAcceleration,
164+
HalleyDescent
163165

164166
# Globalization
165167
## Line Search Algorithms

src/algorithms/halley.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
Halley(; concrete_jac = nothing, linsolve = nothing, linesearch = NoLineSearch(),
33
precs = DEFAULT_PRECS, autodiff = nothing)
44
5-
An experimental Halley's method implementation.
5+
An experimental Halley's method implementation. Improves the convergence rate of Newton's method by using second-order derivative information to correct the descent direction.
6+
7+
Currently depends on TaylorDiff.jl to handle the correction terms,
8+
might have more general implementation in the future.
69
"""
710
function Halley(; concrete_jac = nothing, linsolve = nothing,
811
linesearch = NoLineSearch(), precs = DEFAULT_PRECS, autodiff = nothing)

src/descent/halley.jl

+9-11
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""
22
HalleyDescent(; linsolve = nothing, precs = DEFAULT_PRECS)
33
4-
Compute the descent direction as ``J δu = -fu``. For non-square Jacobian problems, this is
5-
commonly referred to as the Gauss-Newton Descent.
4+
Improve the NewtonDescent with higher-order terms. First compute the descent direction as ``J a = -fu``.
5+
Then compute the hessian-vector-vector product and solve for the second-order correction term as ``J b = H a a``.
6+
Finally, compute the descent direction as ``δu = a * a / (b / 2 - a)``.
67
7-
See also [`Dogleg`](@ref), [`SteepestDescent`](@ref), [`DampedNewtonDescent`](@ref).
8+
See also [`NewtonDescent`](@ref).
89
"""
910
@kwdef @concrete struct HalleyDescent <: AbstractDescentAlgorithm
1011
linsolve = nothing
@@ -22,8 +23,7 @@ end
2223

2324
supports_line_search(::HalleyDescent) = true
2425

25-
@concrete mutable struct HalleyDescentCache{pre_inverted} <:
26-
AbstractDescentCache
26+
@concrete mutable struct HalleyDescentCache{pre_inverted} <: AbstractDescentCache
2727
f
2828
p
2929
δu
@@ -50,8 +50,7 @@ function __internal_init(
5050
return HalleyDescentCache{false}(prob.f, prob.p, δu, δus, b, lincache, timer)
5151
end
5252

53-
function __internal_solve!(
54-
cache::HalleyDescentCache{INV}, J, fu, u, idx::Val = Val(1);
53+
function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val = Val(1);
5554
skip_solve::Bool = false, new_jacobian::Bool = true, kwargs...) where {INV}
5655
δu = get_du(cache, idx)
5756
skip_solve && return δu, true, (;)
@@ -68,15 +67,14 @@ function __internal_solve!(
6867
end
6968
b = cache.b
7069
# compute the hessian-vector-vector product
71-
hvvp = derivative(x -> cache.f(x, cache.p), u, δu, 2)
70+
hvvp = derivative(Base.Fix2(cache.f, cache.p), u, δu, 2)
7271
# second linear solve, reuse factorization if possible
7372
if INV
7473
@bb b = J × vec(hvvp)
7574
else
7675
@static_timeit cache.timer "linear solve 2" begin
77-
b = cache.lincache(;
78-
A = J, b = _vec(hvvp), kwargs..., linu = _vec(b), du = _vec(b),
79-
reuse_A_if_factorization = true)
76+
b = cache.lincache(; A = J, b = _vec(hvvp), kwargs..., linu = _vec(b),
77+
du = _vec(b), reuse_A_if_factorization = true)
8078
b = _restructure(cache.b, b)
8179
end
8280
end

0 commit comments

Comments
 (0)