Skip to content

Commit 151fce7

Browse files
Merge branch 'main' into singleton-httpx-client
2 parents 54147ef + b5b49f4 commit 151fce7

37 files changed

+1827
-476
lines changed

backend_py/primary/primary/auth/auth_helper.py

+3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from primary import config
1414
from primary.services.utils.authenticated_user import AuthenticatedUser
15+
from primary.middleware.add_browser_cache import no_cache
1516

1617

1718
class AuthHelper:
@@ -24,6 +25,7 @@ def __init__(self) -> None:
2425
methods=["GET"],
2526
)
2627

28+
@no_cache
2729
async def _login_route(self, request: Request, redirect_url_after_login: Optional[str] = None) -> RedirectResponse:
2830
# print("######################### _login_route()")
2931

@@ -55,6 +57,7 @@ async def _login_route(self, request: Request, redirect_url_after_login: Optiona
5557

5658
return RedirectResponse(flow_dict["auth_uri"])
5759

60+
@no_cache
5861
async def _authorized_callback_route(self, request: Request) -> Response:
5962
# print("######################### _authorized_callback_route()")
6063

backend_py/primary/primary/config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"ssdl": [SSDL_RESOURCE_SCOPE],
3030
}
3131

32-
print(f"{RESOURCE_SCOPES_DICT=}")
33-
32+
DEFAULT_CACHE_MAX_AGE = 3600 # 1 hour
33+
DEFAULT_STALE_WHILE_REVALIDATE = 3600 * 24 # 24 hour
3434
REDIS_USER_SESSION_URL = "redis://redis-user-session:6379"
3535
REDIS_CACHE_URL = "redis://redis-cache:6379"

backend_py/primary/primary/main.py

+4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from primary.auth.auth_helper import AuthHelper
1313
from primary.auth.enforce_logged_in_middleware import EnforceLoggedInMiddleware
1414
from primary.middleware.add_process_time_to_server_timing_middleware import AddProcessTimeToServerTimingMiddleware
15+
16+
from primary.middleware.add_browser_cache import AddBrowserCacheMiddleware
1517
from primary.routers.dev.router import router as dev_router
1618
from primary.routers.explore.router import router as explore_router
1719
from primary.routers.general import router as general_router
@@ -120,6 +122,7 @@ async def shutdown_event() -> None:
120122
# Also redirects to /login endpoint for some select paths
121123
unprotected_paths = ["/logged_in_user", "/alive", "/openapi.json"]
122124
paths_redirected_to_login = ["/", "/alive_protected"]
125+
123126
app.add_middleware(
124127
EnforceLoggedInMiddleware,
125128
unprotected_paths=unprotected_paths,
@@ -133,6 +136,7 @@ async def shutdown_event() -> None:
133136

134137
# This middleware instance measures execution time of the endpoints, including the cost of other middleware
135138
app.add_middleware(AddProcessTimeToServerTimingMiddleware, metric_name="total")
139+
app.add_middleware(AddBrowserCacheMiddleware)
136140

137141

138142
@app.get("/")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from functools import wraps
2+
from contextvars import ContextVar
3+
from typing import Dict, Any, Callable, Awaitable, Union, Never
4+
5+
from starlette.datastructures import MutableHeaders
6+
from starlette.types import ASGIApp, Scope, Receive, Send, Message
7+
from primary.config import DEFAULT_CACHE_MAX_AGE, DEFAULT_STALE_WHILE_REVALIDATE
8+
9+
# Initialize with a factory function to ensure a new dict for each context
10+
def get_default_context() -> Dict[str, Any]:
11+
return {"max_age": DEFAULT_CACHE_MAX_AGE, "stale_while_revalidate": DEFAULT_STALE_WHILE_REVALIDATE}
12+
13+
14+
cache_context: ContextVar[Dict[str, Any]] = ContextVar("cache_context", default=get_default_context())
15+
16+
17+
def add_custom_cache_time(max_age_s: int, stale_while_revalidate_s: int = 0) -> Callable:
18+
"""
19+
Decorator that sets a custom browser cache time for the endpoint response.
20+
21+
Args:
22+
max_age_s (int): The maximum age in seconds for the cache
23+
stale_while_revalidate_s (int): The stale-while-revalidate time in seconds
24+
25+
Example:
26+
@add_custom_cache_time(300, 600) # 5 minutes max age, 10 minutes stale-while-revalidate
27+
async def my_endpoint():
28+
return {"data": "some_data"}
29+
"""
30+
31+
def decorator(func: Callable) -> Callable:
32+
@wraps(func)
33+
async def wrapper(*args: Any, **kwargs: Any) -> Callable:
34+
context = cache_context.get()
35+
context["max_age"] = max_age_s
36+
context["stale_while_revalidate"] = stale_while_revalidate_s
37+
38+
return await func(*args, **kwargs)
39+
40+
return wrapper
41+
42+
return decorator
43+
44+
45+
def no_cache(func: Callable) -> Callable:
46+
"""
47+
Decorator that explicitly disables browser caching for the endpoint response.
48+
49+
Example:
50+
@no_cache
51+
async def my_endpoint():
52+
return {"data": "some_data"}
53+
"""
54+
55+
@wraps(func)
56+
async def wrapper(*args: Any, **kwargs: Any) -> Callable:
57+
context = cache_context.get()
58+
context["max_age"] = 0
59+
context["stale_while_revalidate"] = 0
60+
61+
return await func(*args, **kwargs)
62+
63+
return wrapper
64+
65+
66+
class AddBrowserCacheMiddleware:
67+
"""
68+
Adds cache-control to the response headers
69+
"""
70+
71+
def __init__(self, app: ASGIApp) -> None:
72+
self.app = app
73+
74+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
75+
if scope["type"] != "http":
76+
return await self.app(scope, receive, send)
77+
78+
# Set initial context and store token
79+
cache_context.set(get_default_context())
80+
81+
async def send_with_cache_header(message: Message) -> None:
82+
if message["type"] == "http.response.start":
83+
headers = MutableHeaders(scope=message)
84+
context = cache_context.get()
85+
cache_control_str = (
86+
f"max-age={context['max_age']}, stale-while-revalidate={context['stale_while_revalidate']}, private"
87+
)
88+
headers.append("cache-control", cache_control_str)
89+
90+
await send(message)
91+
92+
await self.app(scope, receive, send_with_cache_header)

backend_py/primary/primary/routers/explore/router.py

+5
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
from primary.services.sumo_access.case_inspector import CaseInspector
77
from primary.services.sumo_access.sumo_inspector import SumoInspector
88
from primary.services.utils.authenticated_user import AuthenticatedUser
9+
from primary.middleware.add_browser_cache import no_cache
910

1011
from . import schemas
1112

1213
router = APIRouter()
1314

1415

1516
@router.get("/fields")
17+
@no_cache
1618
async def get_fields(
1719
authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user),
1820
) -> List[schemas.FieldInfo]:
@@ -27,6 +29,7 @@ async def get_fields(
2729

2830

2931
@router.get("/cases")
32+
@no_cache
3033
async def get_cases(
3134
authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user),
3235
field_identifier: str = Query(description="Field identifier"),
@@ -43,6 +46,7 @@ async def get_cases(
4346

4447

4548
@router.get("/cases/{case_uuid}/ensembles")
49+
@no_cache
4650
async def get_ensembles(
4751
authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user),
4852
case_uuid: str = Path(description="Sumo case uuid"),
@@ -55,6 +59,7 @@ async def get_ensembles(
5559

5660

5761
@router.get("/cases/{case_uuid}/ensembles/{ensemble_name}")
62+
@no_cache
5863
async def get_ensemble_details(
5964
authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user),
6065
case_uuid: str = Path(description="Sumo case uuid"),

backend_py/primary/primary/routers/general.py

+4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from primary.auth.auth_helper import AuthHelper
1111
from primary.services.graph_access.graph_access import GraphApiAccess
12+
from primary.middleware.add_browser_cache import no_cache
1213

1314
LOGGER = logging.getLogger(__name__)
1415

@@ -25,18 +26,21 @@ class UserInfo(BaseModel):
2526

2627

2728
@router.get("/alive")
29+
@no_cache
2830
def get_alive() -> str:
2931
print("entering alive route")
3032
return f"ALIVE: Backend is alive at this time: {datetime.datetime.now()}"
3133

3234

3335
@router.get("/alive_protected")
36+
@no_cache
3437
def get_alive_protected() -> str:
3538
print("entering alive_protected route")
3639
return f"ALIVE_PROTECTED: Backend is alive at this time: {datetime.datetime.now()}"
3740

3841

3942
@router.get("/logged_in_user", response_model=UserInfo)
43+
@no_cache
4044
async def get_logged_in_user(
4145
request: Request,
4246
includeGraphApiInfo: bool = Query( # pylint: disable=invalid-name
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,100 @@
1-
from typing import List, Optional, Sequence
1+
from typing import Sequence
22

33
from primary.services.summary_vector_statistics import VectorStatistics
4-
from primary.services.sumo_access.summary_access import VectorMetadata
4+
from primary.services.sumo_access.summary_access import RealizationVector
55
from primary.services.utils.statistic_function import StatisticFunction
6+
from primary.services.summary_delta_vectors import RealizationDeltaVector
7+
from primary.services.summary_derived_vectors import DerivedVectorType, DerivedRealizationVector
68
from . import schemas
79

810

11+
def to_api_derived_vector_type(derived_type: DerivedVectorType) -> schemas.DerivedVectorType:
12+
"""
13+
Create API DerivedVectorType from service layer DerivedVectorType
14+
"""
15+
return schemas.DerivedVectorType(derived_type.value)
16+
17+
18+
def to_api_derived_vector_info(derived_type: DerivedVectorType, source_vector: str) -> schemas.DerivedVectorInfo:
19+
"""
20+
Create API DerivedVectorInfo from service layer DerivedVectorInfo
21+
"""
22+
return schemas.DerivedVectorInfo(
23+
type=to_api_derived_vector_type(derived_type),
24+
sourceVector=source_vector,
25+
)
26+
27+
28+
def realization_vector_list_to_api_vector_realization_data_list(
29+
realization_vector_list: list[RealizationVector],
30+
) -> list[schemas.VectorRealizationData]:
31+
"""
32+
Create API VectorRealizationData list from service layer RealizationVector list
33+
"""
34+
return [
35+
schemas.VectorRealizationData(
36+
realization=real_vec.realization,
37+
timestampsUtcMs=real_vec.timestamps_utc_ms,
38+
values=real_vec.values,
39+
unit=real_vec.metadata.unit,
40+
isRate=real_vec.metadata.is_rate,
41+
)
42+
for real_vec in realization_vector_list
43+
]
44+
45+
46+
def derived_vector_realizations_to_api_vector_realization_data_list(
47+
derived_realization_vector_list: list[DerivedRealizationVector], derived_vector_info: schemas.DerivedVectorInfo
48+
) -> list[schemas.VectorRealizationData]:
49+
"""
50+
Create API VectorRealizationData list from service layer DerivedRealizationVector list and derived vector info
51+
"""
52+
return [
53+
schemas.VectorRealizationData(
54+
realization=real_vec.realization,
55+
timestampsUtcMs=real_vec.timestamps_utc_ms,
56+
values=real_vec.values,
57+
unit=real_vec.unit,
58+
isRate=real_vec.is_rate,
59+
derivedVectorInfo=derived_vector_info,
60+
)
61+
for real_vec in derived_realization_vector_list
62+
]
63+
64+
65+
def realization_delta_vector_list_to_api_vector_realization_data_list(
66+
realization_delta_vector_list: list[RealizationDeltaVector],
67+
derived_vector_info: schemas.DerivedVectorInfo | None = None,
68+
) -> list[schemas.VectorRealizationData]:
69+
"""
70+
Create API VectorRealizationData list from service layer RealizationVector list
71+
72+
Optional derived_vector_info is included in the API VectorRealizationData if provided
73+
"""
74+
return [
75+
schemas.VectorRealizationData(
76+
realization=real_vec.realization,
77+
timestampsUtcMs=real_vec.timestamps_utc_ms,
78+
values=real_vec.values,
79+
unit=real_vec.unit,
80+
isRate=real_vec.is_rate,
81+
derivedVectorInfo=derived_vector_info,
82+
)
83+
for real_vec in realization_delta_vector_list
84+
]
85+
86+
987
def to_service_statistic_functions(
10-
api_stat_funcs: Optional[Sequence[schemas.StatisticFunction]],
11-
) -> Optional[List[StatisticFunction]]:
88+
api_stat_funcs: Sequence[schemas.StatisticFunction] | None = None,
89+
) -> list[StatisticFunction] | None:
1290
"""
1391
Convert incoming list of API statistic function enum values to service layer StatisticFunction enums,
1492
also accounting for the case where the list is None
1593
"""
1694
if api_stat_funcs is None:
1795
return None
1896

19-
service_stat_funcs: List[StatisticFunction] = []
97+
service_stat_funcs: list[StatisticFunction] = []
2098
for api_func_enum in api_stat_funcs:
2199
service_func_enum = StatisticFunction.from_string_value(api_func_enum.value)
22100
if service_func_enum:
@@ -26,36 +104,44 @@ def to_service_statistic_functions(
26104

27105

28106
def to_api_vector_statistic_data(
29-
vector_statistics: VectorStatistics, vector_metadata: VectorMetadata
107+
vector_statistics: VectorStatistics,
108+
is_rate: bool,
109+
unit: str,
110+
derived_vector_info: schemas.DerivedVectorInfo | None = None,
30111
) -> schemas.VectorStatisticData:
31112
"""
32113
Create API VectorStatisticData from service layer VectorStatistics
33114
"""
34115
value_objects = _create_statistic_value_object_list(vector_statistics)
35116
ret_data = schemas.VectorStatisticData(
36117
realizations=vector_statistics.realizations,
37-
timestamps_utc_ms=vector_statistics.timestamps_utc_ms,
38-
value_objects=value_objects,
39-
unit=vector_metadata.unit,
40-
is_rate=vector_metadata.is_rate,
118+
timestampsUtcMs=vector_statistics.timestamps_utc_ms,
119+
valueObjects=value_objects,
120+
unit=unit,
121+
isRate=is_rate,
122+
derivedVectorInfo=derived_vector_info,
41123
)
42124

43125
return ret_data
44126

45127

46128
def to_api_delta_ensemble_vector_statistic_data(
47-
vector_statistics: VectorStatistics, is_rate: bool, unit: str
129+
vector_statistics: VectorStatistics,
130+
is_rate: bool,
131+
unit: str,
132+
derived_vector_info: schemas.DerivedVectorInfo | None = None,
48133
) -> schemas.VectorStatisticData:
49134
"""
50135
Create API VectorStatisticData from service layer VectorStatistics
51136
"""
52137
value_objects = _create_statistic_value_object_list(vector_statistics)
53138
ret_data = schemas.VectorStatisticData(
54139
realizations=vector_statistics.realizations,
55-
timestamps_utc_ms=vector_statistics.timestamps_utc_ms,
56-
value_objects=value_objects,
140+
timestampsUtcMs=vector_statistics.timestamps_utc_ms,
141+
valueObjects=value_objects,
57142
unit=unit,
58-
is_rate=is_rate,
143+
isRate=is_rate,
144+
derivedVectorInfo=derived_vector_info,
59145
)
60146

61147
return ret_data
@@ -71,6 +157,6 @@ def _create_statistic_value_object_list(vector_statistics: VectorStatistics) ->
71157
if service_func_enum is not None:
72158
value_arr = vector_statistics.values_dict.get(service_func_enum)
73159
if value_arr is not None:
74-
value_objects.append(schemas.StatisticValueObject(statistic_function=api_func_enum, values=value_arr))
160+
value_objects.append(schemas.StatisticValueObject(statisticFunction=api_func_enum, values=value_arr))
75161

76162
return value_objects

0 commit comments

Comments
 (0)