Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

browser cache #841

Merged
merged 4 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backend_py/primary/primary/auth/auth_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from primary import config
from primary.services.utils.authenticated_user import AuthenticatedUser
from primary.middleware.add_browser_cache import no_cache


class AuthHelper:
Expand All @@ -24,6 +25,7 @@ def __init__(self) -> None:
methods=["GET"],
)

@no_cache
async def _login_route(self, request: Request, redirect_url_after_login: Optional[str] = None) -> RedirectResponse:
# print("######################### _login_route()")

Expand Down Expand Up @@ -55,6 +57,7 @@ async def _login_route(self, request: Request, redirect_url_after_login: Optiona

return RedirectResponse(flow_dict["auth_uri"])

@no_cache
async def _authorized_callback_route(self, request: Request) -> Response:
# print("######################### _authorized_callback_route()")

Expand Down
4 changes: 2 additions & 2 deletions backend_py/primary/primary/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"ssdl": [SSDL_RESOURCE_SCOPE],
}

print(f"{RESOURCE_SCOPES_DICT=}")

DEFAULT_CACHE_MAX_AGE = 3600 # 1 hour
DEFAULT_STALE_WHILE_REVALIDATE = 3600 * 24 # 24 hour
REDIS_USER_SESSION_URL = "redis://redis-user-session:6379"
REDIS_CACHE_URL = "redis://redis-cache:6379"
4 changes: 4 additions & 0 deletions backend_py/primary/primary/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from primary.auth.auth_helper import AuthHelper
from primary.auth.enforce_logged_in_middleware import EnforceLoggedInMiddleware
from primary.middleware.add_process_time_to_server_timing_middleware import AddProcessTimeToServerTimingMiddleware

from primary.middleware.add_browser_cache import AddBrowserCacheMiddleware
from primary.routers.dev.router import router as dev_router
from primary.routers.explore.router import router as explore_router
from primary.routers.general import router as general_router
Expand Down Expand Up @@ -104,6 +106,7 @@ def custom_generate_unique_id(route: APIRoute) -> str:
# Also redirects to /login endpoint for some select paths
unprotected_paths = ["/logged_in_user", "/alive", "/openapi.json"]
paths_redirected_to_login = ["/", "/alive_protected"]

app.add_middleware(
EnforceLoggedInMiddleware,
unprotected_paths=unprotected_paths,
Expand All @@ -117,6 +120,7 @@ def custom_generate_unique_id(route: APIRoute) -> str:

# This middleware instance measures execution time of the endpoints, including the cost of other middleware
app.add_middleware(AddProcessTimeToServerTimingMiddleware, metric_name="total")
app.add_middleware(AddBrowserCacheMiddleware)


@app.get("/")
Expand Down
92 changes: 92 additions & 0 deletions backend_py/primary/primary/middleware/add_browser_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from functools import wraps
from contextvars import ContextVar
from typing import Dict, Any, Callable, Awaitable, Union, Never

from starlette.datastructures import MutableHeaders
from starlette.types import ASGIApp, Scope, Receive, Send, Message
from primary.config import DEFAULT_CACHE_MAX_AGE, DEFAULT_STALE_WHILE_REVALIDATE

# Initialize with a factory function to ensure a new dict for each context
def get_default_context() -> Dict[str, Any]:
return {"max_age": DEFAULT_CACHE_MAX_AGE, "stale_while_revalidate": DEFAULT_STALE_WHILE_REVALIDATE}


cache_context: ContextVar[Dict[str, Any]] = ContextVar("cache_context", default=get_default_context())


def add_custom_cache_time(max_age_s: int, stale_while_revalidate_s: int = 0) -> Callable:
"""
Decorator that sets a custom browser cache time for the endpoint response.

Args:
max_age_s (int): The maximum age in seconds for the cache
stale_while_revalidate_s (int): The stale-while-revalidate time in seconds

Example:
@add_custom_cache_time(300, 600) # 5 minutes max age, 10 minutes stale-while-revalidate
async def my_endpoint():
return {"data": "some_data"}
"""

def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Callable:
context = cache_context.get()
context["max_age"] = max_age_s
context["stale_while_revalidate"] = stale_while_revalidate_s

return await func(*args, **kwargs)

return wrapper

return decorator


def no_cache(func: Callable) -> Callable:
"""
Decorator that explicitly disables browser caching for the endpoint response.

Example:
@no_cache
async def my_endpoint():
return {"data": "some_data"}
"""

@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Callable:
context = cache_context.get()
context["max_age"] = 0
context["stale_while_revalidate"] = 0

return await func(*args, **kwargs)

return wrapper


class AddBrowserCacheMiddleware:
"""
Adds cache-control to the response headers
"""

def __init__(self, app: ASGIApp) -> None:
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
return await self.app(scope, receive, send)

# Set initial context and store token
cache_context.set(get_default_context())

async def send_with_cache_header(message: Message) -> None:
if message["type"] == "http.response.start":
headers = MutableHeaders(scope=message)
context = cache_context.get()
cache_control_str = (
f"max-age={context['max_age']}, stale-while-revalidate={context['stale_while_revalidate']}, private"
)
headers.append("cache-control", cache_control_str)

await send(message)

await self.app(scope, receive, send_with_cache_header)
5 changes: 5 additions & 0 deletions backend_py/primary/primary/routers/explore/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
from primary.services.sumo_access.case_inspector import CaseInspector
from primary.services.sumo_access.sumo_inspector import SumoInspector
from primary.services.utils.authenticated_user import AuthenticatedUser
from primary.middleware.add_browser_cache import no_cache

from . import schemas

router = APIRouter()


@router.get("/fields")
@no_cache
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably could be cached. It very rarely changes - and when it does the user has to wait for access to sync through infrastructure and Azure anyway. E.g. max-age of one minute and a default stale-while-revalidate of one day?

async def get_fields(
authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user),
) -> List[schemas.FieldInfo]:
Expand All @@ -27,6 +29,7 @@ async def get_fields(


@router.get("/cases")
@no_cache
async def get_cases(
authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user),
field_identifier: str = Query(description="Field identifier"),
Expand All @@ -43,6 +46,7 @@ async def get_cases(


@router.get("/cases/{case_uuid}/ensembles")
@no_cache
async def get_ensembles(
authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user),
case_uuid: str = Path(description="Sumo case uuid"),
Expand All @@ -55,6 +59,7 @@ async def get_ensembles(


@router.get("/cases/{case_uuid}/ensembles/{ensemble_name}")
@no_cache
async def get_ensemble_details(
authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user),
case_uuid: str = Path(description="Sumo case uuid"),
Expand Down
4 changes: 4 additions & 0 deletions backend_py/primary/primary/routers/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from primary.auth.auth_helper import AuthHelper
from primary.services.graph_access.graph_access import GraphApiAccess
from primary.middleware.add_browser_cache import no_cache

LOGGER = logging.getLogger(__name__)

Expand All @@ -25,18 +26,21 @@ class UserInfo(BaseModel):


@router.get("/alive")
@no_cache
def get_alive() -> str:
print("entering alive route")
return f"ALIVE: Backend is alive at this time: {datetime.datetime.now()}"


@router.get("/alive_protected")
@no_cache
def get_alive_protected() -> str:
print("entering alive_protected route")
return f"ALIVE_PROTECTED: Backend is alive at this time: {datetime.datetime.now()}"


@router.get("/logged_in_user", response_model=UserInfo)
@no_cache
async def get_logged_in_user(
request: Request,
includeGraphApiInfo: bool = Query( # pylint: disable=invalid-name
Expand Down
3 changes: 3 additions & 0 deletions backend_py/primary/primary/routers/well/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from primary.services.ssdl_access.well_access import WellAccess as SsdlWellAccess


from primary.middleware.add_browser_cache import add_custom_cache_time
from . import schemas
from . import converters

Expand Down Expand Up @@ -40,6 +42,7 @@ async def get_drilled_wellbore_headers(


@router.get("/well_trajectories/")
@add_custom_cache_time(3600 * 24 * 7, 3600 * 24 * 7 * 10) # 1 week cache, 10 week stale-while-revalidate
async def get_well_trajectories(
# fmt:off
authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user),
Expand Down
Loading