1
1
#!/usr/bin/env python3
2
2
# -*- coding: utf-8 -*-
3
- from typing import Any , AsyncGenerator
4
-
5
3
from asgiref .sync import sync_to_async
4
+ from fastapi import Response
6
5
from starlette .background import BackgroundTask
7
6
from starlette .datastructures import UploadFile
7
+ from starlette .middleware .base import BaseHTTPMiddleware
8
8
from starlette .requests import Request
9
- from starlette .types import ASGIApp , Receive , Scope , Send
10
9
11
10
from backend .app .common .enums import OperaLogCipherType
12
11
from backend .app .common .log import log
18
17
from backend .app .utils .timezone import timezone
19
18
20
19
21
- class OperaLogMiddleware :
20
+ class OperaLogMiddleware ( BaseHTTPMiddleware ) :
22
21
"""操作日志中间件"""
23
22
24
- def __init__ (self , app : ASGIApp ):
25
- self .app = app
26
-
27
- async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
28
- if scope ['type' ] != 'http' :
29
- await self .app (scope , receive , send )
30
- return
31
-
32
- request = Request (scope = scope , receive = receive )
33
-
23
+ async def dispatch (self , request : Request , call_next ) -> Response :
34
24
# 排除记录白名单
35
25
path = request .url .path
36
26
if path in settings .OPERA_LOG_EXCLUDE or not path .startswith (f'{ settings .API_V1_STR } ' ):
37
- await self .app (scope , receive , send )
38
- return
27
+ return await call_next (request )
39
28
40
- # 请求信息解析
29
+ # 请求解析
41
30
user_agent , device , os , browser = await parse_user_agent_info (request )
42
31
ip , country , region , city = await parse_ip_info (request )
43
32
try :
@@ -46,10 +35,10 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
46
35
except AttributeError :
47
36
username = None
48
37
method = request .method
49
- args = await self .get_request_args (request )
50
38
router = request .scope .get ('route' )
51
39
summary = getattr (router , 'summary' , None ) or ''
52
- args .update (request .path_params )
40
+ args = await self .get_request_args (request )
41
+ args = await self .desensitization (args )
53
42
54
43
# 设置附加请求信息
55
44
request .state .ip = ip
@@ -63,13 +52,10 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
63
52
64
53
# 执行请求
65
54
start_time = timezone .now ()
66
- code , msg , status , err = await self .execute_request (request , send )
55
+ code , msg , status , err , response = await self .execute_request (request , call_next )
67
56
end_time = timezone .now ()
68
57
cost_time = (end_time - start_time ).total_seconds () * 1000.0
69
58
70
- # 脱敏处理
71
- args = await self .desensitization (args )
72
-
73
59
# 日志创建
74
60
opera_log_in = CreateOperaLog (
75
61
username = username ,
@@ -98,19 +84,15 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
98
84
if err :
99
85
raise err from None
100
86
101
- async def execute_request (self , request : Request , send : Send ) -> tuple :
102
- err : Any = None
87
+ return response
88
+
89
+ async def execute_request (self , request : Request , call_next ) -> tuple :
90
+ """执行请求"""
91
+ err = None
92
+ response = None
103
93
try :
104
- # 详见 https://github.com/tiangolo/fastapi/discussions/8385#discussioncomment-6117967
105
- async def wrapped_rcv_gen () -> AsyncGenerator :
106
- async for _ in request .stream ():
107
- yield {'type' : 'http.request' , 'body' : await request .body ()}
108
- async for message in request .receive : # type: ignore
109
- yield message
110
-
111
- wrapped_rcv = wrapped_rcv_gen ().__anext__
112
- await self .app (request .scope , wrapped_rcv , send )
113
- code , msg , status = await self .exception_middleware_handler (request )
94
+ response = await call_next (request )
95
+ code , msg , status = await self .request_exception_handler (request )
114
96
except Exception as e :
115
97
log .exception (e )
116
98
# code 处理包含 SQLAlchemy 和 Pydantic
@@ -119,12 +101,12 @@ async def wrapped_rcv_gen() -> AsyncGenerator:
119
101
status = 0
120
102
err = e
121
103
122
- return str (code ), msg , status , err
104
+ return str (code ), msg , status , err , response
123
105
124
106
@staticmethod
125
107
@sync_to_async
126
- def exception_middleware_handler (request : Request ) -> tuple :
127
- # 预置响应信息
108
+ def request_exception_handler (request : Request ) -> tuple :
109
+ """请求异常处理器"""
128
110
code = 200
129
111
msg = 'Success'
130
112
status = 1
@@ -148,12 +130,16 @@ def exception_middleware_handler(request: Request) -> tuple:
148
130
149
131
@staticmethod
150
132
async def get_request_args (request : Request ) -> dict :
133
+ """获取请求参数"""
151
134
args = dict (request .query_params )
135
+ args .update (request .path_params )
136
+ # Tip: .body() 必须在 .form() 之前获取
137
+ # https://github.com/encode/starlette/discussions/1933
138
+ body_data = await request .body ()
152
139
form_data = await request .form ()
153
140
if len (form_data ) > 0 :
154
141
args .update ({k : v .filename if isinstance (v , UploadFile ) else v for k , v in form_data .items ()})
155
142
else :
156
- body_data = await request .body ()
157
143
if body_data :
158
144
json_data = await request .json ()
159
145
if not isinstance (json_data , dict ):
@@ -168,7 +154,15 @@ async def get_request_args(request: Request) -> dict:
168
154
@staticmethod
169
155
@sync_to_async
170
156
def desensitization (args : dict ) -> dict | None :
171
- if len (args ) > 0 :
157
+ """
158
+ 脱敏处理
159
+
160
+ :param args:
161
+ :return:
162
+ """
163
+ if not args :
164
+ args = None
165
+ else :
172
166
match settings .OPERA_LOG_ENCRYPT :
173
167
case OperaLogCipherType .aes :
174
168
for key in args .keys ():
@@ -188,4 +182,4 @@ def desensitization(args: dict) -> dict | None:
188
182
for key in args .keys ():
189
183
if key in settings .OPERA_LOG_ENCRYPT_INCLUDE :
190
184
args [key ] = '******'
191
- return args if len ( args ) > 0 else None
185
+ return args
0 commit comments