main
zc 1 month ago
parent fac83cedce
commit fcbb33e8c5
  1. 112
      app/api/chat/ai/chat_router.py
  2. 41
      app/api/chat/ai/chat_service.py

@ -1,3 +1,6 @@
import random
from anyio import sleep
from fastapi import APIRouter, HTTPException, Depends,FastAPI from fastapi import APIRouter, HTTPException, Depends,FastAPI
from starlette.requests import Request from starlette.requests import Request
@ -23,7 +26,8 @@ from app.api.chat.ai.chat_service import (
get_scenic_parking_data, get_scenic_parking_data,
extract_multi_scenic, extract_multi_scenic,
query_multi_scenic_flow, query_multi_scenic_flow,
get_all_toilet_data get_all_toilet_data,
generate_recommended_questions
) )
# 导入用于异步执行同步函数的模块 # 导入用于异步执行同步函数的模块
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -54,7 +58,7 @@ async def universal_exception_handler(request: Request, exc: Exception):
chunk_size = 8 chunk_size = 8
for i in range(0, len(error_msg), chunk_size): for i in range(0, len(error_msg), chunk_size):
chunk = error_msg[i:i + chunk_size] chunk = error_msg[i:i + chunk_size]
yield f"data: {json.dumps({'content': chunk})}\n\n" yield f"data: {chunk}\n\n"
await asyncio.sleep(0.03) # 模拟自然打字速度 await asyncio.sleep(0.03) # 模拟自然打字速度
# 发送结束标记 # 发送结束标记
@ -113,7 +117,7 @@ async def h5_chat_stream(request: Request, inp: ChatIn, redis_client = Depends(g
chunk_size = 8 chunk_size = 8
for i in range(0, len(error_msg), chunk_size): for i in range(0, len(error_msg), chunk_size):
chunk = error_msg[i:i + chunk_size] chunk = error_msg[i:i + chunk_size]
yield f"data: {json.dumps({'content': chunk})}\n\n" yield f"data: {chunk}\n\n"
await asyncio.sleep(0.03) await asyncio.sleep(0.03)
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
@ -159,30 +163,44 @@ async def _handle_chat_request(
user_messages_content = [msg.get('content', '') for msg in user_messages] user_messages_content = [msg.get('content', '') for msg in user_messages]
all_messages = user_messages_content + [inp.message] all_messages = user_messages_content + [inp.message]
# 消息分类 try:
cat = await classify(inp.message) cached = await redis_client.get(inp.message)
print(f"Message category: {cat}") if cached:
is_quick_question = True
spot = None # 缓存命中,直接返回缓存数据
scenics = None questions = cached
if cat == "游玩判断" or cat == "保定文旅": else:
spot = await extract_spot(all_messages) is_quick_question = False
elif cat == "多景区比较": except Exception as e:
scenics = await extract_multi_scenic(all_messages) print(f"[Redis] 查询缓存失败: {e}")
is_quick_question = False
# 知识库查询准备
knowledge_task = None if not is_quick_question:
if spot: # 消息分类
loop = asyncio.get_event_loop() cat = await classify(inp.message)
with ThreadPoolExecutor() as executor: print(f"Message category: {cat}")
knowledge_task = loop.run_in_executor(executor, fetch_and_parse_markdown, user_id, inp.message)
spot = None
# 获取快捷问题 scenics = None
questions = await QuickQuestion.filter(status="0", ischat="0").order_by("order_num").limit(4).values("title", if cat == "游玩判断" or cat == "保定文旅":
"subtitle", spot = await extract_spot(all_messages)
"content") elif cat == "多景区比较":
question_titles = [f"{q['subtitle']}{q['title']}" for q in questions] scenics = await extract_multi_scenic(all_messages)
is_quick_question = inp.message in question_titles
# 知识库查询准备
knowledge_task = None
if spot:
loop = asyncio.get_event_loop()
with ThreadPoolExecutor() as executor:
knowledge_task = loop.run_in_executor(executor, fetch_and_parse_markdown, user_id, inp.message)
# # 获取快捷问题
# questions = await QuickQuestion.filter(status="0", ischat="0").order_by("order_num").limit(4).values("title",
# "subtitle",
# "content")
# question_titles = [f"{q['subtitle']}{q['title']}" for q in questions]
# is_quick_question = inp.message in question_titles
# 添加用户消息到历史 # 添加用户消息到历史
if not is_quick_question: if not is_quick_question:
@ -193,16 +211,36 @@ async def _handle_chat_request(
nonlocal conversation_history nonlocal conversation_history
try: try:
# 记录热门问题 try:
await record_hot_question(inp.message) # 记录热门问题
await record_hot_question(inp.message)
except Exception as e:
print(f"记录热门问题失败: {e},问题内容为:{inp.message}")
if is_quick_question: if is_quick_question:
await asyncio.sleep(0.5)
full_response = ""
# 处理快捷问题 # 处理快捷问题
question_content = next( question_content = questions
q["content"] for q in questions if f"{q['subtitle']}{q['title']}" == inp.message) #每次输出随机1-10个字符
async for chunk in handle_quick_question(inp, question_content): chunk_size = random.randint(5, 15)
for i in range(0, len(question_content), chunk_size):
chunk = question_content[i:i + chunk_size]
full_response += chunk
yield f"data: {chunk}\n\n" yield f"data: {chunk}\n\n"
await asyncio.sleep(0.01) await asyncio.sleep(0.09)
if full_response:
recommended_questions = await generate_recommended_questions(inp.message, full_response)
if recommended_questions:
# 添加分隔符和标题
yield "\n\n### 您可能还想了解:"
# 逐个返回推荐问题
for i, question in enumerate(recommended_questions, 1):
yield f"\n{i}. {question}"
# async for chunk in handle_quick_question(inp, question_content):
# yield f"data: {chunk}\n\n"
# await asyncio.sleep(0.01)
else: else:
# 处理不同分类的消息 # 处理不同分类的消息
if cat == "游玩判断": if cat == "游玩判断":
@ -272,7 +310,7 @@ async def _handle_chat_request(
chunk_size = 8 chunk_size = 8
for i in range(0, len(error_msg), chunk_size): for i in range(0, len(error_msg), chunk_size):
chunk = error_msg[i:i + chunk_size] chunk = error_msg[i:i + chunk_size]
yield f"data: {json.dumps({'content': chunk})}\n\n" yield f"data: {chunk})\n\n"
await asyncio.sleep(0.03) await asyncio.sleep(0.03)
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
@ -495,7 +533,7 @@ async def get_all_scenic_flow(request: Request, req: AllScenicFlowRequest, redis
@router.post("/get_all_toilet_info") @router.post("/get_all_toilet_info")
async def get_all_toilet_info(request: Request, req: AllScenicFlowRequest): async def get_all_toilet_info(request: Request, req: AllScenicFlowRequest, redis_client = Depends(get_redis_client)):
""" """
获取所有厕所信息 获取所有厕所信息
""" """
@ -512,7 +550,7 @@ async def get_all_toilet_info(request: Request, req: AllScenicFlowRequest):
raise HTTPException(status_code=401, detail="无效的签名") raise HTTPException(status_code=401, detail="无效的签名")
try: try:
data = await get_all_toilet_data(request) data = await get_all_toilet_data(request, redis_client)
if not data: if not data:
return { return {
@ -524,7 +562,7 @@ async def get_all_toilet_info(request: Request, req: AllScenicFlowRequest):
return { return {
"code": 200, "code": 200,
"message": "查询成功", "message": "查询成功",
"data": data "data": json.loads(data)
} }
except Exception as e: except Exception as e:
print(f"查询所有厕所信息异常: {e}") print(f"查询所有厕所信息异常: {e}")

@ -36,7 +36,7 @@ EXTRACT_PROMPT = """你是一名景区名称精准匹配助手。用户的问题
白石山景区 白石山景区
阜平云花溪谷-玫瑰谷 阜平云花溪谷-玫瑰谷
保定军校纪念馆 保定军校纪念馆
保定直隶总督署博物馆 直隶总督署博物馆
冉庄地道战遗址 冉庄地道战遗址
刘伶醉景区 刘伶醉景区
曲阳北岳庙景区 曲阳北岳庙景区
@ -51,7 +51,7 @@ EXTRACT_PROMPT = """你是一名景区名称精准匹配助手。用户的问题
满城汉墓景区 满城汉墓景区
灵山聚龙洞旅游风景区 灵山聚龙洞旅游风景区
易县狼牙山风景区 易县狼牙山风景区
留法勤工俭学纪念馆 留法勤工俭学运动纪念馆
白求恩柯棣华纪念馆 白求恩柯棣华纪念馆
唐县秀水峪 唐县秀水峪
腰山王氏庄园 腰山王氏庄园
@ -548,6 +548,8 @@ async def get_all_scenic_flow_data(request: Request, redis_client = None) -> lis
for row in rows: for row in rows:
id,scenic_name, enter_num, leave_num, max_capacity = row id,scenic_name, enter_num, leave_num, max_capacity = row
in_park_num = abs(enter_num - leave_num) # 确保是正数 in_park_num = abs(enter_num - leave_num) # 确保是正数
if in_park_num > max_capacity:
in_park_num = max_capacity
# 避免除以零的情况 # 避免除以零的情况
if max_capacity > 0: if max_capacity > 0:
@ -625,7 +627,8 @@ async def get_scenic_detail_data(request: Request, id: int, redis_client = None)
# 计算在园人数 # 计算在园人数
in_park_num = abs(enter_num - leave_num) # 确保是正数 in_park_num = abs(enter_num - leave_num) # 确保是正数
if in_park_num > max_capacity:
in_park_num = max_capacity
# 计算承载率和舒适度 # 计算承载率和舒适度
if max_capacity > 0: if max_capacity > 0:
capacity_rate = in_park_num / max_capacity capacity_rate = in_park_num / max_capacity
@ -874,7 +877,7 @@ MULTI_SCENIC_EXTRACT_PROMPT = """你是一名景区名称提取助手。用户
白石山景区 白石山景区
阜平云花溪谷-玫瑰谷 阜平云花溪谷-玫瑰谷
保定军校纪念馆 保定军校纪念馆
保定直隶总督署博物馆 直隶总督署博物馆
冉庄地道战遗址 冉庄地道战遗址
刘伶醉景区 刘伶醉景区
曲阳北岳庙景区 曲阳北岳庙景区
@ -889,7 +892,7 @@ MULTI_SCENIC_EXTRACT_PROMPT = """你是一名景区名称提取助手。用户
满城汉墓景区 满城汉墓景区
灵山聚龙洞旅游风景区 灵山聚龙洞旅游风景区
易县狼牙山风景区 易县狼牙山风景区
留法勤工俭学纪念馆 留法勤工俭学运动纪念馆
白求恩柯棣华纪念馆 白求恩柯棣华纪念馆
唐县秀水峪 唐县秀水峪
腰山王氏庄园 腰山王氏庄园
@ -998,17 +1001,29 @@ async def query_multi_scenic_flow(request: Request, scenics: list, msg: str, red
# 在文件末尾添加新函数用于获取所有厕所信息 # 在文件末尾添加新函数用于获取所有厕所信息
async def get_all_toilet_data(request: Request) -> list: async def get_all_toilet_data(request: Request, redis_client = None) -> list:
""" """
查询所有厕所信息 查询所有厕所信息
""" """
try: try:
cache_key = "all_toilet_list"
if redis_client is None:
redis_client = request.app.state.redis_client
try:
cached = await redis_client.get(cache_key)
if cached:
# 缓存命中,直接返回缓存数据
return json.loads(cached)
except Exception as e:
print(f"[Redis] 查询缓存失败: {e}")
pool = request.app.state.mysql_pool pool = request.app.state.mysql_pool
async with pool.acquire() as conn: async with pool.acquire() as conn:
async with conn.cursor() as cur: async with conn.cursor() as cur:
# 查询所有厕所信息 # 查询所有厕所信息
query = """ query = """
SELECT SELECT
id, id,
banner, banner,
title, title,
@ -1026,14 +1041,14 @@ async def get_all_toilet_data(request: Request) -> list:
is_aixin, is_aixin,
createtime, createtime,
updatetime updatetime
FROM FROM
cyjcpt_bd.ai_toilet_info cyjcpt_bd.ai_toilet_info
ORDER BY ORDER BY
id id
""" """
await cur.execute(query) await cur.execute(query)
rows = await cur.fetchall() rows = await cur.fetchall()
# 处理结果 # 处理结果
result = [] result = []
for row in rows: for row in rows:
@ -1056,7 +1071,7 @@ async def get_all_toilet_data(request: Request) -> list:
createtime, createtime,
updatetime updatetime
) = row ) = row
result.append({ result.append({
"id": id, "id": id,
"banner": banner, "banner": banner,
@ -1076,9 +1091,9 @@ async def get_all_toilet_data(request: Request) -> list:
"createtime": createtime, "createtime": createtime,
"updatetime": updatetime "updatetime": updatetime
}) })
return result return result
except Exception as e: except Exception as e:
print(f"[MySQL] 查询所有厕所数据失败: {e}") print(f"[MySQL] 查询所有厕所数据失败: {e}")
return [] return []

Loading…
Cancel
Save