Skip to content

Commit 40c4df9

Browse files
committed
refactor: migrate to LineSearch.jl
1 parent 05aa3db commit 40c4df9

11 files changed

+324
-540
lines changed

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1313
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1414
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1515
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
16+
LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b"
1617
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
1718
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1819
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
@@ -78,6 +79,7 @@ Hwloc = "3"
7879
InteractiveUtils = "<0.0.1, 1"
7980
LazyArrays = "1.8.2, 2"
8081
LeastSquaresOptim = "0.8.5"
82+
LineSearch = "0.1"
8183
LineSearches = "7.2"
8284
LinearAlgebra = "1.10"
8385
LinearSolve = "2.30"

docs/src/devdocs/internal_interfaces.md

-7
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,6 @@ NonlinearSolve.AbstractDampingFunction
3838
NonlinearSolve.AbstractDampingFunctionCache
3939
```
4040

41-
## Line Search
42-
43-
```@docs
44-
NonlinearSolve.AbstractNonlinearSolveLineSearchAlgorithm
45-
NonlinearSolve.AbstractNonlinearSolveLineSearchCache
46-
```
47-
4841
## Trust Region
4942

5043
```@docs

src/NonlinearSolve.jl

+50-48
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ using LazyArrays: LazyArrays, ApplyArray, cache
3030
using LinearAlgebra: LinearAlgebra, ColumnNorm, Diagonal, I, LowerTriangular, Symmetric,
3131
UpperTriangular, axpy!, cond, diag, diagind, dot, issuccess, istril,
3232
istriu, lu, mul!, norm, pinv, tril!, triu!
33+
using LineSearch: LineSearch, AbstractLineSearchAlgorithm, AbstractLineSearchCache,
34+
NoLineSearch
3335
using LineSearches: LineSearches
3436
using LinearSolve: LinearSolve, LUFactorization, QRFactorization, ComposePreconditioner,
3537
InvPreconditioner, needs_concrete_A, AbstractFactorization,
@@ -103,54 +105,54 @@ include("algorithms/extension_algs.jl")
103105
include("utils.jl")
104106
include("default.jl")
105107

106-
@setup_workload begin
107-
nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1),
108-
(NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1]))
109-
probs_nls = NonlinearProblem[]
110-
for (fn, u0) in nlfuncs
111-
push!(probs_nls, NonlinearProblem(fn, u0, 2.0))
112-
end
113-
114-
nls_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(),
115-
PseudoTransient(), Broyden(), Klement(), DFSane(), nothing)
116-
117-
probs_nlls = NonlinearLeastSquaresProblem[]
118-
nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]),
119-
(NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]),
120-
(
121-
NonlinearFunction{true}(
122-
(du, u, p) -> du[1] = u[1] * u[1] - p, resid_prototype = zeros(1)),
123-
[0.1, 0.0]),
124-
(
125-
NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
126-
resid_prototype = zeros(4)),
127-
[0.1, 0.1]))
128-
for (fn, u0) in nlfuncs
129-
push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0))
130-
end
131-
132-
nlls_algs = (LevenbergMarquardt(), GaussNewton(), TrustRegion(),
133-
LevenbergMarquardt(; linsolve = LUFactorization()),
134-
GaussNewton(; linsolve = LUFactorization()),
135-
TrustRegion(; linsolve = LUFactorization()), nothing)
136-
137-
@compile_workload begin
138-
@sync begin
139-
for T in (Float32, Float64), (fn, u0) in nlfuncs
140-
Threads.@spawn NonlinearProblem(fn, T.(u0), T(2))
141-
end
142-
for (fn, u0) in nlfuncs
143-
Threads.@spawn NonlinearLeastSquaresProblem(fn, u0, 2.0)
144-
end
145-
for prob in probs_nls, alg in nls_algs
146-
Threads.@spawn solve(prob, alg; abstol = 1e-2, verbose = false)
147-
end
148-
for prob in probs_nlls, alg in nlls_algs
149-
Threads.@spawn solve(prob, alg; abstol = 1e-2, verbose = false)
150-
end
151-
end
152-
end
153-
end
108+
# @setup_workload begin
109+
# nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1),
110+
# (NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1]))
111+
# probs_nls = NonlinearProblem[]
112+
# for (fn, u0) in nlfuncs
113+
# push!(probs_nls, NonlinearProblem(fn, u0, 2.0))
114+
# end
115+
116+
# nls_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(),
117+
# PseudoTransient(), Broyden(), Klement(), DFSane(), nothing)
118+
119+
# probs_nlls = NonlinearLeastSquaresProblem[]
120+
# nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]),
121+
# (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]),
122+
# (
123+
# NonlinearFunction{true}(
124+
# (du, u, p) -> du[1] = u[1] * u[1] - p, resid_prototype = zeros(1)),
125+
# [0.1, 0.0]),
126+
# (
127+
# NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
128+
# resid_prototype = zeros(4)),
129+
# [0.1, 0.1]))
130+
# for (fn, u0) in nlfuncs
131+
# push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0))
132+
# end
133+
134+
# nlls_algs = (LevenbergMarquardt(), GaussNewton(), TrustRegion(),
135+
# LevenbergMarquardt(; linsolve = LUFactorization()),
136+
# GaussNewton(; linsolve = LUFactorization()),
137+
# TrustRegion(; linsolve = LUFactorization()), nothing)
138+
139+
# @compile_workload begin
140+
# @sync begin
141+
# for T in (Float32, Float64), (fn, u0) in nlfuncs
142+
# Threads.@spawn NonlinearProblem(fn, T.(u0), T(2))
143+
# end
144+
# for (fn, u0) in nlfuncs
145+
# Threads.@spawn NonlinearLeastSquaresProblem(fn, u0, 2.0)
146+
# end
147+
# for prob in probs_nls, alg in nls_algs
148+
# Threads.@spawn solve(prob, alg; abstol = 1e-2, verbose = false)
149+
# end
150+
# for prob in probs_nlls, alg in nlls_algs
151+
# Threads.@spawn solve(prob, alg; abstol = 1e-2, verbose = false)
152+
# end
153+
# end
154+
# end
155+
# end
154156

155157
# Core Algorithms
156158
export NewtonRaphson, PseudoTransient, Klement, Broyden, LimitedMemoryBroyden, DFSane

src/abstract_types.jl

+3-19
Original file line numberDiff line numberDiff line change
@@ -106,22 +106,6 @@ function last_step_accepted(cache::AbstractDescentCache)
106106
return true
107107
end
108108

109-
"""
110-
AbstractNonlinearSolveLineSearchAlgorithm
111-
112-
Abstract Type for all Line Search Algorithms used in NonlinearSolve.jl.
113-
114-
### `__internal_init` specification
115-
116-
```julia
117-
__internal_init(
118-
prob::AbstractNonlinearProblem, alg::AbstractNonlinearSolveLineSearchAlgorithm, f::F,
119-
fu, u, p, args...; internalnorm::IN = DEFAULT_NORM, kwargs...) where {F, IN} -->
120-
AbstractNonlinearSolveLineSearchCache
121-
```
122-
"""
123-
abstract type AbstractNonlinearSolveLineSearchAlgorithm end
124-
125109
"""
126110
AbstractNonlinearSolveLineSearchCache
127111
@@ -512,9 +496,9 @@ SciMLBase.isinplace(::AbstractNonlinearSolveJacobianCache{iip}) where {iip} = ii
512496
abstract type AbstractNonlinearSolveTraceLevel end
513497

514498
# Default Printing
515-
for aType in (AbstractTrustRegionMethod, AbstractNonlinearSolveLineSearchAlgorithm,
516-
AbstractResetCondition, AbstractApproximateJacobianUpdateRule,
517-
AbstractDampingFunction, AbstractNonlinearSolveExtensionAlgorithm)
499+
for aType in (AbstractTrustRegionMethod, AbstractResetCondition,
500+
AbstractApproximateJacobianUpdateRule, AbstractDampingFunction,
501+
AbstractNonlinearSolveExtensionAlgorithm)
518502
@eval function Base.show(io::IO, alg::$(aType))
519503
print(io, "$(nameof(typeof(alg)))()")
520504
end

src/algorithms/dfsane.jl

+4-3
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ For other keyword arguments, see [`RobustNonMonotoneLineSearch`](@ref).
1919
function DFSane(; σ_min = 1 // 10^10, σ_max = 1e10, σ_1 = 1, M::Int = 10, γ = 1 // 10^4,
2020
τ_min = 1 // 10, τ_max = 1 // 2, n_exp::Int = 2, max_inner_iterations::Int = 100,
2121
η_strategy::ETA = (fn_1, n, x_n, f_n) -> fn_1 / n^2) where {ETA}
22-
linesearch = RobustNonMonotoneLineSearch(;
23-
gamma = γ, sigma_1 = σ_1, M, tau_min = τ_min, tau_max = τ_max,
24-
n_exp, η_strategy, maxiters = max_inner_iterations)
22+
# linesearch = RobustNonMonotoneLineSearch(;
23+
# gamma = γ, sigma_1 = σ_1, M, tau_min = τ_min, tau_max = τ_max,
24+
# n_exp, η_strategy, maxiters = max_inner_iterations)
25+
linesearch = NoLineSearch()
2526
return GeneralizedDFSane{:DFSane}(linesearch, σ_min, σ_max, nothing)
2627
end

src/algorithms/klement.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ over this.
2727
function Klement(; max_resets::Int = 100, linsolve = nothing, alpha = nothing,
2828
linesearch = NoLineSearch(), precs = DEFAULT_PRECS,
2929
autodiff = nothing, init_jacobian::Val{IJ} = Val(:identity)) where {IJ}
30-
if !(linesearch isa AbstractNonlinearSolveLineSearchAlgorithm)
31-
Base.depwarn(
32-
"Passing in a `LineSearches.jl` algorithm directly is deprecated. \
33-
Please use `LineSearchesJL` instead.", :Klement)
34-
linesearch = LineSearchesJL(; method = linesearch)
35-
end
30+
# if !(linesearch isa AbstractNonlinearSolveLineSearchAlgorithm)
31+
# Base.depwarn(
32+
# "Passing in a `LineSearches.jl` algorithm directly is deprecated. \
33+
# Please use `LineSearchesJL` instead.", :Klement)
34+
# linesearch = LineSearchesJL(; method = linesearch)
35+
# end
3636

3737
if IJ === :identity
3838
initialization = IdentityInitialization(alpha, DiagonalStructure())

src/algorithms/pseudo_transient.jl

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
PseudoTransient(; concrete_jac = nothing, linsolve = nothing,
3-
linesearch::AbstractNonlinearSolveLineSearchAlgorithm = NoLineSearch(),
4-
precs = DEFAULT_PRECS, autodiff = nothing)
3+
linesearch = NoLineSearch(), precs = DEFAULT_PRECS, autodiff = nothing)
54
65
An implementation of PseudoTransient Method [coffey2003pseudotransient](@cite) that is used
76
to solve steady state problems in an accelerated manner. It uses an adaptive time-stepping
@@ -16,8 +15,8 @@ This implementation specifically uses "switched evolution relaxation"
1615
you are going to need more iterations to converge but it can be more stable.
1716
"""
1817
function PseudoTransient(; concrete_jac = nothing, linsolve = nothing,
19-
linesearch::AbstractNonlinearSolveLineSearchAlgorithm = NoLineSearch(),
20-
precs = DEFAULT_PRECS, autodiff = nothing, alpha_initial = 1e-3)
18+
linesearch = NoLineSearch(), precs = DEFAULT_PRECS, autodiff = nothing,
19+
alpha_initial = 1e-3)
2120
descent = DampedNewtonDescent(; linsolve, precs, initial_damping = alpha_initial,
2221
damping_fn = SwitchedEvolutionRelaxation())
2322
return GeneralizedFirstOrderAlgorithm(;

src/core/approximate_jacobian.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,12 @@ function ApproximateJacobianSolveAlgorithm{concrete_jac, name}(;
5959
linesearch = missing, trustregion = missing, descent, update_rule,
6060
reinit_rule, initialization, max_resets::Int = typemax(Int),
6161
max_shrink_times::Int = typemax(Int)) where {concrete_jac, name}
62-
if linesearch !== missing && !(linesearch isa AbstractNonlinearSolveLineSearchAlgorithm)
63-
Base.depwarn("Passing in a `LineSearches.jl` algorithm directly is deprecated. \
64-
Please use `LineSearchesJL` instead.",
65-
:GeneralizedFirstOrderAlgorithm)
66-
linesearch = LineSearchesJL(; method = linesearch)
67-
end
62+
# if linesearch !== missing && !(linesearch isa AbstractNonlinearSolveLineSearchAlgorithm)
63+
# Base.depwarn("Passing in a `LineSearches.jl` algorithm directly is deprecated. \
64+
# Please use `LineSearchesJL` instead.",
65+
# :GeneralizedFirstOrderAlgorithm)
66+
# linesearch = LineSearchesJL(; method = linesearch)
67+
# end
6868
return ApproximateJacobianSolveAlgorithm{concrete_jac, name}(
6969
linesearch, trustregion, descent, update_rule,
7070
reinit_rule, max_resets, max_shrink_times, initialization)

src/core/generalized_first_order.jl

+14-10
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,12 @@ function GeneralizedFirstOrderAlgorithm{concrete_jac, name}(;
6666
jacobian_ad !== nothing && ADTypes.mode(jacobian_ad) isa ADTypes.ReverseMode,
6767
jacobian_ad, nothing))
6868

69-
if linesearch !== missing && !(linesearch isa AbstractNonlinearSolveLineSearchAlgorithm)
70-
Base.depwarn("Passing in a `LineSearches.jl` algorithm directly is deprecated. \
71-
Please use `LineSearchesJL` instead.",
72-
:GeneralizedFirstOrderAlgorithm)
73-
linesearch = LineSearchesJL(; method = linesearch)
74-
end
69+
# if linesearch !== missing && !(linesearch isa AbstractNonlinearSolveLineSearchAlgorithm)
70+
# Base.depwarn("Passing in a `LineSearches.jl` algorithm directly is deprecated. \
71+
# Please use `LineSearchesJL` instead.",
72+
# :GeneralizedFirstOrderAlgorithm)
73+
# linesearch = LineSearchesJL(; method = linesearch)
74+
# end
7575

7676
return GeneralizedFirstOrderAlgorithm{concrete_jac, name}(
7777
linesearch, trustregion, descent, max_shrink_times,
@@ -199,8 +199,11 @@ function SciMLBase.__init(
199199
if alg.linesearch !== missing
200200
supports_line_search(alg.descent) || error("Line Search not supported by \
201201
$(alg.descent).")
202-
linesearch_cache = __internal_init(
203-
prob, alg.linesearch, f, fu, u, p; stats, internalnorm, kwargs...)
202+
linesearch_ad = alg.forward_ad === nothing ?
203+
(alg.reverse_ad === nothing ? alg.jacobian_ad :
204+
alg.reverse_ad) : alg.forward_ad
205+
linesearch_cache = init(
206+
prob, alg.linesearch, fu, u; stats, autodiff = linesearch_ad, kwargs...)
204207
GB = :LineSearch
205208
end
206209

@@ -264,8 +267,9 @@ function __step!(cache::GeneralizedFirstOrderAlgorithmCache{iip, GB};
264267
cache.make_new_jacobian = true
265268
if GB === :LineSearch
266269
@static_timeit cache.timer "linesearch" begin
267-
linesearch_failed, α = __internal_solve!(
268-
cache.linesearch_cache, cache.u, δu)
270+
linesearch_sol = solve!(cache.linesearch_cache, cache.u, δu)
271+
linesearch_failed = !SciMLBase.successful_retcode(linesearch_sol.retcode)
272+
α = linesearch_sol.step_size
269273
end
270274
if linesearch_failed
271275
cache.retcode = ReturnCode.InternalLineSearchFailed

src/core/spectral_methods.jl

+2-3
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@ Method.
99
1010
### Arguments
1111
12-
- `linesearch`: Globalization using a Line Search Method. This needs to follow the
13-
[`NonlinearSolve.AbstractNonlinearSolveLineSearchAlgorithm`](@ref) interface. This
14-
is not optional currently, but that restriction might be lifted in the future.
12+
- `linesearch`: Globalization using a Line Search Method. This is not optional currently,
13+
but that restriction might be lifted in the future.
1514
- `σ_min`: The minimum spectral parameter allowed. This is used to ensure that the
1615
spectral parameter is not too small.
1716
- `σ_max`: The maximum spectral parameter allowed. This is used to ensure that the

0 commit comments

Comments
 (0)