4
4
from pathlib import Path
5
5
from typing import TYPE_CHECKING , Literal
6
6
7
+ import numpy as np
7
8
import pytest
9
+ from pymatgen .util .coord import find_in_coord_list_pbc
8
10
9
11
if TYPE_CHECKING :
10
12
from collections .abc import Sequence
11
13
14
+ from pymatgen .core .structure import Structure
15
+
12
16
13
17
logger = logging .getLogger ("atomate2" )
14
18
@@ -130,9 +134,58 @@ def check_run_abi(ref_path: str | Path):
130
134
ref_str = file .read ()
131
135
ref = AbinitInputFile .from_string (ref_str .decode ("utf-8" ))
132
136
# Ignore the pseudos as the directory depends on the pseudo root directory
133
- diffs = user .get_differences (ref , ignore_vars = ["pseudos" ])
137
+ # diffs = user.get_differences(ref, ignore_vars=["pseudos"])
138
+ diffs = _get_differences_tol (user , ref , ignore_vars = ["pseudos" ])
134
139
# TODO: should we still add some check on the pseudos here ?
135
- assert diffs == [], "'run.abi' is different from reference."
140
+ assert diffs == [], f"'run.abi' is different from reference: \n { diffs } "
141
+
142
+
143
+ # Adapted from check_poscar in atomate2.utils.testing.vasp.py
144
+ def check_equivalent_frac_coords (
145
+ struct : Structure ,
146
+ struct_ref : Structure ,
147
+ atol = 1e-3 ,
148
+ ) -> None :
149
+ """Check that the frac. coords. of two structures are equivalent (includes pbc)."""
150
+
151
+ user_frac_coords = struct .frac_coords
152
+ ref_frac_coords = struct_ref .frac_coords
153
+
154
+ # In some cases, the ordering of sites can change when copying input files.
155
+ # To account for this, we check that the sites are the same, within a tolerance,
156
+ # while accounting for PBC.
157
+ coord_match = [
158
+ len (find_in_coord_list_pbc (ref_frac_coords , coord , atol = atol )) > 0
159
+ for coord in user_frac_coords
160
+ ]
161
+ assert all (coord_match ), f"The two structures have different frac. coords: \
162
+ { user_frac_coords } vs. { ref_frac_coords } ."
163
+
164
+
165
+ def check_equivalent_znucl_typat (
166
+ znucl_a : list | np .ndarray ,
167
+ znucl_b : list | np .ndarray ,
168
+ typat_a : list | np .ndarray ,
169
+ typat_b : list | np .ndarray ,
170
+ ) -> None :
171
+ """Check that the elements and their number of atoms are equivalent."""
172
+
173
+ sorted_znucl_a = sorted (znucl_a , reverse = True )
174
+ sorted_znucl_b = sorted (znucl_b , reverse = True )
175
+ assert (
176
+ sorted_znucl_a == sorted_znucl_b
177
+ ), f"The elements are different: { znucl_a } vs. { znucl_b } "
178
+
179
+ count_sorted_znucl_a = [
180
+ list (typat_a ).count (list (znucl_a ).index (s ) + 1 ) for s in sorted_znucl_a
181
+ ]
182
+ count_sorted_znucl_b = [
183
+ list (typat_b ).count (list (znucl_b ).index (s ) + 1 ) for s in sorted_znucl_b
184
+ ]
185
+ assert (
186
+ count_sorted_znucl_a == count_sorted_znucl_b
187
+ ), f"The number of same elements is different: \
188
+ { count_sorted_znucl_a } vs. { count_sorted_znucl_b } "
136
189
137
190
138
191
def check_abinit_input_json (ref_path : str | Path ):
@@ -141,9 +194,29 @@ def check_abinit_input_json(ref_path: str | Path):
141
194
142
195
user = loadfn ("abinit_input.json" )
143
196
assert isinstance (user , AbinitInput )
197
+ user_abivars = user .structure .to_abivars ()
198
+
144
199
ref = loadfn (ref_path / "inputs" / "abinit_input.json.gz" )
145
- assert user .structure == ref .structure
146
- assert user .runlevel == ref .runlevel
200
+ ref_abivars = ref .structure .to_abivars ()
201
+
202
+ check_equivalent_frac_coords (user .structure , ref .structure )
203
+ check_equivalent_znucl_typat (
204
+ user_abivars ["znucl" ],
205
+ ref_abivars ["znucl" ],
206
+ user_abivars ["typat" ],
207
+ ref_abivars ["typat" ],
208
+ )
209
+
210
+ for k , user_v in user_abivars .items ():
211
+ if k in ["xred" , "znucl" , "typat" ]:
212
+ continue
213
+ assert k in ref_abivars , f"{ k = } is not a key of the reference input."
214
+ ref_v = ref_abivars [k ]
215
+ if isinstance (user_v , str ):
216
+ assert user_v == ref_v , f"{ k = } -->{ user_v = } versus { ref_v = } "
217
+ else :
218
+ assert np .allclose (user_v , ref_v ), f"{ k = } -->{ user_v = } versus { ref_v = } "
219
+ assert user .runlevel == ref .runlevel , f"{ user .runlevel = } versus { ref .runlevel = } "
147
220
148
221
149
222
def clear_abinit_files ():
@@ -170,3 +243,93 @@ def copy_abinit_outputs(ref_path: str | Path):
170
243
if file .is_file ():
171
244
shutil .copy (file , data_dir )
172
245
decompress_file (str (Path (data_dir , file .name )))
246
+
247
+
248
+ # Patch to allow for a tolerance in the comparison of the ABINIT input variables
249
+ # TODO: remove once new version of Abipy is released
250
+ def _get_differences_tol (
251
+ abi1 , abi2 , ignore_vars = None , rtol = 1e-5 , atol = 1e-12
252
+ ) -> list [str ]:
253
+ """
254
+ Get the differences between this AbinitInputFile and another.
255
+ Allow tolerance for floats.
256
+ """
257
+ diffs = []
258
+ to_ignore = {
259
+ "acell" ,
260
+ "angdeg" ,
261
+ "rprim" ,
262
+ "ntypat" ,
263
+ "natom" ,
264
+ "znucl" ,
265
+ "typat" ,
266
+ "xred" ,
267
+ "xcart" ,
268
+ "xangst" ,
269
+ }
270
+ if ignore_vars is not None :
271
+ to_ignore .update (ignore_vars )
272
+ if abi1 .ndtset != abi2 .ndtset :
273
+ diffs .append (
274
+ f"Number of datasets in this file is { abi1 .ndtset } "
275
+ f"while other file has { abi2 .ndtset } datasets."
276
+ )
277
+ return diffs
278
+ for idataset , self_dataset in enumerate (abi1 .datasets ):
279
+ other_dataset = abi2 .datasets [idataset ]
280
+ if self_dataset .structure != other_dataset .structure :
281
+ diffs .append ("Structures are different." )
282
+ self_dataset_dict = dict (self_dataset )
283
+ other_dataset_dict = dict (other_dataset )
284
+ for k in to_ignore :
285
+ if k in self_dataset_dict :
286
+ del self_dataset_dict [k ]
287
+ if k in other_dataset_dict :
288
+ del other_dataset_dict [k ]
289
+ common_keys = set (self_dataset_dict .keys ()).intersection (
290
+ other_dataset_dict .keys ()
291
+ )
292
+ self_only_keys = set (self_dataset_dict .keys ()).difference (
293
+ other_dataset_dict .keys ()
294
+ )
295
+ other_only_keys = set (other_dataset_dict .keys ()).difference (
296
+ self_dataset_dict .keys ()
297
+ )
298
+ if self_only_keys :
299
+ diffs .append (
300
+ f"The following variables are in this file but not in other: "
301
+ f"{ ', ' .join ([str (k ) for k in self_only_keys ])} "
302
+ )
303
+ if other_only_keys :
304
+ diffs .append (
305
+ f"The following variables are in other file but not in this one: "
306
+ f"{ ', ' .join ([str (k ) for k in other_only_keys ])} "
307
+ )
308
+ for k in common_keys :
309
+ val1 = self_dataset_dict [k ]
310
+ val2 = other_dataset_dict [k ]
311
+ matched = False
312
+ if isinstance (val1 , str ):
313
+ if val1 .endswith (" Ha" ):
314
+ val1 = val1 .replace (" Ha" , "" )
315
+ if val1 .count ("." ) <= 1 and val1 .replace ("." , "" ).isdecimal ():
316
+ val1 = float (val1 )
317
+
318
+ if isinstance (val2 , str ):
319
+ if val2 .endswith (" Ha" ):
320
+ val2 = val2 .replace (" Ha" , "" )
321
+ if val2 .count ("." ) <= 1 and val2 .replace ("." , "" ).isdecimal ():
322
+ val2 = float (val2 )
323
+
324
+ if isinstance (val1 , float ):
325
+ matched = pytest .approx (val1 , rel = rtol , abs = atol ) == val2
326
+ else :
327
+ matched = self_dataset_dict [k ] == other_dataset_dict [k ]
328
+
329
+ if not matched :
330
+ diffs .append (
331
+ f"The variable '{ k } ' is different in the two files:\n "
332
+ f" - this file: '{ self_dataset_dict [k ]} '\n "
333
+ f" - other file: '{ other_dataset_dict [k ]} '"
334
+ )
335
+ return diffs
0 commit comments