Skip to content

Commit e69e5d0

Browse files
committed
Add example for constructing flexible water model
1 parent cec4b5c commit e69e5d0

File tree

6 files changed

+1700
-0
lines changed

6 files changed

+1700
-0
lines changed

examples/water-model/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.pt

examples/water-model/README.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Empircal water models
2+
=====================
3+
4+
This example shows implemenetations of three and four point flexible water models and
5+
uses them to run molecular dynamics simulations of water.

examples/water-model/_model.py

Lines changed: 387 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,387 @@
1+
from typing import Dict, List, Optional
2+
3+
import torch
4+
from metatensor.torch import Labels, TensorBlock, TensorMap, sum_over_samples
5+
from metatensor.torch.atomistic import ModelOutput, NeighborListOptions, System
6+
from torchpme import Calculator, CoulombPotential, P3MCalculator
7+
8+
9+
def lennard_jones_pair(
10+
distances: torch.Tensor,
11+
sigma: torch.Tensor,
12+
epsilon: torch.Tensor,
13+
cutoff: torch.Tensor,
14+
):
15+
"""Lennard-Jones potential for pair terms."""
16+
c6 = (sigma**2 / distances**2) ** 3
17+
c12 = c6**2
18+
lj = 4 * epsilon * (c12 - c6)
19+
20+
cutoff = 1 / cutoff
21+
offset = 4 * epsilon * (sigma**12 * cutoff**12 - sigma**6 * cutoff**6)
22+
23+
return lj - offset
24+
25+
26+
def harmonic_distance_pair(
27+
distances: torch.Tensor,
28+
coefficient: torch.Tensor,
29+
equilibrium_distance: torch.Tensor,
30+
):
31+
"""Harmonic potential for bond terms."""
32+
r2 = (distances - equilibrium_distance) ** 2
33+
34+
return 0.5 * coefficient * r2
35+
36+
37+
def harmonic_angular(
38+
angles: torch.Tensor, coefficient: torch.Tensor, equilibrium_angle: torch.Tensor
39+
):
40+
"""Harmonic potential for angular terms."""
41+
return 0.5 * coefficient * (angles - equilibrium_angle) ** 2
42+
43+
44+
def compute_angles(positions: torch.Tensor, neighbor_indices: torch.Tensor):
45+
"""Compute the angles formed by triplet of atoms based on their positions."""
46+
atom_is = neighbor_indices[:, 0]
47+
atom_js = neighbor_indices[:, 1]
48+
atom_ks = neighbor_indices[:, 2]
49+
50+
pos_is = positions[atom_is]
51+
pos_js = positions[atom_js]
52+
pos_ks = positions[atom_ks]
53+
54+
R_ij = pos_js - pos_is
55+
R_ik = pos_ks - pos_is
56+
57+
return angle(R_ij, R_ik)
58+
59+
60+
def angle(a: torch.Tensor, b: torch.Tensor, dim: int = -1):
61+
"""Compute the angle between two vectors a and b.
62+
63+
Code is taken from https://github.com/pytorch/pytorch/issues/59194
64+
"""
65+
a_norm = a.norm(p=2, dim=dim, keepdim=True)
66+
b_norm = a.norm(p=2, dim=dim, keepdim=True)
67+
angles = 2 * torch.atan2(
68+
(a * b_norm - a_norm * b).norm(p=2, dim=dim),
69+
(a * b_norm + a_norm * b).norm(p=2, dim=dim),
70+
)
71+
72+
return angles
73+
74+
75+
class WaterModel(torch.nn.Module):
76+
def __init__(
77+
self,
78+
cutoff: float,
79+
O_sigma: float,
80+
O_epsilon: float,
81+
O_charge: float,
82+
OH_bond_coefficient: float,
83+
OH_equilibrium_distance: float,
84+
HOH_angle_coefficient: float,
85+
HOH_equilibrium_angle: float,
86+
pme_smearing: float,
87+
pme_mesh_spacing: float,
88+
pme_interpolation_nodes: int = 4,
89+
pme_prefactor: float = 1,
90+
four_point_model: bool = False,
91+
dtype: Optional[float] = None,
92+
):
93+
"""
94+
Flexible water model for three and four point models.
95+
96+
The model contains Lennard Jones interactions between the oxygens as well as
97+
intra molecular bond and angle terms. The electrostatics are computed using the
98+
P3M method. For a four point model the fourth side for the charge interaction is
99+
computed implicitly based on the position of the other atoms.
100+
101+
:param cutoff: Cutoff for the Lennard-Jones interactions.
102+
:param O_sigma: Sigma parameter for the oxygen Lennard-Jones interactions.
103+
:param O_epsilon: Epsilon parameter for the oxygen Lennard-Jones interactions.
104+
:param O_charge: Oxygen's atom charge; hydrogen is computed accordingly.
105+
:param OH_bond_coefficient: Harmonic coefficient for the OH bond.
106+
:param OH_equilibrium_distance: Equilibrium distance for the OH bond.
107+
:param HOH_angle_coefficient: Harmonic coefficient for the HOH angle.
108+
:param HOH_equilibrium_angle: Equilibrium angle for the HOH angle in degrees.
109+
:param pme_smearing: Smearing parameter for the PME.
110+
:param pme_mesh_spacing: Mesh spacing for the PME.
111+
:param pme_interpolation_nodes: Number of interpolation nodes for the PME.
112+
:param pme_prefactor: Prefactor for the PME.
113+
:param four_point_model: If :py:obj:`True`, use the four-point model for the
114+
electrostatics. The fourth point M is implicitly derived from the other
115+
atoms or each water molecule and used during the force computation. See
116+
10.1063/1.3167790 for details on its derivation.
117+
:param dtype: Floating point precision for the model. If :py:obj:`None`, the
118+
:param dtype: default
119+
dtype is used.
120+
"""
121+
super().__init__()
122+
123+
self.dtype = dtype if dtype is not None else torch.get_default_dtype()
124+
self.four_point_model = four_point_model
125+
126+
self.register_buffer("cutoff", torch.tensor(cutoff, dtype=self.dtype))
127+
self.register_buffer("O_sigma", torch.tensor(O_sigma, dtype=self.dtype))
128+
self.register_buffer("O_epsilon", torch.tensor(O_epsilon, dtype=self.dtype))
129+
self.register_buffer(
130+
"OH_bond_coefficient", torch.tensor(OH_bond_coefficient, dtype=self.dtype)
131+
)
132+
self.register_buffer(
133+
"OH_equilibrium_distance",
134+
torch.tensor(OH_equilibrium_distance, dtype=self.dtype),
135+
)
136+
self.register_buffer(
137+
"HOH_angle_coefficient",
138+
torch.tensor(HOH_angle_coefficient, dtype=self.dtype),
139+
)
140+
141+
# Convert degree angle to radians
142+
self.register_buffer(
143+
"HOH_equilibrium_angle",
144+
torch.tensor(HOH_equilibrium_angle * torch.pi / 180, dtype=self.dtype),
145+
)
146+
147+
# Register charges for water model
148+
H_charge = -O_charge / 2
149+
self.register_buffer(
150+
"OHH_charges",
151+
torch.tensor([O_charge, H_charge, H_charge], dtype=self.dtype),
152+
)
153+
154+
self.pme_calculator = P3MCalculator(
155+
potential=CoulombPotential(pme_smearing),
156+
mesh_spacing=pme_mesh_spacing,
157+
interpolation_nodes=pme_interpolation_nodes,
158+
prefactor=pme_prefactor,
159+
)
160+
self.coulomb_calculator = Calculator(
161+
CoulombPotential(), prefactor=pme_prefactor
162+
)
163+
164+
self.nl = NeighborListOptions(cutoff=cutoff, full_list=False, strict=False)
165+
166+
def requested_neighbor_lists(self):
167+
return [self.nl]
168+
169+
def _setup_systems(
170+
self,
171+
systems: list[System],
172+
selected_atoms: Optional[Labels] = None,
173+
) -> tuple[System, TensorBlock]:
174+
"""Remove possible ghost atoms and add charges to the system."""
175+
if len(systems) > 1:
176+
raise ValueError(f"only one system supported, got {len(systems)}")
177+
178+
system_i = 0
179+
system = systems[system_i]
180+
181+
# select only real atoms and discard ghosts
182+
if selected_atoms is not None:
183+
current_system_mask = selected_atoms.column("system") == system_i
184+
current_atoms = selected_atoms.column("atom")
185+
current_atoms = current_atoms[current_system_mask].to(torch.long)
186+
187+
types = system.types[current_atoms]
188+
positions = system.positions[current_atoms]
189+
else:
190+
types = system.types
191+
positions = system.positions
192+
193+
system_final = System(types, positions, system.cell, system.pbc)
194+
195+
return system_final, system.get_neighbor_list(self.nl)
196+
197+
def forward(
198+
self,
199+
systems: List[System], # noqa
200+
outputs: Dict[str, ModelOutput], # noqa
201+
selected_atoms: Optional[Labels] = None,
202+
) -> Dict[str, TensorMap]: # noqa
203+
"""
204+
Compute the energy of the water model.
205+
206+
Water molecules have to be in order OHH and whole across the system.
207+
"""
208+
if list(outputs.keys()) != ["energy"]:
209+
raise ValueError(
210+
f"`outputs` keys ({', '.join(outputs.keys())}) contain unsupported "
211+
"keys. Only 'energy' is supported."
212+
)
213+
214+
system, neighbors = self._setup_systems(systems, selected_atoms)
215+
species = system.types
216+
217+
if system.positions.dtype != self.dtype:
218+
raise ValueError(
219+
f"system.positions.dtype ({system.positions.dtype}) must be "
220+
f"equal to dtype at initilization ({self.dtype})"
221+
)
222+
223+
device = system.positions.device
224+
n_atoms = len(system)
225+
226+
neighbor_indices = neighbors.samples.view(["first_atom", "second_atom"]).values
227+
228+
if device == "cpu":
229+
# move data to 64-bit integers, for some reason indexing with 64-bit
230+
# is a lot faster than using 32-bit integers on CPU. CUDA seems fine
231+
# with either types
232+
neighbor_indices = neighbor_indices.to(
233+
torch.int64, memory_format=torch.contiguous_format
234+
)
235+
236+
neighbor_distances = torch.linalg.norm(neighbors.values, dim=1).squeeze(1)
237+
238+
# Verify that system only contains water molecules in the correct order
239+
if n_atoms % 3 != 0:
240+
raise ValueError(
241+
"system must be water containing a multiple of 3 atoms. "
242+
f"Found {n_atoms} atoms!"
243+
)
244+
245+
reference_types = torch.tensor(
246+
[8, 1, 1], dtype=self.dtype, device=device
247+
).repeat(n_atoms // 3)
248+
if not torch.all(system.types == reference_types):
249+
raise ValueError(
250+
"system must contain only water molecules in the order OHH"
251+
)
252+
253+
energies = torch.zeros(n_atoms, dtype=self.dtype, device=device)
254+
###################
255+
# O-O Lennard-Jones
256+
###################
257+
i = neighbor_indices[:, 0]
258+
j = neighbor_indices[:, 1]
259+
lj_mask = (species[i] == 8) & (species[j] == 8)
260+
lj_neighbor_indices = neighbor_indices[lj_mask]
261+
lj_neighbor_distances = neighbor_distances[lj_mask]
262+
263+
lj = lennard_jones_pair(
264+
distances=lj_neighbor_distances,
265+
sigma=self.O_sigma,
266+
epsilon=self.O_epsilon,
267+
cutoff=self.cutoff,
268+
)
269+
270+
energies.index_add_(0, lj_neighbor_indices[:, 0], lj)
271+
energies.index_add_(0, lj_neighbor_indices[:, 1], lj)
272+
273+
##########
274+
# O-H bond
275+
##########
276+
# select pairs within the same molecule
277+
mol_mask = (i // 3) == (j // 3)
278+
279+
bond_mask = mol_mask & (species[i] == 8)
280+
bond_neighbor_indices = neighbor_indices[bond_mask]
281+
bond_neighbor_distances = neighbor_distances[bond_mask]
282+
283+
cell_dimensions = torch.linalg.norm(system.cell, dim=1)
284+
min_dimension = float(torch.min(cell_dimensions))
285+
half_cell = min_dimension / 2.0
286+
287+
if torch.any(bond_neighbor_distances > half_cell):
288+
raise ValueError(
289+
"Bond distances are larger than half of the cell size. "
290+
"Most likely molecules are not whole."
291+
"This is not supported by the model."
292+
)
293+
294+
bond = harmonic_distance_pair(
295+
distances=bond_neighbor_distances,
296+
coefficient=self.OH_bond_coefficient,
297+
equilibrium_distance=self.OH_equilibrium_distance,
298+
)
299+
300+
energies.index_add_(0, bond_neighbor_indices[:, 0], bond)
301+
energies.index_add_(0, bond_neighbor_indices[:, 1], bond)
302+
303+
#############
304+
# H-O-H angle
305+
#############
306+
all_idx = torch.arange(n_atoms, device=device)
307+
angle_indices = torch.vstack([all_idx[0::3], all_idx[1::3], all_idx[2::3]]).T
308+
angle_values = compute_angles(system.positions, angle_indices)
309+
310+
angles = harmonic_angular(
311+
angles=angle_values,
312+
coefficient=self.HOH_angle_coefficient,
313+
equilibrium_angle=self.HOH_equilibrium_angle,
314+
)
315+
316+
energies.index_add_(0, angle_indices[:, 0], angles)
317+
energies.index_add_(0, angle_indices[:, 1], angles)
318+
energies.index_add_(0, angle_indices[:, 2], angles)
319+
320+
################
321+
# Electrostatics
322+
################
323+
324+
# fourth point is computed according to eq. 2 in 10.1063/1.3167790
325+
if self.four_point_model:
326+
positions_coul = torch.vstack(
327+
[
328+
(
329+
(system.positions[1::3] + system.positions[2::3]) * 0.5
330+
+ system.positions[0::3] * 3
331+
)
332+
/ 4,
333+
system.positions[1::3],
334+
system.positions[2::3],
335+
]
336+
)
337+
else:
338+
positions_coul = system.positions
339+
340+
charges = self.OHH_charges.tile((n_atoms // 3,)).unsqueeze(-1)
341+
342+
# all to all interactions
343+
potential = self.pme_calculator(
344+
positions=positions_coul,
345+
cell=system.cell,
346+
charges=charges,
347+
neighbor_indices=neighbor_indices,
348+
neighbor_distances=neighbor_distances,
349+
)
350+
351+
potential_exclusion = self.coulomb_calculator(
352+
positions=positions_coul,
353+
cell=system.cell,
354+
charges=charges,
355+
neighbor_indices=bond_neighbor_indices,
356+
neighbor_distances=bond_neighbor_distances,
357+
)
358+
359+
potential -= potential_exclusion
360+
energies += (potential * charges).flatten()
361+
362+
#####################
363+
# Wrap into TensorMap
364+
#####################
365+
samples = torch.zeros((n_atoms, 2), device=device, dtype=torch.int32)
366+
samples[:, 0] = 0
367+
samples[:, 1] = torch.arange(n_atoms, device=device, dtype=torch.int32)
368+
369+
properties = torch.tensor([[0]], device=device, dtype=torch.int32)
370+
371+
block = TensorBlock(
372+
values=energies.unsqueeze(-1),
373+
samples=Labels(["system", "atom"], samples),
374+
components=[],
375+
properties=Labels("energy", properties),
376+
)
377+
378+
keys = Labels("_", torch.zeros(1, 1, dtype=torch.int32, device=device))
379+
380+
energy_tensor = TensorMap(keys=keys, blocks=[block])
381+
382+
if outputs["energy"].per_atom:
383+
energy = energy_tensor
384+
else:
385+
energy = sum_over_samples(energy_tensor, sample_names="atom")
386+
387+
return {"energy": energy}

0 commit comments

Comments
 (0)