|
1 | | -from emmet.api.resource.core import HeaderProcessor |
2 | 1 | from fastapi import Response, Request |
| 2 | +from typing import Any |
| 3 | + |
| 4 | +from emmet.api.resource.core import HeaderProcessor |
3 | 5 | from emmet.api.utils import STORE_PARAMS |
4 | 6 | from emmet.api.routes.materials.materials.query_operators import LicenseQuery |
5 | 7 |
|
6 | 8 |
|
| 9 | +def _get_header_key(headers, key: str, default: Any = None) -> Any: |
| 10 | + """Get a case-insensitive key from a set of request headers.""" |
| 11 | + try: |
| 12 | + return next(v for k, v in headers.items() if k.lower() == key.lower()) |
| 13 | + except StopIteration: |
| 14 | + return default |
| 15 | + |
| 16 | + |
7 | 17 | class GlobalHeaderProcessor(HeaderProcessor): |
8 | | - def process_header(self, response: Response, request: Request): |
9 | | - groups = request.headers.get("X-Authenticated-Groups", None) |
10 | | - if groups is not None and "api_all_nolimit" in [ |
| 18 | + |
| 19 | + def process_header(self, response: Response, request: Request) -> None: |
| 20 | + if ( |
| 21 | + groups := _get_header_key(request.headers, "x-authenticated-groups") |
| 22 | + ) is not None and "api_all_nolimit" in [ |
11 | 23 | group.strip() for group in groups.split(",") |
12 | 24 | ]: |
13 | 25 | response.headers["X-Bypass-Rate-Limit"] = "ALL" |
14 | 26 |
|
15 | 27 | # forward Consumer Id header in response |
16 | | - consumer_id = request.headers.get("X-Consumer-Id", "-") |
| 28 | + consumer_id = _get_header_key(request.headers, "x-consumer-id", default="-") |
17 | 29 | response.headers["X-Consumer-Id"] = consumer_id |
18 | 30 |
|
19 | | - if "Content-Type" not in response.headers: |
| 31 | + if _get_header_key(response.headers, "Content-Type") is None: |
20 | 32 | response.headers["Content-Type"] = "application/json" |
21 | 33 |
|
22 | 34 | def configure_query_on_request( |
23 | 35 | self, request: Request, query_operator: LicenseQuery |
24 | 36 | ) -> STORE_PARAMS: |
25 | | - groups = request.headers.get( |
26 | | - "x-consumer-groups", request.headers.get("x-authenticated-groups", "") |
27 | | - ) |
28 | | - if not groups: |
29 | | - return query_operator.query(license="BY-C") |
30 | 37 |
|
31 | | - grps = set(group.strip() for group in groups.split(",")) |
32 | | - if grps & {"TERMS:ACCEPT-NC", "admin"}: |
33 | | - return query_operator.query(license="All") |
| 38 | + if not ( |
| 39 | + groups := _get_header_key( |
| 40 | + request.headers, |
| 41 | + "x-consumer-groups", |
| 42 | + default=_get_header_key( |
| 43 | + request.headers, "x-authenticated-groups", default="" |
| 44 | + ), |
| 45 | + ) |
| 46 | + ): |
| 47 | + return query_operator.query(license="BY-C") |
34 | 48 |
|
35 | | - return query_operator.query(license="BY-C") |
| 49 | + return query_operator.query( |
| 50 | + license=( |
| 51 | + "All" |
| 52 | + if {group.strip() for group in groups.split(",")} |
| 53 | + & {"TERMS:ACCEPT-NC", "admin"} |
| 54 | + else "BY-C" |
| 55 | + ) |
| 56 | + ) |
0 commit comments