Skip to content

Commit 5d18cf7

Browse files
committed
feat: implement all bracketing algorithms
1 parent 6e08b09 commit 5d18cf7

File tree

13 files changed

+756
-25
lines changed

13 files changed

+756
-25
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
1-
name = "SimpleBracketingNonlinearSolve"
1+
name = "BracketingNonlinearSolve"
22
uuid = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e"
33
authors = ["Avik Pal <[email protected]> and contributors"]
44
version = "1.0.0"
55

66
[deps]
77
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
8+
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
9+
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
10+
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
811
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
912

1013
[compat]
1114
CommonSolve = "0.2.4"
15+
ConcreteStructs = "0.2.3"
16+
NonlinearSolveBase = "1"
17+
PrecompileTools = "1.2.1"
1218
SciMLBase = "2.50"
1319
julia = "1.10"
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
module BracketingNonlinearSolve
2+
3+
using ConcreteStructs: @concrete
4+
5+
using CommonSolve: CommonSolve
6+
using NonlinearSolveBase: NonlinearSolveBase
7+
using SciMLBase: SciMLBase, AbstractNonlinearAlgorithm, IntervalNonlinearProblem, ReturnCode
8+
9+
using PrecompileTools: @compile_workload, @setup_workload
10+
11+
abstract type AbstractBracketingAlgorithm <: AbstractNonlinearAlgorithm end
12+
13+
include("common.jl")
14+
15+
include("alefeld.jl")
16+
include("bisection.jl")
17+
include("brent.jl")
18+
include("falsi.jl")
19+
include("itp.jl")
20+
include("ridder.jl")
21+
22+
@setup_workload begin
23+
for T in (Float32, Float64)
24+
prob_brack = IntervalNonlinearProblem{false}(
25+
(u, p) -> u^2 - p, T.((0.0, 2.0)), T(2))
26+
algs = (Alefeld(), Bisection(), Brent(), Falsi(), ITP(), Ridder())
27+
28+
@compile_workload begin
29+
for alg in algs
30+
CommonSolve.solve(prob_brack, alg; abstol = 1e-6)
31+
end
32+
end
33+
end
34+
end
35+
36+
export Alefeld, Bisection, Brent, Falsi, ITP, Ridder
37+
38+
end
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""
2+
Alefeld()
3+
4+
An implementation of algorithm 4.2 from [Alefeld](https://dl.acm.org/doi/10.1145/210089.210111).
5+
6+
The paper brought up two new algorithms. Here choose to implement algorithm 4.2 rather than
7+
algorithm 4.1 because, in certain sense, the second algorithm(4.2) is an optimal procedure.
8+
"""
9+
struct Alefeld <: AbstractBracketingAlgorithm end
10+
11+
function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Alefeld, args...;
12+
maxiters = 1000, abstol = nothing, kwargs...)
13+
f = Base.Fix2(prob.f, prob.p)
14+
a, b = prob.tspan
15+
c = a - (b - a) / (f(b) - f(a)) * f(a)
16+
17+
fc = f(c)
18+
if a == c || b == c
19+
return SciMLBase.build_solution(
20+
prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit, left = a, right = b)
21+
end
22+
23+
if iszero(fc)
24+
return SciMLBase.build_solution(
25+
prob, alg, c, fc; retcode = ReturnCode.Success, left = a, right = b)
26+
end
27+
28+
a, b, d = Impl.bracket(f, a, b, c)
29+
e = zero(a) # Set e as 0 before iteration to avoid a non-value f(e)
30+
31+
for i in 2:maxiters
32+
# The first bracketing block
33+
f₁, f₂, f₃, f₄ = f(a), f(b), f(d), f(e)
34+
if i == 2 || (f₁ == f₂ || f₁ == f₃ || f₁ == f₄ || f₂ == f₃ || f₂ == f₄ || f₃ == f₄)
35+
c = Impl.newton_quadratic(f, a, b, d, 2)
36+
else
37+
c = Impl.ipzero(f, a, b, d, e)
38+
if (c - a) * (c - b) 0
39+
c = Impl.newton_quadratic(f, a, b, d, 2)
40+
end
41+
end
42+
43+
ē, fc = d, f(c)
44+
if a == c || b == c
45+
return SciMLBase.build_solution(
46+
prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit,
47+
left = a, right = b)
48+
end
49+
50+
if iszero(fc)
51+
return SciMLBase.build_solution(
52+
prob, alg, c, fc; retcode = ReturnCode.Success, left = a, right = b)
53+
end
54+
55+
ā, b̄, d̄ = Impl.bracket(f, a, b, c)
56+
57+
# The second bracketing block
58+
f₁, f₂, f₃, f₄ = f(ā), f(b̄), f(d̄), f(ē)
59+
if f₁ == f₂ || f₁ == f₃ || f₁ == f₄ || f₂ == f₃ || f₂ == f₄ || f₃ == f₄
60+
c = Impl.newton_quadratic(f, ā, b̄, d̄, 3)
61+
else
62+
c = Impl.ipzero(f, ā, b̄, d̄, ē)
63+
if (c - ā) * (c - b̄) 0
64+
c = Impl.newton_quadratic(f, ā, b̄, d̄, 3)
65+
end
66+
end
67+
fc = f(c)
68+
69+
if== c ||== c
70+
return SciMLBase.build_solution(
71+
prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit,
72+
left = ā, right = b̄)
73+
end
74+
75+
if iszero(fc)
76+
return SciMLBase.build_solution(
77+
prob, alg, c, fc; retcode = ReturnCode.Success, left = ā, right = b̄)
78+
end
79+
80+
ā, b̄, d̄ = Impl.bracket(f, ā, b̄, c)
81+
82+
# The third bracketing block
83+
u = ifelse(abs(f(ā)) < abs(f(b̄)), ā, b̄)
84+
c = u - 2 * (b̄ - ā) / (f(b̄) - f(ā)) * f(u)
85+
if (abs(c - u)) > 0.5 * (b̄ - ā)
86+
c = 0.5 * (ā + b̄)
87+
end
88+
fc = f(c)
89+
90+
if== c ||== c
91+
return SciMLBase.build_solution(
92+
prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit,
93+
left = ā, right = b̄)
94+
end
95+
96+
if iszero(fc)
97+
return SciMLBase.build_solution(
98+
prob, alg, c, fc; retcode = ReturnCode.Success, left = ā, right = b̄)
99+
end
100+
101+
ā, b̄, d = Impl.bracket(f, ā, b̄, c)
102+
103+
# The last bracketing block
104+
if-< 0.5 * (b - a)
105+
a, b, e = ā, b̄, d̄
106+
else
107+
e = d
108+
c = 0.5 * (ā + b̄)
109+
fc = f(c)
110+
111+
if== c ||== c
112+
return SciMLBase.build_solution(
113+
prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit,
114+
left = ā, right = b̄)
115+
end
116+
if iszero(fc)
117+
return SciMLBase.build_solution(
118+
prob, alg, c, fc; retcode = ReturnCode.Success, left = ā, right = b̄)
119+
end
120+
a, b, d = Impl.bracket(f, ā, b̄, c)
121+
end
122+
end
123+
124+
# Reassign the value a, b, and c
125+
if b == c
126+
b = d
127+
elseif a == c
128+
a = d
129+
end
130+
fc = f(c)
131+
132+
# Reuturn solution when run out of max interation
133+
return SciMLBase.build_solution(
134+
prob, alg, c, fc; retcode = ReturnCode.MaxIters, left = a, right = b)
135+
end
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""
2+
Bisection(; exact_left = false, exact_right = false)
3+
4+
A common bisection method.
5+
6+
### Keyword Arguments
7+
8+
- `exact_left`: whether to enforce whether the left side of the interval must be exactly
9+
zero for the returned result. Defaults to false.
10+
- `exact_right`: whether to enforce whether the right side of the interval must be exactly
11+
zero for the returned result. Defaults to false.
12+
13+
!!! danger "Keyword Arguments"
14+
15+
Currently, the keyword arguments are not implemented.
16+
"""
17+
@kwdef struct Bisection <: AbstractBracketingAlgorithm
18+
exact_left::Bool = false
19+
exact_right::Bool = false
20+
end
21+
22+
function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Bisection,
23+
args...; maxiters = 1000, abstol = nothing, kwargs...)
24+
@assert !SciMLBase.isinplace(prob) "`Bisection` only supports out-of-place problems."
25+
26+
f = Base.Fix2(prob.f, prob.p)
27+
left, right = prob.tspan
28+
fl, fr = f(left), f(right)
29+
30+
abstol = NonlinearSolveBase.get_tolerance(
31+
abstol, promote_type(eltype(left), eltype(right)))
32+
33+
if iszero(fl)
34+
return SciMLBase.build_solution(
35+
prob, alg, left, fl; retcode = ReturnCode.ExactSolutionLeft, left, right)
36+
end
37+
38+
if iszero(fr)
39+
return SciMLBase.build_solution(
40+
prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight, left, right)
41+
end
42+
43+
i = 1
44+
while i maxiters
45+
mid = (left + right) / 2
46+
47+
if mid == left || mid == right
48+
return SciMLBase.build_solution(
49+
prob, alg, left, fl; retcode = ReturnCode.FloatingPointLimit, left, right)
50+
end
51+
52+
fm = f(mid)
53+
if abs((right - left) / 2) < abstol
54+
return SciMLBase.build_solution(
55+
prob, alg, mid, fm; retcode = ReturnCode.Success, left, right)
56+
end
57+
58+
if iszero(fm)
59+
right = mid
60+
break
61+
end
62+
63+
if sign(fl) == sign(fm)
64+
fl = fm
65+
left = mid
66+
else
67+
fr = fm
68+
right = mid
69+
end
70+
71+
i += 1
72+
end
73+
74+
sol, i, left, right, fl, fr = Impl.bisection(
75+
left, right, fl, fr, f, abstol, maxiters - i, prob, alg)
76+
77+
sol !== nothing && return sol
78+
79+
return SciMLBase.build_solution(
80+
prob, alg, left, fl; retcode = ReturnCode.MaxIters, left, right)
81+
end
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""
2+
Brent()
3+
4+
Left non-allocating Brent method.
5+
"""
6+
struct Brent <: AbstractBracketingAlgorithm end
7+
8+
function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
9+
maxiters = 1000, abstol = nothing, kwargs...)
10+
@assert !SciMLBase.isinplace(prob) "`Brent` only supports out-of-place problems."
11+
12+
f = Base.Fix2(prob.f, prob.p)
13+
left, right = prob.tspan
14+
fl, fr = f(left), f(right)
15+
ϵ = eps(convert(typeof(fl), 1))
16+
17+
abstol = NonlinearSolveBase.get_tolerance(
18+
abstol, promote_type(eltype(left), eltype(right)))
19+
20+
if iszero(fl)
21+
return SciMLBase.build_solution(
22+
prob, alg, left, fl; retcode = ReturnCode.ExactSolutionLeft, left, right)
23+
end
24+
25+
if iszero(fr)
26+
return SciMLBase.build_solution(
27+
prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight, left, right)
28+
end
29+
30+
if abs(fl) < abs(fr)
31+
left, right = right, left
32+
fl, fr = fr, fl
33+
end
34+
35+
c = left
36+
d = c
37+
i = 1
38+
cond = true
39+
40+
while i < maxiters
41+
fc = f(c)
42+
43+
if fl != fc && fr != fc
44+
# Inverse quadratic interpolation
45+
s = left * fr * fc / ((fl - fr) * (fl - fc)) +
46+
right * fl * fc / ((fr - fl) * (fr - fc)) +
47+
c * fl * fr / ((fc - fl) * (fc - fr))
48+
else
49+
# Secant method
50+
s = right - fr * (right - left) / (fr - fl)
51+
end
52+
53+
if (s < min((3 * left + right) / 4, right) ||
54+
s > max((3 * left + right) / 4, right)) ||
55+
(cond && abs(s - right) abs(right - c) / 2) ||
56+
(!cond && abs(s - right) abs(c - d) / 2) ||
57+
(cond && abs(right - c) ϵ) ||
58+
(!cond && abs(c - d) ϵ)
59+
# Bisection method
60+
s = (left + right) / 2
61+
if s == left || s == right
62+
return SciMLBase.build_solution(prob, alg, left, fl;
63+
retcode = ReturnCode.FloatingPointLimit, left, right)
64+
end
65+
cond = true
66+
else
67+
cond = false
68+
end
69+
70+
fs = f(s)
71+
if abs((right - left) / 2) < abstol
72+
return SciMLBase.build_solution(prob, alg, s, fs;
73+
retcode = ReturnCode.Success, left, right)
74+
end
75+
76+
if iszero(fs)
77+
if right < left
78+
left = right
79+
fl = fr
80+
end
81+
right = s
82+
fr = fs
83+
break
84+
end
85+
86+
if fl * fs < 0
87+
d, c, right = c, right, s
88+
fr = fs
89+
else
90+
left = s
91+
fl = fs
92+
end
93+
94+
if abs(fl) < abs(fr)
95+
d = c
96+
c, right, left = right, left, c
97+
fc, fr, fl = fr, fl, fc
98+
end
99+
i += 1
100+
end
101+
102+
sol, i, left, right, fl, fr = Impl.bisection(
103+
left, right, fl, fr, f, abstol, maxiters - i, prob, alg)
104+
105+
sol !== nothing && return sol
106+
107+
return SciMLBase.build_solution(
108+
prob, alg, left, fl; retcode = ReturnCode.MaxIters, left, right)
109+
end

0 commit comments

Comments
 (0)