保定ai问答主体项目
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 
bd_ai_fastapi/app/api/chat/ai/chat_router.py

715 lines
27 KiB

import random
from anyio import sleep
from fastapi import APIRouter, HTTPException, Depends,FastAPI
from starlette.requests import Request
from fastapi.responses import StreamingResponse
from typing import AsyncGenerator, List
from .chat_service import classify, extract_spot, query_flow, gen_markdown_stream, ai_chat_stream, \
handle_quick_question, fetch_and_parse_markdown
from app.models.ChatIn import ChatIn, AllScenicFlowRequest, ScenicDetailRequest, ScenicParkingRequest
from app.models.quick_question import QuickQuestion
from app.models.hot_question import HotQuestion
from pydantic import BaseModel
import hmac
import hashlib
import time
import json
from app.settings.config import settings
# 更新导入语句,添加新函数
from app.api.chat.ai.chat_service import (
query_flow,
handle_quick_question,
get_all_scenic_flow_data,
get_scenic_detail_data,
get_scenic_parking_data,
extract_multi_scenic,
query_multi_scenic_flow,
get_all_toilet_data,
generate_recommended_questions
)
# 导入用于异步执行同步函数的模块
from concurrent.futures import ThreadPoolExecutor
import asyncio
# 导入Redis依赖注入
from app.core.redis_dependency import get_redis_client, redis_dependency
router = APIRouter()
SECRET_KEY = settings.SIGN_KEY # 约定的密钥
TIMESTAMP_TOLERANCE = 120 # 时间戳容忍度,单位:秒
CONVERSATION_EXPIRE_TIME = 600 # 对话历史过期时间,单位:秒
app = FastAPI()
# 全局异常处理器 - 捕获所有未处理的异常
@app.exception_handler(Exception)
async def universal_exception_handler(request: Request, exc: Exception):
"""处理所有未被捕获的异常,返回流式错误响应"""
print(f"全局异常捕获: {str(exc)}") # 仅在服务器日志记录详细错误
async def error_stream() -> AsyncGenerator[str, None]:
# 文旅场景化的错误提示
error_msg = "亲爱的游客,目前咨询人数较多,如同热门景点需要稍作等候~ \n\n请您稍后再次尝试,我们会尽快为您提供详尽的文旅咨询服务,感谢您的理解与支持!"
# 分块流式发送
chunk_size = 8
for i in range(0, len(error_msg), chunk_size):
chunk = error_msg[i:i + chunk_size]
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.03) # 模拟自然打字速度
# 发送结束标记
yield "data: [DONE]\n\n"
return StreamingResponse(
error_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Access-Control-Allow-Origin": "*"
}
)
@router.post("/chat", summary="ai对话")
async def h5_chat_stream(request: Request, inp: ChatIn, redis_client = Depends(get_redis_client)):
# 设置超时保护 - 防止长时间无响应
try:
# 设置整体超时时间(根据实际需求调整)
return await asyncio.wait_for(
_handle_chat_request(request, inp, redis_client),
timeout=90.0 # 90秒超时
)
except asyncio.TimeoutError:
# 处理超时情况
async def timeout_stream() -> AsyncGenerator[str, None]:
try:
error_msg = "亲爱的游客,目前系统暂时繁忙,可能是咨询过于火爆~ \n\n请您稍后重新发起咨询,我们会第一时间为您提供文旅信息服务,感谢您的耐心!"
chunk_size = 8
for i in range(0, len(error_msg), chunk_size):
chunk = error_msg[i:i + chunk_size]
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.03)
finally:
# 确保在任何情况下都发送结束标记
yield "data: [DONE]\n\n"
return StreamingResponse(
timeout_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
except Exception as e:
# 这里作为最后的防线,确保任何错误都被捕获
print(f"聊天接口异常: {str(e)}")
async def error_stream() -> AsyncGenerator[str, None]:
try:
error_msg = "亲爱的游客,目前系统暂时繁忙,如同景区高峰期需要稍作等候~ \n\n请您稍后再次尝试,我们会尽快为您提供贴心的文旅咨询服务,感谢您的理解!"
chunk_size = 8
for i in range(0, len(error_msg), chunk_size):
chunk = error_msg[i:i + chunk_size]
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.03)
finally:
# 确保在任何情况下都发送结束标记
yield "data: [DONE]\n\n"
return StreamingResponse(
error_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive"
}
)
async def _handle_chat_request(
request: Request,
inp: ChatIn,
redis_client
) -> StreamingResponse:
"""实际处理聊天请求的内部函数"""
# 签名验证逻辑
if inp.sign is None:
raise HTTPException(status_code=401, detail="缺少签名参数")
if not verify_signature(inp.dict(), inp.sign):
raise HTTPException(status_code=401, detail="无效的签名")
if inp.language is None:
inp.language = "zh_cn"
# 获取用户ID
user_id = inp.user_id
if not user_id:
raise HTTPException(status_code=400, detail="缺少用户 ID")
# 从Redis获取对话历史
conversation_history_key = f"conversation:{user_id}"
conversation_history_str = await redis_client.get(conversation_history_key)
conversation_history = json.loads(conversation_history_str) if conversation_history_str else []
# 限制对话历史长度为10条(5轮对话)
# 如果超过10条,删除最前面的2条(一轮对话)
if len(conversation_history) > 10:
conversation_history = conversation_history[2:]
# 提取用户历史消息
user_messages = [msg for msg in conversation_history if msg.get('role') == 'user']
user_messages_content = [msg.get('content', '') for msg in user_messages]
all_messages = user_messages_content + [inp.message]
try:
cached = await redis_client.get(inp.message)
if cached:
is_quick_question = True
# 缓存命中,直接返回缓存数据
questions = cached
else:
is_quick_question = False
except Exception as e:
print(f"[Redis] 查询缓存失败: {e}")
is_quick_question = False
if not is_quick_question:
# 消息分类
cat = await classify(inp.message)
spot = None
scenics = None
if cat == "游玩判断" or cat == "保定文旅":
spot = await extract_spot(all_messages)
elif cat == "多景区比较":
scenics = await extract_multi_scenic(all_messages)
# 知识库查询准备
loop = asyncio.get_event_loop()
with ThreadPoolExecutor() as executor:
knowledge_task = loop.run_in_executor(executor, fetch_and_parse_markdown, user_id, inp.message)
# # 获取快捷问题
# questions = await QuickQuestion.filter(status="0", ischat="0").order_by("order_num").limit(4).values("title",
# "subtitle",
# "content")
# question_titles = [f"{q['subtitle']}{q['title']}" for q in questions]
# is_quick_question = inp.message in question_titles
# 添加用户消息到历史
if not is_quick_question:
conversation_history.append({"role": "user", "content": inp.message})
async def content_stream() -> AsyncGenerator[str, None]:
knowledge = None
nonlocal conversation_history
steam_status = False
try:
try:
# 记录热门问题
await record_hot_question(inp.message)
except Exception as e:
print(f"记录热门问题失败: {e},问题内容为:{inp.message}")
if is_quick_question:
await asyncio.sleep(0.5)
full_response = ""
# 处理快捷问题
question_content = questions.strip('"')
#每次输出随机1-10个字符
chunk_size = random.randint(5, 15)
for i in range(0, len(question_content), chunk_size):
chunk = question_content[i:i + chunk_size]
full_response += chunk
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.05)
steam_status = True
if full_response:
recommended_questions = await generate_recommended_questions(inp.message, full_response)
if recommended_questions:
# 添加分隔符和标题
yield "\n\n### 您可能还想了解:"
# 逐个返回推荐问题
for i, question in enumerate(recommended_questions, 1):
yield f"\ndata: {i}. {question}\n\n"
# async for chunk in handle_quick_question(inp, question_content):
# yield f"data: {chunk}\n\n"
# await asyncio.sleep(0.01)
else:
# 处理不同分类的消息
if cat == "游玩判断":
if not spot:
ai_response = "**未找到景区信息,请检查名称是否正确。**\n\n(内容由AI生成,仅供参考)"
async for chunk in gen_markdown_stream(inp.message, ai_response, inp.language,
conversation_history):
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.03)
steam_status = True
else:
data = await query_flow(request, spot, redis_client)
if knowledge_task:
try:
knowledge = await asyncio.wait_for(knowledge_task, timeout=20)
except TimeoutError:
print("获取知识库信息超时")
knowledge = None
if knowledge:
data += "\n\n知识库查询到的景区内容:" + knowledge
async for chunk in gen_markdown_stream(inp.message, data, inp.language, conversation_history):
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.03)
steam_status = True
elif cat == "多景区比较":
# 处理多景区比较逻辑
if not scenics:
ai_response = "**未找到景区信息,请检查名称是否正确。**\n\n(内容由AI生成,仅供参考)"
async for chunk in gen_markdown_stream(inp.message, ai_response, inp.language,
conversation_history):
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.03)
steam_status = True
elif len(scenics) == 1:
data = await query_flow(request, scenics[0], redis_client)
if knowledge_task:
try:
knowledge = await asyncio.wait_for(knowledge_task, timeout=20)
except TimeoutError:
print("获取知识库信息超时")
knowledge = None
if knowledge:
data += "\n\n知识库查询到的景区内容:" + knowledge
async for chunk in gen_markdown_stream(inp.message, data, inp.language, conversation_history):
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.03)
steam_status = True
else:
ai_response = await query_multi_scenic_flow(request, scenics, inp.message, redis_client)
async for chunk in gen_markdown_stream(inp.message, ai_response, inp.language,
conversation_history):
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.03)
steam_status = True
else:
if knowledge_task:
try:
knowledge = await asyncio.wait_for(knowledge_task, timeout=20)
except TimeoutError:
print("获取知识库信息超时")
knowledge = None
if knowledge:
inp.message += "\n\n知识库查询到的景区内容:" + knowledge
async for chunk in ai_chat_stream(inp, conversation_history):
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.03)
steam_status = True
# 保存对话历史
# 在保存前再次检查长度,确保不超过10条记录
if not is_quick_question:
# 确保对话历史不超过10条记录(5轮对话)
if len(conversation_history) > 10:
conversation_history = conversation_history[-10:]
await redis_client.setex(
conversation_history_key,
CONVERSATION_EXPIRE_TIME,
json.dumps(conversation_history)
)
except GeneratorExit:
# -------------------------- 客户端断开的日志 --------------------------
print(f"[客户端断开连接] 用户ID: {inp.user_id},当前请求: {inp.message},断开时机: SSE 流传输中")
# 可补充打印请求详情(如客户端IP、时间)
client_ip = request.client.host if request.client else "未知IP"
print(f"[客户端断开详情] IP: {client_ip},用户问题: {inp.message[:50]}") # 截取前50字避免日志过长
raise # 必须重新抛出,让生成器正常退出
except Exception as e:
print(f"content_stream 异常: {str(e)}")
if not steam_status:
# 流式返回错误信息
error_msg = "亲爱的游客,当前咨询服务暂时繁忙,如同景区高峰期需要稍作等候~ \n\n请您稍后再次尝试,我们会尽快为您提供贴心的文旅咨询服务,感谢您的理解!"
chunk_size = 8
for i in range(0, len(error_msg), chunk_size):
chunk = error_msg[i:i + chunk_size]
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.03)
finally:
# 确保在任何情况下都发送结束标记
yield "data: [DONE]\n\n"
return StreamingResponse(
content_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Access-Control-Allow-Origin": "*"
}
)
class ClearConversationRequest(BaseModel):
user_id: int
@router.post("/clear_conversation", summary="清除对话记录")
async def clear_conversation(request: Request, body: ClearConversationRequest, redis_client = Depends(get_redis_client)):
user_id = body.user_id
if not user_id:
raise HTTPException(status_code=400, detail="缺少用户 ID")
conversation_history_key = f"conversation:{user_id}"
await redis_client.delete(conversation_history_key)
return {"message": "对话历史已清除"}
def verify_signature(data: dict, sign: str) -> bool:
sorted_keys = sorted(data.keys())
sign_str = '&'.join([f'{key}={data[key]}' for key in sorted_keys if key != 'sign'])
sign_str += f'&secret={SECRET_KEY}'
calculated_sign = hmac.new(SECRET_KEY.encode(), sign_str.encode(), hashlib.sha256).hexdigest()
return calculated_sign == sign
def verify_timestamp(timestamp: int) -> bool:
current_timestamp = int(time.time() * 1000)
return abs(current_timestamp - timestamp) <= TIMESTAMP_TOLERANCE * 1000
async def record_hot_question(question: str):
"""记录热门问题,如果存在则次数加1,不存在则新增"""
try:
# 查找是否存在该问题
hot_q = await HotQuestion.filter(title=question).first()
if hot_q:
# 已存在,次数加1
hot_q.num += 1
await hot_q.save()
else:
# 不存在,新增记录
await HotQuestion.create(title=question, num=1)
except Exception as e:
print(f"记录热门问题失败: {e}")
# 定义获取问题的响应模型
class QuestionResponse(BaseModel):
id: int
title: str
class Config:
from_attributes = True
@router.post("/get_question", summary="获取开启的前4个问题")
async def get_question(request: Request, req: AllScenicFlowRequest, redis_client = Depends(get_redis_client)):
# Redis 缓存查询
cache_key = "quick_questions"
try:
cached = await redis_client.get(cache_key)
if cached:
# 缓存命中,直接返回缓存数据
return json.loads(json.loads(cached))
except Exception as e:
print(f"[Redis] 查询缓存失败: {e}")
# 验签逻辑
if not req.sign:
raise HTTPException(status_code=401, detail="缺少签名参数")
# if not verify_timestamp(req.timestamp):
# raise HTTPException(status_code=401, detail="时间戳无效")
try:
# 查询状态为正常(0)的问题,按order_num正序排序,取前4条
questions = await QuickQuestion.filter(status="0").order_by("order_num").limit(4).values("title","subtitle","logo","label")
result = {
"code": 200,
"message": "查询成功",
"data": questions
}
# 将结果存入Redis缓存,过期时间10分钟
try:
await redis_client.setex(cache_key, 600, json.dumps(result))
except Exception as e:
print(f"[Redis] 写缓存失败: {e}")
return result
except Exception as e:
print(f"查询快捷问题失败: {e}")
return {
"code": 500,
"message": f"查询失败: {str(e)}",
"data": []
}
class HotQuestionResponse(BaseModel):
id: int
title: str
num: int
update_time: str
class Config:
from_attributes = True
@router.post("/get_hot_questions", summary="获取热门问题top10")
async def get_hot_questions(request: Request, req: AllScenicFlowRequest, redis_client = Depends(get_redis_client)):
# Redis 缓存查询
cache_key = "hot_questions"
try:
cached = await redis_client.get(cache_key)
if cached:
# 缓存命中,直接返回缓存数据
return json.loads(json.loads(cached))
except Exception as e:
print(f"[Redis] 查询缓存失败: {e}")
# 验签逻辑
if not req.sign:
raise HTTPException(status_code=401, detail="缺少签名参数")
# if not verify_timestamp(req.timestamp):
# raise HTTPException(status_code=401, detail="时间戳无效")
"""
获取热门问题top10,按次数倒序排列
"""
try:
# 查询热门问题,按次数倒序排序,取前10条
hot_questions = await HotQuestion \
.filter() \
.order_by("-num") \
.limit(10) \
.values("id", "title", "num", "update_time")
# 格式化时间
for q in hot_questions:
q["update_time"] = q["update_time"].strftime("%Y-%m-%d %H:%M:%S")
result = {
"code": 200,
"message": "查询成功",
"data": hot_questions
}
# 将结果存入Redis缓存,过期时间10分钟
try:
await redis_client.setex(cache_key, 600, json.dumps(result))
except Exception as e:
print(f"[Redis] 写缓存失败: {e}")
return result
except Exception as e:
print(f"查询热门问题失败: {e}")
return {
"code": 500,
"message": f"查询失败: {str(e)}",
"data": []
}
@router.post("/get_all_scenic_flow")
async def get_all_scenic_flow(request: Request, req: AllScenicFlowRequest, redis_client = Depends(get_redis_client)):
"""
查询所有景区的进入人数、离开人数,计算承载率并按承载率倒序排列
"""
# 验签逻辑
if not req.sign:
raise HTTPException(status_code=401, detail="缺少签名参数")
# if not verify_timestamp(req.timestamp):
# raise HTTPException(status_code=401, detail="时间戳无效")
# 构建验证数据(无其他参数,仅包含timestamp)
data = {"timestamp": req.timestamp}
if not verify_signature(data, req.sign):
raise HTTPException(status_code=401, detail="无效的签名")
try:
data = await get_all_scenic_flow_data(request, redis_client)
if not data:
return {
"code": 404,
"message": "未找到景区客流数据",
"data": []
}
return {
"code": 200,
"message": "查询成功",
"data": json.loads(data)
}
except Exception as e:
print(f"查询所有景区客流数据异常: {e}")
return {
"code": 500,
"message": f"查询异常: {str(e)}",
"data": []
}
@router.post("/get_all_toilet_info")
async def get_all_toilet_info(request: Request, req: AllScenicFlowRequest, redis_client = Depends(get_redis_client)):
"""
获取所有厕所信息
"""
# 验签逻辑
if not req.sign:
raise HTTPException(status_code=401, detail="缺少签名参数")
# if not verify_timestamp(req.timestamp):
# raise HTTPException(status_code=401, detail="时间戳无效")
# 构建验证数据(无其他参数,仅包含timestamp)
data = {"timestamp": req.timestamp}
if not verify_signature(data, req.sign):
raise HTTPException(status_code=401, detail="无效的签名")
try:
data = await get_all_toilet_data(request, redis_client)
if not data:
return {
"code": 404,
"message": "未找到厕所信息",
"data": []
}
return {
"code": 200,
"message": "查询成功",
"data": json.loads(data)
}
except Exception as e:
print(f"查询所有厕所信息异常: {e}")
return {
"code": 500,
"message": f"查询异常: {str(e)}",
"data": []
}
# 在现有路由下方添加新接口
@router.post("/get_scenic_detail")
async def get_scenic_detail(request: Request, req: ScenicDetailRequest, redis_client = Depends(get_redis_client)):
"""
查询单个景区的详细信息,包含舒适度判断
"""
# 验签逻辑
if not req.sign:
raise HTTPException(status_code=401, detail="缺少签名参数")
# if not verify_timestamp(req.timestamp):
# raise HTTPException(status_code=401, detail="时间戳无效")
# 构建验证数据
data = {"id": req.id, "timestamp": req.timestamp}
if not verify_signature(data, req.sign):
raise HTTPException(status_code=401, detail="无效的签名")
if not req.id:
return {
"code": 400,
"message": "景区id不能为空",
"data": None
}
try:
data = await get_scenic_detail_data(request, req.id, redis_client)
if not data:
return {
"code": 404,
"message": f"未找到景区信息",
"data": None
}
return {
"code": 200,
"message": "查询成功",
"data": json.loads(data)
}
except Exception as e:
print(f"查询景区详情异常: {e}")
return {
"code": 500,
"message": f"查询异常: {str(e)}",
"data": None
}
# 在现有路由下方添加新接口 - 景区停车场查询
@router.post("/get_scenic_parking")
async def get_scenic_parking(request: Request, req: ScenicParkingRequest, redis_client = Depends(get_redis_client)):
"""
查询景区附近的停车场信息
"""
# 验签逻辑
if not req.sign:
raise HTTPException(status_code=401, detail="缺少签名参数")
# if not verify_timestamp(req.timestamp):
# raise HTTPException(status_code=401, detail="时间戳无效")
# 构建验证数据
data = {
"id": req.id,
"distance": req.distance,
"timestamp": req.timestamp
}
if not verify_signature(data, req.sign):
raise HTTPException(status_code=401, detail="无效的签名")
if not req.id:
return {
"code": 400,
"message": "景区id不能为空",
"data": []
}
if req.distance <= 0:
return {
"code": 400,
"message": "查询距离必须大于0",
"data": []
}
try:
data = await get_scenic_parking_data(request, req.id, req.distance, redis_client)
if not data:
return {
"code": 404,
"message": f"未找到【{req.scenic_name}】附近的停车场信息",
"data": []
}
return {
"code": 200,
"message": "查询成功",
"data": json.loads(data)
}
except Exception as e:
print(f"查询景区停车场数据异常: {e}")
return {
"code": 500,
"message": f"查询异常: {str(e)}",
"data": []
}