保定ai问答主体项目
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 
bd_ai_fastapi/app/api/chat/ai/chat_router.py

457 lines
16 KiB

from fastapi import APIRouter, HTTPException
from starlette.requests import Request
from fastapi.responses import StreamingResponse
from typing import AsyncGenerator, List
from .chat_service import classify, extract_spot, query_flow, gen_markdown_stream, ai_chat_stream, \
handle_quick_question, fetch_and_parse_markdown
from app.models.ChatIn import ChatIn, AllScenicFlowRequest, ScenicDetailRequest, ScenicParkingRequest
from app.models.quick_question import QuickQuestion
from app.models.hot_question import HotQuestion
from pydantic import BaseModel
import hmac
import hashlib
import time
import json
from app.settings.config import settings
# 更新导入语句,添加新函数
from app.api.chat.ai.chat_service import (
query_flow,
handle_quick_question,
get_all_scenic_flow_data,
get_scenic_detail_data,
get_scenic_parking_data,
extract_multi_scenic,
query_multi_scenic_flow
)
# 导入用于异步执行同步函数的模块
from concurrent.futures import ThreadPoolExecutor
import asyncio
router = APIRouter()
SECRET_KEY = settings.SIGN_KEY # 约定的密钥
TIMESTAMP_TOLERANCE = 60 # 时间戳容忍度,单位:秒
CONVERSATION_EXPIRE_TIME = 600 # 对话历史过期时间,单位:秒
@router.post("/chat", summary="ai对话")
async def h5_chat_stream(request: Request, inp: ChatIn):
if inp.sign is None:
raise HTTPException(status_code=401, detail="缺少签名参数")
if not verify_signature(inp.dict(), inp.sign):
raise HTTPException(status_code=401, detail="无效的签名")
if not verify_timestamp(inp.timestamp):
raise HTTPException(status_code=401, detail="时间戳无效")
if inp.language is None:
inp.language = "zh_cn"
# 获取用户 ID
user_id = inp.user_id
if not user_id:
raise HTTPException(status_code=400, detail="缺少用户 ID")
# 从 Redis 中获取用户的对话历史
redis_client = request.app.state.redis_client
conversation_history_key = f"conversation:{user_id}"
conversation_history_str = await redis_client.get(conversation_history_key)
conversation_history = json.loads(conversation_history_str) if conversation_history_str else []
# 只提取用户的历史消息
user_messages = [msg for msg in conversation_history if msg.get('role') == 'user']
user_messages_content = [msg.get('content', '') for msg in user_messages]
# 合并当前消息和历史消息用于分类和景区提取
all_messages = user_messages_content + [inp.message]
# 先进行分类判断
cat = await classify(inp.message)
print(f"Message category: {cat}")
spot = None
if cat == "游玩判断" or cat == "保定文旅":
spot = await extract_spot(all_messages)
elif cat == "多景区比较":
# 对于多景区比较,提取多个景区名称
scenics = await extract_multi_scenic(all_messages)
if spot:
# 使用线程池异步执行同步函数
loop = asyncio.get_event_loop()
with ThreadPoolExecutor() as executor:
knowledge_task = loop.run_in_executor(executor, fetch_and_parse_markdown, user_id, inp.message)
# 获取开启的前4个问题(包含标题和内容)
questions = await QuickQuestion.filter(status="0", ischat="0").order_by("order_num").limit(4).values("title", "subtitle","content")
question_titles = [f"{q['subtitle']}{q['title']}" for q in questions]
# 检查消息是否在问题列表中
is_quick_question = inp.message in question_titles
# 将用户消息添加到对话历史记录中
if not is_quick_question:
conversation_history.append({"role": "user", "content": inp.message})
async def content_stream() -> AsyncGenerator[str, None]:
knowledge = None
nonlocal conversation_history
# 记录热门问题(包括快捷问题)
await record_hot_question(inp.message)
try:
if is_quick_question:
# 找到对应的问题内容
question_content = next(q["content"] for q in questions if f"{q['subtitle']}{q['title']}" == inp.message)
# 处理快捷问题,传递content
async for chunk in handle_quick_question(inp, question_content):
yield chunk
else:
# 原来的逻辑
if cat == "游玩判断":
if not spot:
# 如果是游玩判断但没提取到景区名称,使用默认回复
ai_response = "**未找到景区信息,请检查名称是否正确。**\n\n(内容由AI生成,仅供参考)"
async for chunk in gen_markdown_stream(inp.message, ai_response, inp.language, conversation_history):
yield chunk
else:
data = await query_flow(request, spot)
# 等待知识库查询结果
if spot:
knowledge = await knowledge_task
#如果知识库返回的内容不包含"知识库内未找到相应资源"则拼接字符串
if knowledge and "无法" not in knowledge:
data += "\n\n知识库查询到的景区内容:"+ knowledge
async for chunk in gen_markdown_stream(inp.message, data, inp.language, conversation_history):
yield chunk
elif cat == "多景区比较":
# 处理多景区比较
if not scenics:
# 如果没提取到景区名称,使用默认回复
ai_response = "**未找到景区信息,请检查名称是否正确。**\n\n(内容由AI生成,仅供参考)"
async for chunk in gen_markdown_stream(inp.message, ai_response, inp.language, conversation_history):
yield chunk
elif len(scenics) == 1:
# 如果只提取到一个景区,按单景区处理
data = await query_flow(request, scenics[0])
# 等待知识库查询结果
if scenics[0]:
knowledge = await knowledge_task
#如果知识库返回的内容不包含"知识库内未找到相应资源"则拼接字符串
if knowledge and "无法" not in knowledge:
data += "\n\n知识库查询到的景区内容:"+ knowledge
async for chunk in gen_markdown_stream(inp.message, data, inp.language, conversation_history):
yield chunk
else:
# 查询多个景区的客流数据并比较
ai_response = await query_multi_scenic_flow(request, scenics, inp.message)
async for chunk in gen_markdown_stream(inp.message, ai_response, inp.language, conversation_history):
yield chunk
else:
# 等待知识库查询结果
if spot:
knowledge = await knowledge_task
if knowledge and "无法" not in knowledge:
inp.message += "\n\n知识库查询到的景区内容:"+ knowledge
async for chunk in ai_chat_stream(inp, conversation_history):
yield chunk
# 将更新后的对话历史存回 Redis,并设置过期时间
# 只有非快捷问题才保存对话历史
if not is_quick_question:
await redis_client.setex(conversation_history_key, CONVERSATION_EXPIRE_TIME, json.dumps(conversation_history))
except Exception as e:
print(f"Error in content_stream: {e}")
raise
try:
return StreamingResponse(
content_stream(),
media_type="text/plain",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
except Exception as e:
print(f"Error in StreamingResponse: {e}")
raise HTTPException(status_code=500, detail="Internal Server Error")
class ClearConversationRequest(BaseModel):
user_id: int
@router.post("/clear_conversation", summary="清除对话记录")
async def clear_conversation(request: Request, body: ClearConversationRequest):
user_id = body.user_id
if not user_id:
raise HTTPException(status_code=400, detail="缺少用户 ID")
redis_client = request.app.state.redis_client
conversation_history_key = f"conversation:{user_id}"
await redis_client.delete(conversation_history_key)
return {"message": "对话历史已清除"}
def verify_signature(data: dict, sign: str) -> bool:
sorted_keys = sorted(data.keys())
sign_str = '&'.join([f'{key}={data[key]}' for key in sorted_keys if key != 'sign'])
sign_str += f'&secret={SECRET_KEY}'
calculated_sign = hmac.new(SECRET_KEY.encode(), sign_str.encode(), hashlib.sha256).hexdigest()
return calculated_sign == sign
def verify_timestamp(timestamp: int) -> bool:
current_timestamp = int(time.time() * 1000)
return abs(current_timestamp - timestamp) <= TIMESTAMP_TOLERANCE * 1000
async def record_hot_question(question: str):
"""记录热门问题,如果存在则次数加1,不存在则新增"""
try:
# 查找是否存在该问题
hot_q = await HotQuestion.filter(title=question).first()
if hot_q:
# 已存在,次数加1
hot_q.num += 1
await hot_q.save()
else:
# 不存在,新增记录
await HotQuestion.create(title=question, num=1)
except Exception as e:
print(f"记录热门问题失败: {e}")
# 定义获取问题的响应模型
class QuestionResponse(BaseModel):
id: int
title: str
class Config:
orm_mode = True
@router.post("/get_question", summary="获取开启的前4个问题")
async def get_question(request: Request, req: AllScenicFlowRequest):
# 验签逻辑
if not req.sign:
raise HTTPException(status_code=401, detail="缺少签名参数")
if not verify_timestamp(req.timestamp):
raise HTTPException(status_code=401, detail="时间戳无效")
try:
# 查询状态为正常(0)的问题,按order_num正序排序,取前4条
questions = await QuickQuestion.filter(status="0").order_by("order_num").limit(4).values("title","subtitle","logo","label")
return {
"code": 200,
"message": "查询成功",
"data": questions
}
except Exception as e:
print(f"查询快捷问题失败: {e}")
return {
"code": 500,
"message": f"查询失败: {str(e)}",
"data": []
}
class HotQuestionResponse(BaseModel):
id: int
title: str
num: int
update_time: str
class Config:
orm_mode = True
@router.post("/get_hot_questions", summary="获取热门问题top10")
async def get_hot_questions(request: Request, req: AllScenicFlowRequest):
# 验签逻辑
if not req.sign:
raise HTTPException(status_code=401, detail="缺少签名参数")
if not verify_timestamp(req.timestamp):
raise HTTPException(status_code=401, detail="时间戳无效")
"""
获取热门问题top10,按次数倒序排列
"""
try:
# 查询热门问题,按次数倒序排序,取前10条
hot_questions = await HotQuestion \
.filter() \
.order_by("-num") \
.limit(10) \
.values("id", "title", "num", "update_time")
# 格式化时间
for q in hot_questions:
q["update_time"] = q["update_time"].strftime("%Y-%m-%d %H:%M:%S")
return {
"code": 200,
"message": "查询成功",
"data": hot_questions
}
except Exception as e:
print(f"查询热门问题失败: {e}")
return {
"code": 500,
"message": f"查询失败: {str(e)}",
"data": []
}
@router.post("/get_all_scenic_flow")
async def get_all_scenic_flow(request: Request, req: AllScenicFlowRequest):
"""
查询所有景区的进入人数、离开人数,计算承载率并按承载率倒序排列
"""
# 验签逻辑
if not req.sign:
raise HTTPException(status_code=401, detail="缺少签名参数")
if not verify_timestamp(req.timestamp):
raise HTTPException(status_code=401, detail="时间戳无效")
# 构建验证数据(无其他参数,仅包含timestamp)
data = {"timestamp": req.timestamp}
if not verify_signature(data, req.sign):
raise HTTPException(status_code=401, detail="无效的签名")
try:
data = await get_all_scenic_flow_data(request)
if not data:
return {
"code": 404,
"message": "未找到景区客流数据",
"data": []
}
return {
"code": 200,
"message": "查询成功",
"data": data
}
except Exception as e:
print(f"查询所有景区客流数据异常: {e}")
return {
"code": 500,
"message": f"查询异常: {str(e)}",
"data": []
}
# 在现有路由下方添加新接口
@router.post("/get_scenic_detail")
async def get_scenic_detail(request: Request, req: ScenicDetailRequest):
"""
查询单个景区的详细信息,包含舒适度判断
"""
# 验签逻辑
if not req.sign:
raise HTTPException(status_code=401, detail="缺少签名参数")
if not verify_timestamp(req.timestamp):
raise HTTPException(status_code=401, detail="时间戳无效")
# 构建验证数据
data = {"id": req.id, "timestamp": req.timestamp}
if not verify_signature(data, req.sign):
raise HTTPException(status_code=401, detail="无效的签名")
if not req.id:
return {
"code": 400,
"message": "景区id不能为空",
"data": None
}
try:
data = await get_scenic_detail_data(request, req.id)
if not data:
return {
"code": 404,
"message": f"未找到景区信息",
"data": None
}
return {
"code": 200,
"message": "查询成功",
"data": data
}
except Exception as e:
print(f"查询景区详情异常: {e}")
return {
"code": 500,
"message": f"查询异常: {str(e)}",
"data": None
}
# 在现有路由下方添加新接口 - 景区停车场查询
@router.post("/get_scenic_parking")
async def get_scenic_parking(request: Request, req: ScenicParkingRequest):
"""
查询景区附近的停车场信息
"""
# 验签逻辑
if not req.sign:
raise HTTPException(status_code=401, detail="缺少签名参数")
if not verify_timestamp(req.timestamp):
raise HTTPException(status_code=401, detail="时间戳无效")
# 构建验证数据
data = {
"id": req.id,
"distance": req.distance,
"timestamp": req.timestamp
}
if not verify_signature(data, req.sign):
raise HTTPException(status_code=401, detail="无效的签名")
if not req.id:
return {
"code": 400,
"message": "景区id不能为空",
"data": []
}
if req.distance <= 0:
return {
"code": 400,
"message": "查询距离必须大于0",
"data": []
}
try:
data = await get_scenic_parking_data(request, req.id, req.distance)
if not data:
return {
"code": 404,
"message": f"未找到【{req.scenic_name}】附近的停车场信息",
"data": []
}
return {
"code": 200,
"message": "查询成功",
"data": data
}
except Exception as e:
print(f"查询景区停车场数据异常: {e}")
return {
"code": 500,
"message": f"查询异常: {str(e)}",
"data": []
}