65
65
force_stop:: Bool
66
66
maxiters:: Int
67
67
internalnorm
68
+ u0
69
+ u0_aliased
70
+ alias_u0:: Bool
68
71
end
69
72
70
73
function Base. show (
@@ -91,11 +94,24 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
91
94
@eval begin
92
95
function SciMLBase. __init (
93
96
prob:: $probType , alg:: $algType{N} , args... ; maxtime = nothing ,
94
- maxiters = 1000 , internalnorm = DEFAULT_NORM, kwargs... ) where {N}
97
+ maxiters = 1000 , internalnorm = DEFAULT_NORM,
98
+ alias_u0 = false , verbose = true , kwargs... ) where {N}
99
+ if (alias_u0 && ! ismutable (prob. u0))
100
+ verbose && @warn " `alias_u0` has been set to `true`, but `u0` is \
101
+ immutable (checked using `ArrayInterface.ismutable`)."
102
+ alias_u0 = false # If immutable don't care about aliasing
103
+ end
104
+ u0 = prob. u0
105
+ if alias_u0
106
+ u0_aliased = copy (u0)
107
+ else
108
+ u0_aliased = u0 # Irrelevant
109
+ end
110
+ alias_u0 && (prob = remake (prob; u0 = u0_aliased))
95
111
return NonlinearSolvePolyAlgorithmCache {isinplace(prob), N, maxtime !== nothing} (
96
112
map (
97
- solver -> SciMLBase. __init (
98
- prob, solver, args ... ; maxtime, internalnorm , kwargs... ),
113
+ solver -> SciMLBase. __init (prob, solver, args ... ; maxtime,
114
+ internalnorm, alias_u0, verbose , kwargs... ),
99
115
alg. algs),
100
116
alg,
101
117
- 1 ,
@@ -106,7 +122,10 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
106
122
ReturnCode. Default,
107
123
false ,
108
124
maxiters,
109
- internalnorm)
125
+ internalnorm,
126
+ u0,
127
+ u0_aliased,
128
+ alias_u0)
110
129
end
111
130
end
112
131
end
@@ -120,20 +139,30 @@ end
120
139
121
140
cache_syms = [gensym (" cache" ) for i in 1 : N]
122
141
sol_syms = [gensym (" sol" ) for i in 1 : N]
142
+ u_result_syms = [gensym (" u_result" ) for i in 1 : N]
123
143
for i in 1 : N
124
144
push! (calls,
125
145
quote
126
146
$ (cache_syms[i]) = cache. caches[$ (i)]
127
147
if $ (i) == cache. current
148
+ cache. alias_u0 && copyto! (cache. u0_aliased, cache. u0)
128
149
$ (sol_syms[i]) = SciMLBase. solve! ($ (cache_syms[i]))
129
150
if SciMLBase. successful_retcode ($ (sol_syms[i]))
130
151
stats = $ (sol_syms[i]). stats
131
- u = $ (sol_syms[i]). u
152
+ if cache. alias_u0
153
+ copyto! (cache. u0, $ (sol_syms[i]). u)
154
+ $ (u_result_syms[i]) = cache. u0
155
+ else
156
+ $ (u_result_syms[i]) = $ (sol_syms[i]). u
157
+ end
132
158
fu = get_fu ($ (cache_syms[i]))
133
159
return SciMLBase. build_solution (
134
- $ (sol_syms[i]). prob, cache. alg, u, fu;
135
- retcode = $ (sol_syms[i]). retcode, stats,
160
+ $ (sol_syms[i]). prob, cache. alg, $ (u_result_syms[i]),
161
+ fu; retcode = $ (sol_syms[i]). retcode, stats,
136
162
original = $ (sol_syms[i]), trace = $ (sol_syms[i]). trace)
163
+ elseif cache. alias_u0
164
+ # For safety we need to maintain a copy of the solution
165
+ $ (u_result_syms[i]) = copy ($ (sol_syms[i]). u)
137
166
end
138
167
cache. current = $ (i + 1 )
139
168
end
@@ -144,14 +173,29 @@ end
144
173
for (sym, resid) in zip (cache_syms, resids)
145
174
push! (calls, :($ (resid) = @isdefined ($ (sym)) ? get_fu ($ (sym)) : nothing ))
146
175
end
176
+ push! (calls, quote
177
+ fus = tuple ($ (Tuple (resids)... ))
178
+ minfu, idx = __findmin (cache. internalnorm, fus)
179
+ stats = __compile_stats (cache. caches[idx])
180
+ end )
181
+ for i in 1 : N
182
+ push! (calls, quote
183
+ if idx == $ (i)
184
+ if cache. alias_u0
185
+ u = $ (u_result_syms[i])
186
+ else
187
+ u = get_u (cache. caches[$ i])
188
+ end
189
+ end
190
+ end )
191
+ end
147
192
push! (calls,
148
193
quote
149
- fus = tuple ($ (Tuple (resids)... ))
150
- minfu, idx = __findmin (cache. internalnorm, fus)
151
- stats = __compile_stats (cache. caches[idx])
152
- u = get_u (cache. caches[idx])
153
194
retcode = cache. caches[idx]. retcode
154
-
195
+ if cache. alias_u0
196
+ copyto! (cache. u0, u)
197
+ u = cache. u0
198
+ end
155
199
return SciMLBase. build_solution (cache. caches[idx]. prob, cache. alg, u, fus[idx];
156
200
retcode, stats, cache. caches[idx]. trace)
157
201
end )
@@ -200,22 +244,52 @@ end
200
244
for (probType, pType) in ((:NonlinearProblem , :NLS ), (:NonlinearLeastSquaresProblem , :NLLS ))
201
245
algType = NonlinearSolvePolyAlgorithm{pType}
202
246
@eval begin
203
- @generated function SciMLBase. __solve (
204
- prob:: $probType , alg:: $algType{N} , args... ; kwargs... ) where {N}
205
- calls = [:(current = alg. start_index)]
247
+ @generated function SciMLBase. __solve (prob:: $probType , alg:: $algType{N} , args... ;
248
+ alias_u0 = false , verbose = true , kwargs... ) where {N}
206
249
sol_syms = [gensym (" sol" ) for _ in 1 : N]
250
+ prob_syms = [gensym (" prob" ) for _ in 1 : N]
251
+ u_result_syms = [gensym (" u_result" ) for _ in 1 : N]
252
+ calls = [quote
253
+ current = alg. start_index
254
+ if (alias_u0 && ! ismutable (prob. u0))
255
+ verbose && @warn " `alias_u0` has been set to `true`, but `u0` is \
256
+ immutable (checked using `ArrayInterface.ismutable`)."
257
+ alias_u0 = false # If immutable don't care about aliasing
258
+ end
259
+ u0 = prob. u0
260
+ if alias_u0
261
+ u0_aliased = similar (u0)
262
+ else
263
+ u0_aliased = u0 # Irrelevant
264
+ end
265
+ end ]
207
266
for i in 1 : N
208
267
cur_sol = sol_syms[i]
209
268
push! (calls,
210
269
quote
211
270
if current == $ i
212
- $ (cur_sol) = SciMLBase. __solve (
213
- prob, alg. algs[$ (i)], args... ; kwargs... )
271
+ if alias_u0
272
+ copyto! (u0_aliased, u0)
273
+ $ (prob_syms[i]) = remake (prob; u0 = u0_aliased)
274
+ else
275
+ $ (prob_syms[i]) = prob
276
+ end
277
+ $ (cur_sol) = SciMLBase. __solve ($ (prob_syms[i]), alg. algs[$ (i)],
278
+ args... ; alias_u0, verbose, kwargs... )
214
279
if SciMLBase. successful_retcode ($ (cur_sol))
280
+ if alias_u0
281
+ copyto! (u0, $ (cur_sol). u)
282
+ $ (u_result_syms[i]) = u0
283
+ else
284
+ $ (u_result_syms[i]) = $ (cur_sol). u
285
+ end
215
286
return SciMLBase. build_solution (
216
- prob, alg, $ (cur_sol) . u , $ (cur_sol). resid;
287
+ prob, alg, $ (u_result_syms[i]) , $ (cur_sol). resid;
217
288
$ (cur_sol). retcode, $ (cur_sol). stats,
218
289
original = $ (cur_sol), trace = $ (cur_sol). trace)
290
+ elseif alias_u0
291
+ # For safety we need to maintain a copy of the solution
292
+ $ (u_result_syms[i]) = copy ($ (cur_sol). u)
219
293
end
220
294
current = $ (i + 1 )
221
295
end
@@ -236,9 +310,16 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
236
310
push! (calls,
237
311
quote
238
312
if idx == $ i
239
- return SciMLBase. build_solution (prob, alg, $ (sol_syms[i]). u,
240
- $ (sol_syms[i]). resid; $ (sol_syms[i]). retcode,
241
- $ (sol_syms[i]). stats, $ (sol_syms[i]). trace)
313
+ if alias_u0
314
+ copyto! (u0, $ (u_result_syms[i]))
315
+ $ (u_result_syms[i]) = u0
316
+ else
317
+ $ (u_result_syms[i]) = $ (sol_syms[i]). u
318
+ end
319
+ return SciMLBase. build_solution (
320
+ prob, alg, $ (u_result_syms[i]), $ (sol_syms[i]). resid;
321
+ $ (sol_syms[i]). retcode, $ (sol_syms[i]). stats,
322
+ $ (sol_syms[i]). trace, original = $ (sol_syms[i]))
242
323
end
243
324
end )
244
325
end
0 commit comments