|
|
|
@ -35,6 +35,7 @@ SECRET_KEY = settings.SIGN_KEY # 约定的密钥 |
|
|
|
|
TIMESTAMP_TOLERANCE = 60 # 时间戳容忍度,单位:秒 |
|
|
|
|
CONVERSATION_EXPIRE_TIME = 600 # 对话历史过期时间,单位:秒 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/chat", summary="ai对话") |
|
|
|
|
async def h5_chat_stream(request: Request, inp: ChatIn): |
|
|
|
|
if inp.sign is None: |
|
|
|
@ -63,31 +64,32 @@ async def h5_chat_stream(request: Request, inp: ChatIn): |
|
|
|
|
# 只提取用户的历史消息 |
|
|
|
|
user_messages = [msg for msg in conversation_history if msg.get('role') == 'user'] |
|
|
|
|
user_messages_content = [msg.get('content', '') for msg in user_messages] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 合并当前消息和历史消息用于分类和景区提取 |
|
|
|
|
all_messages = user_messages_content + [inp.message] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 先进行分类判断 |
|
|
|
|
cat = await classify(inp.message) |
|
|
|
|
print(f"Message category: {cat}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spot = None |
|
|
|
|
if cat == "游玩判断" or cat == "保定文旅": |
|
|
|
|
spot = await extract_spot(all_messages) |
|
|
|
|
elif cat == "多景区比较": |
|
|
|
|
# 对于多景区比较,提取多个景区名称 |
|
|
|
|
scenics = await extract_multi_scenic(all_messages) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if spot: |
|
|
|
|
# 使用线程池异步执行同步函数 |
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
|
|
with ThreadPoolExecutor() as executor: |
|
|
|
|
knowledge_task = loop.run_in_executor(executor, fetch_and_parse_markdown, user_id, inp.message) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 获取开启的前4个问题(包含标题和内容) |
|
|
|
|
questions = await QuickQuestion.filter(status="0", ischat="0").order_by("order_num").limit(4).values("title", "subtitle","content") |
|
|
|
|
question_titles = [f"{q['subtitle']}{q['title']}" for q in questions] |
|
|
|
|
questions = await QuickQuestion.filter(status="0", ischat="0").order_by("order_num").limit(4).values("title", |
|
|
|
|
"subtitle", |
|
|
|
|
"content") |
|
|
|
|
question_titles = [f"{q['subtitle']}{q['title']}" for q in questions] |
|
|
|
|
|
|
|
|
|
# 检查消息是否在问题列表中 |
|
|
|
|
is_quick_question = inp.message in question_titles |
|
|
|
@ -106,77 +108,102 @@ async def h5_chat_stream(request: Request, inp: ChatIn): |
|
|
|
|
try: |
|
|
|
|
if is_quick_question: |
|
|
|
|
# 找到对应的问题内容 |
|
|
|
|
question_content = next(q["content"] for q in questions if f"{q['subtitle']}{q['title']}" == inp.message) |
|
|
|
|
question_content = next( |
|
|
|
|
q["content"] for q in questions if f"{q['subtitle']}{q['title']}" == inp.message) |
|
|
|
|
# 处理快捷问题,传递content |
|
|
|
|
async for chunk in handle_quick_question(inp, question_content): |
|
|
|
|
yield chunk |
|
|
|
|
# SSE格式包装 |
|
|
|
|
yield f"data: {chunk}\n\n" |
|
|
|
|
await asyncio.sleep(0.01) # 确保分块发送 |
|
|
|
|
else: |
|
|
|
|
# 原来的逻辑 |
|
|
|
|
if cat == "游玩判断": |
|
|
|
|
if not spot: |
|
|
|
|
# 如果是游玩判断但没提取到景区名称,使用默认回复 |
|
|
|
|
ai_response = "**未找到景区信息,请检查名称是否正确。**\n\n(内容由AI生成,仅供参考)" |
|
|
|
|
async for chunk in gen_markdown_stream(inp.message, ai_response, inp.language, conversation_history): |
|
|
|
|
yield chunk |
|
|
|
|
async for chunk in gen_markdown_stream(inp.message, ai_response, inp.language, |
|
|
|
|
conversation_history): |
|
|
|
|
# SSE格式包装 |
|
|
|
|
yield f"data: {chunk}\n\n" |
|
|
|
|
await asyncio.sleep(0.01) |
|
|
|
|
else: |
|
|
|
|
data = await query_flow(request, spot) |
|
|
|
|
# 等待知识库查询结果 |
|
|
|
|
if spot: |
|
|
|
|
knowledge = await knowledge_task |
|
|
|
|
#如果知识库返回的内容不包含"知识库内未找到相应资源"则拼接字符串 |
|
|
|
|
# 如果知识库返回的内容不包含"无法"则拼接字符串 |
|
|
|
|
if knowledge and "无法" not in knowledge: |
|
|
|
|
data += "\n\n知识库查询到的景区内容:"+ knowledge |
|
|
|
|
data += "\n\n知识库查询到的景区内容:" + knowledge |
|
|
|
|
async for chunk in gen_markdown_stream(inp.message, data, inp.language, conversation_history): |
|
|
|
|
yield chunk |
|
|
|
|
# SSE格式包装 |
|
|
|
|
yield f"data: {chunk}\n\n" |
|
|
|
|
await asyncio.sleep(0.01) |
|
|
|
|
elif cat == "多景区比较": |
|
|
|
|
# 处理多景区比较 |
|
|
|
|
if not scenics: |
|
|
|
|
# 如果没提取到景区名称,使用默认回复 |
|
|
|
|
ai_response = "**未找到景区信息,请检查名称是否正确。**\n\n(内容由AI生成,仅供参考)" |
|
|
|
|
async for chunk in gen_markdown_stream(inp.message, ai_response, inp.language, conversation_history): |
|
|
|
|
yield chunk |
|
|
|
|
async for chunk in gen_markdown_stream(inp.message, ai_response, inp.language, |
|
|
|
|
conversation_history): |
|
|
|
|
# SSE格式包装 |
|
|
|
|
yield f"data: {chunk}\n\n" |
|
|
|
|
await asyncio.sleep(0.01) |
|
|
|
|
elif len(scenics) == 1: |
|
|
|
|
# 如果只提取到一个景区,按单景区处理 |
|
|
|
|
data = await query_flow(request, scenics[0]) |
|
|
|
|
# 等待知识库查询结果 |
|
|
|
|
if scenics[0]: |
|
|
|
|
knowledge = await knowledge_task |
|
|
|
|
#如果知识库返回的内容不包含"知识库内未找到相应资源"则拼接字符串 |
|
|
|
|
# 如果知识库返回的内容不包含"无法"则拼接字符串 |
|
|
|
|
if knowledge and "无法" not in knowledge: |
|
|
|
|
data += "\n\n知识库查询到的景区内容:"+ knowledge |
|
|
|
|
data += "\n\n知识库查询到的景区内容:" + knowledge |
|
|
|
|
async for chunk in gen_markdown_stream(inp.message, data, inp.language, conversation_history): |
|
|
|
|
yield chunk |
|
|
|
|
# SSE格式包装 |
|
|
|
|
yield f"data: {chunk}\n\n" |
|
|
|
|
await asyncio.sleep(0.01) |
|
|
|
|
else: |
|
|
|
|
# 查询多个景区的客流数据并比较 |
|
|
|
|
ai_response = await query_multi_scenic_flow(request, scenics, inp.message) |
|
|
|
|
async for chunk in gen_markdown_stream(inp.message, ai_response, inp.language, conversation_history): |
|
|
|
|
yield chunk |
|
|
|
|
async for chunk in gen_markdown_stream(inp.message, ai_response, inp.language, |
|
|
|
|
conversation_history): |
|
|
|
|
# SSE格式包装 |
|
|
|
|
yield f"data: {chunk}\n\n" |
|
|
|
|
await asyncio.sleep(0.01) |
|
|
|
|
else: |
|
|
|
|
# 等待知识库查询结果 |
|
|
|
|
if spot: |
|
|
|
|
knowledge = await knowledge_task |
|
|
|
|
if knowledge and "无法" not in knowledge: |
|
|
|
|
inp.message += "\n\n知识库查询到的景区内容:"+ knowledge |
|
|
|
|
inp.message += "\n\n知识库查询到的景区内容:" + knowledge |
|
|
|
|
async for chunk in ai_chat_stream(inp, conversation_history): |
|
|
|
|
yield chunk |
|
|
|
|
# SSE格式包装 |
|
|
|
|
yield f"data: {chunk}\n\n" |
|
|
|
|
await asyncio.sleep(0.01) |
|
|
|
|
|
|
|
|
|
# 将更新后的对话历史存回 Redis,并设置过期时间 |
|
|
|
|
# 只有非快捷问题才保存对话历史 |
|
|
|
|
if not is_quick_question: |
|
|
|
|
await redis_client.setex(conversation_history_key, CONVERSATION_EXPIRE_TIME, json.dumps(conversation_history)) |
|
|
|
|
await redis_client.setex(conversation_history_key, CONVERSATION_EXPIRE_TIME, |
|
|
|
|
json.dumps(conversation_history)) |
|
|
|
|
|
|
|
|
|
# 发送结束标记 |
|
|
|
|
yield "data: [DONE]\n\n" |
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
|
print(f"Error in content_stream: {e}") |
|
|
|
|
# 错误信息也按SSE格式发送 |
|
|
|
|
yield f"data: 发生错误:{str(e)}\n\n" |
|
|
|
|
raise |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
return StreamingResponse( |
|
|
|
|
content_stream(), |
|
|
|
|
media_type="text/plain", |
|
|
|
|
media_type="text/event-stream", # SSE标准MIME类型 |
|
|
|
|
headers={ |
|
|
|
|
"Cache-Control": "no-cache", |
|
|
|
|
"Connection": "keep-alive", |
|
|
|
|
"X-Accel-Buffering": "no" |
|
|
|
|
"X-Accel-Buffering": "no", |
|
|
|
|
"Access-Control-Allow-Origin": "*" # 允许跨域(根据实际情况调整) |
|
|
|
|
} |
|
|
|
|
) |
|
|
|
|
except Exception as e: |
|
|
|
|