|
3 | 3 | from collections.abc import Callable |
4 | 4 | from dataclasses import dataclass, field |
5 | 5 | from functools import partial |
6 | | -from typing import Any |
| 6 | +from typing import Any, Protocol |
7 | 7 |
|
8 | 8 | from fastapi import Query |
9 | 9 |
|
@@ -176,3 +176,112 @@ def query( |
176 | 176 | ), |
177 | 177 | ) -> STORE_PARAMS: |
178 | 178 | return self._prepare_query(material_ids) |
| 179 | + |
| 180 | + |
| 181 | +# Allowed values for the optional ``id_format`` query parameter. Anything not |
| 182 | +# in this set is treated as if the parameter was absent (no-op reformatting), |
| 183 | +# which is the safer default for backwards compatibility. |
| 184 | +_ID_FORMAT_VALUES = ("legacy", "alpha") |
| 185 | + |
| 186 | + |
| 187 | +class IdFormatter(Protocol): |
| 188 | + """Callable signature for the formatters consumed by :class:`IdFormatQuery`. |
| 189 | +
|
| 190 | + Each registered formatter is invoked as |
| 191 | + ``formatter(value, legacy=<bool>)`` against every truthy id-field value |
| 192 | + on the response. ``legacy`` is passed by keyword to match the explicit |
| 193 | + signatures of the canonical formatters in :mod:`emmet.core.types.typing` |
| 194 | + and :mod:`emmet.core.xas`. |
| 195 | + """ |
| 196 | + |
| 197 | + def __call__(self, value: Any, *, legacy: bool) -> str: ... |
| 198 | + |
| 199 | + |
| 200 | +@dataclass |
| 201 | +class IdFormatQuery(QueryOperator): |
| 202 | + """Optional response-side reformatting of MP identifier fields. |
| 203 | +
|
| 204 | + Adds an ``id_format`` query parameter to an endpoint and, on |
| 205 | + ``post_process``, rewrites the identifier fields on each returned |
| 206 | + document according to the requested shape: |
| 207 | +
|
| 208 | + - ``id_format=legacy`` -> ``mp-149`` / ``mp-2658_Al`` / ``mp-779827-XANES-O-K`` |
| 209 | + - ``id_format=alpha`` -> ``mp-aaaaaaft`` / ``mp-aaaaadyg_Al`` / ``aaabsjpj-XANES-O-K`` |
| 210 | + - parameter absent (or any other value) -> documents are returned with |
| 211 | + identifier fields exactly as the database stores them; no rewriting |
| 212 | + is attempted. |
| 213 | +
|
| 214 | + This is purely a serialization concern: ``query()`` returns an empty |
| 215 | + criteria dict so this operator never affects which documents the |
| 216 | + database returns. It only mutates the response payload. |
| 217 | +
|
| 218 | + Constructor takes a list of ``(field_name, formatter)`` tuples. Each |
| 219 | + formatter must be a callable with signature ``formatter(value, legacy: bool) -> str`` |
| 220 | + and must be fault-tolerant (i.e. return the input unchanged on parse |
| 221 | + failure, never raise). The canonical formatters live in |
| 222 | + :mod:`emmet.core.types.typing` (``format_identifier``, |
| 223 | + ``format_compound_identifier``, ``format_task_id``) and |
| 224 | + :mod:`emmet.core.xas` (``format_spectrum_id``). |
| 225 | +
|
| 226 | + Example registration: |
| 227 | +
|
| 228 | + .. code-block:: python |
| 229 | +
|
| 230 | + from emmet.core.types.typing import format_identifier, format_task_id |
| 231 | + from emmet.core.xas import format_spectrum_id |
| 232 | +
|
| 233 | + # /materials/summary/ |
| 234 | + IdFormatQuery(id_fields=[("material_id", format_identifier)]) |
| 235 | +
|
| 236 | + # /materials/xas/ |
| 237 | + IdFormatQuery(id_fields=[ |
| 238 | + ("task_id", format_task_id), |
| 239 | + ("spectrum_id", format_spectrum_id), |
| 240 | + ]) |
| 241 | +
|
| 242 | + Attributes: |
| 243 | + id_fields: A list of ``(field_name, formatter)`` tuples describing |
| 244 | + which fields on each returned document to rewrite and how. |
| 245 | + Fields that are absent from a given document (e.g. due to |
| 246 | + sparse-fields projection) are silently skipped. |
| 247 | + """ |
| 248 | + |
| 249 | + id_fields: list[tuple[str, IdFormatter]] = field(default_factory=list) |
| 250 | + |
| 251 | + def query( |
| 252 | + self, |
| 253 | + id_format: str | None = Query( |
| 254 | + None, |
| 255 | + description=( |
| 256 | + "Optional. If set to 'legacy', MP identifier fields in the " |
| 257 | + "response are returned in the form 'mp-149'. If set to " |
| 258 | + "'alpha', they are returned in the padded AlphaID form " |
| 259 | + "'mp-aaaaaaft'. If omitted (or set to any other value), " |
| 260 | + "identifiers are returned in their stored form. This is a " |
| 261 | + "purely cosmetic transform; query inputs accept either " |
| 262 | + "shape regardless." |
| 263 | + ), |
| 264 | + ), |
| 265 | + ) -> STORE_PARAMS: |
| 266 | + # The store query is empty — this operator only affects response |
| 267 | + # serialization. The ``id_format`` value is threaded through the |
| 268 | + # returned ``STORE_PARAMS`` so ``post_process`` can read it back. |
| 269 | + return {"criteria": {}, "id_format": id_format} |
| 270 | + |
| 271 | + def post_process(self, docs: list[dict], query: dict) -> list[dict]: |
| 272 | + fmt = query.get("id_format") |
| 273 | + if fmt not in _ID_FORMAT_VALUES: |
| 274 | + # Absent / invalid value -> no-op. We deliberately do not 400 |
| 275 | + # on a bad value: existing clients that misspell the parameter |
| 276 | + # continue to receive a valid response. |
| 277 | + return docs |
| 278 | + |
| 279 | + legacy = fmt == "legacy" |
| 280 | + for doc in docs: |
| 281 | + if not isinstance(doc, dict): |
| 282 | + continue |
| 283 | + for field_name, formatter in self.id_fields: |
| 284 | + value = doc.get(field_name) |
| 285 | + if value: |
| 286 | + doc[field_name] = formatter(value, legacy=legacy) |
| 287 | + return docs |
0 commit comments