Skip to content

Commit 5c70469

Browse files
cache-time
1 parent d050598 commit 5c70469

File tree

7 files changed

+113
-2
lines changed

7 files changed

+113
-2
lines changed

backend_py/primary/primary/auth/auth_helper.py

+3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from primary import config
1414
from primary.services.utils.authenticated_user import AuthenticatedUser
15+
from primary.middleware.add_browser_cache import no_cache
1516

1617

1718
class AuthHelper:
@@ -24,6 +25,7 @@ def __init__(self) -> None:
2425
methods=["GET"],
2526
)
2627

28+
@no_cache
2729
async def _login_route(self, request: Request, redirect_url_after_login: Optional[str] = None) -> RedirectResponse:
2830
# print("######################### _login_route()")
2931

@@ -55,6 +57,7 @@ async def _login_route(self, request: Request, redirect_url_after_login: Optiona
5557

5658
return RedirectResponse(flow_dict["auth_uri"])
5759

60+
@no_cache
5861
async def _authorized_callback_route(self, request: Request) -> Response:
5962
# print("######################### _authorized_callback_route()")
6063

backend_py/primary/primary/config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"ssdl": [SSDL_RESOURCE_SCOPE],
3030
}
3131

32-
print(f"{RESOURCE_SCOPES_DICT=}")
33-
32+
DEFAULT_CACHE_MAX_AGE = 3600 # 1 hour
33+
DEFAULT_STALE_WHILE_REVALIDATE = 3600 * 24 # 24 hour
3434
REDIS_USER_SESSION_URL = "redis://redis-user-session:6379"
3535
REDIS_CACHE_URL = "redis://redis-cache:6379"

backend_py/primary/primary/main.py

+4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from primary.auth.auth_helper import AuthHelper
1313
from primary.auth.enforce_logged_in_middleware import EnforceLoggedInMiddleware
1414
from primary.middleware.add_process_time_to_server_timing_middleware import AddProcessTimeToServerTimingMiddleware
15+
16+
from primary.middleware.add_browser_cache import AddBrowserCacheMiddleware
1517
from primary.routers.dev.router import router as dev_router
1618
from primary.routers.explore.router import router as explore_router
1719
from primary.routers.general import router as general_router
@@ -104,6 +106,7 @@ def custom_generate_unique_id(route: APIRoute) -> str:
104106
# Also redirects to /login endpoint for some select paths
105107
unprotected_paths = ["/logged_in_user", "/alive", "/openapi.json"]
106108
paths_redirected_to_login = ["/", "/alive_protected"]
109+
107110
app.add_middleware(
108111
EnforceLoggedInMiddleware,
109112
unprotected_paths=unprotected_paths,
@@ -117,6 +120,7 @@ def custom_generate_unique_id(route: APIRoute) -> str:
117120

118121
# This middleware instance measures execution time of the endpoints, including the cost of other middleware
119122
app.add_middleware(AddProcessTimeToServerTimingMiddleware, metric_name="total")
123+
app.add_middleware(AddBrowserCacheMiddleware)
120124

121125

122126
@app.get("/")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from functools import wraps
2+
from contextvars import ContextVar
3+
from typing import Dict, Any, Callable, Awaitable, Union, Never
4+
5+
from starlette.datastructures import MutableHeaders
6+
from starlette.types import ASGIApp, Scope, Receive, Send, Message
7+
from primary.config import DEFAULT_CACHE_MAX_AGE, DEFAULT_STALE_WHILE_REVALIDATE
8+
9+
# Initialize with a factory function to ensure a new dict for each context
10+
def get_default_context() -> Dict[str, Any]:
11+
return {"max_age": DEFAULT_CACHE_MAX_AGE, "stale_while_revalidate": DEFAULT_STALE_WHILE_REVALIDATE}
12+
13+
14+
cache_context: ContextVar[Dict[str, Any]] = ContextVar("cache_context", default=get_default_context())
15+
16+
17+
def add_custom_cache_time(max_age: int, stale_while_revalidate: int = 0) -> Callable:
18+
"""
19+
Decorator that sets a custom browser cache time for the endpoint response.
20+
21+
Args:
22+
max_age (int): The maximum age in seconds for the cache
23+
stale_while_revalidate (int): The stale-while-revalidate time in seconds
24+
25+
Example:
26+
@add_custom_cache_time(300, 600) # 5 minutes max age, 10 minutes stale-while-revalidate
27+
async def my_endpoint():
28+
return {"data": "some_data"}
29+
"""
30+
31+
def decorator(func: Callable) -> Callable:
32+
@wraps(func)
33+
async def wrapper(*args: Any, **kwargs: Any) -> Callable:
34+
context = cache_context.get()
35+
context["max_age"] = max_age
36+
context["stale_while_revalidate"] = stale_while_revalidate
37+
38+
return await func(*args, **kwargs)
39+
40+
return wrapper
41+
42+
return decorator
43+
44+
45+
def no_cache(func: Callable) -> Callable:
46+
"""
47+
Decorator that explicitly disables browser caching for the endpoint response.
48+
49+
Example:
50+
@no_cache
51+
async def my_endpoint():
52+
return {"data": "some_data"}
53+
"""
54+
55+
@wraps(func)
56+
async def wrapper(*args: Any, **kwargs: Any) -> Callable:
57+
context = cache_context.get()
58+
context["max_age"] = 0
59+
context["stale_while_revalidate"] = 0
60+
61+
return await func(*args, **kwargs)
62+
63+
return wrapper
64+
65+
66+
class AddBrowserCacheMiddleware:
67+
"""
68+
Adds cache-control to the response headers
69+
"""
70+
71+
def __init__(self, app: ASGIApp) -> None:
72+
self.app = app
73+
74+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
75+
if scope["type"] != "http":
76+
return await self.app(scope, receive, send)
77+
78+
# Set initial context and store token
79+
cache_context.set(get_default_context())
80+
81+
async def send_with_cache_header(message: Message) -> None:
82+
if message["type"] == "http.response.start":
83+
headers = MutableHeaders(scope=message)
84+
context = cache_context.get()
85+
cache_control_str = (
86+
f"max-age={context['max_age']}, stale-while-revalidate={context['stale_while_revalidate']}, private"
87+
)
88+
headers.append("cache-control", cache_control_str)
89+
90+
await send(message)
91+
92+
await self.app(scope, receive, send_with_cache_header)

backend_py/primary/primary/routers/explore/router.py

+5
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
from primary.services.sumo_access.case_inspector import CaseInspector
77
from primary.services.sumo_access.sumo_inspector import SumoInspector
88
from primary.services.utils.authenticated_user import AuthenticatedUser
9+
from primary.middleware.add_browser_cache import no_cache
910

1011
from . import schemas
1112

1213
router = APIRouter()
1314

1415

1516
@router.get("/fields")
17+
@no_cache
1618
async def get_fields(
1719
authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user),
1820
) -> List[schemas.FieldInfo]:
@@ -27,6 +29,7 @@ async def get_fields(
2729

2830

2931
@router.get("/cases")
32+
@no_cache
3033
async def get_cases(
3134
authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user),
3235
field_identifier: str = Query(description="Field identifier"),
@@ -43,6 +46,7 @@ async def get_cases(
4346

4447

4548
@router.get("/cases/{case_uuid}/ensembles")
49+
@no_cache
4650
async def get_ensembles(
4751
authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user),
4852
case_uuid: str = Path(description="Sumo case uuid"),
@@ -55,6 +59,7 @@ async def get_ensembles(
5559

5660

5761
@router.get("/cases/{case_uuid}/ensembles/{ensemble_name}")
62+
@no_cache
5863
async def get_ensemble_details(
5964
authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user),
6065
case_uuid: str = Path(description="Sumo case uuid"),

backend_py/primary/primary/routers/general.py

+4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from primary.auth.auth_helper import AuthHelper
1111
from primary.services.graph_access.graph_access import GraphApiAccess
12+
from primary.middleware.add_browser_cache import no_cache
1213

1314
LOGGER = logging.getLogger(__name__)
1415

@@ -25,18 +26,21 @@ class UserInfo(BaseModel):
2526

2627

2728
@router.get("/alive")
29+
@no_cache
2830
def get_alive() -> str:
2931
print("entering alive route")
3032
return f"ALIVE: Backend is alive at this time: {datetime.datetime.now()}"
3133

3234

3335
@router.get("/alive_protected")
36+
@no_cache
3437
def get_alive_protected() -> str:
3538
print("entering alive_protected route")
3639
return f"ALIVE_PROTECTED: Backend is alive at this time: {datetime.datetime.now()}"
3740

3841

3942
@router.get("/logged_in_user", response_model=UserInfo)
43+
@no_cache
4044
async def get_logged_in_user(
4145
request: Request,
4246
includeGraphApiInfo: bool = Query( # pylint: disable=invalid-name

backend_py/primary/primary/routers/well/router.py

+3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
from primary.services.ssdl_access.well_access import WellAccess as SsdlWellAccess
1313

14+
15+
from primary.middleware.add_browser_cache import add_custom_cache_time
1416
from . import schemas
1517
from . import converters
1618

@@ -40,6 +42,7 @@ async def get_drilled_wellbore_headers(
4042

4143

4244
@router.get("/well_trajectories/")
45+
@add_custom_cache_time(3600 * 24 * 7, 3600 * 24 * 7 * 10) # 1 week cache, 10 week stale-while-revalidate
4346
async def get_well_trajectories(
4447
# fmt:off
4548
authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user),

0 commit comments

Comments
 (0)