Skip to content

Commit

Permalink
Middleware starlettify (#8)
Browse files Browse the repository at this point in the history
drop with_plugins from middleware
  • Loading branch information
tomwojcik authored Apr 18, 2020
1 parent e222f73 commit 79aa8ff
Show file tree
Hide file tree
Showing 14 changed files with 135 additions and 105 deletions.
12 changes: 12 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.5.0
hooks:
- id: check-merge-conflict
- id: debug-statements
- id: no-commit-to-branch

- repo: https://github.com/ambv/black
rev: 19.10b0
hooks:
- id: black
27 changes: 15 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@
[![PyPI version](https://badge.fury.io/py/starlette-context.svg)](https://badge.fury.io/py/starlette-context)
[![PyPI license](https://img.shields.io/pypi/l/ansicolortags.svg)](https://pypi.python.org/pypi/ansicolortags/)
[![codecov](https://codecov.io/gh/tomwojcik/starlette-context/branch/master/graph/badge.svg)](https://codecov.io/gh/tomwojcik/starlette-context)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)

# starlette context
Middleware for Starlette that allows you to store and access the context data of a request. Can be used with logging so logs automatically use request headers such as x-request-id or x-correlation-id.

### Motivation

I use FastAPI. I needed something that will allow me to log with context data. Right now I can just `log.info('Message')` and I have log (in ELK) with request id and correlation id. I don't even think about passing this data to logger. It's there automatically.

### Installation

`$ pip install starlette-context`
Expand All @@ -30,33 +27,39 @@ https://github.com/MagicStack/contextvars
All other dependencies from `requirements-dev.txt` are only needed to run tests or examples. Test/dev env is dockerized if you want to try them yourself.

### Example
**examples/simple_examples/set_context_in_middleware.py**

```python
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import JSONResponse

import uvicorn
from starlette_context import context, plugins
from starlette_context.middleware import ContextMiddleware

middleware = [
Middleware(
ContextMiddleware,
plugins=(
plugins.RequestIdPlugin(),
plugins.CorrelationIdPlugin()
)
)
]

app = Starlette(debug=True)
app.add_middleware(ContextMiddleware.with_plugins( # easily extensible
plugins.RequestIdPlugin, # request id
plugins.CorrelationIdPlugin # correlation id
))
app = Starlette(middleware=middleware)


@app.route('/')
@app.route("/")
async def index(request: Request):
return JSONResponse(context.data)


uvicorn.run(app, host="0.0.0.0")

```
In this example the response containes a json with
In this example the response contains a json with
```json
{
"X-Correlation-ID":"5ca2f0b43115461bad07ccae5976a990",
Expand Down
43 changes: 24 additions & 19 deletions examples/example_with_exception_handling/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from starlette.applications import Starlette
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.middleware.base import (
BaseHTTPMiddleware,
RequestResponseEndpoint,
Expand All @@ -14,15 +15,6 @@
from examples.example_with_exception_handling.logger import log
from starlette_context import middleware, plugins

app = Starlette(debug=True)


@app.route("/")
async def index(request: Request):
log.info("pre exception")
_ = 1 / 0
return JSONResponse({"wont reach this place": None})


class ExceptionHandlingMiddleware(BaseHTTPMiddleware):
@staticmethod
Expand Down Expand Up @@ -54,15 +46,28 @@ async def dispatch(


# middleware order is important! exc handler has to be topmost
middleware = [
Middleware(
middleware.ContextMiddleware,
plugins=(
plugins.CorrelationIdPlugin(),
plugins.RequestIdPlugin(),
plugins.DateHeaderPlugin(),
plugins.ForwardedForPlugin(),
plugins.UserAgentPlugin(),
),
),
Middleware(ExceptionHandlingMiddleware),
]

app = Starlette(debug=True, middleware=middleware)


@app.route("/")
async def index(request: Request):
log.info("pre exception")
_ = 1 / 0
return JSONResponse({"wont reach this place": None})


app.add_middleware(ExceptionHandlingMiddleware)
app.add_middleware(
middleware.ContextMiddleware.with_plugins(
plugins.UserAgentPlugin,
plugins.ForwardedForPlugin,
plugins.DateHeaderPlugin,
plugins.RequestIdPlugin,
plugins.CorrelationIdPlugin,
)
)
uvicorn.run(app, host="0.0.0.0")
27 changes: 16 additions & 11 deletions examples/example_with_logger/app.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,32 @@
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import JSONResponse

import uvicorn
from examples.example_with_logger.logger import log
from starlette_context import context, middleware, plugins

app = Starlette(debug=True)
middleware = [
Middleware(
middleware.ContextMiddleware,
plugins=(
plugins.CorrelationIdPlugin(),
plugins.RequestIdPlugin(),
plugins.DateHeaderPlugin(),
plugins.ForwardedForPlugin(),
plugins.UserAgentPlugin(),
),
)
]

app = Starlette(debug=True, middleware=middleware)


@app.route("/")
async def index(request: Request):
async def index(_: Request):
log.info("Log from view")
return JSONResponse(context.data)


app.add_middleware(
middleware.ContextMiddleware.with_plugins(
plugins.CorrelationIdPlugin,
plugins.RequestIdPlugin,
plugins.DateHeaderPlugin,
plugins.ForwardedForPlugin,
plugins.UserAgentPlugin,
)
)
uvicorn.run(app, host="0.0.0.0")
13 changes: 8 additions & 5 deletions examples/simple_examples/set_context_in_middleware.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import JSONResponse

import uvicorn
from starlette_context import context, plugins
from starlette_context.middleware import ContextMiddleware

app = Starlette(debug=True)
app.add_middleware(
ContextMiddleware.with_plugins(
plugins.RequestIdPlugin, plugins.CorrelationIdPlugin
middleware = [
Middleware(
ContextMiddleware,
plugins=(plugins.RequestIdPlugin(), plugins.CorrelationIdPlugin()),
)
)
]

app = Starlette(debug=True, middleware=middleware)


@app.route("/")
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ autoflake
mypy
black
isort
pre-commit-hooks
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import setuptools

VERSION = "0.2.0"
VERSION = "0.2.1"


def get_long_description():
Expand Down
9 changes: 8 additions & 1 deletion starlette_context/ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@ def data(self) -> dict:
"""
Dump this to json. Object itself it not serializable.
"""
return _request_scope_context_storage.get()
try:
return _request_scope_context_storage.get()
except LookupError as e:
raise RuntimeError(
"You didn't use ContextMiddleware or "
"you're trying to access `context` object "
"outside of the request-response cycle."
) from e

def copy(self) -> dict:
"""
Expand Down
23 changes: 8 additions & 15 deletions starlette_context/middleware.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from contextvars import Token
from typing import List, Type, Union
from typing import Optional, Sequence

from starlette.middleware.base import (
BaseHTTPMiddleware,
Expand All @@ -18,20 +18,13 @@ class ContextMiddleware(BaseHTTPMiddleware):
If not used, you won't be able to use context object.
"""

plugins: List[Plugin] = []

@classmethod
def with_plugins(
cls, *plugins: Union[Plugin, Type[Plugin]]
) -> Type["ContextMiddleware"]:
for plugin in plugins:
if isinstance(plugin, Plugin):
cls.plugins.append(plugin)
elif issubclass(plugin, Plugin):
cls.plugins.append(plugin())
else:
raise TypeError("Only plugins are allowed.")
return cls
def __init__(
self, plugins: Optional[Sequence[Plugin]] = None, *args, **kwargs
) -> None:
super().__init__(*args, **kwargs)
self.plugins = plugins or ()
if not all([isinstance(plugin, Plugin) for plugin in self.plugins]):
raise TypeError("This is not a valid instance of a plugin")

async def set_context(self, request: Request) -> dict:
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def headers():

@pytest.fixture(scope="function", autouse=True)
def mocked_middleware() -> middleware.ContextMiddleware:
return middleware.ContextMiddleware(MagicMock())
return middleware.ContextMiddleware(app=MagicMock())


@pytest.fixture(scope="function", autouse=True)
Expand Down
17 changes: 8 additions & 9 deletions tests/test_integration/test_async_body.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.testclient import TestClient

from starlette_context import context, plugins
from starlette_context.middleware import ContextMiddleware

from starlette_context import plugins, context


class GetPayloadUsingPlugin(plugins.Plugin):
key = "from_plugin"
Expand All @@ -24,10 +24,12 @@ async def set_context(self, request: Request) -> dict:
return {"from_middleware": await request.json(), **from_plugin}


app = Starlette()
app.add_middleware(
GetPayloadFromBodyMiddleware.with_plugins(GetPayloadUsingPlugin)
)
middleware = [
Middleware(
GetPayloadFromBodyMiddleware, plugins=(GetPayloadUsingPlugin(),)
)
]
app = Starlette(middleware=middleware)


@app.route("/", methods=["POST"])
Expand All @@ -46,6 +48,3 @@ def test_async_body():
"from_plugin": {"test": "payload"},
}
assert expected_resp == resp.json()

# ugly cleanup
ContextMiddleware.plugins = []
2 changes: 1 addition & 1 deletion tests/test_integration/test_context_no_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ async def index(request: Request):


def test_no_middleware():
with pytest.raises(LookupError):
with pytest.raises(RuntimeError):
client.get("/")
Loading

0 comments on commit 79aa8ff

Please sign in to comment.