保定ai问答主体项目
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 
bd_ai_fastapi/app/core/middlewares.py

198 lines
7.6 KiB

import json
import re
from datetime import datetime
from typing import Any, AsyncGenerator
from fastapi import FastAPI
from fastapi.responses import Response
from fastapi.routing import APIRoute
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.types import ASGIApp, Receive, Scope, Send
from app.core.dependency import AuthControl
from app.models.admin import AuditLog, User
from .bgtask import BgTasks
class SimpleBaseMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
request = Request(scope, receive=receive)
response = await self.before_request(request) or self.app
await response(request.scope, request.receive, send)
await self.after_request(request)
async def before_request(self, request: Request):
return self.app
async def after_request(self, request: Request):
return None
class BackGroundTaskMiddleware(SimpleBaseMiddleware):
async def before_request(self, request):
await BgTasks.init_bg_tasks_obj()
async def after_request(self, request):
await BgTasks.execute_tasks()
class HttpAuditLogMiddleware(BaseHTTPMiddleware):
def __init__(self, app, methods: list[str], exclude_paths: list[str]):
super().__init__(app)
self.methods = methods
self.exclude_paths = exclude_paths
self.audit_log_paths = ["/api/v1/auditlog/list"]
self.max_body_size = 1024 * 1024 # 1MB 响应体大小限制
async def get_request_args(self, request: Request) -> dict:
args = {}
# 获取查询参数
for key, value in request.query_params.items():
args[key] = value
# 获取请求体
if request.method in ["POST", "PUT", "PATCH"]:
try:
# 检查内容类型
content_type = request.headers.get("content-type", "")
if "multipart/form-data" in content_type:
# 文件上传请求,使用form()
body = await request.form()
for k, v in body.items():
if hasattr(v, "filename"): # 文件上传行为
args[k] = v.filename
elif isinstance(v, list) and v and hasattr(v[0], "filename"):
args[k] = [file.filename for file in v]
else:
args[k] = v
else:
# 尝试JSON解析
body = await request.json()
args.update(body)
except (json.JSONDecodeError, UnicodeDecodeError, Exception):
# 捕获所有可能的错误,包括UnicodeDecodeError
try:
# 尝试作为普通表单数据处理
body = await request.form()
for k, v in body.items():
if hasattr(v, "filename"): # 文件上传行为
args[k] = v.filename
elif isinstance(v, list) and v and hasattr(v[0], "filename"):
args[k] = [file.filename for file in v]
else:
args[k] = v
except Exception:
pass
except Exception:
pass
return args
async def get_response_body(self, request: Request, response: Response) -> Any:
# 检查Content-Length
content_length = response.headers.get("content-length")
if content_length and int(content_length) > self.max_body_size:
return {"code": 0, "msg": "Response too large to log", "data": None}
if hasattr(response, "body"):
body = response.body
else:
body_chunks = []
async for chunk in response.body_iterator:
if not isinstance(chunk, bytes):
chunk = chunk.encode(response.charset)
body_chunks.append(chunk)
response.body_iterator = self._async_iter(body_chunks)
body = b"".join(body_chunks)
if any(request.url.path.startswith(path) for path in self.audit_log_paths):
try:
data = self.lenient_json(body)
# 只保留基本信息,去除详细的响应内容
if isinstance(data, dict):
data.pop("response_body", None)
if "data" in data and isinstance(data["data"], list):
for item in data["data"]:
item.pop("response_body", None)
return data
except Exception:
return None
return self.lenient_json(body)
def lenient_json(self, v: Any) -> Any:
if isinstance(v, (str, bytes)):
try:
return json.loads(v)
except (ValueError, TypeError):
pass
return v
async def _async_iter(self, items: list[bytes]) -> AsyncGenerator[bytes, None]:
for item in items:
yield item
async def get_request_log(self, request: Request, response: Response) -> dict:
"""
根据request和response对象获取对应的日志记录数据
"""
data: dict = {"path": request.url.path, "status": response.status_code, "method": request.method}
# 路由信息
app: FastAPI = request.app
for route in app.routes:
if (
isinstance(route, APIRoute)
and route.path_regex.match(request.url.path)
and request.method in route.methods
):
data["module"] = ",".join(route.tags)
data["summary"] = route.summary
# 获取用户信息
try:
token = request.headers.get("token")
user_obj = None
if token:
user_obj: User = await AuthControl.is_authed(token)
data["user_id"] = user_obj.id if user_obj else 0
data["username"] = user_obj.username if user_obj else ""
except Exception:
data["user_id"] = 0
data["username"] = ""
return data
async def before_request(self, request: Request):
request_args = await self.get_request_args(request)
request.state.request_args = request_args
async def after_request(self, request: Request, response: Response, process_time: int):
if request.method in self.methods:
for path in self.exclude_paths:
if re.search(path, request.url.path, re.I) is not None:
return
data: dict = await self.get_request_log(request=request, response=response)
data["response_time"] = process_time
data["request_args"] = request.state.request_args
data["response_body"] = await self.get_response_body(request, response)
await AuditLog.create(**data)
return response
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
start_time: datetime = datetime.now()
await self.before_request(request)
response = await call_next(request)
end_time: datetime = datetime.now()
process_time = int((end_time.timestamp() - start_time.timestamp()) * 1000)
await self.after_request(request, response, process_time)
return response