|
|
|
from fastapi import APIRouter, HTTPException
|
|
|
|
|
|
|
|
from starlette.requests import Request
|
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
from typing import AsyncGenerator
|
|
|
|
from .chat_service import classify, extract_spot, query_flow, gen_markdown_stream, ai_chat_stream
|
|
|
|
from app.models.ChatIn import ChatIn
|
|
|
|
import hmac
|
|
|
|
import hashlib
|
|
|
|
import time
|
|
|
|
import json
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from app.settings.config import settings
|
|
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
|
|
SECRET_KEY = settings.SIGN_KEY # 约定的密钥
|
|
|
|
TIMESTAMP_TOLERANCE = 60 # 时间戳容忍度,单位:秒
|
|
|
|
CONVERSATION_EXPIRE_TIME = 600 # 对话历史过期时间,单位:秒
|
|
|
|
|
|
|
|
@router.post("/chat", summary="ai对话")
|
|
|
|
async def h5_chat_stream(request: Request, inp: ChatIn):
|
|
|
|
if inp.sign is None:
|
|
|
|
raise HTTPException(status_code=401, detail="缺少签名参数")
|
|
|
|
|
|
|
|
if not verify_signature(inp.dict(), inp.sign):
|
|
|
|
raise HTTPException(status_code=401, detail="无效的签名")
|
|
|
|
|
|
|
|
if not verify_timestamp(inp.timestamp):
|
|
|
|
raise HTTPException(status_code=401, detail="时间戳无效")
|
|
|
|
|
|
|
|
if inp.language is None:
|
|
|
|
inp.language = "zh_cn"
|
|
|
|
|
|
|
|
# 获取用户 ID
|
|
|
|
user_id = inp.user_id
|
|
|
|
if not user_id:
|
|
|
|
raise HTTPException(status_code=400, detail="缺少用户 ID")
|
|
|
|
|
|
|
|
# 从 Redis 中获取用户的对话历史
|
|
|
|
redis_client = request.app.state.redis_client
|
|
|
|
conversation_history_key = f"conversation:{user_id}"
|
|
|
|
conversation_history_str = await redis_client.get(conversation_history_key)
|
|
|
|
conversation_history = json.loads(conversation_history_str) if conversation_history_str else []
|
|
|
|
|
|
|
|
# 分类阶段(保留同步调用)
|
|
|
|
cat = await classify(inp.message)
|
|
|
|
|
|
|
|
async def content_stream() -> AsyncGenerator[str, None]:
|
|
|
|
nonlocal conversation_history
|
|
|
|
try:
|
|
|
|
if cat == "游玩判断":
|
|
|
|
spot = await extract_spot(inp.message)
|
|
|
|
data = await query_flow(request, spot)
|
|
|
|
async for chunk in gen_markdown_stream(inp.message, data, inp.language, conversation_history):
|
|
|
|
yield chunk
|
|
|
|
else:
|
|
|
|
async for chunk in ai_chat_stream(inp, conversation_history):
|
|
|
|
yield chunk
|
|
|
|
|
|
|
|
# 将更新后的对话历史存回 Redis,并设置过期时间
|
|
|
|
await redis_client.setex(conversation_history_key, CONVERSATION_EXPIRE_TIME, json.dumps(conversation_history))
|
|
|
|
except Exception as e:
|
|
|
|
print(f"Error in content_stream: {e}")
|
|
|
|
raise
|
|
|
|
|
|
|
|
try:
|
|
|
|
return StreamingResponse(
|
|
|
|
content_stream(),
|
|
|
|
media_type="text/plain",
|
|
|
|
headers={
|
|
|
|
"Cache-Control": "no-cache",
|
|
|
|
"Connection": "keep-alive",
|
|
|
|
"X-Accel-Buffering": "no"
|
|
|
|
}
|
|
|
|
)
|
|
|
|
except Exception as e:
|
|
|
|
print(f"Error in StreamingResponse: {e}")
|
|
|
|
raise HTTPException(status_code=500, detail="Internal Server Error")
|
|
|
|
|
|
|
|
|
|
|
|
class ClearConversationRequest(BaseModel):
|
|
|
|
user_id: int
|
|
|
|
|
|
|
|
@router.post("/clear_conversation", summary="清除对话记录")
|
|
|
|
async def clear_conversation(request: Request, body: ClearConversationRequest):
|
|
|
|
user_id = body.user_id
|
|
|
|
if not user_id:
|
|
|
|
raise HTTPException(status_code=400, detail="缺少用户 ID")
|
|
|
|
|
|
|
|
redis_client = request.app.state.redis_client
|
|
|
|
conversation_history_key = f"conversation:{user_id}"
|
|
|
|
await redis_client.delete(conversation_history_key)
|
|
|
|
return {"message": "对话历史已清除"}
|
|
|
|
|
|
|
|
def verify_signature(data: dict, sign: str) -> bool:
|
|
|
|
sorted_keys = sorted(data.keys())
|
|
|
|
sign_str = '&'.join([f'{key}={data[key]}' for key in sorted_keys if key != 'sign'])
|
|
|
|
sign_str += f'&secret={SECRET_KEY}'
|
|
|
|
calculated_sign = hmac.new(SECRET_KEY.encode(), sign_str.encode(), hashlib.sha256).hexdigest()
|
|
|
|
return calculated_sign == sign
|
|
|
|
|
|
|
|
def verify_timestamp(timestamp: int) -> bool:
|
|
|
|
current_timestamp = int(time.time() * 1000)
|
|
|
|
return abs(current_timestamp - timestamp) <= TIMESTAMP_TOLERANCE * 1000
|