Skip to content

Commit f68fc0a

Browse files
committed
Add full test suite
1 parent 3b7f949 commit f68fc0a

File tree

8 files changed

+790
-266
lines changed

8 files changed

+790
-266
lines changed

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,14 +153,17 @@ ignore_missing_imports = true
153153
no_strict_optional = true
154154

155155
[tool.pytest.ini_options]
156-
addopts = "-p no:warnings --import-mode=importlib --cov-config=pyproject.toml"
156+
addopts = "-p no:warnings --import-mode=importlib --cov-config=pyproject.toml -m 'not openmm_mace'"
157157
filterwarnings = [
158158
"ignore:.*POTCAR.*:UserWarning",
159159
"ignore:.*input structure.*:UserWarning",
160160
"ignore:.*is not gzipped.*:UserWarning",
161161
"ignore:.*magmom.*:UserWarning",
162162
"ignore::DeprecationWarning",
163163
]
164+
markers = [
165+
"openmm_mace: tests marked openmm_mace are skipped by default because they are very slow (unskip with pytest -m openmm_mace)",
166+
]
164167

165168
[tool.coverage.run]
166169
include = ["src/*"]

src/atomate2/openmm/jobs/mace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import numpy as np
88
import openmm
99
import openmm.unit as omm_unit
10-
from atoms2.comp.md.atomate2.utils import structure_to_topology
1110
from emmet.core.openmm import OpenMMInterchange, OpenMMTaskDocument
1211
from emmet.core.vasp.task_valid import TaskState
1312
from jobflow import Response
@@ -18,7 +17,8 @@
1817
from pymatgen.core import Structure
1918

2019
from atomate2.openmm.jobs.base import openmm_job
21-
from atomate2.openmm.mace_force import MacePotential
20+
from atomate2.openmm.mace_utils import MacePotential
21+
from atomate2.openmm.utils import structure_to_topology
2222

2323

2424
@openmm_job

src/atomate2/openmm/mace_force.py renamed to src/atomate2/openmm/mace_utils.py

Lines changed: 256 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,12 @@ class written by Harry Moore. The nnpops_nl function in the neighbors file
4141
from e3nn.util import jit
4242
from mace.tools import atomic_numbers_to_indices, to_one_hot, utils
4343

44-
from atomate2.openmm.neighbors import nnpops_nl, wrapping_nl
44+
try:
45+
from NNPOps.neighbors import getNeighborPairs
46+
except ImportError as err:
47+
raise ImportError(
48+
"NNPOps is not installed. Please install it from conda-forge."
49+
) from err
4550

4651

4752
class MaceForce(torch.nn.Module):
@@ -313,3 +318,253 @@ def add_forces(
313318
force.setForceGroup(0)
314319
force.setUsesPeriodicBoundaryConditions(periodic)
315320
system.addForce(force)
321+
322+
323+
def nnpops_nl(
324+
positions: torch.Tensor,
325+
cell: torch.Tensor,
326+
pbc: bool,
327+
cutoff: float,
328+
sorti: bool = False,
329+
) -> tuple[torch.Tensor, torch.Tensor]:
330+
"""Run a neighbor list computation using NNPOps.
331+
332+
It outputs neighbors and shifts in the same format as ASE
333+
https://wiki.fysik.dtu.dk/ase/ase/neighborlist.html#ase.neighborlist.primitive_neighbor_list
334+
335+
neighbors, shifts = nnpops_nl(..)
336+
is equivalent to
337+
338+
[i, j], S = primitive_neighbor_list( quantities="ijS", ...)
339+
340+
Parameters
341+
----------
342+
positions : torch.Tensor
343+
Atom positions, shape (num_atoms, 3)
344+
cell : torch.Tensor
345+
Unit cell, shape (3, 3)
346+
pbc : bool
347+
Whether to use periodic boundary conditions
348+
cutoff : float
349+
Cutoff distance for neighbors
350+
sorti : bool, optional
351+
Whether to sort the neighbor list by the first index.
352+
Defaults to False.
353+
354+
Returns
355+
-------
356+
tuple[torch.Tensor, torch.Tensor]
357+
A tuple containing:
358+
- neighbors (torch.Tensor): Neighbor list, shape (2, num_neighbors)
359+
- shifts (torch.Tensor): Shift vectors, shape (num_neighbors, 3)
360+
"""
361+
device = positions.device
362+
neighbors, deltas, _, _ = getNeighborPairs(
363+
positions,
364+
cutoff=cutoff,
365+
max_num_pairs=-1,
366+
box_vectors=cell if pbc else None,
367+
check_errors=False,
368+
)
369+
370+
neighbors = neighbors.to(dtype=torch.long)
371+
372+
# remove empty neighbors
373+
mask = neighbors[0] > -1
374+
neighbors = neighbors[:, mask]
375+
deltas = deltas[mask, :]
376+
377+
# compute shifts TODO: pass deltas and distance directly to model
378+
# From ASE docs:
379+
# wrapped_delta = pos[i] - pos[j] - shift.cell
380+
# => shift = ((pos[i]-pos[j]) - wrapped_delta).cell^-1
381+
if pbc:
382+
shifts = torch.mm(
383+
(positions[neighbors[0]] - positions[neighbors[1]]) - deltas,
384+
torch.linalg.inv(cell),
385+
)
386+
else:
387+
shifts = torch.zeros(deltas.shape, device=device)
388+
389+
# we have i<j, get also i>j
390+
neighbors = torch.hstack((neighbors, torch.stack((neighbors[1], neighbors[0]))))
391+
shifts = torch.vstack((shifts, -shifts))
392+
393+
if sorti:
394+
idx = torch.argsort(neighbors[0])
395+
neighbors = neighbors[:, idx]
396+
shifts = shifts[idx, :]
397+
398+
return neighbors, shifts
399+
400+
401+
@torch.jit.script
402+
def wrapping_nl(
403+
positions: torch.Tensor,
404+
cell: torch.Tensor,
405+
pbc: bool,
406+
cutoff: float,
407+
sorti: bool = False,
408+
) -> tuple[torch.Tensor, torch.Tensor]:
409+
"""Neighbor list including self-interactions across periodic boundaries.
410+
411+
Parameters
412+
----------
413+
positions : torch.Tensor
414+
Atom positions, shape (num_atoms, 3)
415+
cell : torch.Tensor
416+
Unit cell, shape (3, 3)
417+
pbc : bool
418+
Whether to use periodic boundary conditions
419+
cutoff : float
420+
Cutoff distance for neighbors
421+
sorti : bool, optional
422+
Whether to sort the neighbor list by the first index.
423+
Defaults to False.
424+
425+
Returns
426+
-------
427+
tuple[torch.Tensor, torch.Tensor]
428+
A tuple containing:
429+
- neighbors (torch.Tensor): Neighbor list, shape (2, num_neighbors)
430+
- shifts (torch.Tensor): Shift vectors, shape (num_neighbors, 3)
431+
"""
432+
num_atoms = positions.shape[0]
433+
device = positions.device
434+
dtype = positions.dtype
435+
436+
# Get all unique pairs including self-pairs (i <= j)
437+
uij = torch.triu_indices(num_atoms, num_atoms, offset=0, device=device)
438+
i_indices = uij[0]
439+
j_indices = uij[1]
440+
441+
if pbc:
442+
# Compute displacement vectors between atom pairs
443+
deltas = positions[i_indices] - positions[j_indices]
444+
445+
# Compute inverse cell matrix
446+
inv_cell = torch.linalg.inv(cell)
447+
448+
# Compute fractional coordinates of displacement vectors
449+
frac_deltas = torch.matmul(deltas, inv_cell)
450+
451+
# Determine the maximum number of shifts needed along each axis
452+
cell_lengths = torch.linalg.norm(cell, dim=0)
453+
n_max = torch.ceil(cutoff / cell_lengths).to(torch.int32)
454+
455+
# Extract scalar values from n_max
456+
n_max0 = int(n_max[0])
457+
n_max1 = int(n_max[1])
458+
n_max2 = int(n_max[2])
459+
460+
# Generate shift ranges
461+
shift_range_x = torch.arange(-n_max0, n_max0 + 1, device=device, dtype=dtype)
462+
shift_range_y = torch.arange(-n_max1, n_max1 + 1, device=device, dtype=dtype)
463+
shift_range_z = torch.arange(-n_max2, n_max2 + 1, device=device, dtype=dtype)
464+
465+
# Generate all combinations of shifts within the range [-n_max, n_max]
466+
shift_x, shift_y, shift_z = torch.meshgrid(
467+
shift_range_x, shift_range_y, shift_range_z, indexing="ij"
468+
)
469+
470+
shifts_list = torch.stack(
471+
(shift_x.reshape(-1), shift_y.reshape(-1), shift_z.reshape(-1)), dim=1
472+
)
473+
474+
# Total number of shifts
475+
num_shifts = shifts_list.shape[0]
476+
477+
# Expand atom pairs and shifts
478+
num_pairs = i_indices.shape[0]
479+
i_indices_expanded = i_indices.repeat_interleave(num_shifts)
480+
j_indices_expanded = j_indices.repeat_interleave(num_shifts)
481+
shifts_expanded = shifts_list.repeat(num_pairs, 1)
482+
483+
# Expand fractional displacements
484+
frac_deltas_expanded = frac_deltas.repeat_interleave(num_shifts, dim=0)
485+
486+
# Apply shifts to fractional displacements
487+
shifted_frac_deltas = frac_deltas_expanded - shifts_expanded
488+
489+
# Convert back to Cartesian coordinates
490+
shifted_deltas = torch.matmul(shifted_frac_deltas, cell)
491+
492+
# Compute distances
493+
distances = torch.linalg.norm(shifted_deltas, dim=1)
494+
495+
# Apply cutoff filter
496+
within_cutoff = distances <= cutoff
497+
498+
# Exclude self-pairs where shift is zero (no periodic boundary crossing)
499+
shift_zero = (shifts_expanded == 0).all(dim=1)
500+
i_eq_j = i_indices_expanded == j_indices_expanded
501+
exclude_self_zero_shift = i_eq_j & shift_zero
502+
within_cutoff = within_cutoff & (~exclude_self_zero_shift)
503+
504+
num_within_cutoff = int(within_cutoff.sum())
505+
506+
i_indices_final = i_indices_expanded[within_cutoff]
507+
j_indices_final = j_indices_expanded[within_cutoff]
508+
shifts_final = shifts_expanded[within_cutoff]
509+
510+
# Generate neighbor pairs and shifts
511+
neighbors = torch.stack((i_indices_final, j_indices_final), dim=0)
512+
shifts = shifts_final
513+
514+
# Add symmetric pairs (j, i) and negate shifts,
515+
# but avoid duplicates for self-pairs
516+
i_neq_j = i_indices_final != j_indices_final
517+
neighbors_sym = torch.stack(
518+
(j_indices_final[i_neq_j], i_indices_final[i_neq_j]), dim=0
519+
)
520+
shifts_sym = -shifts_final[i_neq_j]
521+
522+
neighbors = torch.cat((neighbors, neighbors_sym), dim=1)
523+
shifts = torch.cat((shifts, shifts_sym), dim=0)
524+
525+
if sorti:
526+
idx = torch.argsort(neighbors[0])
527+
neighbors = neighbors[:, idx]
528+
shifts = shifts[idx, :]
529+
530+
return neighbors, shifts
531+
532+
# Non-periodic case
533+
deltas = positions[i_indices] - positions[j_indices]
534+
distances = torch.linalg.norm(deltas, dim=1)
535+
536+
# Apply cutoff filter
537+
within_cutoff = distances <= cutoff
538+
539+
# Exclude self-pairs where distance is zero
540+
i_eq_j = i_indices == j_indices
541+
exclude_self_zero_distance = i_eq_j & (distances == 0)
542+
within_cutoff = within_cutoff & (~exclude_self_zero_distance)
543+
544+
num_within_cutoff = int(within_cutoff.sum())
545+
546+
i_indices_final = i_indices[within_cutoff]
547+
j_indices_final = j_indices[within_cutoff]
548+
549+
shifts_final = torch.zeros((num_within_cutoff, 3), device=device, dtype=dtype)
550+
551+
# Generate neighbor pairs and shifts
552+
neighbors = torch.stack((i_indices_final, j_indices_final), dim=0)
553+
shifts = shifts_final
554+
555+
# Add symmetric pairs (j, i) and shifts (only if i != j)
556+
i_neq_j = i_indices_final != j_indices_final
557+
neighbors_sym = torch.stack(
558+
(j_indices_final[i_neq_j], i_indices_final[i_neq_j]), dim=0
559+
)
560+
shifts_sym = shifts_final[i_neq_j] # shifts are zero
561+
562+
neighbors = torch.cat((neighbors, neighbors_sym), dim=1)
563+
shifts = torch.cat((shifts, shifts_sym), dim=0)
564+
565+
if sorti:
566+
idx = torch.argsort(neighbors[0])
567+
neighbors = neighbors[:, idx]
568+
shifts = shifts[idx, :]
569+
570+
return neighbors, shifts

0 commit comments

Comments
 (0)