Skip to content

Commit 80fbb1b

Browse files
committed
add test for exclusion radius functionality and update changelog.rst
1 parent 80cf9a1 commit 80fbb1b

File tree

2 files changed

+56
-2
lines changed

2 files changed

+56
-2
lines changed

docs/src/references/changelog.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ Added
3030
* Add a nan check for ``KSpaceFilter``
3131
* Allow passing ``full_neighbor_list`` and ``prefactor`` to tuning functions
3232

33+
Fixed
34+
#####
35+
36+
* Fix exclusion_radius to work in a direct calculator
3337
`Version 0.3.0 <https://github.com/lab-cosmo/torch-pme/releases/tag/v0.3.0>`_ - 2025-02-21
3438
------------------------------------------------------------------------------------------
3539

tests/calculators/test_calculator.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,13 @@
1616

1717
# non-range-separated Coulomb direct calculator
1818
class CalculatorTest(Calculator):
19-
def __init__(self):
19+
def __init__(self, exclusion_radius=None, exclusion_degree=1):
2020
super().__init__(
21-
potential=CoulombPotential(smearing=None, exclusion_radius=None)
21+
potential=CoulombPotential(
22+
smearing=None,
23+
exclusion_radius=exclusion_radius,
24+
exclusion_degree=exclusion_degree,
25+
),
2226
)
2327

2428

@@ -255,3 +259,49 @@ def test_invalid_dtype_neighbor_distances():
255259
neighbor_indices=NEIGHBOR_INDICES,
256260
neighbor_distances=NEIGHBOR_DISTANCES.to(dtype=torch.float64),
257261
)
262+
263+
264+
def test_exclusion_radius():
265+
"""Test that the exclusion radius is applied correctly"""
266+
exclusion_radius = 4.0
267+
exclusion_degree = 8
268+
calculator1 = CalculatorTest()
269+
potential1 = calculator1.forward(
270+
positions=POSITIONS_1,
271+
charges=CHARGES_1,
272+
cell=CELL_1,
273+
neighbor_indices=NEIGHBOR_INDICES,
274+
neighbor_distances=NEIGHBOR_DISTANCES,
275+
)
276+
calculator2 = CalculatorTest(
277+
exclusion_radius=exclusion_radius, exclusion_degree=exclusion_degree
278+
)
279+
potential2 = calculator2.forward(
280+
positions=POSITIONS_1,
281+
charges=CHARGES_1,
282+
cell=CELL_1,
283+
neighbor_indices=NEIGHBOR_INDICES,
284+
neighbor_distances=NEIGHBOR_DISTANCES,
285+
)
286+
assert torch.allclose(
287+
(
288+
potential1
289+
* (
290+
1
291+
- (
292+
1
293+
- (
294+
(
295+
1
296+
- torch.cos(
297+
torch.pi * (NEIGHBOR_DISTANCES[0] / exclusion_radius)
298+
)
299+
)
300+
* 0.5
301+
)
302+
** exclusion_degree
303+
)
304+
)
305+
),
306+
potential2,
307+
)

0 commit comments

Comments
 (0)