3535
3636class ExpRBIntegrator (IntegrateSys ):
3737
38- _valid_methods = {"exprb2" : 2 , "exprb3" : 3 , "epi2" : 2 , "epi3" : 3 ,
39- "exprb2_dense" : 2 , "exprb2_pfd" : 2 ,
38+ _valid_methods = {"exprb2" : 2 , "exprb3" : 3 , "pexprb4" : 4 , " epi2" : 2 , "epi3" : 3 ,
39+ "exprb2_dense" : 2 , "exprb2_pfd" : 2 , "exp_pfd" : 1 ,
4040 "exprb2_pfd_rs" : 2 , "exp_pfd_rs" : 1 }
4141
4242 def __init__ (self , sys : OdeSys , t0 : float , y0 : jax .Array , method = "epi2" , ** kwargs ):
@@ -78,7 +78,7 @@ def _remf(self, tr: float, yr: jax.Array,
7878 jac_yd = sys_jac_lop_yt (yr - yt )
7979 return frhs_yr - frhs_yt - jac_yd - v * dt
8080
81- def _phi2v_nonauto (self , sys_jac_lop , dt ):
81+ def _phi2v_nonauto (self , sys_jac_lop , dt , c = 1.0 ):
8282 r"""
8383 For rosenbrock exp integrators, this method computes
8484 the correction term for nonautonomous systems.
@@ -96,10 +96,10 @@ def _phi2v_nonauto(self, sys_jac_lop, dt):
9696 fytt = sys_jac_lop ._fdt ()
9797 # check for nonautonomous system
9898 if jnp .linalg .norm (fytt , ord = jax .numpy .inf ) > self .tol_fdt :
99- return (dt ** 2. )* phi_linop (sys_jac_lop , dt , fytt , 2 , self .max_krylov_dim , self .iom ), fytt
99+ return (c ** 2. ) * ( dt ** 2. )* phi_linop (sys_jac_lop , c * dt , fytt , 2 , self .max_krylov_dim , self .iom ), fytt
100100 return 0. , 0.
101101
102- def _step_exprb2 (self , dt : float ) -> StepResult :
102+ def _step_exprb2 (self , dt : float , frhs_kwargs : dict ) -> StepResult :
103103 """
104104 Computes the solution update by:
105105 y_{t+1} = y_t + dt*\v arphi_1(dt*J_t)F(t, y_t)+
@@ -109,7 +109,7 @@ def _step_exprb2(self, dt: float) -> StepResult:
109109 """
110110 t = self .t
111111 yt = self .y_hist [0 ]
112- sys_jac_lop = self .sys .fjac (t , yt )
112+ sys_jac_lop = self .sys .fjac (t , yt , frhs_kwargs = frhs_kwargs )
113113 fyt = sys_jac_lop ._frhs_cached ()
114114 phi2_v , _0 = self ._phi2v_nonauto (sys_jac_lop , dt )
115115 y_new = yt \
@@ -181,19 +181,22 @@ def _step_epi3(self, dt: float, frhs_kwargs: dict) -> StepResult:
181181
182182 return StepResult (t + dt , dt , y_new , y_err )
183183
184- def _step_exprb3 (self , dt : float ) -> StepResult :
185- """
184+ def _step_exprb3 (self , dt : float , frhs_kwargs : dict ) -> StepResult :
185+ r """
186186 Computes the solution update by:
187- y_{t+1} = y_t + dt*\v arphi_1(dt*J_t)F(t, y_t) +
188- 2*dt*\v arphi_3(dt*J_t)R_2 +
189- dt**2*\v arphi_2(dt*J_t)F'(t, y_t)
190187
191- doi: https://doi.org/10.1137/080717717
188+ .. math::
189+
190+ y_{t+1} = y_t + dt*\varphi_1(dt*J_t)F(t, y_t) +
191+ 2*dt*\varphi_3(dt*J_t)R_2 +
192+ dt**2*\varphi_2(dt*J_t)F'(t, y_t)
193+
194+ Ref: doi: https://doi.org/10.1137/080717717
192195 """
193196 t = self .t
194197 yt = self .y_hist [0 ] # y_t
195198
196- sys_jac_lop = self .sys .fjac (t , yt )
199+ sys_jac_lop = self .sys .fjac (t , yt , frhs_kwargs = frhs_kwargs )
197200 fyt = sys_jac_lop ._frhs_cached ()
198201
199202 # phi2_v is nonzero for nonautonomous systems
@@ -215,6 +218,62 @@ def _step_exprb3(self, dt: float) -> StepResult:
215218
216219 return StepResult (t + dt , dt , y_new , y_err )
217220
221+ def _step_pexprb4 (self , dt : float , frhs_kwargs : dict ) -> StepResult :
222+ r"""
223+ Computes the solution update by:
224+
225+ .. math::
226+
227+ U_{n2} = u_n + (0.5)dt\varphi_1((0.5)dt J_t)F(t, y_t) +
228+ 0.5^2 dt^2 \varphi_2(0.5 dt*J_t)F'(t, y_t)
229+ U_{n3} = u_n + (1.0)dt\varphi_1((1.0)dt J_t)F(t, y_t) +
230+ dt^2 \varphi_2(dt*J_t)F'(t, y_t)
231+
232+ y_{t+1} = y_t + dt*\varphi_1(dt*J_t)F(t, y_t) +
233+ dt^2 \varphi_2(dt*J_t)F'(t, y_t) +
234+ dt [16 \varphi_3(dt J_t) - 48 \varphi_4(dt J_t)] R_2(U_{n2}) +
235+ dt [-2 \varphi_3(dt J_t) + 12 \varphi_4(dt J_t)] R_3(U_{n3})
236+
237+ Where $`U_{n2}`$ and $`U_{n3}`$ can be computed in parallel.
238+
239+ Ref: V. T. Luan. and A. Ostermann. Parallel Exponential rosenbrock methods.
240+ Computers and Mathmatics with Applications, v71. 2016.
241+ """
242+ t = self .t
243+ yt = self .y_hist [0 ] # y_t
244+
245+ sys_jac_lop = self .sys .fjac (t , yt , frhs_kwargs = frhs_kwargs )
246+ fyt = sys_jac_lop ._frhs_cached ()
247+
248+ # butcher tableau coeffs
249+ c_2 = 0.5
250+ c_3 = 1.0
251+
252+ # compute U_{n2}
253+ t_2 = t + c_2 * dt
254+ phi2_v_2 , v_2 = self ._phi2v_nonauto (sys_jac_lop , dt , c = c_2 )
255+ y_2 = yt \
256+ + c_2 * dt * phi_linop (sys_jac_lop , c_2 * dt , fyt , 1 , self .max_krylov_dim , self .iom ) \
257+ + phi2_v_2
258+ r_2 = self ._remf (t_2 , y_2 , fyt , sys_jac_lop , v = v_2 )
259+
260+ # compute U_{n3}
261+ t_3 = t + c_3 * dt
262+ phi2_v , v = self ._phi2v_nonauto (sys_jac_lop , dt , c = c_3 )
263+ y_3 = yt \
264+ + c_3 * dt * phi_linop (sys_jac_lop , c_3 * dt , fyt , 1 , self .max_krylov_dim , self .iom ) \
265+ + phi2_v
266+ r_3 = self ._remf (t_3 , y_3 , fyt , sys_jac_lop , v = v )
267+
268+ # compute final update
269+ vb0 = jnp .zeros (yt .shape )
270+ b2_b3 = kiops_fixedsteps (
271+ sys_jac_lop , dt , [vb0 , vb0 , vb0 , dt * (16.0 * r_2 - 2.0 * r_3 ), dt * (- 48.0 * r_2 + 12.0 * r_3 )],
272+ max_krylov_dim = self .max_krylov_dim , iom = self .iom )
273+ y_new = y_3 + b2_b3
274+
275+ y_err = - 1.0
276+ return StepResult (t + dt , dt , y_new , y_err )
218277
219278 @jax .jit
220279 def _step_exprb2_jit (t , yt , dt , sys ):
@@ -323,22 +382,43 @@ def _step_exp_pfd_rs(self, dt: float) -> StepResult:
323382 y_err = - 1.
324383 return StepResult (t + dt , dt , y_new , y_err )
325384
385+ def _step_exp_pfd (self , dt : float , frhs_kwargs : dict ) -> StepResult :
386+ r"""
387+ Computes the solution update by:
388+ y_{t+1} = \varphi_0(dt*L)*y0
389+ where L is a dense matrix and computing varphi
390+ using cauchy contour integral approach with quadrature rule
391+ NOTE: Only useful for pure linear systems
392+ """
393+ t = self .t
394+ yt = self .y_hist [0 ]
395+ J = np .asarray (self .sys .fjac (t , yt , frhs_kwargs = frhs_kwargs ).dense ())
396+
397+ phi0J_yt = f_phi_k_pfd (J * dt , yt , 0 , self .pfd_method )
398+ y_new = jnp .asarray (phi0J_yt .flatten ())
399+ y_err = - 1.
400+ return StepResult (t + dt , dt , y_new , y_err )
401+
326402 def step (self , dt : float , frhs_kwargs : dict = {}) -> StepResult :
327403 if self .method == "exprb2" :
328- return self ._step_exprb2 (dt )
404+ return self ._step_exprb2 (dt , frhs_kwargs )
329405 elif self .method == "exprb3" :
330- return self ._step_exprb3 (dt )
406+ return self ._step_exprb3 (dt , frhs_kwargs )
407+ elif self .method == "pexprb4" :
408+ return self ._step_pexprb4 (dt , frhs_kwargs )
331409 elif self .method == "epi2" :
332410 return self ._step_epi2 (dt , frhs_kwargs )
333411 elif self .method == "epi3" :
334412 if len (self .y_hist ) >= 2 :
335413 return self ._step_epi3 (dt , frhs_kwargs )
336414 else :
337- return self ._step_epi2 (dt , frhs_kwargs )
415+ return self ._step_exprb3 (dt , frhs_kwargs )
338416 elif self .method == "exprb2_dense" :
339417 return self ._step_exprb2_dense (dt )
340418 elif self .method == "exprb2_pfd" :
341419 return self ._step_exprb2_pfd (dt , frhs_kwargs )
420+ elif self .method == "exp_pfd" :
421+ return self ._step_exp_pfd (dt , frhs_kwargs )
342422 elif self .method == "exprb2_pfd_rs" and HAS_ORMATEX_RUST :
343423 return self ._step_exprb2_pfd_rs (dt )
344424 elif self .method == "exp_pfd_rs" and HAS_ORMATEX_RUST :
0 commit comments