Skip to content

Commit 8c49db9

Browse files
committed
add tests anaddb
1 parent fd54870 commit 8c49db9

File tree

18 files changed

+168
-25
lines changed

18 files changed

+168
-25
lines changed

src/atomate2/abinit/jobs/anaddb.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import TYPE_CHECKING
99

1010
import jobflow
11+
import numpy as np
1112
from jobflow import Maker, Response, job
1213

1314
from atomate2 import SETTINGS
@@ -28,9 +29,7 @@
2829

2930
logger = logging.getLogger(__name__)
3031

31-
__all__ = [
32-
"AnaddbMaker",
33-
]
32+
__all__ = ["AnaddbMaker", "AnaddbDfptDteMaker"]
3433

3534

3635
@dataclass
@@ -71,7 +70,8 @@ def make(
7170
A JobHistory object containing the history of this job.
7271
"""
7372
# Flatten the list of previous outputs dir
74-
prev_outputs = [item for sublist in prev_outputs for item in sublist]
73+
# prev_outputs = [item for sublist in prev_outputs for item in sublist]
74+
prev_outputs = list(np.hstack(prev_outputs))
7575

7676
# Setup job and get general job configuration
7777
config = setup_job(

tests/abinit/conftest.py

+115-21
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
_FAKE_RUN_ABINIT_KWARGS = {}
1818
_MRGDDB_FILES = "mrgddb.in"
1919
_FAKE_RUN_MRGDDB_KWARGS = {}
20+
_ANADDB_FILES = ("anaddb.in", "anaddb_input.json")
21+
_FAKE_RUN_ANADDB_KWARGS = {}
2022

2123

2224
@pytest.fixture(scope="session")
@@ -195,7 +197,7 @@ def mock_run_mrgddb(wall_time=None, start_time=None):
195197
index = CURRENT_JOB.job.index
196198
ref_path = abinit_test_dir / _REF_PATHS[name][str(index)]
197199
check_mrgddb_inputs(ref_path)
198-
fake_run_mrgddb(ref_path)
200+
fake_run_abinit(ref_path)
199201

200202
mocker.patch.object(atomate2.abinit.run, "run_mrgddb", mock_run_mrgddb)
201203
mocker.patch.object(atomate2.abinit.jobs.mrgddb, "run_mrgddb", mock_run_mrgddb)
@@ -213,26 +215,6 @@ def _run(ref_paths, fake_run_mrgddb_kwargs=None):
213215
_FAKE_RUN_MRGDDB_KWARGS.clear()
214216

215217

216-
def fake_run_mrgddb(ref_path: str | Path):
217-
"""
218-
Emulate running Mrgddb.
219-
220-
Parameters
221-
----------
222-
ref_path
223-
Path to reference directory with Mrgddb input files in the folder named 'inputs'
224-
and output files in the folder named 'outputs'.
225-
"""
226-
logger.info("Running fake Mrgddb.")
227-
228-
ref_path = Path(ref_path)
229-
230-
copy_abinit_outputs(ref_path)
231-
232-
# pretend to run Mrgddb by copying pre-generated outputs from reference dir
233-
logger.info("Generated fake Mrgddb outputs")
234-
235-
236218
def check_mrgddb_inputs(
237219
ref_path: str | Path,
238220
check_inputs: Sequence[Literal["mrgddb.in"]] = _MRGDDB_FILES,
@@ -270,6 +252,118 @@ def check_mrgddb_in(ref_path: str | Path):
270252
assert str_in == ref_str, "'mrgddb.in' is different from reference."
271253

272254

255+
@pytest.fixture()
256+
def mock_anaddb(mocker, abinit_test_dir, abinit_integration_tests):
257+
"""
258+
This fixture allows one to mock running Anaddb.
259+
260+
It works by monkeypatching (replacing) calls to run_anaddb.
261+
262+
The primary idea is that instead of running Anaddb to generate the output files,
263+
reference files will be copied into the directory instead.
264+
"""
265+
import atomate2.abinit.files
266+
import atomate2.abinit.jobs.anaddb
267+
import atomate2.abinit.run
268+
269+
# Wrap the write_anaddb_input_set so that we can check inputs after calling it
270+
def wrapped_write_anaddb_input_set(*args, **kwargs):
271+
from jobflow import CURRENT_JOB
272+
273+
name = CURRENT_JOB.job.name
274+
index = CURRENT_JOB.job.index
275+
ref_path = abinit_test_dir / _REF_PATHS[name][str(index)]
276+
277+
atomate2.abinit.files.write_anaddb_input_set(*args, **kwargs)
278+
check_anaddb_inputs(ref_path)
279+
280+
mocker.patch.object(
281+
atomate2.abinit.jobs.anaddb,
282+
"write_anaddb_input_set",
283+
wrapped_write_anaddb_input_set,
284+
)
285+
286+
if not abinit_integration_tests:
287+
# Mock anaddb run (i.e. this will copy reference files)
288+
def mock_run_anaddb(wall_time=None, start_time=None):
289+
from jobflow import CURRENT_JOB
290+
291+
name = CURRENT_JOB.job.name
292+
index = CURRENT_JOB.job.index
293+
ref_path = abinit_test_dir / _REF_PATHS[name][str(index)]
294+
check_anaddb_inputs(ref_path)
295+
fake_run_abinit(ref_path)
296+
297+
mocker.patch.object(atomate2.abinit.run, "run_anaddb", mock_run_anaddb)
298+
mocker.patch.object(atomate2.abinit.jobs.anaddb, "run_anaddb", mock_run_anaddb)
299+
300+
def _run(ref_paths, fake_run_anaddb_kwargs=None):
301+
if fake_run_anaddb_kwargs is None:
302+
fake_run_anaddb_kwargs = {}
303+
_REF_PATHS.update(ref_paths)
304+
_FAKE_RUN_ANADDB_KWARGS.update(fake_run_anaddb_kwargs)
305+
306+
yield _run
307+
308+
mocker.stopall()
309+
_REF_PATHS.clear()
310+
_FAKE_RUN_ANADDB_KWARGS.clear()
311+
312+
313+
def check_anaddb_inputs(
314+
ref_path: str | Path,
315+
check_inputs: Sequence[Literal["anaddb.in"]] = _ANADDB_FILES,
316+
):
317+
ref_path = Path(ref_path)
318+
319+
if "anaddb.in" in check_inputs:
320+
check_anaddb_in(ref_path)
321+
322+
if "anaddb_input.json" in check_inputs:
323+
check_anaddb_input_json(ref_path)
324+
325+
logger.info("Verified inputs successfully")
326+
327+
328+
def convert_file_to_dict(file_path):
329+
import gzip
330+
331+
result_dict = {}
332+
333+
if file_path.endswith(".gz"):
334+
file_opener = gzip.open
335+
mode = "rt" # read text mode for gzip
336+
else:
337+
file_opener = open
338+
mode = "r"
339+
340+
with file_opener(file_path, mode) as file:
341+
for line in file:
342+
key, value = line.split()
343+
try:
344+
result_dict[key] = int(value) # Assuming values are integers
345+
except ValueError:
346+
result_dict[key] = str(value) # Fall back to string if not an integer
347+
return result_dict
348+
349+
350+
def check_anaddb_in(ref_path: str | Path):
351+
user = convert_file_to_dict("anaddb.in")
352+
ref = convert_file_to_dict(str(ref_path / "inputs" / "anaddb.in.gz"))
353+
assert user == ref, "'anaddb.in' is different from reference."
354+
355+
356+
def check_anaddb_input_json(ref_path: str | Path):
357+
from abipy.abio.inputs import AnaddbInput
358+
from monty.serialization import loadfn
359+
360+
user = loadfn("anaddb_input.json")
361+
assert isinstance(user, AnaddbInput)
362+
ref = loadfn(ref_path / "inputs" / "anaddb_input.json.gz")
363+
assert user.structure == ref.structure
364+
assert user == ref
365+
366+
273367
def copy_abinit_outputs(ref_path: str | Path):
274368
import shutil
275369

tests/abinit/jobs/test_anaddb.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
def test_anaddb_dfpt_dte_silicon_carbide_standard(
2+
mock_anaddb, abinit_test_dir, clean_dir
3+
):
4+
import os
5+
6+
from jobflow import run_locally
7+
from monty.serialization import loadfn
8+
from pymatgen.core.structure import Structure
9+
10+
from atomate2.abinit.schemas.anaddb import AnaddbTaskDoc
11+
12+
# load the initial structure, the maker and the ref_paths from the test_dir
13+
test_dir = (
14+
abinit_test_dir
15+
/ "jobs"
16+
/ "anaddb"
17+
/ "AnaddbDfptDteMaker"
18+
/ "silicon_carbide_standard"
19+
)
20+
structure = Structure.from_file(test_dir / "initial_structure.json.gz")
21+
maker_info = loadfn(test_dir / "maker.json.gz")
22+
maker = maker_info["maker"]
23+
ref_paths = loadfn(test_dir / "ref_paths.json.gz")
24+
25+
from pathlib import Path
26+
27+
from monty.shutil import copy_r, decompress_dir, remove
28+
29+
path_tmp_prev_outputs = Path(os.getcwd()) / "prev_outputs"
30+
if path_tmp_prev_outputs.exists():
31+
remove(path_tmp_prev_outputs)
32+
os.mkdir(path_tmp_prev_outputs)
33+
copy_r(src=test_dir / "prev_outputs", dst=path_tmp_prev_outputs)
34+
decompress_dir(path_tmp_prev_outputs)
35+
36+
prev_outputs = [
37+
path_tmp_prev_outputs / subdir
38+
for subdir in next(os.walk(test_dir / "prev_outputs"))[1]
39+
]
40+
41+
mock_anaddb(ref_paths)
42+
43+
# make the job, run it and ensure that it finished running successfully
44+
job = maker.make(structure=structure, prev_outputs=prev_outputs)
45+
responses = run_locally(job, create_folders=True, ensure_success=True)
46+
47+
# validation the outputs of the job
48+
output1 = responses[job.uuid][1].output
49+
assert isinstance(output1, AnaddbTaskDoc)

0 commit comments

Comments
 (0)