@@ -41,7 +41,12 @@ class written by Harry Moore. The nnpops_nl function in the neighbors file
41
41
from e3nn .util import jit
42
42
from mace .tools import atomic_numbers_to_indices , to_one_hot , utils
43
43
44
- from atomate2 .openmm .neighbors import nnpops_nl , wrapping_nl
44
+ try :
45
+ from NNPOps .neighbors import getNeighborPairs
46
+ except ImportError as err :
47
+ raise ImportError (
48
+ "NNPOps is not installed. Please install it from conda-forge."
49
+ ) from err
45
50
46
51
47
52
class MaceForce (torch .nn .Module ):
@@ -313,3 +318,253 @@ def add_forces(
313
318
force .setForceGroup (0 )
314
319
force .setUsesPeriodicBoundaryConditions (periodic )
315
320
system .addForce (force )
321
+
322
+
323
+ def nnpops_nl (
324
+ positions : torch .Tensor ,
325
+ cell : torch .Tensor ,
326
+ pbc : bool ,
327
+ cutoff : float ,
328
+ sorti : bool = False ,
329
+ ) -> tuple [torch .Tensor , torch .Tensor ]:
330
+ """Run a neighbor list computation using NNPOps.
331
+
332
+ It outputs neighbors and shifts in the same format as ASE
333
+ https://wiki.fysik.dtu.dk/ase/ase/neighborlist.html#ase.neighborlist.primitive_neighbor_list
334
+
335
+ neighbors, shifts = nnpops_nl(..)
336
+ is equivalent to
337
+
338
+ [i, j], S = primitive_neighbor_list( quantities="ijS", ...)
339
+
340
+ Parameters
341
+ ----------
342
+ positions : torch.Tensor
343
+ Atom positions, shape (num_atoms, 3)
344
+ cell : torch.Tensor
345
+ Unit cell, shape (3, 3)
346
+ pbc : bool
347
+ Whether to use periodic boundary conditions
348
+ cutoff : float
349
+ Cutoff distance for neighbors
350
+ sorti : bool, optional
351
+ Whether to sort the neighbor list by the first index.
352
+ Defaults to False.
353
+
354
+ Returns
355
+ -------
356
+ tuple[torch.Tensor, torch.Tensor]
357
+ A tuple containing:
358
+ - neighbors (torch.Tensor): Neighbor list, shape (2, num_neighbors)
359
+ - shifts (torch.Tensor): Shift vectors, shape (num_neighbors, 3)
360
+ """
361
+ device = positions .device
362
+ neighbors , deltas , _ , _ = getNeighborPairs (
363
+ positions ,
364
+ cutoff = cutoff ,
365
+ max_num_pairs = - 1 ,
366
+ box_vectors = cell if pbc else None ,
367
+ check_errors = False ,
368
+ )
369
+
370
+ neighbors = neighbors .to (dtype = torch .long )
371
+
372
+ # remove empty neighbors
373
+ mask = neighbors [0 ] > - 1
374
+ neighbors = neighbors [:, mask ]
375
+ deltas = deltas [mask , :]
376
+
377
+ # compute shifts TODO: pass deltas and distance directly to model
378
+ # From ASE docs:
379
+ # wrapped_delta = pos[i] - pos[j] - shift.cell
380
+ # => shift = ((pos[i]-pos[j]) - wrapped_delta).cell^-1
381
+ if pbc :
382
+ shifts = torch .mm (
383
+ (positions [neighbors [0 ]] - positions [neighbors [1 ]]) - deltas ,
384
+ torch .linalg .inv (cell ),
385
+ )
386
+ else :
387
+ shifts = torch .zeros (deltas .shape , device = device )
388
+
389
+ # we have i<j, get also i>j
390
+ neighbors = torch .hstack ((neighbors , torch .stack ((neighbors [1 ], neighbors [0 ]))))
391
+ shifts = torch .vstack ((shifts , - shifts ))
392
+
393
+ if sorti :
394
+ idx = torch .argsort (neighbors [0 ])
395
+ neighbors = neighbors [:, idx ]
396
+ shifts = shifts [idx , :]
397
+
398
+ return neighbors , shifts
399
+
400
+
401
+ @torch .jit .script
402
+ def wrapping_nl (
403
+ positions : torch .Tensor ,
404
+ cell : torch .Tensor ,
405
+ pbc : bool ,
406
+ cutoff : float ,
407
+ sorti : bool = False ,
408
+ ) -> tuple [torch .Tensor , torch .Tensor ]:
409
+ """Neighbor list including self-interactions across periodic boundaries.
410
+
411
+ Parameters
412
+ ----------
413
+ positions : torch.Tensor
414
+ Atom positions, shape (num_atoms, 3)
415
+ cell : torch.Tensor
416
+ Unit cell, shape (3, 3)
417
+ pbc : bool
418
+ Whether to use periodic boundary conditions
419
+ cutoff : float
420
+ Cutoff distance for neighbors
421
+ sorti : bool, optional
422
+ Whether to sort the neighbor list by the first index.
423
+ Defaults to False.
424
+
425
+ Returns
426
+ -------
427
+ tuple[torch.Tensor, torch.Tensor]
428
+ A tuple containing:
429
+ - neighbors (torch.Tensor): Neighbor list, shape (2, num_neighbors)
430
+ - shifts (torch.Tensor): Shift vectors, shape (num_neighbors, 3)
431
+ """
432
+ num_atoms = positions .shape [0 ]
433
+ device = positions .device
434
+ dtype = positions .dtype
435
+
436
+ # Get all unique pairs including self-pairs (i <= j)
437
+ uij = torch .triu_indices (num_atoms , num_atoms , offset = 0 , device = device )
438
+ i_indices = uij [0 ]
439
+ j_indices = uij [1 ]
440
+
441
+ if pbc :
442
+ # Compute displacement vectors between atom pairs
443
+ deltas = positions [i_indices ] - positions [j_indices ]
444
+
445
+ # Compute inverse cell matrix
446
+ inv_cell = torch .linalg .inv (cell )
447
+
448
+ # Compute fractional coordinates of displacement vectors
449
+ frac_deltas = torch .matmul (deltas , inv_cell )
450
+
451
+ # Determine the maximum number of shifts needed along each axis
452
+ cell_lengths = torch .linalg .norm (cell , dim = 0 )
453
+ n_max = torch .ceil (cutoff / cell_lengths ).to (torch .int32 )
454
+
455
+ # Extract scalar values from n_max
456
+ n_max0 = int (n_max [0 ])
457
+ n_max1 = int (n_max [1 ])
458
+ n_max2 = int (n_max [2 ])
459
+
460
+ # Generate shift ranges
461
+ shift_range_x = torch .arange (- n_max0 , n_max0 + 1 , device = device , dtype = dtype )
462
+ shift_range_y = torch .arange (- n_max1 , n_max1 + 1 , device = device , dtype = dtype )
463
+ shift_range_z = torch .arange (- n_max2 , n_max2 + 1 , device = device , dtype = dtype )
464
+
465
+ # Generate all combinations of shifts within the range [-n_max, n_max]
466
+ shift_x , shift_y , shift_z = torch .meshgrid (
467
+ shift_range_x , shift_range_y , shift_range_z , indexing = "ij"
468
+ )
469
+
470
+ shifts_list = torch .stack (
471
+ (shift_x .reshape (- 1 ), shift_y .reshape (- 1 ), shift_z .reshape (- 1 )), dim = 1
472
+ )
473
+
474
+ # Total number of shifts
475
+ num_shifts = shifts_list .shape [0 ]
476
+
477
+ # Expand atom pairs and shifts
478
+ num_pairs = i_indices .shape [0 ]
479
+ i_indices_expanded = i_indices .repeat_interleave (num_shifts )
480
+ j_indices_expanded = j_indices .repeat_interleave (num_shifts )
481
+ shifts_expanded = shifts_list .repeat (num_pairs , 1 )
482
+
483
+ # Expand fractional displacements
484
+ frac_deltas_expanded = frac_deltas .repeat_interleave (num_shifts , dim = 0 )
485
+
486
+ # Apply shifts to fractional displacements
487
+ shifted_frac_deltas = frac_deltas_expanded - shifts_expanded
488
+
489
+ # Convert back to Cartesian coordinates
490
+ shifted_deltas = torch .matmul (shifted_frac_deltas , cell )
491
+
492
+ # Compute distances
493
+ distances = torch .linalg .norm (shifted_deltas , dim = 1 )
494
+
495
+ # Apply cutoff filter
496
+ within_cutoff = distances <= cutoff
497
+
498
+ # Exclude self-pairs where shift is zero (no periodic boundary crossing)
499
+ shift_zero = (shifts_expanded == 0 ).all (dim = 1 )
500
+ i_eq_j = i_indices_expanded == j_indices_expanded
501
+ exclude_self_zero_shift = i_eq_j & shift_zero
502
+ within_cutoff = within_cutoff & (~ exclude_self_zero_shift )
503
+
504
+ num_within_cutoff = int (within_cutoff .sum ())
505
+
506
+ i_indices_final = i_indices_expanded [within_cutoff ]
507
+ j_indices_final = j_indices_expanded [within_cutoff ]
508
+ shifts_final = shifts_expanded [within_cutoff ]
509
+
510
+ # Generate neighbor pairs and shifts
511
+ neighbors = torch .stack ((i_indices_final , j_indices_final ), dim = 0 )
512
+ shifts = shifts_final
513
+
514
+ # Add symmetric pairs (j, i) and negate shifts,
515
+ # but avoid duplicates for self-pairs
516
+ i_neq_j = i_indices_final != j_indices_final
517
+ neighbors_sym = torch .stack (
518
+ (j_indices_final [i_neq_j ], i_indices_final [i_neq_j ]), dim = 0
519
+ )
520
+ shifts_sym = - shifts_final [i_neq_j ]
521
+
522
+ neighbors = torch .cat ((neighbors , neighbors_sym ), dim = 1 )
523
+ shifts = torch .cat ((shifts , shifts_sym ), dim = 0 )
524
+
525
+ if sorti :
526
+ idx = torch .argsort (neighbors [0 ])
527
+ neighbors = neighbors [:, idx ]
528
+ shifts = shifts [idx , :]
529
+
530
+ return neighbors , shifts
531
+
532
+ # Non-periodic case
533
+ deltas = positions [i_indices ] - positions [j_indices ]
534
+ distances = torch .linalg .norm (deltas , dim = 1 )
535
+
536
+ # Apply cutoff filter
537
+ within_cutoff = distances <= cutoff
538
+
539
+ # Exclude self-pairs where distance is zero
540
+ i_eq_j = i_indices == j_indices
541
+ exclude_self_zero_distance = i_eq_j & (distances == 0 )
542
+ within_cutoff = within_cutoff & (~ exclude_self_zero_distance )
543
+
544
+ num_within_cutoff = int (within_cutoff .sum ())
545
+
546
+ i_indices_final = i_indices [within_cutoff ]
547
+ j_indices_final = j_indices [within_cutoff ]
548
+
549
+ shifts_final = torch .zeros ((num_within_cutoff , 3 ), device = device , dtype = dtype )
550
+
551
+ # Generate neighbor pairs and shifts
552
+ neighbors = torch .stack ((i_indices_final , j_indices_final ), dim = 0 )
553
+ shifts = shifts_final
554
+
555
+ # Add symmetric pairs (j, i) and shifts (only if i != j)
556
+ i_neq_j = i_indices_final != j_indices_final
557
+ neighbors_sym = torch .stack (
558
+ (j_indices_final [i_neq_j ], i_indices_final [i_neq_j ]), dim = 0
559
+ )
560
+ shifts_sym = shifts_final [i_neq_j ] # shifts are zero
561
+
562
+ neighbors = torch .cat ((neighbors , neighbors_sym ), dim = 1 )
563
+ shifts = torch .cat ((shifts , shifts_sym ), dim = 0 )
564
+
565
+ if sorti :
566
+ idx = torch .argsort (neighbors [0 ])
567
+ neighbors = neighbors [:, idx ]
568
+ shifts = shifts [idx , :]
569
+
570
+ return neighbors , shifts
0 commit comments