Skip to content

Commit 4d8dba0

Browse files
author
Jason Munro
authored
Add task_ids query to legacy molecules rester (#897)
* Add task_ids query to legacy molecules rester * Fix jcesr test
1 parent daefc5f commit 4d8dba0

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

mp_api/client/routes/molecules/jcesr.py

+10
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pymatgen.core.periodic_table import Element
77

88
from mp_api.client.core import BaseRester
9+
from mp_api.client.core.utils import validate_ids
910

1011

1112
class JcesrMoleculesRester(BaseRester[MoleculesDoc]):
@@ -15,6 +16,7 @@ class JcesrMoleculesRester(BaseRester[MoleculesDoc]):
1516

1617
def search(
1718
self,
19+
task_ids: str | list[str] | None = None,
1820
charge: tuple[float, float] | None = None,
1921
elements: list[Element] | None = None,
2022
EA: tuple[float, float] | None = None,
@@ -30,6 +32,8 @@ def search(
3032
"""Query equations of state docs using a variety of search criteria.
3133
3234
Arguments:
35+
task_ids (str, List[str]): A single molecule task ID string or list of strings.
36+
(e.g., mol-45004, [mol-45004, mol-45228]).
3337
charge (Tuple[float,float]): Minimum and maximum value of the charge in +e to consider.
3438
elements (List[Element]): A list of elements.
3539
film_orientation (List[Elements]): List of elements that are in the molecule.
@@ -49,6 +53,12 @@ def search(
4953
"""
5054
query_params = defaultdict(dict) # type: dict
5155

56+
if task_ids:
57+
if isinstance(task_ids, str):
58+
task_ids = [task_ids]
59+
60+
query_params.update({"task_ids": ",".join(validate_ids(task_ids))})
61+
5262
if elements:
5363
query_params.update({"elements": ",".join([str(ele) for ele in elements])})
5464

tests/molecules/test_jcesr.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ def rester():
2525

2626
sub_doc_fields = [] # type: list
2727

28-
alt_name_dict = {} # type: dict
28+
alt_name_dict = {"task_ids": "task_id"} # type: dict
2929

3030
custom_field_tests = {
31+
"task_ids": ["mol-45228"],
3132
"elements": [Element("H")],
3233
"pointgroup": "C1",
3334
"smiles": "C#CC(=C)C.CNCCNC",

0 commit comments

Comments
 (0)