Skip to content

Commit

Permalink
Merge pull request #184 from AikidoSec/AIK-3589
Browse files Browse the repository at this point in the history
AIK-3589 Add support for starlette
  • Loading branch information
willem-delbare authored Sep 12, 2024
2 parents 43fdc40 + e724865 commit 8a6b7f7
Show file tree
Hide file tree
Showing 11 changed files with 262 additions and 28 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/end2end.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ jobs:
working-directory: ./sample-apps/quart-postgres-uvicorn
run: |
docker compose -f docker-compose.yml -f docker-compose.benchmark.yml up --build -d
- name: Start starlette-postgres-uvicorn
working-directory: ./sample-apps/starlette-postgres-uvicorn
run: |
docker compose -f docker-compose.yml -f docker-compose.benchmark.yml up --build -d
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ Zen for Python 3 is compatible with:
*[Django](docs/django.md)
*[Flask](docs/flask.md)
*[Quart](docs/quart.md)
*[Starlette](docs/starlette.md)


### WSGI servers
*[Gunicorn](docs/gunicorn.md)
Expand Down
1 change: 1 addition & 0 deletions aikido_zen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def protect(mode="daemon"):
import aikido_zen.sources.django
import aikido_zen.sources.flask
import aikido_zen.sources.quart
import aikido_zen.sources.starlette
import aikido_zen.sources.xml
import aikido_zen.sources.lxml

Expand Down
2 changes: 1 addition & 1 deletion aikido_zen/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
current_context = contextvars.ContextVar("current_context", default=None)

WSGI_SOURCES = ["django", "flask"]
ASGI_SOURCES = ["quart", "django_async"]
ASGI_SOURCES = ["quart", "django_async", "starlette"]


def get_current_context():
Expand Down
15 changes: 15 additions & 0 deletions aikido_zen/sources/starlette/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Init.py file for starlette module
---
Starlette wrapping is subdivided in two parts :
- starlette.applications : Wraps __call__ on Starlette class to run "init" stage.
- starlette.routing : request_response function : Run pre_response code and
also runs post_response code after getting response from user function.
Folder also includes helper functions :
- extract_data_from_request, which will extract the data from a request object safely,
e.g. body, json, form. This also saves it inside the current context.
"""

import aikido_zen.sources.starlette.starlette_applications
import aikido_zen.sources.starlette.starlette_routing
26 changes: 26 additions & 0 deletions aikido_zen/sources/starlette/extract_data_from_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Exports function extract_data_from_request"""

from aikido_zen.context import get_current_context


async def extract_data_from_request(request):
"""Extracts json, form_data or body from Starlette request"""
context = get_current_context()
if not context:
return

# Parse data
try:
context.body = await request.json()
except ValueError:
# Throws error if the body is not json
pass
if not context.body:
form_data = await request.form()
if form_data:
# Convert to dict object :
context.body = {key: value for key, value in form_data.items()}
if not context.body:
context.body = await request.body()

context.set_as_current_context()
38 changes: 38 additions & 0 deletions aikido_zen/sources/starlette/starlette_applications.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Wraps starlette.applications for initial request_handler"""

import copy
import aikido_zen.importhook as importhook
from aikido_zen.helpers.logging import logger
from aikido_zen.context import Context
from aikido_zen.background_process.packages import add_wrapped_package
from ..functions.request_handler import request_handler


@importhook.on_import("starlette.applications")
def on_starlette_import(starlette):
"""
Hook 'n wrap on `starlette.applications`
Our goal is to wrap the __call__ function of the Starlette class
"""
modified_starlette = importhook.copy_module(starlette)
former_call = copy.deepcopy(starlette.Starlette.__call__)

async def aikido___call__(app, scope, receive=None, send=None):
return await aik_call_wrapper(former_call, app, scope, receive, send)

setattr(modified_starlette.Starlette, "__call__", aikido___call__)
add_wrapped_package("starlette")
return modified_starlette


async def aik_call_wrapper(former_call, app, scope, receive, send):
"""Aikido's __call__ wrapper"""
try:
if scope["type"] != "http":
return await former_call(app, scope, receive, send)
context1 = Context(req=scope, source="starlette")
context1.set_as_current_context()
request_handler(stage="init")
except Exception as e:
logger.debug("Exception on aikido __call__ function : %s", e)
return await former_call(app, scope, receive, send)
85 changes: 85 additions & 0 deletions aikido_zen/sources/starlette/starlette_routing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Wraps starlette.applications for initial request_handler
Attention: We will be using rr to refer to request_response. It's used a lot and
readability would be impaired if we did not abbreviate this
"""

import copy
import aikido_zen.importhook as importhook
from aikido_zen.helpers.logging import logger
from .extract_data_from_request import extract_data_from_request
from ..functions.request_handler import request_handler


@importhook.on_import("starlette.routing")
def on_starlette_import(routing):
"""
Hook 'n wrap on `starlette.routing`
Wraps the request_response function so we can wrap the function given to request_response
"""
modified_routing = importhook.copy_module(routing)
former_rr_func = copy.deepcopy(routing.request_response)

def aikido_rr_func(func):
wrapped_route_func = aik_route_func_wrapper(func)
return former_rr_func(wrapped_route_func)

setattr(routing, "request_response", aikido_rr_func)
setattr(modified_routing, "request_response", aikido_rr_func)
return modified_routing


def aik_route_func_wrapper(func):
"""Aikido's __call__ wrapper"""

async def aikido_route_func(*args, **kwargs):
# Code before response (pre_response stage)
try:
req = args[0]
if not req:
return
await extract_data_from_request(req)
pre_response_results = request_handler(stage="pre_response")
if pre_response_results:
response = create_starlette_response(pre_response_results)
if response:
# Make sure to not return when an error occured or there is an invalid response
return response
except Exception as e:
logger.debug("Exception occured in pre_response stage starlette : %s", e)

# Calling the function, check if it's async or not (same checks as in codebase starlette)
try:
import functools
from starlette.concurrency import run_in_threadpool
from starlette._utils import is_async_callable
except ImportError:
logger.info("Make sure starlette install OK : .concurrency, ._utils")
return await func(*args, **kwargs)
res = None
if is_async_callable(func):
res = await func(*args, **kwargs)
else:
# `func` provided by the end-user is allowed to be both synchronous and asynchronous
# Here we convert sync functions in the same way the starlette codebase does to ensure
# there are no compatibility issues and the behaviour remains unchanged.
res = await functools.partial(run_in_threadpool, func)(*args, **kwargs)

# Code after response (post_response stage)
request_handler(stage="post_response", status_code=res.status_code)
return res

return aikido_route_func


def create_starlette_response(pre_response):
"""Tries to import PlainTextResponse and generates starlette plain text response"""
text, status_code = pre_response
try:
from starlette.responses import PlainTextResponse
except ImportError:
logger.info(
"Ensure `starlette` install is valid, failed to import starlette.responses"
)
return None
return PlainTextResponse(text, status_code)
41 changes: 41 additions & 0 deletions docs/starlette.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Starlette

1. Install `aikido_zen` package with pip :
```sh
pip install aikido_zen
```

2. Add the following snippet to the top of your `app.py` file :
```python
import aikido_zen
aikido_zen.protect()
```
Make sure this is above any other import, including above builtin package imports.

3. Setting your environment variables :
Make sure to set your token in order to communicate with Aikido's servers
```env
AIKIDO_TOKEN="AIK_RUNTIME_YOUR_TOKEN_HERE"
```

## Blocking mode

By default, the firewall will run in non-blocking mode. When it detects an attack, the attack will be reported to Aikido and continue executing the call.

You can enable blocking mode by setting the environment variable `AIKIDO_BLOCKING` to `true`:

```sh
AIKIDO_BLOCKING=true
```

It's recommended to enable this on your staging environment for a considerable amount of time before enabling it on your production environment (e.g. one week).

## Debug mode

If you need to debug the firewall, you can run your code with the environment variable `AIKIDO_DEBUG` set to `true`:

```sh
AIKIDO_DEBUG=true
```

This will output debug information to the console (e.g. no token was found, unsupported packages, extra information, ...).
38 changes: 38 additions & 0 deletions end2end/starlette_postgres_uvicorn_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
import requests
# e2e tests for flask_postgres sample app
post_url_fw = "http://localhost:8102/create"
post_url_nofw = "http://localhost:8103/create"
sync_route_fw = "http://localhost:8102/sync_route"
sync_route_nofw = "http://localhost:8103/sync_route"

def test_safe_response_with_firewall():
dog_name = "Bobby Tables"
res = requests.post(post_url_fw, data={'dog_name': dog_name})
assert res.status_code == 201


def test_safe_response_without_firewall():
dog_name = "Bobby Tables"
res = requests.post(post_url_nofw, data={'dog_name': dog_name})
assert res.status_code == 201


def test_dangerous_response_with_firewall():
dog_name = "Dangerous Bobby', TRUE); -- "
res = requests.post(post_url_fw, data={'dog_name': dog_name})
assert res.status_code == 500

def test_dangerous_response_without_firewall():
dog_name = "Dangerous Bobby', TRUE); -- "
res = requests.post(post_url_nofw, data={'dog_name': dog_name})
assert res.status_code == 201


def test_sync_route_with_firewall():
res = requests.get(sync_route_fw)
assert res.status_code == 200

def test_sync_route_without_firewall():
res = requests.get(sync_route_nofw)
assert res.status_code == 200
37 changes: 11 additions & 26 deletions sample-apps/starlette-postgres-uvicorn/app.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from dotenv import load_dotenv
import os
import asyncpg
from starlette.applications import Starlette
from starlette.responses import HTMLResponse, JSONResponse
from starlette.routing import Route
from starlette.templating import Jinja2Templates
from starlette.requests import Request

load_dotenv()
firewall_disabled = os.getenv("FIREWALL_DISABLED")
if firewall_disabled is not None:
if firewall_disabled.lower() != "1":
import aikido_zen # Aikido package import
aikido_zen.protect()

import asyncpg
from starlette.applications import Starlette
from starlette.responses import HTMLResponse, JSONResponse
from starlette.routing import Route
from starlette.templating import Jinja2Templates
from starlette.requests import Request

templates = Jinja2Templates(directory="templates")

async def get_db_connection():
Expand All @@ -40,9 +40,6 @@ async def get_dogpage(request: Request):
async def show_create_dog_form(request: Request):
return templates.TemplateResponse('create_dog.html', {"request": request})

async def show_create_dog_form_many(request: Request):
return templates.TemplateResponse('create_dog.html', {"request": request})

async def create_dog(request: Request):
data = await request.form()
dog_name = data.get('dog_name')
Expand All @@ -58,26 +55,14 @@ async def create_dog(request: Request):

return JSONResponse({"message": f'Dog {dog_name} created successfully'}, status_code=201)

async def create_dog_many(request: Request):
data = await request.form()
dog_names = data.getlist('dog_name') # Expecting a list of dog names

if not dog_names:
return JSONResponse({"error": "dog_names must be a list and cannot be empty"}, status_code=400)

conn = await get_db_connection()
try:
await conn.executemany("INSERT INTO dogs (dog_name, isAdmin) VALUES ($1, FALSE)", [(name,) for name in dog_names])
finally:
await conn.close()

return JSONResponse({"message": f'{", ".join(dog_names)} created successfully'}, status_code=201)
def sync_route(request):
data = {"message": "This is a non-async route!"}
return JSONResponse(data)

app = Starlette(routes=[
Route("/", homepage),
Route("/dogpage/{dog_id:int}", get_dogpage),
Route("/create", show_create_dog_form, methods=["GET"]),
Route("/create_many", show_create_dog_form_many, methods=["GET"]),
Route("/create", create_dog, methods=["POST"]),
Route("/create_many", create_dog_many, methods=["POST"]),
Route("/sync_route", sync_route)
])

0 comments on commit 8a6b7f7

Please sign in to comment.