Skip to content

Commit 438b476

Browse files
E-RumPicoCentauri
andauthored
Implement Magnetostatics in TorchPME (#133)
* Add CalculatorDipole and PotentialDipole classes with tests * Refactor PotentialDipole class to inherit from torch.nn.Module * Refactor potential calculations in CalculatorDipole and PotentialDipole classes for improved performance and clarity * Add lr_from_dist calculation to PotentialDipole class * Implement lr_from_k_sq, self_contribution and background_correction for PotentialDipole class * Finish calculator_dipole.py * Correct energy calculation in CalculatorDipole and fix smearing factor in PotentialDipole * Refactor device handling in CalculatorDipole and PotentialDipole classes; improve performance of potential calculations and enhance test coverage for magnetostatics * Enhance PotentialDipole class by allowing device parameter to accept string type; fix tensor dimension in f_cutoff method and update test assertions for clarity * Fix a bug in k_space calculator; add epsilon parameter to PotentialDipole and update background_correction method * Add dipoles test frames and enhance tests for magnetostatic Ewald calculations * Replace vesin.torch, as it doesn’t support Windows OS. * Refactor CalculatorDipole and PotentialDipole classes to remove device and dtype parameters; update tests accordingly for improved clarity and consistency. * Update __init__.py to include CalculatorDipole and PotentialDipole in exports; enhance compute_distances function to optionally return distance vectors. * restructure some files * Enhance docstrings for CalculatorDipole and PotentialDipole classes; provide detailed parameter descriptions and usage examples for improved clarity. * proofread docs --------- Co-authored-by: Philip Loche <[email protected]>
1 parent 30aec6b commit 438b476

File tree

18 files changed

+654
-28
lines changed

18 files changed

+654
-28
lines changed

README.rst

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@ torch-pme
1111
1212
``torch-pme`` enables efficient and auto-differentiable computation of long-range
1313
interactions in *PyTorch*. Auto-differentiation is supported for particle *positions*,
14-
*charges*, and *cell* parameters, allowing not only the computation of forces but also
15-
enabling general applications in machine learning tasks. The library offers classes for
16-
Particle-Particle Particle-Mesh Ewald (``P3M``), Particle Mesh Ewald (``PME``), standard
17-
``Ewald``, and non-periodic methods, with the flexibility to calculate potentials beyond
18-
:math:`1/r` electrostatics, including arbitrary order :math:`1/r^p` potentials.
14+
15+
*charges*/*dipoles*, and *cell* parameters, allowing not only the automatic computation
16+
of forces but also enabling general applications in machine learning tasks. For
17+
**monopoles** the library offers classes for Particle-Particle Particle-Mesh Ewald
18+
(``P3M``), Particle Mesh Ewald (``PME``), standard ``Ewald``, and non-periodic methods.
19+
The library has the flexibility to calculate potentials beyond :math:`1/r`
20+
electrostatics, including arbitrary order :math:`1/r^p` potentials. For **dipolar**
21+
interaction we offer to calculate the :math:`1/r^3` potential using the standard
22+
``Ewald`` method.
1923

2024
Optimized for both CPU and GPU devices, ``torch-pme`` is fully `TorchScriptable`_,
2125
allowing it to be converted into a format that runs independently of Python, such as in
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
CalculatorDipole
2+
################
3+
4+
.. autoclass:: torchpme.CalculatorDipole
5+
:members:
6+
7+
.. minigallery::
8+
:add-heading:
9+
10+
torchpme.CalculatorDipole

docs/src/references/changelog.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ changelog <https://keepachangelog.com/en/1.1.0/>`_ format. This project follows
2727
Added
2828
#####
2929

30+
* Added classes for the calculation of dipole interactions
3031
* Better documentation for for ``cell``, ``charges`` and ``positions`` parameters
3132

3233
Removed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
PotentialDipole
2+
###############
3+
4+
.. autoclass:: torchpme.PotentialDipole
5+
:members:
6+
7+
.. minigallery:: torchpme.PotentialDipole
8+
:add-heading:

examples/dipoles_test_frames.xyz

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
8
2+
Lattice="8.460000038146973 0.0 0.0 0.0 8.460000038146973 0.0 0.0 0.0 8.460000038146973" Properties=species:S:1:pos:R:3:dipoles:R:3:forces:R:3 energy=1.5323538047012266 pbc="T T T"
3+
Na 6.40999985 1.92000008 7.67000008 0.87830178 0.44712111 0.83810067 0.57283619 1.14204645 1.33681563
4+
Na 4.18000031 4.96000004 8.02999973 0.73324216 0.97949808 -0.26166893 -0.03507047 0.24607122 1.13457851
5+
Na 4.30999994 7.90000010 3.68999982 -0.31266377 -0.51245950 -0.18034898 -0.01967271 0.0084999 -0.36205072
6+
Na 8.23999977 2.98000002 3.89999986 -0.12426795 0.51958488 0.66561609 -0.02508925 -0.11864547 -0.19522136
7+
Cl 4.57999992 0.88000000 7.42999983 -0.41149415 0.34275195 0.52485166 -0.51002499 -1.5905667 -1.21172538
8+
Cl 7.16000032 5.13000011 1.26999998 0.93048734 -0.30992820 0.32975380 -0.00883947 -0.11365123 -0.21224626
9+
Cl 8.27999973 8.02999973 5.61999989 0.17311888 -0.83212290 -0.09616741 0.34479173 0.24309117 0.15871536
10+
Cl 4.19000006 3.82999992 5.23000002 0.31179980 0.71290202 -0.87487990 -0.31893102 0.18315466 -0.64886577
11+
8
12+
Lattice="8.0 0.0 0.0 0.0 8.0 0.0 0.0 0.0 8.0" Properties=species:S:1:pos:R:3:dipoles:R:3:forces:R:3 energy=2.005822828064277 pbc="T T T"
13+
Na 7.50000000 1.50999999 0.05000000 -0.85723323 0.56525094 -0.66685193 -1.46823136 1.56818289 1.1018063
14+
Na 6.37000036 6.44000006 5.28999996 0.83409694 -0.89774275 0.74302336 1.85867637 1.43423158 -0.04219309
15+
Na 1.15999997 2.56000018 3.88999987 -0.44331833 0.29179266 -0.94754402 0.06170197 -0.30920673 -0.13200583
16+
Na 5.46000004 7.88000011 6.57999992 0.70186080 0.98230338 -0.06109690 -0.98041385 -1.39992847 1.33130465
17+
Cl 3.61999989 5.91000032 3.55999994 -0.93308686 -0.88801660 -0.49452546 -0.49279598 -0.71824789 -0.67401955
18+
Cl 6.17000008 4.02999973 2.26999998 0.10955457 0.48917852 0.99768683 -0.00349212 0.44927656 -0.29142061
19+
Cl 4.38000011 0.11000000 0.82000005 -0.30364667 0.26808690 -0.68324974 -0.30267466 -0.5973787 -0.8010298
20+
Cl 0.71000004 2.49000001 0.93999994 -0.56214647 0.24127451 0.28667338 1.32722962 -0.42692926 -0.49244207
21+
8
22+
Lattice="10.0 0.0 0.0 0.0 10.0 0.0 0.0 0.0 10.0" Properties=species:S:1:pos:R:3:dipoles:R:3:forces:R:3 energy=2.261072327985546 pbc="T T T"
23+
Na 9.36999989 1.88999999 0.06000000 0.14120932 0.48341832 0.45537058 -2.12882348 -0.97937658 -0.75803323
24+
Na 7.96000004 8.05000019 6.61999989 -0.45102672 -0.03732401 -0.74260234 0.54009361 -0.42119106 -0.60862247
25+
Na 1.44999993 3.19999981 4.86999989 0.25118867 -0.29834069 -0.69578594 0.00429438 0.08823258 -0.29035364
26+
Na 6.83000040 9.85000038 8.23000050 -0.44058746 0.57551755 -0.88325601 -0.0410207 0.05467702 0.84370787
27+
Cl 4.53000021 7.37999964 4.46000004 0.96973183 -0.29250836 0.83184721 -0.02518924 -0.13902849 -0.24158766
28+
Cl 7.71000004 5.03999996 2.84000015 0.24131227 -0.80615721 -0.57806687 -0.26878176 0.42816378 0.41324843
29+
Cl 5.46999979 0.14000000 1.02999997 0.02369047 -0.50338626 0.70727541 -0.37825837 0.20451106 -0.10303815
30+
Cl 0.88000000 3.10999990 1.18000007 -0.74464657 -0.57116991 -0.90829042 2.29768557 0.7640117 0.74467885

src/torchpme/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,19 @@
22

33
from . import calculators, lib, potentials, prefactors, tuning # noqa
44
from ._version import __version__, __version_tuple__ # noqa
5-
from .calculators import Calculator, EwaldCalculator, P3MCalculator, PMECalculator
5+
from .calculators import (
6+
Calculator,
7+
CalculatorDipole,
8+
EwaldCalculator,
9+
P3MCalculator,
10+
PMECalculator,
11+
)
612
from .potentials import (
713
CombinedPotential,
814
CoulombPotential,
915
InversePowerLawPotential,
1016
Potential,
17+
PotentialDipole,
1118
SplinePotential,
1219
)
1320

@@ -24,4 +31,6 @@
2431
"InversePowerLawPotential",
2532
"SplinePotential",
2633
"CombinedPotential",
34+
"PotentialDipole",
35+
"CalculatorDipole",
2736
]

src/torchpme/calculators/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
from .calculator import Calculator
2+
from .calculator_dipole import CalculatorDipole
23
from .ewald import EwaldCalculator
34
from .p3m import P3MCalculator
45
from .pme import PMECalculator
56

6-
__all__ = ["Calculator", "EwaldCalculator", "P3MCalculator", "PMECalculator"]
7+
__all__ = [
8+
"Calculator",
9+
"EwaldCalculator",
10+
"P3MCalculator",
11+
"PMECalculator",
12+
"CalculatorDipole",
13+
]

src/torchpme/calculators/calculator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ def forward(
126126
``n_channels = 1``. If more than one "channel" is provided multiple
127127
potentials for the same position are computed depending on the charges and
128128
the potentials.
129-
:param cell: torch.tensor of shape ``(3, 3)``, where ``cell[i]`` is the i-th basis
130-
vector of the unit cell
129+
:param cell: torch.tensor of shape ``(3, 3)``, where ``cell[i]`` is the i-th
130+
basis vector of the unit cell
131131
:param positions: torch.tensor of shape ``(N, 3)`` containing the Cartesian
132132
coordinates of the ``N`` particles within the supercell.
133133
:param neighbor_indices: torch.tensor with the ``i,j`` indices of neighbors for
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
from typing import Optional
2+
3+
import torch
4+
from torch import profiler
5+
6+
from .._utils import _validate_parameters
7+
from ..lib import generate_kvectors_for_ewald
8+
from ..potentials import PotentialDipole
9+
10+
11+
class CalculatorDipole(torch.nn.Module):
12+
"""
13+
Base calculator for interacting dipoles in the torch interface.
14+
15+
:param potential: a :class:`PotentialDipole` class object containing the functions
16+
that are necessary to compute the various components of the potential, as
17+
well as the parameters that determine the behavior of the potential itself.
18+
:param full_neighbor_list: parameter indicating whether the neighbor information
19+
will come from a full (True) or half (False, default) neighbor list.
20+
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
21+
common values.
22+
:param lr_wavelength: the wavelength of the long-range part of the potential.
23+
"""
24+
25+
def __init__(
26+
self,
27+
potential: PotentialDipole,
28+
full_neighbor_list: bool = False,
29+
prefactor: float = 1.0,
30+
lr_wavelength: Optional[float] = None,
31+
):
32+
super().__init__()
33+
34+
if not isinstance(potential, PotentialDipole):
35+
raise TypeError(
36+
f"Potential must be an instance of PotentialDipole, got {type(potential)}"
37+
)
38+
39+
self.potential = potential
40+
self.lr_wavelength = lr_wavelength
41+
42+
assert (
43+
self.lr_wavelength is not None
44+
and self.potential.smearing is not None
45+
or (self.lr_wavelength is None and self.potential.smearing is None)
46+
), "Either both `lr_wavelength` and `smearing` must be set or both must be None"
47+
48+
self.full_neighbor_list = full_neighbor_list
49+
50+
self.prefactor = prefactor
51+
52+
def _compute_rspace(
53+
self,
54+
dipoles: torch.Tensor,
55+
neighbor_indices: torch.Tensor,
56+
neighbor_vectors: torch.Tensor,
57+
) -> torch.Tensor:
58+
# Compute the pair potential terms V(r_ij) for each pair of atoms (i,j)
59+
# contained in the neighbor list
60+
with profiler.record_function("compute bare potential"):
61+
if self.potential.smearing is None:
62+
potentials_bare = self.potential.from_dist(neighbor_vectors)
63+
else:
64+
potentials_bare = self.potential.sr_from_dist(neighbor_vectors)
65+
66+
# Multiply the bare potential terms V(r_ij) with the corresponding dipoles
67+
# of ``atom j'' to obtain q_j*V(r_ij). Since each atom j can be a neighbor of
68+
# multiple atom i's, we need to access those from neighbor_indices
69+
atom_is = neighbor_indices[:, 0]
70+
atom_js = neighbor_indices[:, 1]
71+
with profiler.record_function("compute real potential"):
72+
contributions_is = torch.bmm(
73+
potentials_bare, dipoles[atom_js].unsqueeze(-1)
74+
).squeeze(-1)
75+
76+
# For each atom i, add up all contributions of the form q_j*V(r_ij) for j
77+
# ranging over all of its neighbors.
78+
with profiler.record_function("assign potential"):
79+
potential = torch.zeros_like(dipoles)
80+
potential.index_add_(0, atom_is, contributions_is)
81+
# If we are using a half neighbor list, we need to add the contributions
82+
# from the "inverse" pairs (j, i) to the atoms i
83+
if not self.full_neighbor_list:
84+
contributions_js = torch.bmm(
85+
potentials_bare, dipoles[atom_is].unsqueeze(-1)
86+
).squeeze(-1)
87+
potential.index_add_(0, atom_js, contributions_js)
88+
89+
# Compensate for double counting of pairs (i,j) and (j,i)
90+
return potential / 2
91+
92+
def _compute_kspace(
93+
self,
94+
dipoles: torch.Tensor,
95+
cell: torch.Tensor,
96+
positions: torch.Tensor,
97+
) -> torch.Tensor:
98+
# Define k-space cutoff from required real-space resolution
99+
k_cutoff = 2 * torch.pi / self.lr_wavelength
100+
101+
# Compute number of times each basis vector of the reciprocal space can be
102+
# scaled until the cutoff is reached
103+
basis_norms = torch.linalg.norm(cell, dim=1)
104+
ns_float = k_cutoff * basis_norms / 2 / torch.pi
105+
ns = torch.ceil(ns_float).long()
106+
107+
# Generate k-vectors and evaluate
108+
kvectors = generate_kvectors_for_ewald(ns=ns, cell=cell)
109+
knorm_sq = torch.sum(kvectors**2, dim=1)
110+
# We remove the singularity at k=0 by explicitly setting its
111+
# value to be equal to zero. This mathematically corresponds
112+
# to the requirement that the net charge of the cell is zero.
113+
# G = 4 * torch.pi * torch.exp(-0.5 * smearing**2 * knorm_sq) / knorm_sq
114+
G = self.potential.lr_from_k_sq(knorm_sq)
115+
116+
# Compute the energy using the explicit method that
117+
# follows directly from the Poisson summation formula.
118+
# For this, we precompute trigonometric factors for optimization, which leads
119+
# to N^2 rather than N^3 scaling.
120+
trig_args = kvectors @ (positions.T) # [k, i]
121+
c = torch.cos(trig_args) # [k, i]
122+
s = torch.sin(trig_args) # [k, i]
123+
sc = torch.stack([c, s], dim=0) # [2 "f", k, i]
124+
mu_k = dipoles @ kvectors.T # [i, k]
125+
sc_summed_G = torch.einsum("fki, ik, k->fk", sc, mu_k, G)
126+
energy = torch.einsum("fk, fki, kc->ic", sc_summed_G, sc, kvectors)
127+
energy /= torch.abs(cell.det())
128+
energy -= dipoles * self.potential.self_contribution()
129+
energy += self.potential.background_correction(
130+
torch.abs(cell.det())
131+
) * dipoles.sum(dim=0)
132+
return energy / 2
133+
134+
def forward(
135+
self,
136+
dipoles: torch.Tensor,
137+
cell: torch.Tensor,
138+
positions: torch.Tensor,
139+
neighbor_indices: torch.Tensor,
140+
neighbor_vectors: torch.Tensor,
141+
):
142+
r"""
143+
Compute the potential "energy".
144+
145+
It is calculated as:
146+
147+
.. math::
148+
149+
V_i = \frac{1}{2} \sum_{j} \boldsymbol{\mu_j} \, \mathbf{v}(\mathbf{r_{ij}})
150+
151+
where :math:`\mathbf{v}(\mathbf{r})` is the pair potential defined by the ``potential``
152+
parameter, and :math:`\boldsymbol{\mu_j}` are atomic "dipoles".
153+
154+
If the ``smearing`` of the ``potential`` is not set, the calculator evaluates
155+
only the real-space part of the potential. Otherwise, provided that the
156+
calculator implements a ``_compute_kspace`` method, it will also evaluate the
157+
long-range part using a Fourier-domain method.
158+
159+
:param dipoles: torch.tensor of shape ``(len(positions), 3)``
160+
containaing the atomic dipoles.
161+
:param cell: torch.tensor of shape ``(3, 3)``, where ``cell[i]`` is the i-th basis
162+
vector of the unit cell
163+
:param positions: torch.tensor of shape ``(N, 3)`` containing the Cartesian
164+
coordinates of the ``N`` particles within the supercell.
165+
:param neighbor_indices: torch.tensor with the ``i,j`` indices of neighbors for
166+
which the potential should be computed in real space.
167+
:param neighbor_vectors: torch.tensor with the pair vectors of the neighbors
168+
for which the potential should be computed in real space.
169+
"""
170+
# TODO: _validate_parameters to allow also dipoles. Temporarily pass the
171+
# distance tensor.
172+
_validate_parameters(
173+
charges=dipoles,
174+
cell=cell,
175+
positions=positions,
176+
neighbor_indices=neighbor_indices,
177+
neighbor_distances=neighbor_vectors.norm(dim=-1),
178+
smearing=self.potential.smearing,
179+
)
180+
181+
# Compute short-range (SR) part using a real space sum
182+
potential_sr = self._compute_rspace(
183+
dipoles=dipoles,
184+
neighbor_indices=neighbor_indices,
185+
neighbor_vectors=neighbor_vectors,
186+
)
187+
188+
if self.potential.smearing is None:
189+
return self.prefactor * potential_sr
190+
# Compute long-range (LR) part using a Fourier / reciprocal space sum
191+
potential_lr = self._compute_kspace(
192+
dipoles=dipoles,
193+
cell=cell,
194+
positions=positions,
195+
)
196+
197+
return self.prefactor * (potential_sr + potential_lr)

src/torchpme/metatensor/calculator.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,10 @@ class Calculator(torch.nn.Module):
1717
"""
1818
Base calculator for the metatensor interface.
1919
20-
This is just a thin wrapper around the corresponding
21-
generic torch :class:`torchpme.calculators.Calculator`.
22-
If you want to wrap a ``metatensor`` interface around another
23-
calculator, you can just define the class and set the static
24-
member ``_base_calculator`` to the corresponding
25-
torch calculator.
20+
This is just a thin wrapper around the corresponding generic torch
21+
:class:`torchpme.calculators.Calculator`. If you want to wrap a ``metatensor``
22+
interface around another calculator, you can just define the class and set the
23+
static member ``_base_calculator`` to the corresponding torch calculator.
2624
"""
2725

2826
_base_calculator: type[torch_calculators.Calculator] = torch_calculators.Calculator

0 commit comments

Comments
 (0)