Skip to content

Commit eef94c8

Browse files
authored
Add a nan check for KSpaceFilter (#182)
1 parent 97fb4dc commit eef94c8

File tree

2 files changed

+50
-1
lines changed

2 files changed

+50
-1
lines changed

src/torchpme/lib/kspace_filter.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def forward(self, mesh_values: torch.Tensor) -> torch.Tensor:
178178

179179
filter_hat = mesh_hat * self._kfilter
180180

181-
return torch.fft.irfftn(
181+
result = torch.fft.irfftn(
182182
filter_hat,
183183
norm=self._ifft_norm,
184184
dim=dims,
@@ -188,6 +188,16 @@ def forward(self, mesh_values: torch.Tensor) -> torch.Tensor:
188188
s=mesh_values.shape[-3:],
189189
)
190190

191+
if torch.isnan(result).any():
192+
raise ValueError(
193+
"NaNs detected in the k-space filter result. This are probably caused "
194+
"by an unsuitable `mesh_spacing`, resulting in a problematic grid of "
195+
f"shape: {list(mesh_values.shape)}. Try adjsuting the grid by using a "
196+
"different `mesh_spacing` value."
197+
)
198+
199+
return result
200+
191201
def _prep_kvectors(
192202
self, cell: Optional[torch.Tensor], ns_mesh: Optional[torch.Tensor]
193203
):

tests/calculators/test_workflow.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,42 @@ def test_smearing_incompatability(self, CalculatorClass, params, device, dtype):
205205
TypeError, match="Must specify smearing to use a potential with .*"
206206
):
207207
CalculatorClass(**params)
208+
209+
210+
def test_kspace_filter_error_catch():
211+
interpolation_nodes = 5
212+
213+
calculator = P3MCalculator(
214+
potential=CoulombPotential(smearing=1, exclusion_radius=4.5),
215+
interpolation_nodes=interpolation_nodes,
216+
full_neighbor_list=True,
217+
mesh_spacing=0.5,
218+
)
219+
220+
charges = torch.ones([4, 1])
221+
positions = torch.arange(4 * 3).reshape(4, 3).to(torch.float32)
222+
cell = torch.tensor(
223+
[
224+
[-2.2958, -0.5882, -0.0797],
225+
[1.3575, -0.2575, -1.9272],
226+
[1.9694, -5.7254, 2.1524],
227+
],
228+
)
229+
230+
neighbor_indices = torch.zeros((0, 2), dtype=torch.int64)
231+
neighbor_distances = torch.zeros((0,))
232+
233+
match = (
234+
"NaNs detected in the k-space filter result. This are probably caused "
235+
"by an unsuitable `mesh_spacing`, resulting in a problematic grid of "
236+
r"shape: \[1, 16, 16, 32\]. Try adjsuting the grid by using a "
237+
"different `mesh_spacing` value."
238+
)
239+
with pytest.raises(ValueError, match=match):
240+
calculator.forward(
241+
charges=charges,
242+
positions=positions,
243+
cell=cell,
244+
neighbor_indices=neighbor_indices,
245+
neighbor_distances=neighbor_distances,
246+
)

0 commit comments

Comments
 (0)