Skip to content

Commit d669bb8

Browse files
committed
test: add tests for the bracketing methods
1 parent 56c6d4d commit d669bb8

File tree

7 files changed

+115
-7
lines changed

7 files changed

+115
-7
lines changed

lib/BracketingNonlinearSolve/Project.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,11 @@ NonlinearSolveBase = "1"
1717
PrecompileTools = "1.2.1"
1818
SciMLBase = "2.50"
1919
julia = "1.10"
20+
21+
[extras]
22+
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
23+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
24+
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
25+
26+
[targets]
27+
test = ["InteractiveUtils", "Test", "TestItemRunner"]

lib/BracketingNonlinearSolve/src/BracketingNonlinearSolve.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module BracketingNonlinearSolve
22

33
using ConcreteStructs: @concrete
44

5-
using CommonSolve: CommonSolve
5+
using CommonSolve: CommonSolve, solve
66
using NonlinearSolveBase: NonlinearSolveBase
77
using SciMLBase: SciMLBase, AbstractNonlinearAlgorithm, IntervalNonlinearProblem, ReturnCode
88

@@ -33,6 +33,9 @@ include("ridder.jl")
3333
end
3434
end
3535

36+
export IntervalNonlinearProblem
37+
export solve
38+
3639
export Alefeld, Bisection, Brent, Falsi, ITP, Ridder
3740

3841
end

lib/BracketingNonlinearSolve/src/brent.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,10 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
9393

9494
if abs(fl) < abs(fr)
9595
d = c
96-
c, right, left = right, left, c
97-
fc, fr, fl = fr, fl, fc
96+
c, right = right, left
97+
left = c
98+
fc, fr = fr, fl
99+
fl = fc
98100
end
99101
i += 1
100102
end
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
@testsnippet RootfindingTestSnippet begin
2+
using NonlinearSolveBase, BracketingNonlinearSolve
3+
4+
quadratic_f(u, p) = u .* u .- p
5+
quadratic_f!(du, u, p) = (du .= u .* u .- p)
6+
quadratic_f2(u, p) = @. p[1] * u * u - p[2]
7+
end
8+
9+
@testitem "Interval Nonlinear Problems" setup=[RootfindingTestSnippet] tags=[:core] begin
10+
@testset for alg in (Bisection(), Falsi(), Ridder(), Brent(), ITP(), Alefeld())
11+
tspan = (1.0, 20.0)
12+
13+
function g(p)
14+
probN = IntervalNonlinearProblem{false}(quadratic_f, typeof(p).(tspan), p)
15+
return solve(probN, alg; abstol = 1e-9).left
16+
end
17+
18+
@testset for p in 1.1:0.1:100.0
19+
@test g(p)sqrt(p) atol=1e-3 rtol=1e-3
20+
# @test ForwardDiff.derivative(g, p)≈1 / (2 * sqrt(p)) atol=1e-3 rtol=1e-3
21+
end
22+
23+
t = (p) -> [sqrt(p[2] / p[1])]
24+
p = [0.9, 50.0]
25+
26+
function g2(p)
27+
probN = IntervalNonlinearProblem{false}(quadratic_f2, tspan, p)
28+
sol = solve(probN, alg; abstol = 1e-9)
29+
return [sol.u]
30+
end
31+
32+
@test g2(p)[sqrt(p[2] / p[1])] atol=1e-3 rtol=1e-3
33+
# @test ForwardDiff.jacobian(g2, p)≈ForwardDiff.jacobian(t, p) atol=1e-3 rtol=1e-3
34+
35+
probB = IntervalNonlinearProblem{false}(quadratic_f, (1.0, 2.0), 2.0)
36+
sol = solve(probB, alg; abstol = 1e-9)
37+
@test sol.leftsqrt(2.0) atol=1e-3 rtol=1e-3
38+
39+
if !(alg isa Bisection || alg isa Falsi)
40+
probB = IntervalNonlinearProblem{false}(quadratic_f, (sqrt(2.0), 10.0), 2.0)
41+
sol = solve(probB, alg; abstol = 1e-9)
42+
@test sol.leftsqrt(2.0) atol=1e-3 rtol=1e-3
43+
44+
probB = IntervalNonlinearProblem{false}(quadratic_f, (0.0, sqrt(2.0)), 2.0)
45+
sol = solve(probB, alg; abstol = 1e-9)
46+
@test sol.leftsqrt(2.0) atol=1e-3 rtol=1e-3
47+
end
48+
end
49+
end
50+
51+
@testitem "Tolerance Tests Interval Methods" setup=[RootfindingTestSnippet] tags=[:core] begin
52+
prob = IntervalNonlinearProblem(quadratic_f, (1.0, 20.0), 2.0)
53+
ϵ = eps(Float64) # least possible tol for all methods
54+
55+
@testset for alg in (Bisection(), Falsi(), ITP())
56+
@testset for abstol in [0.1, 0.01, 0.001, 0.0001, 1e-5, 1e-6, 1e-7]
57+
sol = solve(prob, alg; abstol)
58+
result_tol = abs(sol.u - sqrt(2))
59+
@test result_tol < abstol
60+
# test that the solution is not calculated upto max precision
61+
@test result_tol > ϵ
62+
end
63+
end
64+
65+
@testset for alg in (Ridder(), Brent())
66+
# Ridder and Brent converge rapidly so as we lower tolerance below 0.01, it
67+
# converges with max precision to the solution
68+
@testset for abstol in [0.1]
69+
sol = solve(prob, alg; abstol)
70+
result_tol = abs(sol.u - sqrt(2))
71+
@test result_tol < abstol
72+
# test that the solution is not calculated upto max precision
73+
@test result_tol > ϵ
74+
end
75+
end
76+
end
77+
78+
@testitem "Flipped Signs and Reversed Tspan" setup=[RootfindingTestSnippet] tags=[:core] begin
79+
@testset for alg in (Alefeld(), Bisection(), Falsi(), Brent(), ITP(), Ridder())
80+
f1(u, p) = u * u - p
81+
f2(u, p) = p - u * u
82+
83+
for p in 1:4
84+
inp1 = IntervalNonlinearProblem(f1, (1.0, 2.0), p)
85+
inp2 = IntervalNonlinearProblem(f2, (1.0, 2.0), p)
86+
inp3 = IntervalNonlinearProblem(f1, (2.0, 1.0), p)
87+
inp4 = IntervalNonlinearProblem(f2, (2.0, 1.0), p)
88+
@test abs.(solve(inp1, alg).u) sqrt.(p)
89+
@test abs.(solve(inp2, alg).u) sqrt.(p)
90+
@test abs.(solve(inp3, alg).u) sqrt.(p)
91+
@test abs.(solve(inp4, alg).u) sqrt.(p)
92+
end
93+
end
94+
end
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1+
using TestItemRunner, InteractiveUtils
12

3+
@info sprint(InteractiveUtils.versioninfo)
4+
5+
@run_package_tests

lib/NonlinearSolveBase/Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
1313
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
1414
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1515
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
16-
UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b"
1716

1817
[weakdeps]
1918
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -35,5 +34,4 @@ RecursiveArrayTools = "3"
3534
SciMLBase = "2.50"
3635
SparseArrays = "1.10"
3736
StaticArraysCore = "1.4"
38-
UnrolledUtilities = "0.1"
3937
julia = "1.10"

lib/NonlinearSolveBase/src/utils.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@ using ArrayInterface: ArrayInterface
44
using FastClosures: @closure
55
using LinearAlgebra: norm
66
using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition
7-
using UnrolledUtilities: unrolled_all
87

98
using ..NonlinearSolveBase: L2_NORM, Linf_NORM
109

11-
fast_scalar_indexing(xs...) = unrolled_all(ArrayInterface.fast_scalar_indexing, xs)
10+
fast_scalar_indexing(xs...) = all(ArrayInterface.fast_scalar_indexing, xs)
1211

1312
function nonallocating_isapprox(x::Number, y::Number; atol = false,
1413
rtol = atol > 0 ? false : sqrt(eps(promote_type(typeof(x), typeof(y)))))

0 commit comments

Comments
 (0)