|
| 1 | +from typing import Optional |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import profiler |
| 5 | + |
| 6 | +from .._utils import _validate_parameters |
| 7 | +from ..lib import generate_kvectors_for_ewald |
| 8 | +from ..potentials import PotentialDipole |
| 9 | + |
| 10 | + |
| 11 | +class CalculatorDipole(torch.nn.Module): |
| 12 | + """ |
| 13 | + Base calculator for interacting dipoles in the torch interface. |
| 14 | +
|
| 15 | + :param potential: a :class:`PotentialDipole` class object containing the functions |
| 16 | + that are necessary to compute the various components of the potential, as |
| 17 | + well as the parameters that determine the behavior of the potential itself. |
| 18 | + :param full_neighbor_list: parameter indicating whether the neighbor information |
| 19 | + will come from a full (True) or half (False, default) neighbor list. |
| 20 | + :param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and |
| 21 | + common values. |
| 22 | + :param lr_wavelength: the wavelength of the long-range part of the potential. |
| 23 | + """ |
| 24 | + |
| 25 | + def __init__( |
| 26 | + self, |
| 27 | + potential: PotentialDipole, |
| 28 | + full_neighbor_list: bool = False, |
| 29 | + prefactor: float = 1.0, |
| 30 | + lr_wavelength: Optional[float] = None, |
| 31 | + ): |
| 32 | + super().__init__() |
| 33 | + |
| 34 | + if not isinstance(potential, PotentialDipole): |
| 35 | + raise TypeError( |
| 36 | + f"Potential must be an instance of PotentialDipole, got {type(potential)}" |
| 37 | + ) |
| 38 | + |
| 39 | + self.potential = potential |
| 40 | + self.lr_wavelength = lr_wavelength |
| 41 | + |
| 42 | + assert ( |
| 43 | + self.lr_wavelength is not None |
| 44 | + and self.potential.smearing is not None |
| 45 | + or (self.lr_wavelength is None and self.potential.smearing is None) |
| 46 | + ), "Either both `lr_wavelength` and `smearing` must be set or both must be None" |
| 47 | + |
| 48 | + self.full_neighbor_list = full_neighbor_list |
| 49 | + |
| 50 | + self.prefactor = prefactor |
| 51 | + |
| 52 | + def _compute_rspace( |
| 53 | + self, |
| 54 | + dipoles: torch.Tensor, |
| 55 | + neighbor_indices: torch.Tensor, |
| 56 | + neighbor_vectors: torch.Tensor, |
| 57 | + ) -> torch.Tensor: |
| 58 | + # Compute the pair potential terms V(r_ij) for each pair of atoms (i,j) |
| 59 | + # contained in the neighbor list |
| 60 | + with profiler.record_function("compute bare potential"): |
| 61 | + if self.potential.smearing is None: |
| 62 | + potentials_bare = self.potential.from_dist(neighbor_vectors) |
| 63 | + else: |
| 64 | + potentials_bare = self.potential.sr_from_dist(neighbor_vectors) |
| 65 | + |
| 66 | + # Multiply the bare potential terms V(r_ij) with the corresponding dipoles |
| 67 | + # of ``atom j'' to obtain q_j*V(r_ij). Since each atom j can be a neighbor of |
| 68 | + # multiple atom i's, we need to access those from neighbor_indices |
| 69 | + atom_is = neighbor_indices[:, 0] |
| 70 | + atom_js = neighbor_indices[:, 1] |
| 71 | + with profiler.record_function("compute real potential"): |
| 72 | + contributions_is = torch.bmm( |
| 73 | + potentials_bare, dipoles[atom_js].unsqueeze(-1) |
| 74 | + ).squeeze(-1) |
| 75 | + |
| 76 | + # For each atom i, add up all contributions of the form q_j*V(r_ij) for j |
| 77 | + # ranging over all of its neighbors. |
| 78 | + with profiler.record_function("assign potential"): |
| 79 | + potential = torch.zeros_like(dipoles) |
| 80 | + potential.index_add_(0, atom_is, contributions_is) |
| 81 | + # If we are using a half neighbor list, we need to add the contributions |
| 82 | + # from the "inverse" pairs (j, i) to the atoms i |
| 83 | + if not self.full_neighbor_list: |
| 84 | + contributions_js = torch.bmm( |
| 85 | + potentials_bare, dipoles[atom_is].unsqueeze(-1) |
| 86 | + ).squeeze(-1) |
| 87 | + potential.index_add_(0, atom_js, contributions_js) |
| 88 | + |
| 89 | + # Compensate for double counting of pairs (i,j) and (j,i) |
| 90 | + return potential / 2 |
| 91 | + |
| 92 | + def _compute_kspace( |
| 93 | + self, |
| 94 | + dipoles: torch.Tensor, |
| 95 | + cell: torch.Tensor, |
| 96 | + positions: torch.Tensor, |
| 97 | + ) -> torch.Tensor: |
| 98 | + # Define k-space cutoff from required real-space resolution |
| 99 | + k_cutoff = 2 * torch.pi / self.lr_wavelength |
| 100 | + |
| 101 | + # Compute number of times each basis vector of the reciprocal space can be |
| 102 | + # scaled until the cutoff is reached |
| 103 | + basis_norms = torch.linalg.norm(cell, dim=1) |
| 104 | + ns_float = k_cutoff * basis_norms / 2 / torch.pi |
| 105 | + ns = torch.ceil(ns_float).long() |
| 106 | + |
| 107 | + # Generate k-vectors and evaluate |
| 108 | + kvectors = generate_kvectors_for_ewald(ns=ns, cell=cell) |
| 109 | + knorm_sq = torch.sum(kvectors**2, dim=1) |
| 110 | + # We remove the singularity at k=0 by explicitly setting its |
| 111 | + # value to be equal to zero. This mathematically corresponds |
| 112 | + # to the requirement that the net charge of the cell is zero. |
| 113 | + # G = 4 * torch.pi * torch.exp(-0.5 * smearing**2 * knorm_sq) / knorm_sq |
| 114 | + G = self.potential.lr_from_k_sq(knorm_sq) |
| 115 | + |
| 116 | + # Compute the energy using the explicit method that |
| 117 | + # follows directly from the Poisson summation formula. |
| 118 | + # For this, we precompute trigonometric factors for optimization, which leads |
| 119 | + # to N^2 rather than N^3 scaling. |
| 120 | + trig_args = kvectors @ (positions.T) # [k, i] |
| 121 | + c = torch.cos(trig_args) # [k, i] |
| 122 | + s = torch.sin(trig_args) # [k, i] |
| 123 | + sc = torch.stack([c, s], dim=0) # [2 "f", k, i] |
| 124 | + mu_k = dipoles @ kvectors.T # [i, k] |
| 125 | + sc_summed_G = torch.einsum("fki, ik, k->fk", sc, mu_k, G) |
| 126 | + energy = torch.einsum("fk, fki, kc->ic", sc_summed_G, sc, kvectors) |
| 127 | + energy /= torch.abs(cell.det()) |
| 128 | + energy -= dipoles * self.potential.self_contribution() |
| 129 | + energy += self.potential.background_correction( |
| 130 | + torch.abs(cell.det()) |
| 131 | + ) * dipoles.sum(dim=0) |
| 132 | + return energy / 2 |
| 133 | + |
| 134 | + def forward( |
| 135 | + self, |
| 136 | + dipoles: torch.Tensor, |
| 137 | + cell: torch.Tensor, |
| 138 | + positions: torch.Tensor, |
| 139 | + neighbor_indices: torch.Tensor, |
| 140 | + neighbor_vectors: torch.Tensor, |
| 141 | + ): |
| 142 | + r""" |
| 143 | + Compute the potential "energy". |
| 144 | +
|
| 145 | + It is calculated as: |
| 146 | +
|
| 147 | + .. math:: |
| 148 | +
|
| 149 | + V_i = \frac{1}{2} \sum_{j} \boldsymbol{\mu_j} \, \mathbf{v}(\mathbf{r_{ij}}) |
| 150 | +
|
| 151 | + where :math:`\mathbf{v}(\mathbf{r})` is the pair potential defined by the ``potential`` |
| 152 | + parameter, and :math:`\boldsymbol{\mu_j}` are atomic "dipoles". |
| 153 | +
|
| 154 | + If the ``smearing`` of the ``potential`` is not set, the calculator evaluates |
| 155 | + only the real-space part of the potential. Otherwise, provided that the |
| 156 | + calculator implements a ``_compute_kspace`` method, it will also evaluate the |
| 157 | + long-range part using a Fourier-domain method. |
| 158 | +
|
| 159 | + :param dipoles: torch.tensor of shape ``(len(positions), 3)`` |
| 160 | + containaing the atomic dipoles. |
| 161 | + :param cell: torch.tensor of shape ``(3, 3)``, where ``cell[i]`` is the i-th basis |
| 162 | + vector of the unit cell |
| 163 | + :param positions: torch.tensor of shape ``(N, 3)`` containing the Cartesian |
| 164 | + coordinates of the ``N`` particles within the supercell. |
| 165 | + :param neighbor_indices: torch.tensor with the ``i,j`` indices of neighbors for |
| 166 | + which the potential should be computed in real space. |
| 167 | + :param neighbor_vectors: torch.tensor with the pair vectors of the neighbors |
| 168 | + for which the potential should be computed in real space. |
| 169 | + """ |
| 170 | + # TODO: _validate_parameters to allow also dipoles. Temporarily pass the |
| 171 | + # distance tensor. |
| 172 | + _validate_parameters( |
| 173 | + charges=dipoles, |
| 174 | + cell=cell, |
| 175 | + positions=positions, |
| 176 | + neighbor_indices=neighbor_indices, |
| 177 | + neighbor_distances=neighbor_vectors.norm(dim=-1), |
| 178 | + smearing=self.potential.smearing, |
| 179 | + ) |
| 180 | + |
| 181 | + # Compute short-range (SR) part using a real space sum |
| 182 | + potential_sr = self._compute_rspace( |
| 183 | + dipoles=dipoles, |
| 184 | + neighbor_indices=neighbor_indices, |
| 185 | + neighbor_vectors=neighbor_vectors, |
| 186 | + ) |
| 187 | + |
| 188 | + if self.potential.smearing is None: |
| 189 | + return self.prefactor * potential_sr |
| 190 | + # Compute long-range (LR) part using a Fourier / reciprocal space sum |
| 191 | + potential_lr = self._compute_kspace( |
| 192 | + dipoles=dipoles, |
| 193 | + cell=cell, |
| 194 | + positions=positions, |
| 195 | + ) |
| 196 | + |
| 197 | + return self.prefactor * (potential_sr + potential_lr) |
0 commit comments