main
zc 1 month ago
parent fac83cedce
commit fcbb33e8c5
  1. 112
      app/api/chat/ai/chat_router.py
  2. 27
      app/api/chat/ai/chat_service.py

@ -1,3 +1,6 @@
import random
from anyio import sleep
from fastapi import APIRouter, HTTPException, Depends,FastAPI
from starlette.requests import Request
@ -23,7 +26,8 @@ from app.api.chat.ai.chat_service import (
get_scenic_parking_data,
extract_multi_scenic,
query_multi_scenic_flow,
get_all_toilet_data
get_all_toilet_data,
generate_recommended_questions
)
# 导入用于异步执行同步函数的模块
from concurrent.futures import ThreadPoolExecutor
@ -54,7 +58,7 @@ async def universal_exception_handler(request: Request, exc: Exception):
chunk_size = 8
for i in range(0, len(error_msg), chunk_size):
chunk = error_msg[i:i + chunk_size]
yield f"data: {json.dumps({'content': chunk})}\n\n"
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.03) # 模拟自然打字速度
# 发送结束标记
@ -113,7 +117,7 @@ async def h5_chat_stream(request: Request, inp: ChatIn, redis_client = Depends(g
chunk_size = 8
for i in range(0, len(error_msg), chunk_size):
chunk = error_msg[i:i + chunk_size]
yield f"data: {json.dumps({'content': chunk})}\n\n"
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.03)
yield "data: [DONE]\n\n"
@ -159,30 +163,44 @@ async def _handle_chat_request(
user_messages_content = [msg.get('content', '') for msg in user_messages]
all_messages = user_messages_content + [inp.message]
# 消息分类
cat = await classify(inp.message)
print(f"Message category: {cat}")
spot = None
scenics = None
if cat == "游玩判断" or cat == "保定文旅":
spot = await extract_spot(all_messages)
elif cat == "多景区比较":
scenics = await extract_multi_scenic(all_messages)
# 知识库查询准备
knowledge_task = None
if spot:
loop = asyncio.get_event_loop()
with ThreadPoolExecutor() as executor:
knowledge_task = loop.run_in_executor(executor, fetch_and_parse_markdown, user_id, inp.message)
# 获取快捷问题
questions = await QuickQuestion.filter(status="0", ischat="0").order_by("order_num").limit(4).values("title",
"subtitle",
"content")
question_titles = [f"{q['subtitle']}{q['title']}" for q in questions]
is_quick_question = inp.message in question_titles
try:
cached = await redis_client.get(inp.message)
if cached:
is_quick_question = True
# 缓存命中,直接返回缓存数据
questions = cached
else:
is_quick_question = False
except Exception as e:
print(f"[Redis] 查询缓存失败: {e}")
is_quick_question = False
if not is_quick_question:
# 消息分类
cat = await classify(inp.message)
print(f"Message category: {cat}")
spot = None
scenics = None
if cat == "游玩判断" or cat == "保定文旅":
spot = await extract_spot(all_messages)
elif cat == "多景区比较":
scenics = await extract_multi_scenic(all_messages)
# 知识库查询准备
knowledge_task = None
if spot:
loop = asyncio.get_event_loop()
with ThreadPoolExecutor() as executor:
knowledge_task = loop.run_in_executor(executor, fetch_and_parse_markdown, user_id, inp.message)
# # 获取快捷问题
# questions = await QuickQuestion.filter(status="0", ischat="0").order_by("order_num").limit(4).values("title",
# "subtitle",
# "content")
# question_titles = [f"{q['subtitle']}{q['title']}" for q in questions]
# is_quick_question = inp.message in question_titles
# 添加用户消息到历史
if not is_quick_question:
@ -193,16 +211,36 @@ async def _handle_chat_request(
nonlocal conversation_history
try:
# 记录热门问题
await record_hot_question(inp.message)
try:
# 记录热门问题
await record_hot_question(inp.message)
except Exception as e:
print(f"记录热门问题失败: {e},问题内容为:{inp.message}")
if is_quick_question:
await asyncio.sleep(0.5)
full_response = ""
# 处理快捷问题
question_content = next(
q["content"] for q in questions if f"{q['subtitle']}{q['title']}" == inp.message)
async for chunk in handle_quick_question(inp, question_content):
question_content = questions
#每次输出随机1-10个字符
chunk_size = random.randint(5, 15)
for i in range(0, len(question_content), chunk_size):
chunk = question_content[i:i + chunk_size]
full_response += chunk
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.01)
await asyncio.sleep(0.09)
if full_response:
recommended_questions = await generate_recommended_questions(inp.message, full_response)
if recommended_questions:
# 添加分隔符和标题
yield "\n\n### 您可能还想了解:"
# 逐个返回推荐问题
for i, question in enumerate(recommended_questions, 1):
yield f"\n{i}. {question}"
# async for chunk in handle_quick_question(inp, question_content):
# yield f"data: {chunk}\n\n"
# await asyncio.sleep(0.01)
else:
# 处理不同分类的消息
if cat == "游玩判断":
@ -272,7 +310,7 @@ async def _handle_chat_request(
chunk_size = 8
for i in range(0, len(error_msg), chunk_size):
chunk = error_msg[i:i + chunk_size]
yield f"data: {json.dumps({'content': chunk})}\n\n"
yield f"data: {chunk})\n\n"
await asyncio.sleep(0.03)
yield "data: [DONE]\n\n"
@ -495,7 +533,7 @@ async def get_all_scenic_flow(request: Request, req: AllScenicFlowRequest, redis
@router.post("/get_all_toilet_info")
async def get_all_toilet_info(request: Request, req: AllScenicFlowRequest):
async def get_all_toilet_info(request: Request, req: AllScenicFlowRequest, redis_client = Depends(get_redis_client)):
"""
获取所有厕所信息
"""
@ -512,7 +550,7 @@ async def get_all_toilet_info(request: Request, req: AllScenicFlowRequest):
raise HTTPException(status_code=401, detail="无效的签名")
try:
data = await get_all_toilet_data(request)
data = await get_all_toilet_data(request, redis_client)
if not data:
return {
@ -524,7 +562,7 @@ async def get_all_toilet_info(request: Request, req: AllScenicFlowRequest):
return {
"code": 200,
"message": "查询成功",
"data": data
"data": json.loads(data)
}
except Exception as e:
print(f"查询所有厕所信息异常: {e}")

@ -36,7 +36,7 @@ EXTRACT_PROMPT = """你是一名景区名称精准匹配助手。用户的问题
白石山景区
阜平云花溪谷-玫瑰谷
保定军校纪念馆
保定直隶总督署博物馆
直隶总督署博物馆
冉庄地道战遗址
刘伶醉景区
曲阳北岳庙景区
@ -51,7 +51,7 @@ EXTRACT_PROMPT = """你是一名景区名称精准匹配助手。用户的问题
满城汉墓景区
灵山聚龙洞旅游风景区
易县狼牙山风景区
留法勤工俭学纪念馆
留法勤工俭学运动纪念馆
白求恩柯棣华纪念馆
唐县秀水峪
腰山王氏庄园
@ -548,6 +548,8 @@ async def get_all_scenic_flow_data(request: Request, redis_client = None) -> lis
for row in rows:
id,scenic_name, enter_num, leave_num, max_capacity = row
in_park_num = abs(enter_num - leave_num) # 确保是正数
if in_park_num > max_capacity:
in_park_num = max_capacity
# 避免除以零的情况
if max_capacity > 0:
@ -625,7 +627,8 @@ async def get_scenic_detail_data(request: Request, id: int, redis_client = None)
# 计算在园人数
in_park_num = abs(enter_num - leave_num) # 确保是正数
if in_park_num > max_capacity:
in_park_num = max_capacity
# 计算承载率和舒适度
if max_capacity > 0:
capacity_rate = in_park_num / max_capacity
@ -874,7 +877,7 @@ MULTI_SCENIC_EXTRACT_PROMPT = """你是一名景区名称提取助手。用户
白石山景区
阜平云花溪谷-玫瑰谷
保定军校纪念馆
保定直隶总督署博物馆
直隶总督署博物馆
冉庄地道战遗址
刘伶醉景区
曲阳北岳庙景区
@ -889,7 +892,7 @@ MULTI_SCENIC_EXTRACT_PROMPT = """你是一名景区名称提取助手。用户
满城汉墓景区
灵山聚龙洞旅游风景区
易县狼牙山风景区
留法勤工俭学纪念馆
留法勤工俭学运动纪念馆
白求恩柯棣华纪念馆
唐县秀水峪
腰山王氏庄园
@ -998,11 +1001,23 @@ async def query_multi_scenic_flow(request: Request, scenics: list, msg: str, red
# 在文件末尾添加新函数用于获取所有厕所信息
async def get_all_toilet_data(request: Request) -> list:
async def get_all_toilet_data(request: Request, redis_client = None) -> list:
"""
查询所有厕所信息
"""
try:
cache_key = "all_toilet_list"
if redis_client is None:
redis_client = request.app.state.redis_client
try:
cached = await redis_client.get(cache_key)
if cached:
# 缓存命中,直接返回缓存数据
return json.loads(cached)
except Exception as e:
print(f"[Redis] 查询缓存失败: {e}")
pool = request.app.state.mysql_pool
async with pool.acquire() as conn:
async with conn.cursor() as cur:

Loading…
Cancel
Save