-
-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add test_concurrency_correct_headers
- Loading branch information
Showing
1 changed file
with
80 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import uuid | ||
|
||
import httpx | ||
|
||
from starlette.applications import Starlette | ||
from starlette.middleware import Middleware | ||
from starlette.requests import Request | ||
from starlette.responses import JSONResponse | ||
|
||
from starlette_context import context, plugins | ||
from starlette_context.header_keys import HeaderKeys | ||
from starlette_context.middleware import RawContextMiddleware | ||
import pytest | ||
import pytest_asyncio | ||
from starlette.routing import Route | ||
from asgi_lifespan import LifespanManager | ||
from starlette.exceptions import HTTPException | ||
|
||
|
||
def should_raise(number: int) -> bool: | ||
return number % 2 == 0 | ||
|
||
|
||
@pytest_asyncio.fixture | ||
async def app(): | ||
class CloudProviderException(HTTPException): | ||
pass | ||
|
||
async def cloud_provider_exception_handler(request: Request, exc: HTTPException): | ||
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code, headers={ | ||
HeaderKeys.request_id: context[HeaderKeys.request_id] | ||
}) | ||
|
||
async def index(request: Request) -> JSONResponse: | ||
number = request.path_params['number'] | ||
if should_raise(number): | ||
raise CloudProviderException | ||
return JSONResponse(content={"trace_id": context[HeaderKeys.request_id], "from": "view"}) | ||
|
||
middleware = [ | ||
Middleware( | ||
RawContextMiddleware, | ||
plugins=( | ||
plugins.RequestIdPlugin(), | ||
), | ||
) | ||
] | ||
exception_handlers = { | ||
CloudProviderException: cloud_provider_exception_handler | ||
} | ||
app = Starlette( | ||
middleware=middleware, | ||
routes=[Route("/{number:int}", index)], | ||
exception_handlers=exception_handlers | ||
) | ||
|
||
async with LifespanManager(app): | ||
yield app | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_concurrency_correct_headers(app): | ||
transport = httpx.ASGITransport( | ||
app=app, | ||
raise_app_exceptions=False | ||
) | ||
async with httpx.AsyncClient(app=app, transport=transport, base_url="http://test") as client: | ||
for number in range(1, 101): | ||
rid = uuid.uuid4().hex | ||
resp = await client.get(f"/{number}", headers={HeaderKeys.request_id: rid}) | ||
|
||
if should_raise(number): | ||
assert resp.status_code == 500 | ||
assert resp.headers[HeaderKeys.request_id] == rid | ||
else: | ||
assert resp.status_code == 200 | ||
d = resp.json() | ||
assert d['from'] == 'view' | ||
assert d['trace_id'] == rid | ||
assert resp.headers[HeaderKeys.request_id] == rid |