保定ai问答主体项目
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
bd_ai_fastapi/app/api/chat/ai/chat_router.py

141 lines
5.4 KiB

3 months ago
from fastapi import APIRouter, HTTPException
from starlette.requests import Request
from fastapi.responses import StreamingResponse
from typing import AsyncGenerator, List
from .chat_service import classify, extract_spot, query_flow, gen_markdown_stream, ai_chat_stream, handle_quick_question
3 months ago
from app.models.ChatIn import ChatIn
from app.models.quick_question import QuickQuestion
from pydantic import BaseModel
3 months ago
import hmac
import hashlib
import time
import json
from app.settings.config import settings
router = APIRouter()
SECRET_KEY = settings.SIGN_KEY # 约定的密钥
TIMESTAMP_TOLERANCE = 60 # 时间戳容忍度,单位:秒
CONVERSATION_EXPIRE_TIME = 600 # 对话历史过期时间,单位:秒
@router.post("/chat", summary="ai对话")
async def h5_chat_stream(request: Request, inp: ChatIn):
if inp.sign is None:
raise HTTPException(status_code=401, detail="缺少签名参数")
if not verify_signature(inp.dict(), inp.sign):
raise HTTPException(status_code=401, detail="无效的签名")
if not verify_timestamp(inp.timestamp):
raise HTTPException(status_code=401, detail="时间戳无效")
if inp.language is None:
inp.language = "zh_cn"
# 获取用户 ID
user_id = inp.user_id
if not user_id:
raise HTTPException(status_code=400, detail="缺少用户 ID")
# 从 Redis 中获取用户的对话历史
redis_client = request.app.state.redis_client
conversation_history_key = f"conversation:{user_id}"
conversation_history_str = await redis_client.get(conversation_history_key)
conversation_history = json.loads(conversation_history_str) if conversation_history_str else []
# 获取开启的前4个问题(包含标题和内容)
questions = await QuickQuestion.filter(status="0").order_by("order_num").limit(4).values("title", "content")
question_titles = [q["title"] for q in questions]
# 检查消息是否在问题列表中
is_quick_question = inp.message in question_titles
# 分类阶段(如果不是快捷问题才执行)
cat = None
if not is_quick_question:
cat = await classify(inp.message)
3 months ago
async def content_stream() -> AsyncGenerator[str, None]:
nonlocal conversation_history
try:
if is_quick_question:
# 找到对应的问题内容
question_content = next(q["content"] for q in questions if q["title"] == inp.message)
# 处理快捷问题,传递content
async for chunk in handle_quick_question(inp, question_content):
3 months ago
yield chunk
else:
# 原来的逻辑
if cat == "游玩判断":
spot = await extract_spot(inp.message)
data = await query_flow(request, spot)
async for chunk in gen_markdown_stream(inp.message, data, inp.language, conversation_history):
yield chunk
else:
async for chunk in ai_chat_stream(inp, conversation_history):
yield chunk
3 months ago
# 将更新后的对话历史存回 Redis,并设置过期时间
# 只有非快捷问题才保存对话历史
if not is_quick_question:
await redis_client.setex(conversation_history_key, CONVERSATION_EXPIRE_TIME, json.dumps(conversation_history))
3 months ago
except Exception as e:
print(f"Error in content_stream: {e}")
raise
try:
return StreamingResponse(
content_stream(),
3 months ago
media_type="text/plain",
3 months ago
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
except Exception as e:
print(f"Error in StreamingResponse: {e}")
raise HTTPException(status_code=500, detail="Internal Server Error")
3 months ago
3 months ago
class ClearConversationRequest(BaseModel):
user_id: int
@router.post("/clear_conversation", summary="清除对话记录")
async def clear_conversation(request: Request, body: ClearConversationRequest):
user_id = body.user_id
if not user_id:
raise HTTPException(status_code=400, detail="缺少用户 ID")
redis_client = request.app.state.redis_client
conversation_history_key = f"conversation:{user_id}"
await redis_client.delete(conversation_history_key)
return {"message": "对话历史已清除"}
def verify_signature(data: dict, sign: str) -> bool:
sorted_keys = sorted(data.keys())
sign_str = '&'.join([f'{key}={data[key]}' for key in sorted_keys if key != 'sign'])
sign_str += f'&secret={SECRET_KEY}'
calculated_sign = hmac.new(SECRET_KEY.encode(), sign_str.encode(), hashlib.sha256).hexdigest()
return calculated_sign == sign
def verify_timestamp(timestamp: int) -> bool:
current_timestamp = int(time.time() * 1000)
return abs(current_timestamp - timestamp) <= TIMESTAMP_TOLERANCE * 1000
# 定义获取问题的响应模型
class QuestionResponse(BaseModel):
id: int
title: str
class Config:
orm_mode = True
@router.get("/getQuestion", summary="获取开启的前4个问题")
async def get_question():
# 查询状态为正常(0)的问题,按order_num正序排序,取前4条
questions = await QuickQuestion.filter(status="0").order_by("order_num").limit(4).values("title","subtitle","logo","label")
return questions