diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 9627f5d6..9449e0d3 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -888,6 +888,7 @@ def _submit_requests( # noqa if data_tuples and "meta" in data_tuples[0][0]: total_data["meta"]["time_stamp"] = data_tuples[0][0]["meta"]["time_stamp"] + total_data["meta"]["facets"] = data_tuples[0][0]["meta"].get("facet", None) if pbar is not None: pbar.close() @@ -1236,6 +1237,7 @@ def _get_all_documents( fields=None, chunk_size=1000, num_chunks=None, + facets=None, ) -> list[T] | list[dict]: """Iterates over pages until all documents are retrieved. Displays progress using tqdm. This method is designed to give a common @@ -1267,7 +1269,6 @@ def _get_all_documents( ) chosen_param = list_entries[0][0] if len(list_entries) > 0 else None - results = self._query_resource( query_params, fields=fields, @@ -1275,8 +1276,10 @@ def _get_all_documents( chunk_size=chunk_size, num_chunks=num_chunks, ) - - return results["data"] + if facets: + return results["data"], results["meta"] + else: + return results["data"] def count(self, criteria: dict | None = None) -> int | str: """Return a count of total documents. diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index 5e12f9aa..cb159d19 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from datetime import datetime from emmet.core.tasks import TaskDoc @@ -39,27 +40,41 @@ def search( elements: list[str] | None = None, exclude_elements: list[str] | None = None, formula: str | list[str] | None = None, + calc_type: str | None = None, + run_type: str | None = None, + task_type: str | None = None, + chemsys: str | list[str] | None = None, last_updated: tuple[datetime, datetime] | None = None, + batches: str | list[str] | None = None, num_chunks: int | None = None, chunk_size: int = 1000, - all_fields: bool = True, + all_fields: bool = False, fields: list[str] | None = None, + facets: str | list[str] | None = None, ) -> list[TaskDoc] | list[dict]: """Query core task docs using a variety of search criteria. Arguments: task_ids (str, List[str]): List of Materials Project IDs to return data for. - elements (List[str]): A list of elements. + chemsys: (str, List[str]): A list of chemical systems to search for. + elements: (List[str]): A list of elements to search for. exclude_elements (List[str]): A list of elements to exclude. formula (str, List[str]): A formula including anonymized formula or wild cards (e.g., Fe2O3, ABO3, Si*). A list of chemical formulas can also be passed (e.g., [Fe2O3, ABO3]). last_updated (tuple[datetime, datetime]): A tuple of min and max UTC formatted datetimes. + batches (str, List[str]): A list of batch IDs to search for. + run_type (str): The type of task to search for. Can be one of the following: + #TODO: check enum + task_type (str): The type of task to search for. Can be one of the following: + #TODO check enum + calc_type (str): The type of calculation to search for. A combination of the run_type and task_type. num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. chunk_size (int): Number of data entries per chunk. Max size is 100. all_fields (bool): Whether to return all fields in the document. Defaults to True. fields (List[str]): List of fields in TaskDoc to return data for. Default is material_id, last_updated, and formula_pretty if all_fields is False. + facets (str, List[str]): List of facets to return data for. Returns: ([TaskDoc], [dict]) List of task documents or dictionaries. @@ -73,7 +88,14 @@ def search( query_params.update({"task_ids": ",".join(validate_ids(task_ids))}) if formula: - query_params.update({"formula": formula}) + query_params.update( + {"formula": ",".join(formula) if isinstance(formula, list) else formula} + ) + + if chemsys: + query_params.update( + {"chemsys": ",".join(chemsys) if isinstance(chemsys, list) else chemsys} + ) if elements: query_params.update({"elements": ",".join(elements)}) @@ -89,6 +111,36 @@ def search( } ) + if task_type: + query_params.update({"task_type": task_type}) + + if calc_type: + query_params.update({"calc_type": calc_type}) + + if run_type: + query_params.update({"run_type": run_type}) + + if batches: + query_params.update( + {"batches": ".".join(batches) if isinstance(batches, list) else batches} + ) + + if facets: + query_params.update( + {"facets": ",".join(facets) if isinstance(facets, list) else facets} + ) + + if all_fields: + warnings.warn( + """Please only use all_fields=True when necessary, as it may cause slow query. + """ + ) + if fields and ("calcs_reversed" in fields or "orig_inputs" in fields): + warnings.warn( + """Please only include calcs_reversed and orig_inputs when necessary, as it may cause slow query. + """ + ) + return super()._search( num_chunks=num_chunks, chunk_size=chunk_size,