Skip to content

Commit 7de7ab8

Browse files
committed
Fix bug in advection diffusion benchmark when computing adjoint solution with Robin boundary conditions.
1 parent bff552d commit 7de7ab8

File tree

4 files changed

+47
-18
lines changed

4 files changed

+47
-18
lines changed

Diff for: pyapprox/benchmarks/pde_benchmarks.py

+28-5
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ def raw_advection_diffusion_reaction_kle_dRdp(kle, residual, sol, param_vals):
8989
kle_vals = kle(param_vals[:, None])
9090
assert kle_vals.ndim == 2
9191
dkdp = kle_vals*kle.eig_vecs
92+
else:
93+
dkdp = kle.eig_vecs
9294
Du = [torch.linalg.multi_dot((dmats[dd], sol))
9395
for dd in range(mesh.nphys_vars)]
9496
kDu = [Du[dd][:, None]*dkdp for dd in range(mesh.nphys_vars)]
@@ -98,10 +100,32 @@ def raw_advection_diffusion_reaction_kle_dRdp(kle, residual, sol, param_vals):
98100

99101

100102
def advection_diffusion_reaction_kle_dRdp(
101-
bndry_indices, kle, residual, sol, param_vals):
103+
mesh, kle, bndry_conds, residual, sol, param_vals):
102104
dRdp = raw_advection_diffusion_reaction_kle_dRdp(
103105
kle, residual, sol, param_vals)
104-
dRdp[np.hstack(bndry_indices)] = 0.0
106+
for ii, bndry_cond in enumerate(bndry_conds):
107+
idx = mesh._bndry_indices[ii]
108+
if bndry_cond[1] == "D":
109+
dRdp[idx] = 0.0
110+
elif bndry_cond[1] == "R":
111+
mesh_pts_idx = mesh._bndry_slice(mesh.mesh_pts, idx, 1)
112+
normal_vals = mesh._bndrys[ii].normals(mesh_pts_idx)
113+
if kle.use_log:
114+
kle_vals = kle(param_vals[:, None])
115+
dkdp = kle_vals*kle.eig_vecs
116+
else:
117+
dkdp = torch.as_tensor(kle.eig_vecs)
118+
flux_vals = [
119+
(torch.linalg.multi_dot(
120+
(mesh._bndry_slice(mesh._dmat(dd), idx, 0), sol))[:, None]
121+
* mesh._bndry_slice(dkdp, idx, 0))
122+
for dd in range(mesh.nphys_vars)]
123+
flux_normal_vals = [
124+
normal_vals[:, dd:dd+1]*flux_vals[dd]
125+
for dd in range(mesh.nphys_vars)]
126+
dRdp[idx] = sum(flux_normal_vals)
127+
else:
128+
raise NotImplementedError()
105129
return dRdp
106130

107131

@@ -121,8 +145,7 @@ def __init__(self, mesh, bndry_conds, kle, vel_fun, react_funs, forc_fun,
121145
import inspect
122146
if "mesh" == inspect.getfullargspec(functional).args[0]:
123147
if "physics" == inspect.getfullargspec(functional).args[1]:
124-
functional = partial(
125-
functional, mesh, self._fwd_solver.physics)
148+
functional = partial(functional, mesh, self._fwd_solver.physics)
126149
else:
127150
functional = partial(functional, mesh)
128151
for ii in range(len(functional_deriv_funs)):
@@ -141,7 +164,7 @@ def __init__(self, mesh, bndry_conds, kle, vel_fun, react_funs, forc_fun,
141164
if issubclass(type(self._fwd_solver), SteadyStatePDE):
142165
dqdu, dqdp = functional_deriv_funs
143166
dRdp = partial(advection_diffusion_reaction_kle_dRdp,
144-
mesh._bndry_indices, self._kle)
167+
mesh, self._kle, self._fwd_solver.physics._bndry_conds)
145168
# dRdp must be after boundary conditions are applied.
146169
# For now assume that parameters do not effect boundary conditions
147170
# so dRdp at boundary indices is zero

Diff for: pyapprox/benchmarks/tests/test_pde_benchmarks.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -68,21 +68,26 @@ def set_random_sample(physics, sample):
6868
physics._diff_fun = partial(
6969
inv_model.base_model._fast_interpolate,
7070
inv_model.base_model._kle(sample[:, None]))
71-
sample.requires_grad_= True
71+
sample.requires_grad_ = True
7272
dRdp_ad = torch.autograd.functional.jacobian(
7373
partial(inv_model.base_model._adj_solver._parameterized_residual,
7474
fwd_sol, set_random_sample),
7575
sample, strict=True)
7676
dRdp = inv_model.base_model._adj_solver._dRdp(
7777
inv_model.base_model._fwd_solver.physics, fwd_sol.clone(),
7878
sample)
79+
print(dRdp)
80+
print(dRdp_ad)
7981
assert np.allclose(dRdp, dRdp_ad)
8082

8183
# TODO add std to params list
8284
init_guess = true_params + np.random.normal(0, 0.01, true_params.shape)
8385
# init_guess = sample.numpy()[:, None]
84-
85-
# from pyapprox.util.utilities import approx_fprime
86+
# from pyapprox.util.utilities import approx_fprime, approx_jacobian
87+
# fd_jac = approx_jacobian(
88+
# lambda p: inv_model.base_model._adj_solver._parameterized_residual(
89+
# fwd_sol, set_random_sample, torch.as_tensor(p)), sample)
90+
# print(fd_jac)
8691
# print(approx_fprime(init_guess, inv_model), 'fd')
8792
# print(inv_model(init_guess, return_grad=True), 'g')
8893
# assert False

Diff for: pyapprox/pde/autopde/mesh.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -400,10 +400,10 @@ def __init__(self, orders, basis_types=None):
400400
self._flux_islinear = False
401401
self._flux_normal_vals = [None for dd in range(2*self.nphys_vars)]
402402
self._normal_vals = [None for dd in range(2*self.nphys_vars)]
403-
403+
404404
def _clear_flux_normal_vals(self):
405405
self._flux_normal_vals = [None for dd in range(2*self.nphys_vars)]
406-
406+
407407
@staticmethod
408408
def _get_basis_types(nphys_vars, basis_types):
409409
if basis_types is None:
@@ -693,7 +693,7 @@ def _dmat(self, dd):
693693
def _apply_dirichlet_boundary_conditions_special_indexing(
694694
self, bndry_conds, residual, jac, sol):
695695
# special indexing copies data which slows down function
696-
696+
697697
# needs to have indices as argument so this fucntion can be used
698698
# when setting boundary conditions for forward and adjoint solves
699699

@@ -716,12 +716,11 @@ def _apply_dirichlet_boundary_conditions_special_indexing(
716716
residual[idx] = sol[idx]-bndry_vals
717717
return residual, jac
718718

719-
720719
@staticmethod
721720
def _bndry_slice(vec, idx, axis):
722721
# avoid copying data
723722
if len(idx) == 1:
724-
if axis == 0 :
723+
if axis == 0:
725724
return vec[idx]
726725
return vec[:, idx]
727726
stride = idx[1]-idx[0]
@@ -755,7 +754,7 @@ def _apply_neumann_and_robin_boundary_conditions(
755754
if bndry_cond[1] != "N" and bndry_cond[1] != "R":
756755
continue
757756
idx = self._bndry_indices[ii]
758-
mesh_pts_idx = self._bndry_slice(self.mesh_pts, idx, 1)
757+
mesh_pts_idx = self._bndry_slice(self.mesh_pts, idx, 1)
759758
if self._normal_vals[ii] is None:
760759
self._normal_vals[ii] = self._bndrys[ii].normals(mesh_pts_idx)
761760
if not self._flux_islinear or self._flux_normal_vals[ii] is None:
@@ -770,9 +769,10 @@ def _apply_neumann_and_robin_boundary_conditions(
770769
# (D2*u)*n2+D2*u*n2
771770
jac[idx] = sum(flux_normal_vals)
772771
bndry_vals = bndry_cond[0](mesh_pts_idx)[:, 0]
773-
772+
774773
# residual[idx] = torch.linalg.multi_dot((jac[idx], sol))-bndry_vals
775-
residual[idx] = torch.linalg.multi_dot((self._bndry_slice(jac, idx, 0), sol))-bndry_vals
774+
residual[idx] = torch.linalg.multi_dot(
775+
(self._bndry_slice(jac, idx, 0), sol))-bndry_vals
776776
if bndry_cond[1] == "R":
777777
jac[idx, idx] += bndry_cond[2]
778778
# residual[idx] += bndry_cond[2]*sol[idx]
@@ -813,7 +813,7 @@ def _apply_boundary_conditions(self, bndry_conds, residual, jac, sol,
813813
class TransformedCollocationMesh(CanonicalCollocationMesh):
814814
# TODO need to changes weights of _get_quadrature_rule to account
815815
# for any scaling transformations
816-
816+
817817
def __init__(self, orders, transform, basis_types=None):
818818

819819
super().__init__(orders, basis_types)

Diff for: pyapprox/pde/autopde/physics.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ def _transient_residual(self, sol, time):
8484
return res, jac
8585

8686
def _scalar_flux_jac(self, mesh, idx):
87-
return [mesh._bndry_slice(mesh._dmat(dd), idx, 0) for dd in range(mesh.nphys_vars)]
87+
return [mesh._bndry_slice(mesh._dmat(dd), idx, 0)
88+
for dd in range(mesh.nphys_vars)]
8889

8990
def _clear_data(self):
9091
# used for data that is the same for entire transient simulation

0 commit comments

Comments
 (0)