Skip to content

Commit f39d11e

Browse files
authored
Fix use request.form() in middleware (#260)
* Fix request.form() in middleware * resume global exception handling * update the order in which data is processed * Add demo interface for uploading files * fix path params
1 parent 227221f commit f39d11e

File tree

3 files changed

+52
-50
lines changed

3 files changed

+52
-50
lines changed

backend/app/api/v1/mixed/tests.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,29 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
3-
from fastapi import APIRouter
3+
from typing import Annotated
4+
5+
from fastapi import APIRouter, File, Form, UploadFile
46

57
from backend.app.common.response.response_schema import response_base
68
from backend.app.tasks import task_demo_async
79

810
router = APIRouter(prefix='/tests')
911

1012

11-
@router.post('/send', summary='测试异步任务')
12-
async def task_send():
13-
result = task_demo_async.delay()
14-
return await response_base.success(data=result.id)
15-
16-
1713
@router.post('/send', summary='异步任务演示')
1814
async def send_task():
1915
result = task_demo_async.delay()
2016
return await response_base.success(data=result.id)
17+
18+
19+
@router.post('/files', summary='上传文件演示')
20+
async def create_file(
21+
file: Annotated[bytes, File()],
22+
fileb: Annotated[UploadFile, File()],
23+
token: Annotated[str, Form()],
24+
):
25+
return {
26+
'file_size': len(file),
27+
'token': token,
28+
'fileb_content_type': fileb.content_type,
29+
}

backend/app/middleware/opera_log_middleware.py

+35-41
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
3-
from typing import Any, AsyncGenerator
4-
53
from asgiref.sync import sync_to_async
4+
from fastapi import Response
65
from starlette.background import BackgroundTask
76
from starlette.datastructures import UploadFile
7+
from starlette.middleware.base import BaseHTTPMiddleware
88
from starlette.requests import Request
9-
from starlette.types import ASGIApp, Receive, Scope, Send
109

1110
from backend.app.common.enums import OperaLogCipherType
1211
from backend.app.common.log import log
@@ -18,26 +17,16 @@
1817
from backend.app.utils.timezone import timezone
1918

2019

21-
class OperaLogMiddleware:
20+
class OperaLogMiddleware(BaseHTTPMiddleware):
2221
"""操作日志中间件"""
2322

24-
def __init__(self, app: ASGIApp):
25-
self.app = app
26-
27-
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
28-
if scope['type'] != 'http':
29-
await self.app(scope, receive, send)
30-
return
31-
32-
request = Request(scope=scope, receive=receive)
33-
23+
async def dispatch(self, request: Request, call_next) -> Response:
3424
# 排除记录白名单
3525
path = request.url.path
3626
if path in settings.OPERA_LOG_EXCLUDE or not path.startswith(f'{settings.API_V1_STR}'):
37-
await self.app(scope, receive, send)
38-
return
27+
return await call_next(request)
3928

40-
# 请求信息解析
29+
# 请求解析
4130
user_agent, device, os, browser = await parse_user_agent_info(request)
4231
ip, country, region, city = await parse_ip_info(request)
4332
try:
@@ -46,10 +35,10 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
4635
except AttributeError:
4736
username = None
4837
method = request.method
49-
args = await self.get_request_args(request)
5038
router = request.scope.get('route')
5139
summary = getattr(router, 'summary', None) or ''
52-
args.update(request.path_params)
40+
args = await self.get_request_args(request)
41+
args = await self.desensitization(args)
5342

5443
# 设置附加请求信息
5544
request.state.ip = ip
@@ -63,13 +52,10 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
6352

6453
# 执行请求
6554
start_time = timezone.now()
66-
code, msg, status, err = await self.execute_request(request, send)
55+
code, msg, status, err, response = await self.execute_request(request, call_next)
6756
end_time = timezone.now()
6857
cost_time = (end_time - start_time).total_seconds() * 1000.0
6958

70-
# 脱敏处理
71-
args = await self.desensitization(args)
72-
7359
# 日志创建
7460
opera_log_in = CreateOperaLog(
7561
username=username,
@@ -98,19 +84,15 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
9884
if err:
9985
raise err from None
10086

101-
async def execute_request(self, request: Request, send: Send) -> tuple:
102-
err: Any = None
87+
return response
88+
89+
async def execute_request(self, request: Request, call_next) -> tuple:
90+
"""执行请求"""
91+
err = None
92+
response = None
10393
try:
104-
# 详见 https://github.com/tiangolo/fastapi/discussions/8385#discussioncomment-6117967
105-
async def wrapped_rcv_gen() -> AsyncGenerator:
106-
async for _ in request.stream():
107-
yield {'type': 'http.request', 'body': await request.body()}
108-
async for message in request.receive: # type: ignore
109-
yield message
110-
111-
wrapped_rcv = wrapped_rcv_gen().__anext__
112-
await self.app(request.scope, wrapped_rcv, send)
113-
code, msg, status = await self.exception_middleware_handler(request)
94+
response = await call_next(request)
95+
code, msg, status = await self.request_exception_handler(request)
11496
except Exception as e:
11597
log.exception(e)
11698
# code 处理包含 SQLAlchemy 和 Pydantic
@@ -119,12 +101,12 @@ async def wrapped_rcv_gen() -> AsyncGenerator:
119101
status = 0
120102
err = e
121103

122-
return str(code), msg, status, err
104+
return str(code), msg, status, err, response
123105

124106
@staticmethod
125107
@sync_to_async
126-
def exception_middleware_handler(request: Request) -> tuple:
127-
# 预置响应信息
108+
def request_exception_handler(request: Request) -> tuple:
109+
"""请求异常处理器"""
128110
code = 200
129111
msg = 'Success'
130112
status = 1
@@ -148,12 +130,16 @@ def exception_middleware_handler(request: Request) -> tuple:
148130

149131
@staticmethod
150132
async def get_request_args(request: Request) -> dict:
133+
"""获取请求参数"""
151134
args = dict(request.query_params)
135+
args.update(request.path_params)
136+
# Tip: .body() 必须在 .form() 之前获取
137+
# https://github.com/encode/starlette/discussions/1933
138+
body_data = await request.body()
152139
form_data = await request.form()
153140
if len(form_data) > 0:
154141
args.update({k: v.filename if isinstance(v, UploadFile) else v for k, v in form_data.items()})
155142
else:
156-
body_data = await request.body()
157143
if body_data:
158144
json_data = await request.json()
159145
if not isinstance(json_data, dict):
@@ -168,7 +154,15 @@ async def get_request_args(request: Request) -> dict:
168154
@staticmethod
169155
@sync_to_async
170156
def desensitization(args: dict) -> dict | None:
171-
if len(args) > 0:
157+
"""
158+
脱敏处理
159+
160+
:param args:
161+
:return:
162+
"""
163+
if not args:
164+
args = None
165+
else:
172166
match settings.OPERA_LOG_ENCRYPT:
173167
case OperaLogCipherType.aes:
174168
for key in args.keys():
@@ -188,4 +182,4 @@ def desensitization(args: dict) -> dict | None:
188182
for key in args.keys():
189183
if key in settings.OPERA_LOG_ENCRYPT_INCLUDE:
190184
args[key] = '******'
191-
return args if len(args) > 0 else None
185+
return args

requirements.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ celery==5.3.6
1010
cryptography==41.0.7
1111
email-validator==2.0.0
1212
fast-captcha==0.2.1
13-
fastapi==0.105.0
13+
fastapi==0.108.0
1414
fastapi-limiter==0.1.5
1515
fastapi-pagination==0.12.13
1616
gunicorn==21.2.0
@@ -33,7 +33,6 @@ pytz==2023.3
3333
redis[hiredis]==4.5.5
3434
ruff==0.1.8
3535
SQLAlchemy==2.0.23
36-
starlette==0.27.0
3736
supervisor==4.2.5
3837
user-agents==2.2.0
3938
uvicorn[standard]==0.24.0

0 commit comments

Comments
 (0)