Skip to content

Commit beb4318

Browse files
Patch symmetry query (#1442)
* fix single symm query * fix license for thermo, chemenv, and oxi states * mypy as always
1 parent a903c4f commit beb4318

6 files changed

Lines changed: 49 additions & 15 deletions

File tree

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,56 @@
1-
from emmet.api.resource.core import HeaderProcessor
21
from fastapi import Response, Request
2+
from typing import Any
3+
4+
from emmet.api.resource.core import HeaderProcessor
35
from emmet.api.utils import STORE_PARAMS
46
from emmet.api.routes.materials.materials.query_operators import LicenseQuery
57

68

9+
def _get_header_key(headers, key: str, default: Any = None) -> Any:
10+
"""Get a case-insensitive key from a set of request headers."""
11+
try:
12+
return next(v for k, v in headers.items() if k.lower() == key.lower())
13+
except StopIteration:
14+
return default
15+
16+
717
class GlobalHeaderProcessor(HeaderProcessor):
8-
def process_header(self, response: Response, request: Request):
9-
groups = request.headers.get("X-Authenticated-Groups", None)
10-
if groups is not None and "api_all_nolimit" in [
18+
19+
def process_header(self, response: Response, request: Request) -> None:
20+
if (
21+
groups := _get_header_key(request.headers, "x-authenticated-groups")
22+
) is not None and "api_all_nolimit" in [
1123
group.strip() for group in groups.split(",")
1224
]:
1325
response.headers["X-Bypass-Rate-Limit"] = "ALL"
1426

1527
# forward Consumer Id header in response
16-
consumer_id = request.headers.get("X-Consumer-Id", "-")
28+
consumer_id = _get_header_key(request.headers, "x-consumer-id", default="-")
1729
response.headers["X-Consumer-Id"] = consumer_id
1830

19-
if "Content-Type" not in response.headers:
31+
if _get_header_key(response.headers, "Content-Type") is None:
2032
response.headers["Content-Type"] = "application/json"
2133

2234
def configure_query_on_request(
2335
self, request: Request, query_operator: LicenseQuery
2436
) -> STORE_PARAMS:
25-
groups = request.headers.get(
26-
"x-consumer-groups", request.headers.get("x-authenticated-groups", "")
27-
)
28-
if not groups:
29-
return query_operator.query(license="BY-C")
3037

31-
grps = set(group.strip() for group in groups.split(","))
32-
if grps & {"TERMS:ACCEPT-NC", "admin"}:
33-
return query_operator.query(license="All")
38+
if not (
39+
groups := _get_header_key(
40+
request.headers,
41+
"x-consumer-groups",
42+
default=_get_header_key(
43+
request.headers, "x-authenticated-groups", default=""
44+
),
45+
)
46+
):
47+
return query_operator.query(license="BY-C")
3448

35-
return query_operator.query(license="BY-C")
49+
return query_operator.query(
50+
license=(
51+
"All"
52+
if {group.strip() for group in groups.split(",")}
53+
& {"TERMS:ACCEPT-NC", "admin"}
54+
else "BY-C"
55+
)
56+
)

emmet-api/emmet/api/routes/materials/chemenv/resources.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from emmet.api.routes.materials.materials.query_operators import (
1111
MultiMaterialIDQuery,
1212
ElementsQuery,
13+
LicenseQuery,
1314
)
1415
from emmet.api.routes.materials.chemenv.query_operators import ChemEnvQuery
1516
from emmet.core.chemenv import ChemEnvDoc
@@ -31,6 +32,7 @@ def chemenv_resource(chemenv_store):
3132
),
3233
],
3334
header_processor=GlobalHeaderProcessor(),
35+
query_to_configure_on_request=LicenseQuery(),
3436
tags=["Materials Chemical Environment"],
3537
sub_path="/chemenv/",
3638
disable_validation=True,

emmet-api/emmet/api/routes/materials/materials/query_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def query(
253253

254254
if (
255255
len(spacegroup_numbers) == 1
256+
and crystal_systems
256257
and get_crystal_system_from_international_number(spacegroup_numbers[0])
257258
not in crystal_systems
258259
) or (

emmet-api/emmet/api/routes/materials/oxidation_states/resources.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def oxi_states_resource(oxi_states_store):
3535
LicenseQuery(),
3636
],
3737
header_processor=GlobalHeaderProcessor(),
38+
query_to_configure_on_request=LicenseQuery(),
3839
tags=["Materials Oxidation States"],
3940
sub_path="/oxidation_states/",
4041
disable_validation=True,

emmet-api/emmet/api/routes/materials/thermo/resources.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def thermo_resource(thermo_store):
4040
LicenseQuery(),
4141
],
4242
header_processor=GlobalHeaderProcessor(),
43+
query_to_configure_on_request=LicenseQuery(),
4344
tags=["Materials Thermo"],
4445
sub_path="/thermo/",
4546
disable_validation=True,

emmet-api/tests/materials/materials/test_query_operators.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,14 @@ def test_deprecation_query():
111111
def test_symmetry_query():
112112
op = SymmetryQuery()
113113

114+
assert op.query(crystal_system="Trigonal") == {
115+
"criteria": {"symmetry.crystal_system": "Trigonal"}
116+
}
117+
assert op.query(spacegroup_symbol="P6_3/mmc") == {
118+
"criteria": {"symmetry.number": 194}
119+
}
120+
assert op.query(spacegroup_number=194) == {"criteria": {"symmetry.number": 194}}
121+
114122
for aux_query in [
115123
{"spacegroup_number": 221},
116124
{"spacegroup_symbol": "Pm-3m"},

0 commit comments

Comments
 (0)