Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Atlas search task #975

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,7 @@ def _submit_requests( # noqa

if data_tuples and "meta" in data_tuples[0][0]:
total_data["meta"]["time_stamp"] = data_tuples[0][0]["meta"]["time_stamp"]
total_data["meta"]["facets"] = data_tuples[0][0]["meta"].get("facet", None)

if pbar is not None:
pbar.close()
Expand Down Expand Up @@ -1236,6 +1237,7 @@ def _get_all_documents(
fields=None,
chunk_size=1000,
num_chunks=None,
facets=None,
) -> list[T] | list[dict]:
"""Iterates over pages until all documents are retrieved. Displays
progress using tqdm. This method is designed to give a common
Expand Down Expand Up @@ -1267,16 +1269,17 @@ def _get_all_documents(
)

chosen_param = list_entries[0][0] if len(list_entries) > 0 else None

results = self._query_resource(
query_params,
fields=fields,
parallel_param=chosen_param,
chunk_size=chunk_size,
num_chunks=num_chunks,
)

return results["data"]
if facets:
return results["data"], results["meta"]
else:
return results["data"]

def count(self, criteria: dict | None = None) -> int | str:
"""Return a count of total documents.
Expand Down
58 changes: 55 additions & 3 deletions mp_api/client/routes/materials/tasks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from datetime import datetime

from emmet.core.tasks import TaskDoc
Expand Down Expand Up @@ -39,27 +40,41 @@ def search(
elements: list[str] | None = None,
exclude_elements: list[str] | None = None,
formula: str | list[str] | None = None,
calc_type: str | None = None,
run_type: str | None = None,
task_type: str | None = None,
chemsys: str | list[str] | None = None,
last_updated: tuple[datetime, datetime] | None = None,
batches: str | list[str] | None = None,
num_chunks: int | None = None,
chunk_size: int = 1000,
all_fields: bool = True,
all_fields: bool = False,
fields: list[str] | None = None,
facets: str | list[str] | None = None,
) -> list[TaskDoc] | list[dict]:
"""Query core task docs using a variety of search criteria.

Arguments:
task_ids (str, List[str]): List of Materials Project IDs to return data for.
elements (List[str]): A list of elements.
chemsys: (str, List[str]): A list of chemical systems to search for.
elements: (List[str]): A list of elements to search for.
exclude_elements (List[str]): A list of elements to exclude.
formula (str, List[str]): A formula including anonymized formula
or wild cards (e.g., Fe2O3, ABO3, Si*). A list of chemical formulas can also be passed
(e.g., [Fe2O3, ABO3]).
last_updated (tuple[datetime, datetime]): A tuple of min and max UTC formatted datetimes.
batches (str, List[str]): A list of batch IDs to search for.
run_type (str): The type of task to search for. Can be one of the following:
#TODO: check enum
task_type (str): The type of task to search for. Can be one of the following:
#TODO check enum
calc_type (str): The type of calculation to search for. A combination of the run_type and task_type.
num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
chunk_size (int): Number of data entries per chunk. Max size is 100.
all_fields (bool): Whether to return all fields in the document. Defaults to True.
fields (List[str]): List of fields in TaskDoc to return data for.
Default is material_id, last_updated, and formula_pretty if all_fields is False.
facets (str, List[str]): List of facets to return data for.

Returns:
([TaskDoc], [dict]) List of task documents or dictionaries.
Expand All @@ -73,7 +88,14 @@ def search(
query_params.update({"task_ids": ",".join(validate_ids(task_ids))})

if formula:
query_params.update({"formula": formula})
query_params.update(
{"formula": ",".join(formula) if isinstance(formula, list) else formula}
)

if chemsys:
query_params.update(
{"chemsys": ",".join(chemsys) if isinstance(chemsys, list) else chemsys}
)

if elements:
query_params.update({"elements": ",".join(elements)})
Expand All @@ -89,6 +111,36 @@ def search(
}
)

if task_type:
query_params.update({"task_type": task_type})

if calc_type:
query_params.update({"calc_type": calc_type})

if run_type:
query_params.update({"run_type": run_type})

if batches:
query_params.update(
{"batches": ".".join(batches) if isinstance(batches, list) else batches}
)

if facets:
query_params.update(
{"facets": ",".join(facets) if isinstance(facets, list) else facets}
)

if all_fields:
warnings.warn(
"""Please only use all_fields=True when necessary, as it may cause slow query.
"""
)
if fields and ("calcs_reversed" in fields or "orig_inputs" in fields):
warnings.warn(
"""Please only include calcs_reversed and orig_inputs when necessary, as it may cause slow query.
"""
)

return super()._search(
num_chunks=num_chunks,
chunk_size=chunk_size,
Expand Down
Loading