Skip to content

Commit 8607b2a

Browse files
committed
adds pexprb4 impl in python
1 parent 8fbb892 commit 8607b2a

File tree

5 files changed

+206
-33
lines changed

5 files changed

+206
-33
lines changed

ormatex_py/ode_exp.py

Lines changed: 96 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535

3636
class ExpRBIntegrator(IntegrateSys):
3737

38-
_valid_methods = {"exprb2": 2, "exprb3": 3, "epi2": 2, "epi3": 3,
39-
"exprb2_dense": 2, "exprb2_pfd": 2,
38+
_valid_methods = {"exprb2": 2, "exprb3": 3, "pexprb4": 4, "epi2": 2, "epi3": 3,
39+
"exprb2_dense": 2, "exprb2_pfd": 2, "exp_pfd": 1,
4040
"exprb2_pfd_rs": 2, "exp_pfd_rs": 1}
4141

4242
def __init__(self, sys: OdeSys, t0: float, y0: jax.Array, method="epi2", **kwargs):
@@ -78,7 +78,7 @@ def _remf(self, tr: float, yr: jax.Array,
7878
jac_yd = sys_jac_lop_yt(yr - yt)
7979
return frhs_yr - frhs_yt - jac_yd - v*dt
8080

81-
def _phi2v_nonauto(self, sys_jac_lop, dt):
81+
def _phi2v_nonauto(self, sys_jac_lop, dt, c=1.0):
8282
r"""
8383
For rosenbrock exp integrators, this method computes
8484
the correction term for nonautonomous systems.
@@ -96,10 +96,10 @@ def _phi2v_nonauto(self, sys_jac_lop, dt):
9696
fytt = sys_jac_lop._fdt()
9797
# check for nonautonomous system
9898
if jnp.linalg.norm(fytt, ord=jax.numpy.inf) > self.tol_fdt:
99-
return (dt**2.)*phi_linop(sys_jac_lop, dt, fytt, 2, self.max_krylov_dim, self.iom), fytt
99+
return (c**2.)*(dt**2.)*phi_linop(sys_jac_lop, c*dt, fytt, 2, self.max_krylov_dim, self.iom), fytt
100100
return 0., 0.
101101

102-
def _step_exprb2(self, dt: float) -> StepResult:
102+
def _step_exprb2(self, dt: float, frhs_kwargs: dict) -> StepResult:
103103
"""
104104
Computes the solution update by:
105105
y_{t+1} = y_t + dt*\varphi_1(dt*J_t)F(t, y_t)+
@@ -109,7 +109,7 @@ def _step_exprb2(self, dt: float) -> StepResult:
109109
"""
110110
t = self.t
111111
yt = self.y_hist[0]
112-
sys_jac_lop = self.sys.fjac(t, yt)
112+
sys_jac_lop = self.sys.fjac(t, yt, frhs_kwargs=frhs_kwargs)
113113
fyt = sys_jac_lop._frhs_cached()
114114
phi2_v, _0 = self._phi2v_nonauto(sys_jac_lop, dt)
115115
y_new = yt \
@@ -181,19 +181,22 @@ def _step_epi3(self, dt: float, frhs_kwargs: dict) -> StepResult:
181181

182182
return StepResult(t+dt, dt, y_new, y_err)
183183

184-
def _step_exprb3(self, dt: float) -> StepResult:
185-
"""
184+
def _step_exprb3(self, dt: float, frhs_kwargs: dict) -> StepResult:
185+
r"""
186186
Computes the solution update by:
187-
y_{t+1} = y_t + dt*\varphi_1(dt*J_t)F(t, y_t) +
188-
2*dt*\varphi_3(dt*J_t)R_2 +
189-
dt**2*\varphi_2(dt*J_t)F'(t, y_t)
190187
191-
doi: https://doi.org/10.1137/080717717
188+
.. math::
189+
190+
y_{t+1} = y_t + dt*\varphi_1(dt*J_t)F(t, y_t) +
191+
2*dt*\varphi_3(dt*J_t)R_2 +
192+
dt**2*\varphi_2(dt*J_t)F'(t, y_t)
193+
194+
Ref: doi: https://doi.org/10.1137/080717717
192195
"""
193196
t = self.t
194197
yt = self.y_hist[0] # y_t
195198

196-
sys_jac_lop = self.sys.fjac(t, yt)
199+
sys_jac_lop = self.sys.fjac(t, yt, frhs_kwargs=frhs_kwargs)
197200
fyt = sys_jac_lop._frhs_cached()
198201

199202
# phi2_v is nonzero for nonautonomous systems
@@ -215,6 +218,62 @@ def _step_exprb3(self, dt: float) -> StepResult:
215218

216219
return StepResult(t+dt, dt, y_new, y_err)
217220

221+
def _step_pexprb4(self, dt: float, frhs_kwargs: dict) -> StepResult:
222+
r"""
223+
Computes the solution update by:
224+
225+
.. math::
226+
227+
U_{n2} = u_n + (0.5)dt\varphi_1((0.5)dt J_t)F(t, y_t) +
228+
0.5^2 dt^2 \varphi_2(0.5 dt*J_t)F'(t, y_t)
229+
U_{n3} = u_n + (1.0)dt\varphi_1((1.0)dt J_t)F(t, y_t) +
230+
dt^2 \varphi_2(dt*J_t)F'(t, y_t)
231+
232+
y_{t+1} = y_t + dt*\varphi_1(dt*J_t)F(t, y_t) +
233+
dt^2 \varphi_2(dt*J_t)F'(t, y_t) +
234+
dt [16 \varphi_3(dt J_t) - 48 \varphi_4(dt J_t)] R_2(U_{n2}) +
235+
dt [-2 \varphi_3(dt J_t) + 12 \varphi_4(dt J_t)] R_3(U_{n3})
236+
237+
Where $`U_{n2}`$ and $`U_{n3}`$ can be computed in parallel.
238+
239+
Ref: V. T. Luan. and A. Ostermann. Parallel Exponential rosenbrock methods.
240+
Computers and Mathmatics with Applications, v71. 2016.
241+
"""
242+
t = self.t
243+
yt = self.y_hist[0] # y_t
244+
245+
sys_jac_lop = self.sys.fjac(t, yt, frhs_kwargs=frhs_kwargs)
246+
fyt = sys_jac_lop._frhs_cached()
247+
248+
# butcher tableau coeffs
249+
c_2 = 0.5
250+
c_3 = 1.0
251+
252+
# compute U_{n2}
253+
t_2 = t + c_2*dt
254+
phi2_v_2, v_2 = self._phi2v_nonauto(sys_jac_lop, dt, c=c_2)
255+
y_2 = yt \
256+
+ c_2*dt*phi_linop(sys_jac_lop, c_2*dt, fyt, 1, self.max_krylov_dim, self.iom) \
257+
+ phi2_v_2
258+
r_2 = self._remf(t_2, y_2, fyt, sys_jac_lop, v=v_2)
259+
260+
# compute U_{n3}
261+
t_3 = t + c_3*dt
262+
phi2_v, v = self._phi2v_nonauto(sys_jac_lop, dt, c=c_3)
263+
y_3 = yt \
264+
+ c_3*dt*phi_linop(sys_jac_lop, c_3*dt, fyt, 1, self.max_krylov_dim, self.iom) \
265+
+ phi2_v
266+
r_3 = self._remf(t_3, y_3, fyt, sys_jac_lop, v=v)
267+
268+
# compute final update
269+
vb0 = jnp.zeros(yt.shape)
270+
b2_b3 = kiops_fixedsteps(
271+
sys_jac_lop, dt, [vb0, vb0, vb0, dt*(16.0*r_2-2.0*r_3), dt*(-48.0*r_2+12.0*r_3)],
272+
max_krylov_dim=self.max_krylov_dim, iom=self.iom)
273+
y_new = y_3 + b2_b3
274+
275+
y_err = -1.0
276+
return StepResult(t+dt, dt, y_new, y_err)
218277

219278
@jax.jit
220279
def _step_exprb2_jit(t, yt, dt, sys):
@@ -323,22 +382,43 @@ def _step_exp_pfd_rs(self, dt: float) -> StepResult:
323382
y_err = -1.
324383
return StepResult(t+dt, dt, y_new, y_err)
325384

385+
def _step_exp_pfd(self, dt: float, frhs_kwargs: dict) -> StepResult:
386+
r"""
387+
Computes the solution update by:
388+
y_{t+1} = \varphi_0(dt*L)*y0
389+
where L is a dense matrix and computing varphi
390+
using cauchy contour integral approach with quadrature rule
391+
NOTE: Only useful for pure linear systems
392+
"""
393+
t = self.t
394+
yt = self.y_hist[0]
395+
J = np.asarray(self.sys.fjac(t, yt, frhs_kwargs=frhs_kwargs).dense())
396+
397+
phi0J_yt = f_phi_k_pfd(J*dt, yt, 0, self.pfd_method)
398+
y_new = jnp.asarray(phi0J_yt.flatten())
399+
y_err = -1.
400+
return StepResult(t+dt, dt, y_new, y_err)
401+
326402
def step(self, dt: float, frhs_kwargs: dict={}) -> StepResult:
327403
if self.method == "exprb2":
328-
return self._step_exprb2(dt)
404+
return self._step_exprb2(dt, frhs_kwargs)
329405
elif self.method == "exprb3":
330-
return self._step_exprb3(dt)
406+
return self._step_exprb3(dt, frhs_kwargs)
407+
elif self.method == "pexprb4":
408+
return self._step_pexprb4(dt, frhs_kwargs)
331409
elif self.method == "epi2":
332410
return self._step_epi2(dt, frhs_kwargs)
333411
elif self.method == "epi3":
334412
if len(self.y_hist) >= 2:
335413
return self._step_epi3(dt, frhs_kwargs)
336414
else:
337-
return self._step_epi2(dt, frhs_kwargs)
415+
return self._step_exprb3(dt, frhs_kwargs)
338416
elif self.method == "exprb2_dense":
339417
return self._step_exprb2_dense(dt)
340418
elif self.method == "exprb2_pfd":
341419
return self._step_exprb2_pfd(dt, frhs_kwargs)
420+
elif self.method == "exp_pfd":
421+
return self._step_exp_pfd(dt, frhs_kwargs)
342422
elif self.method == "exprb2_pfd_rs" and HAS_ORMATEX_RUST:
343423
return self._step_exprb2_pfd_rs(dt)
344424
elif self.method == "exp_pfd_rs" and HAS_ORMATEX_RUST:

ormatex_py/ode_sys.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,8 @@ class AdJacLinOp(SysJacLinOp):
258258

259259
def __init__(self, t, u, frhs: eqx.Module, frhs_kwargs: dict={}, **kwargs):
260260
super().__init__(t, u, frhs, frhs_kwargs, **kwargs)
261-
self._frhs_u, self._fjac_u = jax.linearize(partial(self._frhs, **self._frhs_kwargs), self._t, self._u)
261+
self._frhs_u, self._fjac_u = jax.linearize(
262+
partial(self._frhs, **self._frhs_kwargs), self._t, self._u)
262263

263264
@jax.jit
264265
def _matvec(self, v: jax.Array) -> jax.Array:

ormatex_py/progression/bateman_sys.py

Lines changed: 79 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import jax
55
import numpy as np
66
from jax import numpy as jnp
7+
from itertools import cycle
78

89
from ormatex_py import integrate_wrapper
910
from ormatex_py.ode_sys import OdeSys, OdeSplitSys, MatrixLinOp, CustomJacLinOp
@@ -173,7 +174,7 @@ def analytic_bateman_single_parent(t, batmat, n0):
173174
return np.asarray(N, dtype=np.float64)
174175

175176

176-
def analytic_bateman_s3(method="epi2", do_plot=True, dt=10.0, tf=1000., pfd_method="cram_16"):
177+
def analytic_bateman_s3(method="epi2", do_plot=True, dt=10.0, tf=1000., pfd_method="cram_16", true_analytic=False):
177178
jax.config.update("jax_enable_x64", True)
178179
keymap = ["c_0", "c_1", "c_2"]
179180
decay_lib_sp = {
@@ -185,15 +186,31 @@ def analytic_bateman_s3(method="epi2", do_plot=True, dt=10.0, tf=1000., pfd_meth
185186
n0 = 1.0
186187
t0 = 0.0
187188
t = np.arange(t0, tf+dt, dt)
189+
# test_ode_sys = TestBatemanSysJac(keymap, decay_lib_sp)
190+
test_ode_sys = TestBatemanSysNonlinFeed(keymap, decay_lib_sp)
191+
y0 = jnp.array([n0, 0.0, 0.0])
192+
188193
# analytic result
189-
y_true = analytic_bateman_single_parent(t, bmat, n0)
194+
if true_analytic:
195+
y_true = analytic_bateman_single_parent(t, bmat, n0)
196+
else:
197+
# use a fine timestep size with exprb3 integrator
198+
dt_fine = 0.001
199+
nsteps_fine = int((tf - t0) / dt_fine)
200+
fine_res = integrate_wrapper.integrate(test_ode_sys, y0, t0, dt_fine, nsteps_fine, 'dopri5', tol_fdt=1e-25)
201+
t_fine = np.asarray(fine_res.t_res)
202+
y_fine = np.asarray(fine_res.y_res)
203+
y_true = np.array([
204+
np.interp(t, t_fine, y_fine[:, 0]),
205+
np.interp(t, t_fine, y_fine[:, 1]),
206+
np.interp(t, t_fine, y_fine[:, 2])
207+
]).T
190208

191209
# compute numerical result
192-
test_ode_sys = TestBatemanSysJac(keymap, decay_lib_sp)
193-
y0 = jnp.array([n0, 0.0, 0.0])
194210
nsteps = int((tf - t0) / dt)
195211
res = integrate_wrapper.integrate(
196-
test_ode_sys, y0, t0, dt, nsteps, method, max_krylov_dim=12, iom=12, pfd_method=pfd_method)
212+
test_ode_sys, y0, t0, dt, nsteps, method, pfd_method=pfd_method, tol_fdt=1e-25,
213+
max_krylov_dim=100, iom=100)
197214
t_res, y_res = res.t_res, res.y_res
198215
t_res = np.asarray(t_res)
199216
y_res = np.asarray(y_res)
@@ -203,7 +220,7 @@ def analytic_bateman_s3(method="epi2", do_plot=True, dt=10.0, tf=1000., pfd_meth
203220
plt.figure()
204221
plt.xscale('log')
205222
plt.yscale('log')
206-
plt.ylim((1e-8, 10.0))
223+
plt.ylim((1e-8, 80.0))
207224
plt.xlim((1.0, tf))
208225
# numerical
209226
plt.plot(t_res+1.0, y_res[:, 0]+1e-16, label="c_0")
@@ -225,9 +242,14 @@ def analytic_bateman_s3(method="epi2", do_plot=True, dt=10.0, tf=1000., pfd_meth
225242

226243

227244
def run_sweep():
228-
methods = ["epi2", "epi3", "exprb3", "exp2_dense", "exp3_dense",
229-
"exprb2_dense", "exprb2_pfd_rs", "exp_pfd_rs",
230-
"implicit_euler", "implicit_esdirk3", "implicit_esdirk4"]
245+
methods = ["epi2_leja_re", "epi3", "exprb3", "exprb2_rs", "exprb3_rs", "epi3_rs",
246+
"exprb2_pfd", "exp_pfd",
247+
"implicit_euler", "implicit_esdirk3",
248+
]
249+
methods = ["exprb2", "exprb3", "exprb3_rs", "epi3", "epi3_rs",
250+
"implicit_euler", "implicit_esdirk3",
251+
]
252+
methods = ["pexprb4", "exprb3", "exprb2"]
231253
dts = [1., 2., 5., 10., 25., 50.]
232254
tf = 100.
233255
nspecies = 3
@@ -250,10 +272,13 @@ def run_sweep():
250272

251273
# error vs time step size for each method
252274
plt.figure()
275+
markers = ['x', 'p']
276+
marker_cycler = cycle(markers)
253277
for method in methods:
254278
dt = mae_dict[method][:, 0]
255-
s3_err = mae_dict[method][:, 2]
256-
plt.plot(dt, s3_err, '-o', label=method)
279+
s3_err = mae_dict[method][:, 3]
280+
mk = next(marker_cycler)
281+
plt.plot(dt, s3_err, marker=mk, ls='--', label=method, alpha=0.75)
257282
plt.yscale("log")
258283
plt.xscale("log")
259284
plt.grid(ls='--')
@@ -265,8 +290,49 @@ def run_sweep():
265290
plt.savefig("bateman_ex_1_converg.png")
266291
for method in methods:
267292
for dt in [10., 25.]:
268-
analytic_bateman_s3(method, dt=dt, tf=500., do_plot=True)
293+
analytic_bateman_s3(method, dt=dt, tf=200., do_plot=True)
294+
295+
class TestBatemanSysNonlinFeed(OdeSys):
296+
bat_mat: jax.Array
297+
feed_period: float
298+
feed_scale: float
299+
300+
def __init__(self, keymap, decay_lib, *args, **kwargs):
301+
bmat = gen_bateman_matrix(keymap, decay_lib)
302+
print("Bateman test system to solve:")
303+
print(bmat)
304+
self.bat_mat = bmat
305+
self.feed_period = 0.2e-1
306+
self.feed_scale = 1.0e-2
307+
super().__init__()
269308

309+
def _nonlin_feed(self, t, u, **kwargs):
310+
# linear-in-time feed example
311+
# feed_rate = jnp.clip(1e-6*t, 0.0, 100.0)
312+
# quadratic-in-time feed example
313+
# feed_rate = -1e-7*(t-500.)**2+1e-7*500**2
314+
# feed_rate = jnp.clip(-1e-7*(t-200.)**2+1e-7*500**2, 0.0, 100.0)
315+
# cubic-in-time feed example
316+
feed_rate = jnp.clip(2e-10*(t-200)**3-1e-7*(t-200.)**2+1e-7*500**2, 0.0, 100.0)
317+
fdr = jnp.zeros_like(u)
318+
fdr = fdr.at[0].set(feed_rate)
319+
return fdr
320+
321+
def _nonlin_sink(self, t, u, **kwargs):
322+
# quadratic-in-time sink example
323+
sink_rate = jnp.clip(-1e-6*((t-100)**2)+0.01, 0.0, 1.0)
324+
# linear-in-time sink example
325+
# sink_rate = jnp.clip(1e-5*t, 0.0, 1.0)
326+
s = jnp.zeros_like(u)
327+
s = s.at[2].set(-sink_rate)
328+
return s
329+
330+
@jax.jit
331+
def _frhs(self, t: float, u: jax.Array, **kwargs) -> jax.Array:
332+
# res = self.bat_mat @ u + self._nonlin_sink(t, u, **kwargs)
333+
# res = self.bat_mat @ u + self._nonlin_feed(t, u, **kwargs)
334+
res = self.bat_mat @ u + self._nonlin_sink(t, u, **kwargs) + self._nonlin_feed(t, u, **kwargs)
335+
return res
270336

271337
class TestBatemanSysJac(OdeSplitSys):
272338
"""
@@ -331,7 +397,7 @@ def _fl(self, t: float, u: jax.Array, **kwargs):
331397
plt.figure()
332398
plt.xscale('log')
333399
plt.yscale('log')
334-
plt.ylim((1e-14, 10.0))
400+
plt.ylim((1e-14, 80.0))
335401
plt.plot(t_res, y_res[:, 0], label="c_0")
336402
plt.plot(t_res, y_res[:, 1], label="c_1")
337403
plt.plot(t_res, y_res[:, 2], label="c_2")

ormatex_py/progression/lotka_volterra.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def gen_sys(autonomous=True):
158158
nsteps = int(t_end/dt)
159159

160160
res = integrate_wrapper.integrate(
161-
lv_sys, y0, t0, dt, nsteps, method=method, max_krylov_dim=10, iom=5, tol_fdt=0)
161+
lv_sys, y0, t0, dt, nsteps, method=method, max_krylov_dim=20, iom=20, tol_fdt=0)
162162
t_res, y_res = res.t_res, res.y_res
163163
t_res = jnp.asarray(t_res)
164164
y_res = jnp.asarray(y_res)
@@ -222,7 +222,7 @@ def gen_sys(autonomous=True):
222222

223223

224224
def sweep_methods(autonomous=False):
225-
methods = ["epi3", "exprb2", "exprb3", "exp3_dense", "exp2_dense", "implicit_euler", "implicit_esdirk3"]
225+
methods = ["epi3", "exprb2", "exprb3", "pexprb4", "implicit_euler", "implicit_esdirk3"]
226226
plt.figure()
227227
for method in methods:
228228
err_dt, err, lb = main(method, do_plot=False, autonomous=autonomous)

0 commit comments

Comments
 (0)