You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
198 lines
7.6 KiB
198 lines
7.6 KiB
import json
|
|
import re
|
|
from datetime import datetime
|
|
from typing import Any, AsyncGenerator
|
|
|
|
from fastapi import FastAPI
|
|
from fastapi.responses import Response
|
|
from fastapi.routing import APIRoute
|
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
from starlette.requests import Request
|
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
|
|
|
from app.core.dependency import AuthControl
|
|
from app.models.admin import AuditLog, User
|
|
|
|
from .bgtask import BgTasks
|
|
|
|
|
|
class SimpleBaseMiddleware:
|
|
def __init__(self, app: ASGIApp) -> None:
|
|
self.app = app
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
if scope["type"] != "http":
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
request = Request(scope, receive=receive)
|
|
|
|
response = await self.before_request(request) or self.app
|
|
await response(request.scope, request.receive, send)
|
|
await self.after_request(request)
|
|
|
|
async def before_request(self, request: Request):
|
|
return self.app
|
|
|
|
async def after_request(self, request: Request):
|
|
return None
|
|
|
|
|
|
class BackGroundTaskMiddleware(SimpleBaseMiddleware):
|
|
async def before_request(self, request):
|
|
await BgTasks.init_bg_tasks_obj()
|
|
|
|
async def after_request(self, request):
|
|
await BgTasks.execute_tasks()
|
|
|
|
|
|
class HttpAuditLogMiddleware(BaseHTTPMiddleware):
|
|
def __init__(self, app, methods: list[str], exclude_paths: list[str]):
|
|
super().__init__(app)
|
|
self.methods = methods
|
|
self.exclude_paths = exclude_paths
|
|
self.audit_log_paths = ["/api/v1/auditlog/list"]
|
|
self.max_body_size = 1024 * 1024 # 1MB 响应体大小限制
|
|
|
|
async def get_request_args(self, request: Request) -> dict:
|
|
args = {}
|
|
# 获取查询参数
|
|
for key, value in request.query_params.items():
|
|
args[key] = value
|
|
|
|
# 获取请求体
|
|
if request.method in ["POST", "PUT", "PATCH"]:
|
|
try:
|
|
# 检查内容类型
|
|
content_type = request.headers.get("content-type", "")
|
|
if "multipart/form-data" in content_type:
|
|
# 文件上传请求,使用form()
|
|
body = await request.form()
|
|
for k, v in body.items():
|
|
if hasattr(v, "filename"): # 文件上传行为
|
|
args[k] = v.filename
|
|
elif isinstance(v, list) and v and hasattr(v[0], "filename"):
|
|
args[k] = [file.filename for file in v]
|
|
else:
|
|
args[k] = v
|
|
else:
|
|
# 尝试JSON解析
|
|
body = await request.json()
|
|
args.update(body)
|
|
except (json.JSONDecodeError, UnicodeDecodeError, Exception):
|
|
# 捕获所有可能的错误,包括UnicodeDecodeError
|
|
try:
|
|
# 尝试作为普通表单数据处理
|
|
body = await request.form()
|
|
for k, v in body.items():
|
|
if hasattr(v, "filename"): # 文件上传行为
|
|
args[k] = v.filename
|
|
elif isinstance(v, list) and v and hasattr(v[0], "filename"):
|
|
args[k] = [file.filename for file in v]
|
|
else:
|
|
args[k] = v
|
|
except Exception:
|
|
pass
|
|
except Exception:
|
|
pass
|
|
return args
|
|
|
|
async def get_response_body(self, request: Request, response: Response) -> Any:
|
|
# 检查Content-Length
|
|
content_length = response.headers.get("content-length")
|
|
if content_length and int(content_length) > self.max_body_size:
|
|
return {"code": 0, "msg": "Response too large to log", "data": None}
|
|
|
|
if hasattr(response, "body"):
|
|
body = response.body
|
|
else:
|
|
body_chunks = []
|
|
async for chunk in response.body_iterator:
|
|
if not isinstance(chunk, bytes):
|
|
chunk = chunk.encode(response.charset)
|
|
body_chunks.append(chunk)
|
|
|
|
response.body_iterator = self._async_iter(body_chunks)
|
|
body = b"".join(body_chunks)
|
|
|
|
if any(request.url.path.startswith(path) for path in self.audit_log_paths):
|
|
try:
|
|
data = self.lenient_json(body)
|
|
# 只保留基本信息,去除详细的响应内容
|
|
if isinstance(data, dict):
|
|
data.pop("response_body", None)
|
|
if "data" in data and isinstance(data["data"], list):
|
|
for item in data["data"]:
|
|
item.pop("response_body", None)
|
|
return data
|
|
except Exception:
|
|
return None
|
|
|
|
return self.lenient_json(body)
|
|
|
|
def lenient_json(self, v: Any) -> Any:
|
|
if isinstance(v, (str, bytes)):
|
|
try:
|
|
return json.loads(v)
|
|
except (ValueError, TypeError):
|
|
pass
|
|
return v
|
|
|
|
async def _async_iter(self, items: list[bytes]) -> AsyncGenerator[bytes, None]:
|
|
for item in items:
|
|
yield item
|
|
|
|
async def get_request_log(self, request: Request, response: Response) -> dict:
|
|
"""
|
|
根据request和response对象获取对应的日志记录数据
|
|
"""
|
|
data: dict = {"path": request.url.path, "status": response.status_code, "method": request.method}
|
|
# 路由信息
|
|
app: FastAPI = request.app
|
|
for route in app.routes:
|
|
if (
|
|
isinstance(route, APIRoute)
|
|
and route.path_regex.match(request.url.path)
|
|
and request.method in route.methods
|
|
):
|
|
data["module"] = ",".join(route.tags)
|
|
data["summary"] = route.summary
|
|
# 获取用户信息
|
|
try:
|
|
token = request.headers.get("token")
|
|
user_obj = None
|
|
if token:
|
|
user_obj: User = await AuthControl.is_authed(token)
|
|
data["user_id"] = user_obj.id if user_obj else 0
|
|
data["username"] = user_obj.username if user_obj else ""
|
|
except Exception:
|
|
data["user_id"] = 0
|
|
data["username"] = ""
|
|
return data
|
|
|
|
async def before_request(self, request: Request):
|
|
request_args = await self.get_request_args(request)
|
|
request.state.request_args = request_args
|
|
|
|
async def after_request(self, request: Request, response: Response, process_time: int):
|
|
if request.method in self.methods:
|
|
for path in self.exclude_paths:
|
|
if re.search(path, request.url.path, re.I) is not None:
|
|
return
|
|
data: dict = await self.get_request_log(request=request, response=response)
|
|
data["response_time"] = process_time
|
|
|
|
data["request_args"] = request.state.request_args
|
|
data["response_body"] = await self.get_response_body(request, response)
|
|
await AuditLog.create(**data)
|
|
|
|
return response
|
|
|
|
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
|
start_time: datetime = datetime.now()
|
|
await self.before_request(request)
|
|
response = await call_next(request)
|
|
end_time: datetime = datetime.now()
|
|
process_time = int((end_time.timestamp() - start_time.timestamp()) * 1000)
|
|
await self.after_request(request, response, process_time)
|
|
return response
|
|
|