Skip to content

Commit 3925219

Browse files
Merge pull request #392 from SciML/ap/reliable_aliasing
Handle polyalgorithm aliasing correctly
2 parents e1c0528 + 45e0716 commit 3925219

File tree

5 files changed

+137
-29
lines changed

5 files changed

+137
-29
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "3.8.0"
4+
version = "3.8.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/NonlinearSolve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_work
1212
LinearAlgebra, LinearSolve, MaybeInplace, Preferences, Printf, SciMLBase,
1313
SimpleNonlinearSolve, SparseArrays, SparseDiffTools
1414

15-
import ArrayInterface: undefmatrix, can_setindex, restructure, fast_scalar_indexing
15+
import ArrayInterface: undefmatrix, can_setindex, restructure, fast_scalar_indexing,
16+
ismutable
1617
import DiffEqBase: AbstractNonlinearTerminationMode,
1718
AbstractSafeNonlinearTerminationMode,
1819
AbstractSafeBestNonlinearTerminationMode,

src/default.jl

Lines changed: 102 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ end
6565
force_stop::Bool
6666
maxiters::Int
6767
internalnorm
68+
u0
69+
u0_aliased
70+
alias_u0::Bool
6871
end
6972

7073
function Base.show(
@@ -91,11 +94,24 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
9194
@eval begin
9295
function SciMLBase.__init(
9396
prob::$probType, alg::$algType{N}, args...; maxtime = nothing,
94-
maxiters = 1000, internalnorm = DEFAULT_NORM, kwargs...) where {N}
97+
maxiters = 1000, internalnorm = DEFAULT_NORM,
98+
alias_u0 = false, verbose = true, kwargs...) where {N}
99+
if (alias_u0 && !ismutable(prob.u0))
100+
verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \
101+
immutable (checked using `ArrayInterface.ismutable`)."
102+
alias_u0 = false # If immutable don't care about aliasing
103+
end
104+
u0 = prob.u0
105+
if alias_u0
106+
u0_aliased = copy(u0)
107+
else
108+
u0_aliased = u0 # Irrelevant
109+
end
110+
alias_u0 && (prob = remake(prob; u0 = u0_aliased))
95111
return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N, maxtime !== nothing}(
96112
map(
97-
solver -> SciMLBase.__init(
98-
prob, solver, args...; maxtime, internalnorm, kwargs...),
113+
solver -> SciMLBase.__init(prob, solver, args...; maxtime,
114+
internalnorm, alias_u0, verbose, kwargs...),
99115
alg.algs),
100116
alg,
101117
-1,
@@ -106,7 +122,10 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
106122
ReturnCode.Default,
107123
false,
108124
maxiters,
109-
internalnorm)
125+
internalnorm,
126+
u0,
127+
u0_aliased,
128+
alias_u0)
110129
end
111130
end
112131
end
@@ -120,20 +139,30 @@ end
120139

121140
cache_syms = [gensym("cache") for i in 1:N]
122141
sol_syms = [gensym("sol") for i in 1:N]
142+
u_result_syms = [gensym("u_result") for i in 1:N]
123143
for i in 1:N
124144
push!(calls,
125145
quote
126146
$(cache_syms[i]) = cache.caches[$(i)]
127147
if $(i) == cache.current
148+
cache.alias_u0 && copyto!(cache.u0_aliased, cache.u0)
128149
$(sol_syms[i]) = SciMLBase.solve!($(cache_syms[i]))
129150
if SciMLBase.successful_retcode($(sol_syms[i]))
130151
stats = $(sol_syms[i]).stats
131-
u = $(sol_syms[i]).u
152+
if cache.alias_u0
153+
copyto!(cache.u0, $(sol_syms[i]).u)
154+
$(u_result_syms[i]) = cache.u0
155+
else
156+
$(u_result_syms[i]) = $(sol_syms[i]).u
157+
end
132158
fu = get_fu($(cache_syms[i]))
133159
return SciMLBase.build_solution(
134-
$(sol_syms[i]).prob, cache.alg, u, fu;
135-
retcode = $(sol_syms[i]).retcode, stats,
160+
$(sol_syms[i]).prob, cache.alg, $(u_result_syms[i]),
161+
fu; retcode = $(sol_syms[i]).retcode, stats,
136162
original = $(sol_syms[i]), trace = $(sol_syms[i]).trace)
163+
elseif cache.alias_u0
164+
# For safety we need to maintain a copy of the solution
165+
$(u_result_syms[i]) = copy($(sol_syms[i]).u)
137166
end
138167
cache.current = $(i + 1)
139168
end
@@ -144,14 +173,29 @@ end
144173
for (sym, resid) in zip(cache_syms, resids)
145174
push!(calls, :($(resid) = @isdefined($(sym)) ? get_fu($(sym)) : nothing))
146175
end
176+
push!(calls, quote
177+
fus = tuple($(Tuple(resids)...))
178+
minfu, idx = __findmin(cache.internalnorm, fus)
179+
stats = __compile_stats(cache.caches[idx])
180+
end)
181+
for i in 1:N
182+
push!(calls, quote
183+
if idx == $(i)
184+
if cache.alias_u0
185+
u = $(u_result_syms[i])
186+
else
187+
u = get_u(cache.caches[$i])
188+
end
189+
end
190+
end)
191+
end
147192
push!(calls,
148193
quote
149-
fus = tuple($(Tuple(resids)...))
150-
minfu, idx = __findmin(cache.internalnorm, fus)
151-
stats = __compile_stats(cache.caches[idx])
152-
u = get_u(cache.caches[idx])
153194
retcode = cache.caches[idx].retcode
154-
195+
if cache.alias_u0
196+
copyto!(cache.u0, u)
197+
u = cache.u0
198+
end
155199
return SciMLBase.build_solution(cache.caches[idx].prob, cache.alg, u, fus[idx];
156200
retcode, stats, cache.caches[idx].trace)
157201
end)
@@ -200,22 +244,52 @@ end
200244
for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProblem, :NLLS))
201245
algType = NonlinearSolvePolyAlgorithm{pType}
202246
@eval begin
203-
@generated function SciMLBase.__solve(
204-
prob::$probType, alg::$algType{N}, args...; kwargs...) where {N}
205-
calls = [:(current = alg.start_index)]
247+
@generated function SciMLBase.__solve(prob::$probType, alg::$algType{N}, args...;
248+
alias_u0 = false, verbose = true, kwargs...) where {N}
206249
sol_syms = [gensym("sol") for _ in 1:N]
250+
prob_syms = [gensym("prob") for _ in 1:N]
251+
u_result_syms = [gensym("u_result") for _ in 1:N]
252+
calls = [quote
253+
current = alg.start_index
254+
if (alias_u0 && !ismutable(prob.u0))
255+
verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \
256+
immutable (checked using `ArrayInterface.ismutable`)."
257+
alias_u0 = false # If immutable don't care about aliasing
258+
end
259+
u0 = prob.u0
260+
if alias_u0
261+
u0_aliased = similar(u0)
262+
else
263+
u0_aliased = u0 # Irrelevant
264+
end
265+
end]
207266
for i in 1:N
208267
cur_sol = sol_syms[i]
209268
push!(calls,
210269
quote
211270
if current == $i
212-
$(cur_sol) = SciMLBase.__solve(
213-
prob, alg.algs[$(i)], args...; kwargs...)
271+
if alias_u0
272+
copyto!(u0_aliased, u0)
273+
$(prob_syms[i]) = remake(prob; u0 = u0_aliased)
274+
else
275+
$(prob_syms[i]) = prob
276+
end
277+
$(cur_sol) = SciMLBase.__solve($(prob_syms[i]), alg.algs[$(i)],
278+
args...; alias_u0, verbose, kwargs...)
214279
if SciMLBase.successful_retcode($(cur_sol))
280+
if alias_u0
281+
copyto!(u0, $(cur_sol).u)
282+
$(u_result_syms[i]) = u0
283+
else
284+
$(u_result_syms[i]) = $(cur_sol).u
285+
end
215286
return SciMLBase.build_solution(
216-
prob, alg, $(cur_sol).u, $(cur_sol).resid;
287+
prob, alg, $(u_result_syms[i]), $(cur_sol).resid;
217288
$(cur_sol).retcode, $(cur_sol).stats,
218289
original = $(cur_sol), trace = $(cur_sol).trace)
290+
elseif alias_u0
291+
# For safety we need to maintain a copy of the solution
292+
$(u_result_syms[i]) = copy($(cur_sol).u)
219293
end
220294
current = $(i + 1)
221295
end
@@ -236,9 +310,16 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
236310
push!(calls,
237311
quote
238312
if idx == $i
239-
return SciMLBase.build_solution(prob, alg, $(sol_syms[i]).u,
240-
$(sol_syms[i]).resid; $(sol_syms[i]).retcode,
241-
$(sol_syms[i]).stats, $(sol_syms[i]).trace)
313+
if alias_u0
314+
copyto!(u0, $(u_result_syms[i]))
315+
$(u_result_syms[i]) = u0
316+
else
317+
$(u_result_syms[i]) = $(sol_syms[i]).u
318+
end
319+
return SciMLBase.build_solution(
320+
prob, alg, $(u_result_syms[i]), $(sol_syms[i]).resid;
321+
$(sol_syms[i]).retcode, $(sol_syms[i]).stats,
322+
$(sol_syms[i]).trace, original = $(sol_syms[i]))
242323
end
243324
end)
244325
end

src/utils.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,16 @@ LazyArrays.applied_axes(::typeof(__zero), x) = axes(x)
9494
@inline __is_complex(::Type{Complex}) = true
9595
@inline __is_complex(::Type{T}) where {T} = false
9696

97-
function __findmin_caches(f, caches)
98-
return __findmin(f get_fu, caches)
99-
end
100-
function __findmin(f, x)
101-
return findmin(x) do xᵢ
97+
@inline __findmin_caches(f, caches) = __findmin(f get_fu, caches)
98+
# FIXME: DEFAULT_NORM makes an Array of NaNs not a NaN (atleast according to `isnan`)
99+
@inline __findmin(::typeof(DEFAULT_NORM), x) = __findmin(Base.Fix1(maximum, abs), x)
100+
@inline function __findmin(f, x)
101+
fmin = @closure xᵢ -> begin
102102
xᵢ === nothing && return Inf
103103
fx = f(xᵢ)
104-
return isnan(fx) ? Inf : fx
104+
return ifelse(isnan(fx), Inf, fx)
105105
end
106+
return findmin(fmin, x)
106107
end
107108

108109
@inline __can_setindex(x) = can_setindex(x)

test/misc/aliasing_tests.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
@testitem "PolyAlgorithm Aliasing" begin
2+
using NonlinearProblemLibrary
3+
4+
# Use a problem that the initial solvers cannot solve and cause the initial value to
5+
# diverge. If we don't alias correctly, all the subsequent algorithms will also fail.
6+
prob = NonlinearProblemLibrary.nlprob_23_testcases["Generalized Rosenbrock function"].prob
7+
u0 = copy(prob.u0)
8+
prob = remake(prob; u0 = copy(u0))
9+
10+
# If aliasing is not handled properly this will diverge
11+
sol = solve(prob; abstol = 1e-6, alias_u0 = true,
12+
termination_condition = AbsNormTerminationMode())
13+
14+
@test sol.u === prob.u0
15+
@test SciMLBase.successful_retcode(sol.retcode)
16+
17+
prob = remake(prob; u0 = copy(u0))
18+
19+
cache = init(prob; abstol = 1e-6, alias_u0 = true,
20+
termination_condition = AbsNormTerminationMode())
21+
sol = solve!(cache)
22+
23+
@test sol.u === prob.u0
24+
@test SciMLBase.successful_retcode(sol.retcode)
25+
end

0 commit comments

Comments
 (0)