1
1
from functools import wraps
2
2
from contextvars import ContextVar
3
- from typing import Dict , Any
3
+ from typing import Dict , Any , Callable , TypeVar , Awaitable
4
4
5
5
from starlette .datastructures import MutableHeaders
6
+ from starlette .types import ASGIApp , Scope , Receive , Send , Message
6
7
from primary .config import DEFAULT_CACHE_MAX_AGE , DEFAULT_STALE_WHILE_REVALIDATE
7
8
9
+ T = TypeVar ("T" )
8
10
9
11
# Initialize with a factory function to ensure a new dict for each context
10
12
def get_default_context () -> Dict [str , Any ]:
@@ -14,7 +16,9 @@ def get_default_context() -> Dict[str, Any]:
14
16
cache_context : ContextVar [Dict [str , Any ]] = ContextVar ("cache_context" , default = get_default_context ())
15
17
16
18
17
- def add_custom_cache_time (max_age : int , stale_while_revalidate : int = 0 ) -> Any :
19
+ def add_custom_cache_time (
20
+ max_age : int , stale_while_revalidate : int = 0
21
+ ) -> Callable [[Callable [..., Awaitable [T ]]], Callable [..., Awaitable [T ]]]:
18
22
"""
19
23
Decorator that sets a custom cache time for the endpoint response.
20
24
@@ -28,9 +32,9 @@ async def my_endpoint():
28
32
return {"data": "some_data"}
29
33
"""
30
34
31
- def decorator (func ) :
35
+ def decorator (func : Callable [..., Awaitable [ T ]]) -> Callable [..., Awaitable [ T ]] :
32
36
@wraps (func )
33
- async def wrapper (* args , ** kwargs ) :
37
+ async def wrapper (* args : Any , ** kwargs : Any ) -> T :
34
38
context = cache_context .get ()
35
39
context ["max_age" ] = max_age
36
40
context ["stale_while_revalidate" ] = stale_while_revalidate
@@ -47,17 +51,17 @@ class AddBrowserCacheMiddleware:
47
51
Adds cache-control to the response headers
48
52
"""
49
53
50
- def __init__ (self , app ) :
54
+ def __init__ (self , app : ASGIApp ) -> None :
51
55
self .app = app
52
56
53
- async def __call__ (self , scope , receive , send ) -> None :
57
+ async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
54
58
if scope ["type" ] != "http" :
55
59
return await self .app (scope , receive , send )
56
60
57
61
# Set initial context and store token
58
62
cache_context .set (get_default_context ())
59
63
60
- async def send_with_cache_header (message ) -> None :
64
+ async def send_with_cache_header (message : Message ) -> None :
61
65
if message ["type" ] == "http.response.start" :
62
66
headers = MutableHeaders (scope = message )
63
67
context = cache_context .get ()
0 commit comments