|
| 1 | +from typing import Dict, List, Optional |
| 2 | + |
| 3 | +import torch |
| 4 | +from metatensor.torch import Labels, TensorBlock, TensorMap, sum_over_samples |
| 5 | +from metatensor.torch.atomistic import ModelOutput, NeighborListOptions, System |
| 6 | +from torchpme import Calculator, CoulombPotential, P3MCalculator |
| 7 | + |
| 8 | + |
| 9 | +def lennard_jones_pair( |
| 10 | + distances: torch.Tensor, |
| 11 | + sigma: torch.Tensor, |
| 12 | + epsilon: torch.Tensor, |
| 13 | + cutoff: torch.Tensor, |
| 14 | +): |
| 15 | + """Lennard-Jones potential for pair terms.""" |
| 16 | + c6 = (sigma**2 / distances**2) ** 3 |
| 17 | + c12 = c6**2 |
| 18 | + lj = 4 * epsilon * (c12 - c6) |
| 19 | + |
| 20 | + cutoff = 1 / cutoff |
| 21 | + offset = 4 * epsilon * (sigma**12 * cutoff**12 - sigma**6 * cutoff**6) |
| 22 | + |
| 23 | + return lj - offset |
| 24 | + |
| 25 | + |
| 26 | +def harmonic_distance_pair( |
| 27 | + distances: torch.Tensor, |
| 28 | + coefficient: torch.Tensor, |
| 29 | + equilibrium_distance: torch.Tensor, |
| 30 | +): |
| 31 | + """Harmonic potential for bond terms.""" |
| 32 | + r2 = (distances - equilibrium_distance) ** 2 |
| 33 | + |
| 34 | + return 0.5 * coefficient * r2 |
| 35 | + |
| 36 | + |
| 37 | +def harmonic_angular( |
| 38 | + angles: torch.Tensor, coefficient: torch.Tensor, equilibrium_angle: torch.Tensor |
| 39 | +): |
| 40 | + """Harmonic potential for angular terms.""" |
| 41 | + return 0.5 * coefficient * (angles - equilibrium_angle) ** 2 |
| 42 | + |
| 43 | + |
| 44 | +def compute_angles(positions: torch.Tensor, neighbor_indices: torch.Tensor): |
| 45 | + """Compute the angles formed by triplet of atoms based on their positions.""" |
| 46 | + atom_is = neighbor_indices[:, 0] |
| 47 | + atom_js = neighbor_indices[:, 1] |
| 48 | + atom_ks = neighbor_indices[:, 2] |
| 49 | + |
| 50 | + pos_is = positions[atom_is] |
| 51 | + pos_js = positions[atom_js] |
| 52 | + pos_ks = positions[atom_ks] |
| 53 | + |
| 54 | + R_ij = pos_js - pos_is |
| 55 | + R_ik = pos_ks - pos_is |
| 56 | + |
| 57 | + return angle(R_ij, R_ik) |
| 58 | + |
| 59 | + |
| 60 | +def angle(a: torch.Tensor, b: torch.Tensor, dim: int = -1): |
| 61 | + """Compute the angle between two vectors a and b. |
| 62 | +
|
| 63 | + Code is taken from https://github.com/pytorch/pytorch/issues/59194 |
| 64 | + """ |
| 65 | + a_norm = a.norm(p=2, dim=dim, keepdim=True) |
| 66 | + b_norm = a.norm(p=2, dim=dim, keepdim=True) |
| 67 | + angles = 2 * torch.atan2( |
| 68 | + (a * b_norm - a_norm * b).norm(p=2, dim=dim), |
| 69 | + (a * b_norm + a_norm * b).norm(p=2, dim=dim), |
| 70 | + ) |
| 71 | + |
| 72 | + return angles |
| 73 | + |
| 74 | + |
| 75 | +class WaterModel(torch.nn.Module): |
| 76 | + def __init__( |
| 77 | + self, |
| 78 | + cutoff: float, |
| 79 | + O_sigma: float, |
| 80 | + O_epsilon: float, |
| 81 | + O_charge: float, |
| 82 | + OH_bond_coefficient: float, |
| 83 | + OH_equilibrium_distance: float, |
| 84 | + HOH_angle_coefficient: float, |
| 85 | + HOH_equilibrium_angle: float, |
| 86 | + pme_smearing: float, |
| 87 | + pme_mesh_spacing: float, |
| 88 | + pme_interpolation_nodes: int = 4, |
| 89 | + pme_prefactor: float = 1, |
| 90 | + four_point_model: bool = False, |
| 91 | + dtype: Optional[float] = None, |
| 92 | + ): |
| 93 | + """ |
| 94 | + Flexible water model for three and four point models. |
| 95 | +
|
| 96 | + The model contains Lennard Jones interactions between the oxygens as well as |
| 97 | + intra molecular bond and angle terms. The electrostatics are computed using the |
| 98 | + P3M method. For a four point model the fourth side for the charge interaction is |
| 99 | + computed implicitly based on the position of the other atoms. |
| 100 | +
|
| 101 | + :param cutoff: Cutoff for the Lennard-Jones interactions. |
| 102 | + :param O_sigma: Sigma parameter for the oxygen Lennard-Jones interactions. |
| 103 | + :param O_epsilon: Epsilon parameter for the oxygen Lennard-Jones interactions. |
| 104 | + :param O_charge: Oxygen's atom charge; hydrogen is computed accordingly. |
| 105 | + :param OH_bond_coefficient: Harmonic coefficient for the OH bond. |
| 106 | + :param OH_equilibrium_distance: Equilibrium distance for the OH bond. |
| 107 | + :param HOH_angle_coefficient: Harmonic coefficient for the HOH angle. |
| 108 | + :param HOH_equilibrium_angle: Equilibrium angle for the HOH angle in degrees. |
| 109 | + :param pme_smearing: Smearing parameter for the PME. |
| 110 | + :param pme_mesh_spacing: Mesh spacing for the PME. |
| 111 | + :param pme_interpolation_nodes: Number of interpolation nodes for the PME. |
| 112 | + :param pme_prefactor: Prefactor for the PME. |
| 113 | + :param four_point_model: If :py:obj:`True`, use the four-point model for the |
| 114 | + electrostatics. The fourth point M is implicitly derived from the other |
| 115 | + atoms or each water molecule and used during the force computation. See |
| 116 | + 10.1063/1.3167790 for details on its derivation. |
| 117 | + :param dtype: Floating point precision for the model. If :py:obj:`None`, the |
| 118 | + :param dtype: default |
| 119 | + dtype is used. |
| 120 | + """ |
| 121 | + super().__init__() |
| 122 | + |
| 123 | + self.dtype = dtype if dtype is not None else torch.get_default_dtype() |
| 124 | + self.four_point_model = four_point_model |
| 125 | + |
| 126 | + self.register_buffer("cutoff", torch.tensor(cutoff, dtype=self.dtype)) |
| 127 | + self.register_buffer("O_sigma", torch.tensor(O_sigma, dtype=self.dtype)) |
| 128 | + self.register_buffer("O_epsilon", torch.tensor(O_epsilon, dtype=self.dtype)) |
| 129 | + self.register_buffer( |
| 130 | + "OH_bond_coefficient", torch.tensor(OH_bond_coefficient, dtype=self.dtype) |
| 131 | + ) |
| 132 | + self.register_buffer( |
| 133 | + "OH_equilibrium_distance", |
| 134 | + torch.tensor(OH_equilibrium_distance, dtype=self.dtype), |
| 135 | + ) |
| 136 | + self.register_buffer( |
| 137 | + "HOH_angle_coefficient", |
| 138 | + torch.tensor(HOH_angle_coefficient, dtype=self.dtype), |
| 139 | + ) |
| 140 | + |
| 141 | + # Convert degree angle to radians |
| 142 | + self.register_buffer( |
| 143 | + "HOH_equilibrium_angle", |
| 144 | + torch.tensor(HOH_equilibrium_angle * torch.pi / 180, dtype=self.dtype), |
| 145 | + ) |
| 146 | + |
| 147 | + # Register charges for water model |
| 148 | + H_charge = -O_charge / 2 |
| 149 | + self.register_buffer( |
| 150 | + "OHH_charges", |
| 151 | + torch.tensor([O_charge, H_charge, H_charge], dtype=self.dtype), |
| 152 | + ) |
| 153 | + |
| 154 | + self.pme_calculator = P3MCalculator( |
| 155 | + potential=CoulombPotential(pme_smearing), |
| 156 | + mesh_spacing=pme_mesh_spacing, |
| 157 | + interpolation_nodes=pme_interpolation_nodes, |
| 158 | + prefactor=pme_prefactor, |
| 159 | + ) |
| 160 | + self.coulomb_calculator = Calculator( |
| 161 | + CoulombPotential(), prefactor=pme_prefactor |
| 162 | + ) |
| 163 | + |
| 164 | + self.nl = NeighborListOptions(cutoff=cutoff, full_list=False, strict=False) |
| 165 | + |
| 166 | + def requested_neighbor_lists(self): |
| 167 | + return [self.nl] |
| 168 | + |
| 169 | + def _setup_systems( |
| 170 | + self, |
| 171 | + systems: list[System], |
| 172 | + selected_atoms: Optional[Labels] = None, |
| 173 | + ) -> tuple[System, TensorBlock]: |
| 174 | + """Remove possible ghost atoms and add charges to the system.""" |
| 175 | + if len(systems) > 1: |
| 176 | + raise ValueError(f"only one system supported, got {len(systems)}") |
| 177 | + |
| 178 | + system_i = 0 |
| 179 | + system = systems[system_i] |
| 180 | + |
| 181 | + # select only real atoms and discard ghosts |
| 182 | + if selected_atoms is not None: |
| 183 | + current_system_mask = selected_atoms.column("system") == system_i |
| 184 | + current_atoms = selected_atoms.column("atom") |
| 185 | + current_atoms = current_atoms[current_system_mask].to(torch.long) |
| 186 | + |
| 187 | + types = system.types[current_atoms] |
| 188 | + positions = system.positions[current_atoms] |
| 189 | + else: |
| 190 | + types = system.types |
| 191 | + positions = system.positions |
| 192 | + |
| 193 | + system_final = System(types, positions, system.cell, system.pbc) |
| 194 | + |
| 195 | + return system_final, system.get_neighbor_list(self.nl) |
| 196 | + |
| 197 | + def forward( |
| 198 | + self, |
| 199 | + systems: List[System], # noqa |
| 200 | + outputs: Dict[str, ModelOutput], # noqa |
| 201 | + selected_atoms: Optional[Labels] = None, |
| 202 | + ) -> Dict[str, TensorMap]: # noqa |
| 203 | + """ |
| 204 | + Compute the energy of the water model. |
| 205 | +
|
| 206 | + Water molecules have to be in order OHH and whole across the system. |
| 207 | + """ |
| 208 | + if list(outputs.keys()) != ["energy"]: |
| 209 | + raise ValueError( |
| 210 | + f"`outputs` keys ({', '.join(outputs.keys())}) contain unsupported " |
| 211 | + "keys. Only 'energy' is supported." |
| 212 | + ) |
| 213 | + |
| 214 | + system, neighbors = self._setup_systems(systems, selected_atoms) |
| 215 | + species = system.types |
| 216 | + |
| 217 | + if system.positions.dtype != self.dtype: |
| 218 | + raise ValueError( |
| 219 | + f"system.positions.dtype ({system.positions.dtype}) must be " |
| 220 | + f"equal to dtype at initilization ({self.dtype})" |
| 221 | + ) |
| 222 | + |
| 223 | + device = system.positions.device |
| 224 | + n_atoms = len(system) |
| 225 | + |
| 226 | + neighbor_indices = neighbors.samples.view(["first_atom", "second_atom"]).values |
| 227 | + |
| 228 | + if device == "cpu": |
| 229 | + # move data to 64-bit integers, for some reason indexing with 64-bit |
| 230 | + # is a lot faster than using 32-bit integers on CPU. CUDA seems fine |
| 231 | + # with either types |
| 232 | + neighbor_indices = neighbor_indices.to( |
| 233 | + torch.int64, memory_format=torch.contiguous_format |
| 234 | + ) |
| 235 | + |
| 236 | + neighbor_distances = torch.linalg.norm(neighbors.values, dim=1).squeeze(1) |
| 237 | + |
| 238 | + # Verify that system only contains water molecules in the correct order |
| 239 | + if n_atoms % 3 != 0: |
| 240 | + raise ValueError( |
| 241 | + "system must be water containing a multiple of 3 atoms. " |
| 242 | + f"Found {n_atoms} atoms!" |
| 243 | + ) |
| 244 | + |
| 245 | + reference_types = torch.tensor( |
| 246 | + [8, 1, 1], dtype=self.dtype, device=device |
| 247 | + ).repeat(n_atoms // 3) |
| 248 | + if not torch.all(system.types == reference_types): |
| 249 | + raise ValueError( |
| 250 | + "system must contain only water molecules in the order OHH" |
| 251 | + ) |
| 252 | + |
| 253 | + energies = torch.zeros(n_atoms, dtype=self.dtype, device=device) |
| 254 | + ################### |
| 255 | + # O-O Lennard-Jones |
| 256 | + ################### |
| 257 | + i = neighbor_indices[:, 0] |
| 258 | + j = neighbor_indices[:, 1] |
| 259 | + lj_mask = (species[i] == 8) & (species[j] == 8) |
| 260 | + lj_neighbor_indices = neighbor_indices[lj_mask] |
| 261 | + lj_neighbor_distances = neighbor_distances[lj_mask] |
| 262 | + |
| 263 | + lj = lennard_jones_pair( |
| 264 | + distances=lj_neighbor_distances, |
| 265 | + sigma=self.O_sigma, |
| 266 | + epsilon=self.O_epsilon, |
| 267 | + cutoff=self.cutoff, |
| 268 | + ) |
| 269 | + |
| 270 | + energies.index_add_(0, lj_neighbor_indices[:, 0], lj) |
| 271 | + energies.index_add_(0, lj_neighbor_indices[:, 1], lj) |
| 272 | + |
| 273 | + ########## |
| 274 | + # O-H bond |
| 275 | + ########## |
| 276 | + # select pairs within the same molecule |
| 277 | + mol_mask = (i // 3) == (j // 3) |
| 278 | + |
| 279 | + bond_mask = mol_mask & (species[i] == 8) |
| 280 | + bond_neighbor_indices = neighbor_indices[bond_mask] |
| 281 | + bond_neighbor_distances = neighbor_distances[bond_mask] |
| 282 | + |
| 283 | + cell_dimensions = torch.linalg.norm(system.cell, dim=1) |
| 284 | + min_dimension = float(torch.min(cell_dimensions)) |
| 285 | + half_cell = min_dimension / 2.0 |
| 286 | + |
| 287 | + if torch.any(bond_neighbor_distances > half_cell): |
| 288 | + raise ValueError( |
| 289 | + "Bond distances are larger than half of the cell size. " |
| 290 | + "Most likely molecules are not whole." |
| 291 | + "This is not supported by the model." |
| 292 | + ) |
| 293 | + |
| 294 | + bond = harmonic_distance_pair( |
| 295 | + distances=bond_neighbor_distances, |
| 296 | + coefficient=self.OH_bond_coefficient, |
| 297 | + equilibrium_distance=self.OH_equilibrium_distance, |
| 298 | + ) |
| 299 | + |
| 300 | + energies.index_add_(0, bond_neighbor_indices[:, 0], bond) |
| 301 | + energies.index_add_(0, bond_neighbor_indices[:, 1], bond) |
| 302 | + |
| 303 | + ############# |
| 304 | + # H-O-H angle |
| 305 | + ############# |
| 306 | + all_idx = torch.arange(n_atoms, device=device) |
| 307 | + angle_indices = torch.vstack([all_idx[0::3], all_idx[1::3], all_idx[2::3]]).T |
| 308 | + angle_values = compute_angles(system.positions, angle_indices) |
| 309 | + |
| 310 | + angles = harmonic_angular( |
| 311 | + angles=angle_values, |
| 312 | + coefficient=self.HOH_angle_coefficient, |
| 313 | + equilibrium_angle=self.HOH_equilibrium_angle, |
| 314 | + ) |
| 315 | + |
| 316 | + energies.index_add_(0, angle_indices[:, 0], angles) |
| 317 | + energies.index_add_(0, angle_indices[:, 1], angles) |
| 318 | + energies.index_add_(0, angle_indices[:, 2], angles) |
| 319 | + |
| 320 | + ################ |
| 321 | + # Electrostatics |
| 322 | + ################ |
| 323 | + |
| 324 | + # fourth point is computed according to eq. 2 in 10.1063/1.3167790 |
| 325 | + if self.four_point_model: |
| 326 | + positions_coul = torch.vstack( |
| 327 | + [ |
| 328 | + ( |
| 329 | + (system.positions[1::3] + system.positions[2::3]) * 0.5 |
| 330 | + + system.positions[0::3] * 3 |
| 331 | + ) |
| 332 | + / 4, |
| 333 | + system.positions[1::3], |
| 334 | + system.positions[2::3], |
| 335 | + ] |
| 336 | + ) |
| 337 | + else: |
| 338 | + positions_coul = system.positions |
| 339 | + |
| 340 | + charges = self.OHH_charges.tile((n_atoms // 3,)).unsqueeze(-1) |
| 341 | + |
| 342 | + # all to all interactions |
| 343 | + potential = self.pme_calculator( |
| 344 | + positions=positions_coul, |
| 345 | + cell=system.cell, |
| 346 | + charges=charges, |
| 347 | + neighbor_indices=neighbor_indices, |
| 348 | + neighbor_distances=neighbor_distances, |
| 349 | + ) |
| 350 | + |
| 351 | + potential_exclusion = self.coulomb_calculator( |
| 352 | + positions=positions_coul, |
| 353 | + cell=system.cell, |
| 354 | + charges=charges, |
| 355 | + neighbor_indices=bond_neighbor_indices, |
| 356 | + neighbor_distances=bond_neighbor_distances, |
| 357 | + ) |
| 358 | + |
| 359 | + potential -= potential_exclusion |
| 360 | + energies += (potential * charges).flatten() |
| 361 | + |
| 362 | + ##################### |
| 363 | + # Wrap into TensorMap |
| 364 | + ##################### |
| 365 | + samples = torch.zeros((n_atoms, 2), device=device, dtype=torch.int32) |
| 366 | + samples[:, 0] = 0 |
| 367 | + samples[:, 1] = torch.arange(n_atoms, device=device, dtype=torch.int32) |
| 368 | + |
| 369 | + properties = torch.tensor([[0]], device=device, dtype=torch.int32) |
| 370 | + |
| 371 | + block = TensorBlock( |
| 372 | + values=energies.unsqueeze(-1), |
| 373 | + samples=Labels(["system", "atom"], samples), |
| 374 | + components=[], |
| 375 | + properties=Labels("energy", properties), |
| 376 | + ) |
| 377 | + |
| 378 | + keys = Labels("_", torch.zeros(1, 1, dtype=torch.int32, device=device)) |
| 379 | + |
| 380 | + energy_tensor = TensorMap(keys=keys, blocks=[block]) |
| 381 | + |
| 382 | + if outputs["energy"].per_atom: |
| 383 | + energy = energy_tensor |
| 384 | + else: |
| 385 | + energy = sum_over_samples(energy_tensor, sample_names="atom") |
| 386 | + |
| 387 | + return {"energy": energy} |
0 commit comments