11#!/usr/bin/env python3
22# -*- coding: utf-8 -*-
3- from typing import Any , AsyncGenerator
4-
53from asgiref .sync import sync_to_async
4+ from fastapi import Response
65from starlette .background import BackgroundTask
76from starlette .datastructures import UploadFile
7+ from starlette .middleware .base import BaseHTTPMiddleware
88from starlette .requests import Request
9- from starlette .types import ASGIApp , Receive , Scope , Send
109
1110from backend .app .common .enums import OperaLogCipherType
1211from backend .app .common .log import log
1817from backend .app .utils .timezone import timezone
1918
2019
21- class OperaLogMiddleware :
20+ class OperaLogMiddleware ( BaseHTTPMiddleware ) :
2221 """操作日志中间件"""
2322
24- def __init__ (self , app : ASGIApp ):
25- self .app = app
26-
27- async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
28- if scope ['type' ] != 'http' :
29- await self .app (scope , receive , send )
30- return
31-
32- request = Request (scope = scope , receive = receive )
33-
23+ async def dispatch (self , request : Request , call_next ) -> Response :
3424 # 排除记录白名单
3525 path = request .url .path
3626 if path in settings .OPERA_LOG_EXCLUDE or not path .startswith (f'{ settings .API_V1_STR } ' ):
37- await self .app (scope , receive , send )
38- return
27+ return await call_next (request )
3928
40- # 请求信息解析
29+ # 请求解析
4130 user_agent , device , os , browser = await parse_user_agent_info (request )
4231 ip , country , region , city = await parse_ip_info (request )
4332 try :
@@ -46,10 +35,10 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
4635 except AttributeError :
4736 username = None
4837 method = request .method
49- args = await self .get_request_args (request )
5038 router = request .scope .get ('route' )
5139 summary = getattr (router , 'summary' , None ) or ''
52- args .update (request .path_params )
40+ args = await self .get_request_args (request )
41+ args = await self .desensitization (args )
5342
5443 # 设置附加请求信息
5544 request .state .ip = ip
@@ -63,13 +52,10 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
6352
6453 # 执行请求
6554 start_time = timezone .now ()
66- code , msg , status , err = await self .execute_request (request , send )
55+ code , msg , status , err , response = await self .execute_request (request , call_next )
6756 end_time = timezone .now ()
6857 cost_time = (end_time - start_time ).total_seconds () * 1000.0
6958
70- # 脱敏处理
71- args = await self .desensitization (args )
72-
7359 # 日志创建
7460 opera_log_in = CreateOperaLog (
7561 username = username ,
@@ -98,19 +84,15 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
9884 if err :
9985 raise err from None
10086
101- async def execute_request (self , request : Request , send : Send ) -> tuple :
102- err : Any = None
87+ return response
88+
89+ async def execute_request (self , request : Request , call_next ) -> tuple :
90+ """执行请求"""
91+ err = None
92+ response = None
10393 try :
104- # 详见 https://github.com/tiangolo/fastapi/discussions/8385#discussioncomment-6117967
105- async def wrapped_rcv_gen () -> AsyncGenerator :
106- async for _ in request .stream ():
107- yield {'type' : 'http.request' , 'body' : await request .body ()}
108- async for message in request .receive : # type: ignore
109- yield message
110-
111- wrapped_rcv = wrapped_rcv_gen ().__anext__
112- await self .app (request .scope , wrapped_rcv , send )
113- code , msg , status = await self .exception_middleware_handler (request )
94+ response = await call_next (request )
95+ code , msg , status = await self .request_exception_handler (request )
11496 except Exception as e :
11597 log .exception (e )
11698 # code 处理包含 SQLAlchemy 和 Pydantic
@@ -119,12 +101,12 @@ async def wrapped_rcv_gen() -> AsyncGenerator:
119101 status = 0
120102 err = e
121103
122- return str (code ), msg , status , err
104+ return str (code ), msg , status , err , response
123105
124106 @staticmethod
125107 @sync_to_async
126- def exception_middleware_handler (request : Request ) -> tuple :
127- # 预置响应信息
108+ def request_exception_handler (request : Request ) -> tuple :
109+ """请求异常处理器"""
128110 code = 200
129111 msg = 'Success'
130112 status = 1
@@ -148,12 +130,16 @@ def exception_middleware_handler(request: Request) -> tuple:
148130
149131 @staticmethod
150132 async def get_request_args (request : Request ) -> dict :
133+ """获取请求参数"""
151134 args = dict (request .query_params )
135+ args .update (request .path_params )
136+ # Tip: .body() 必须在 .form() 之前获取
137+ # https://github.com/encode/starlette/discussions/1933
138+ body_data = await request .body ()
152139 form_data = await request .form ()
153140 if len (form_data ) > 0 :
154141 args .update ({k : v .filename if isinstance (v , UploadFile ) else v for k , v in form_data .items ()})
155142 else :
156- body_data = await request .body ()
157143 if body_data :
158144 json_data = await request .json ()
159145 if not isinstance (json_data , dict ):
@@ -168,7 +154,15 @@ async def get_request_args(request: Request) -> dict:
168154 @staticmethod
169155 @sync_to_async
170156 def desensitization (args : dict ) -> dict | None :
171- if len (args ) > 0 :
157+ """
158+ 脱敏处理
159+
160+ :param args:
161+ :return:
162+ """
163+ if not args :
164+ args = None
165+ else :
172166 match settings .OPERA_LOG_ENCRYPT :
173167 case OperaLogCipherType .aes :
174168 for key in args .keys ():
@@ -188,4 +182,4 @@ def desensitization(args: dict) -> dict | None:
188182 for key in args .keys ():
189183 if key in settings .OPERA_LOG_ENCRYPT_INCLUDE :
190184 args [key ] = '******'
191- return args if len ( args ) > 0 else None
185+ return args
0 commit comments