接口修改为严格的SSE协议响应

main
zc 1 month ago
parent 5006b4b9d2
commit 2d8552e5c8
  1. 1
      app/__init__.py
  2. 79
      app/api/chat/ai/chat_router.py

@ -22,7 +22,6 @@ except ImportError:
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
mysql_port = int(os.getenv("MYSQL_PORT", 3306))
app.state.mysql_pool = await create_pool(
host=settings.FLOW_MYSQL_HOST,
port=settings.FLOW_MYSQL_PORT,

@ -35,6 +35,7 @@ SECRET_KEY = settings.SIGN_KEY # 约定的密钥
TIMESTAMP_TOLERANCE = 60 # 时间戳容忍度,单位:秒
CONVERSATION_EXPIRE_TIME = 600 # 对话历史过期时间,单位:秒
@router.post("/chat", summary="ai对话")
async def h5_chat_stream(request: Request, inp: ChatIn):
if inp.sign is None:
@ -63,31 +64,32 @@ async def h5_chat_stream(request: Request, inp: ChatIn):
# 只提取用户的历史消息
user_messages = [msg for msg in conversation_history if msg.get('role') == 'user']
user_messages_content = [msg.get('content', '') for msg in user_messages]
# 合并当前消息和历史消息用于分类和景区提取
all_messages = user_messages_content + [inp.message]
# 先进行分类判断
cat = await classify(inp.message)
print(f"Message category: {cat}")
spot = None
if cat == "游玩判断" or cat == "保定文旅":
spot = await extract_spot(all_messages)
elif cat == "多景区比较":
# 对于多景区比较,提取多个景区名称
scenics = await extract_multi_scenic(all_messages)
if spot:
# 使用线程池异步执行同步函数
loop = asyncio.get_event_loop()
with ThreadPoolExecutor() as executor:
knowledge_task = loop.run_in_executor(executor, fetch_and_parse_markdown, user_id, inp.message)
# 获取开启的前4个问题(包含标题和内容)
questions = await QuickQuestion.filter(status="0", ischat="0").order_by("order_num").limit(4).values("title", "subtitle","content")
question_titles = [f"{q['subtitle']}{q['title']}" for q in questions]
questions = await QuickQuestion.filter(status="0", ischat="0").order_by("order_num").limit(4).values("title",
"subtitle",
"content")
question_titles = [f"{q['subtitle']}{q['title']}" for q in questions]
# 检查消息是否在问题列表中
is_quick_question = inp.message in question_titles
@ -106,77 +108,102 @@ async def h5_chat_stream(request: Request, inp: ChatIn):
try:
if is_quick_question:
# 找到对应的问题内容
question_content = next(q["content"] for q in questions if f"{q['subtitle']}{q['title']}" == inp.message)
question_content = next(
q["content"] for q in questions if f"{q['subtitle']}{q['title']}" == inp.message)
# 处理快捷问题,传递content
async for chunk in handle_quick_question(inp, question_content):
yield chunk
# SSE格式包装
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.01) # 确保分块发送
else:
# 原来的逻辑
if cat == "游玩判断":
if not spot:
# 如果是游玩判断但没提取到景区名称,使用默认回复
ai_response = "**未找到景区信息,请检查名称是否正确。**\n\n(内容由AI生成,仅供参考)"
async for chunk in gen_markdown_stream(inp.message, ai_response, inp.language, conversation_history):
yield chunk
async for chunk in gen_markdown_stream(inp.message, ai_response, inp.language,
conversation_history):
# SSE格式包装
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.01)
else:
data = await query_flow(request, spot)
# 等待知识库查询结果
if spot:
knowledge = await knowledge_task
#如果知识库返回的内容不包含"知识库内未找到相应资源"则拼接字符串
# 如果知识库返回的内容不包含"无法"则拼接字符串
if knowledge and "无法" not in knowledge:
data += "\n\n知识库查询到的景区内容:"+ knowledge
data += "\n\n知识库查询到的景区内容:" + knowledge
async for chunk in gen_markdown_stream(inp.message, data, inp.language, conversation_history):
yield chunk
# SSE格式包装
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.01)
elif cat == "多景区比较":
# 处理多景区比较
if not scenics:
# 如果没提取到景区名称,使用默认回复
ai_response = "**未找到景区信息,请检查名称是否正确。**\n\n(内容由AI生成,仅供参考)"
async for chunk in gen_markdown_stream(inp.message, ai_response, inp.language, conversation_history):
yield chunk
async for chunk in gen_markdown_stream(inp.message, ai_response, inp.language,
conversation_history):
# SSE格式包装
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.01)
elif len(scenics) == 1:
# 如果只提取到一个景区,按单景区处理
data = await query_flow(request, scenics[0])
# 等待知识库查询结果
if scenics[0]:
knowledge = await knowledge_task
#如果知识库返回的内容不包含"知识库内未找到相应资源"则拼接字符串
# 如果知识库返回的内容不包含"无法"则拼接字符串
if knowledge and "无法" not in knowledge:
data += "\n\n知识库查询到的景区内容:"+ knowledge
data += "\n\n知识库查询到的景区内容:" + knowledge
async for chunk in gen_markdown_stream(inp.message, data, inp.language, conversation_history):
yield chunk
# SSE格式包装
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.01)
else:
# 查询多个景区的客流数据并比较
ai_response = await query_multi_scenic_flow(request, scenics, inp.message)
async for chunk in gen_markdown_stream(inp.message, ai_response, inp.language, conversation_history):
yield chunk
async for chunk in gen_markdown_stream(inp.message, ai_response, inp.language,
conversation_history):
# SSE格式包装
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.01)
else:
# 等待知识库查询结果
if spot:
knowledge = await knowledge_task
if knowledge and "无法" not in knowledge:
inp.message += "\n\n知识库查询到的景区内容:"+ knowledge
inp.message += "\n\n知识库查询到的景区内容:" + knowledge
async for chunk in ai_chat_stream(inp, conversation_history):
yield chunk
# SSE格式包装
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.01)
# 将更新后的对话历史存回 Redis,并设置过期时间
# 只有非快捷问题才保存对话历史
if not is_quick_question:
await redis_client.setex(conversation_history_key, CONVERSATION_EXPIRE_TIME, json.dumps(conversation_history))
await redis_client.setex(conversation_history_key, CONVERSATION_EXPIRE_TIME,
json.dumps(conversation_history))
# 发送结束标记
yield "data: [DONE]\n\n"
except Exception as e:
print(f"Error in content_stream: {e}")
# 错误信息也按SSE格式发送
yield f"data: 发生错误:{str(e)}\n\n"
raise
try:
return StreamingResponse(
content_stream(),
media_type="text/plain",
media_type="text/event-stream", # SSE标准MIME类型
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
"X-Accel-Buffering": "no",
"Access-Control-Allow-Origin": "*" # 允许跨域(根据实际情况调整)
}
)
except Exception as e:

Loading…
Cancel
Save