-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #184 from AikidoSec/AIK-3589
AIK-3589 Add support for starlette
- Loading branch information
Showing
11 changed files
with
262 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
""" | ||
Init.py file for starlette module | ||
--- | ||
Starlette wrapping is subdivided in two parts : | ||
- starlette.applications : Wraps __call__ on Starlette class to run "init" stage. | ||
- starlette.routing : request_response function : Run pre_response code and | ||
also runs post_response code after getting response from user function. | ||
Folder also includes helper functions : | ||
- extract_data_from_request, which will extract the data from a request object safely, | ||
e.g. body, json, form. This also saves it inside the current context. | ||
""" | ||
|
||
import aikido_zen.sources.starlette.starlette_applications | ||
import aikido_zen.sources.starlette.starlette_routing |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
"""Exports function extract_data_from_request""" | ||
|
||
from aikido_zen.context import get_current_context | ||
|
||
|
||
async def extract_data_from_request(request): | ||
"""Extracts json, form_data or body from Starlette request""" | ||
context = get_current_context() | ||
if not context: | ||
return | ||
|
||
# Parse data | ||
try: | ||
context.body = await request.json() | ||
except ValueError: | ||
# Throws error if the body is not json | ||
pass | ||
if not context.body: | ||
form_data = await request.form() | ||
if form_data: | ||
# Convert to dict object : | ||
context.body = {key: value for key, value in form_data.items()} | ||
if not context.body: | ||
context.body = await request.body() | ||
|
||
context.set_as_current_context() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
"""Wraps starlette.applications for initial request_handler""" | ||
|
||
import copy | ||
import aikido_zen.importhook as importhook | ||
from aikido_zen.helpers.logging import logger | ||
from aikido_zen.context import Context | ||
from aikido_zen.background_process.packages import add_wrapped_package | ||
from ..functions.request_handler import request_handler | ||
|
||
|
||
@importhook.on_import("starlette.applications") | ||
def on_starlette_import(starlette): | ||
""" | ||
Hook 'n wrap on `starlette.applications` | ||
Our goal is to wrap the __call__ function of the Starlette class | ||
""" | ||
modified_starlette = importhook.copy_module(starlette) | ||
former_call = copy.deepcopy(starlette.Starlette.__call__) | ||
|
||
async def aikido___call__(app, scope, receive=None, send=None): | ||
return await aik_call_wrapper(former_call, app, scope, receive, send) | ||
|
||
setattr(modified_starlette.Starlette, "__call__", aikido___call__) | ||
add_wrapped_package("starlette") | ||
return modified_starlette | ||
|
||
|
||
async def aik_call_wrapper(former_call, app, scope, receive, send): | ||
"""Aikido's __call__ wrapper""" | ||
try: | ||
if scope["type"] != "http": | ||
return await former_call(app, scope, receive, send) | ||
context1 = Context(req=scope, source="starlette") | ||
context1.set_as_current_context() | ||
request_handler(stage="init") | ||
except Exception as e: | ||
logger.debug("Exception on aikido __call__ function : %s", e) | ||
return await former_call(app, scope, receive, send) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
""" | ||
Wraps starlette.applications for initial request_handler | ||
Attention: We will be using rr to refer to request_response. It's used a lot and | ||
readability would be impaired if we did not abbreviate this | ||
""" | ||
|
||
import copy | ||
import aikido_zen.importhook as importhook | ||
from aikido_zen.helpers.logging import logger | ||
from .extract_data_from_request import extract_data_from_request | ||
from ..functions.request_handler import request_handler | ||
|
||
|
||
@importhook.on_import("starlette.routing") | ||
def on_starlette_import(routing): | ||
""" | ||
Hook 'n wrap on `starlette.routing` | ||
Wraps the request_response function so we can wrap the function given to request_response | ||
""" | ||
modified_routing = importhook.copy_module(routing) | ||
former_rr_func = copy.deepcopy(routing.request_response) | ||
|
||
def aikido_rr_func(func): | ||
wrapped_route_func = aik_route_func_wrapper(func) | ||
return former_rr_func(wrapped_route_func) | ||
|
||
setattr(routing, "request_response", aikido_rr_func) | ||
setattr(modified_routing, "request_response", aikido_rr_func) | ||
return modified_routing | ||
|
||
|
||
def aik_route_func_wrapper(func): | ||
"""Aikido's __call__ wrapper""" | ||
|
||
async def aikido_route_func(*args, **kwargs): | ||
# Code before response (pre_response stage) | ||
try: | ||
req = args[0] | ||
if not req: | ||
return | ||
await extract_data_from_request(req) | ||
pre_response_results = request_handler(stage="pre_response") | ||
if pre_response_results: | ||
response = create_starlette_response(pre_response_results) | ||
if response: | ||
# Make sure to not return when an error occured or there is an invalid response | ||
return response | ||
except Exception as e: | ||
logger.debug("Exception occured in pre_response stage starlette : %s", e) | ||
|
||
# Calling the function, check if it's async or not (same checks as in codebase starlette) | ||
try: | ||
import functools | ||
from starlette.concurrency import run_in_threadpool | ||
from starlette._utils import is_async_callable | ||
except ImportError: | ||
logger.info("Make sure starlette install OK : .concurrency, ._utils") | ||
return await func(*args, **kwargs) | ||
res = None | ||
if is_async_callable(func): | ||
res = await func(*args, **kwargs) | ||
else: | ||
# `func` provided by the end-user is allowed to be both synchronous and asynchronous | ||
# Here we convert sync functions in the same way the starlette codebase does to ensure | ||
# there are no compatibility issues and the behaviour remains unchanged. | ||
res = await functools.partial(run_in_threadpool, func)(*args, **kwargs) | ||
|
||
# Code after response (post_response stage) | ||
request_handler(stage="post_response", status_code=res.status_code) | ||
return res | ||
|
||
return aikido_route_func | ||
|
||
|
||
def create_starlette_response(pre_response): | ||
"""Tries to import PlainTextResponse and generates starlette plain text response""" | ||
text, status_code = pre_response | ||
try: | ||
from starlette.responses import PlainTextResponse | ||
except ImportError: | ||
logger.info( | ||
"Ensure `starlette` install is valid, failed to import starlette.responses" | ||
) | ||
return None | ||
return PlainTextResponse(text, status_code) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Starlette | ||
|
||
1. Install `aikido_zen` package with pip : | ||
```sh | ||
pip install aikido_zen | ||
``` | ||
|
||
2. Add the following snippet to the top of your `app.py` file : | ||
```python | ||
import aikido_zen | ||
aikido_zen.protect() | ||
``` | ||
Make sure this is above any other import, including above builtin package imports. | ||
|
||
3. Setting your environment variables : | ||
Make sure to set your token in order to communicate with Aikido's servers | ||
```env | ||
AIKIDO_TOKEN="AIK_RUNTIME_YOUR_TOKEN_HERE" | ||
``` | ||
|
||
## Blocking mode | ||
|
||
By default, the firewall will run in non-blocking mode. When it detects an attack, the attack will be reported to Aikido and continue executing the call. | ||
|
||
You can enable blocking mode by setting the environment variable `AIKIDO_BLOCKING` to `true`: | ||
|
||
```sh | ||
AIKIDO_BLOCKING=true | ||
``` | ||
|
||
It's recommended to enable this on your staging environment for a considerable amount of time before enabling it on your production environment (e.g. one week). | ||
|
||
## Debug mode | ||
|
||
If you need to debug the firewall, you can run your code with the environment variable `AIKIDO_DEBUG` set to `true`: | ||
|
||
```sh | ||
AIKIDO_DEBUG=true | ||
``` | ||
|
||
This will output debug information to the console (e.g. no token was found, unsupported packages, extra information, ...). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import pytest | ||
import requests | ||
# e2e tests for flask_postgres sample app | ||
post_url_fw = "http://localhost:8102/create" | ||
post_url_nofw = "http://localhost:8103/create" | ||
sync_route_fw = "http://localhost:8102/sync_route" | ||
sync_route_nofw = "http://localhost:8103/sync_route" | ||
|
||
def test_safe_response_with_firewall(): | ||
dog_name = "Bobby Tables" | ||
res = requests.post(post_url_fw, data={'dog_name': dog_name}) | ||
assert res.status_code == 201 | ||
|
||
|
||
def test_safe_response_without_firewall(): | ||
dog_name = "Bobby Tables" | ||
res = requests.post(post_url_nofw, data={'dog_name': dog_name}) | ||
assert res.status_code == 201 | ||
|
||
|
||
def test_dangerous_response_with_firewall(): | ||
dog_name = "Dangerous Bobby', TRUE); -- " | ||
res = requests.post(post_url_fw, data={'dog_name': dog_name}) | ||
assert res.status_code == 500 | ||
|
||
def test_dangerous_response_without_firewall(): | ||
dog_name = "Dangerous Bobby', TRUE); -- " | ||
res = requests.post(post_url_nofw, data={'dog_name': dog_name}) | ||
assert res.status_code == 201 | ||
|
||
|
||
def test_sync_route_with_firewall(): | ||
res = requests.get(sync_route_fw) | ||
assert res.status_code == 200 | ||
|
||
def test_sync_route_without_firewall(): | ||
res = requests.get(sync_route_nofw) | ||
assert res.status_code == 200 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters