From fcbb33e8c5d230f9563e43239e8aa500d3f20a9c Mon Sep 17 00:00:00 2001 From: zc Date: Sat, 13 Sep 2025 10:26:51 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/chat/ai/chat_router.py | 112 +++++++++++++++++++++----------- app/api/chat/ai/chat_service.py | 41 ++++++++---- 2 files changed, 103 insertions(+), 50 deletions(-) diff --git a/app/api/chat/ai/chat_router.py b/app/api/chat/ai/chat_router.py index 627b957..d263371 100644 --- a/app/api/chat/ai/chat_router.py +++ b/app/api/chat/ai/chat_router.py @@ -1,3 +1,6 @@ +import random + +from anyio import sleep from fastapi import APIRouter, HTTPException, Depends,FastAPI from starlette.requests import Request @@ -23,7 +26,8 @@ from app.api.chat.ai.chat_service import ( get_scenic_parking_data, extract_multi_scenic, query_multi_scenic_flow, - get_all_toilet_data + get_all_toilet_data, + generate_recommended_questions ) # 导入用于异步执行同步函数的模块 from concurrent.futures import ThreadPoolExecutor @@ -54,7 +58,7 @@ async def universal_exception_handler(request: Request, exc: Exception): chunk_size = 8 for i in range(0, len(error_msg), chunk_size): chunk = error_msg[i:i + chunk_size] - yield f"data: {json.dumps({'content': chunk})}\n\n" + yield f"data: {chunk}\n\n" await asyncio.sleep(0.03) # 模拟自然打字速度 # 发送结束标记 @@ -113,7 +117,7 @@ async def h5_chat_stream(request: Request, inp: ChatIn, redis_client = Depends(g chunk_size = 8 for i in range(0, len(error_msg), chunk_size): chunk = error_msg[i:i + chunk_size] - yield f"data: {json.dumps({'content': chunk})}\n\n" + yield f"data: {chunk}\n\n" await asyncio.sleep(0.03) yield "data: [DONE]\n\n" @@ -159,30 +163,44 @@ async def _handle_chat_request( user_messages_content = [msg.get('content', '') for msg in user_messages] all_messages = user_messages_content + [inp.message] - # 消息分类 - cat = await classify(inp.message) - print(f"Message category: {cat}") - - spot = None - scenics = None - if cat == "游玩判断" or cat == "保定文旅": - spot = await extract_spot(all_messages) - elif cat == "多景区比较": - scenics = await extract_multi_scenic(all_messages) - - # 知识库查询准备 - knowledge_task = None - if spot: - loop = asyncio.get_event_loop() - with ThreadPoolExecutor() as executor: - knowledge_task = loop.run_in_executor(executor, fetch_and_parse_markdown, user_id, inp.message) - - # 获取快捷问题 - questions = await QuickQuestion.filter(status="0", ischat="0").order_by("order_num").limit(4).values("title", - "subtitle", - "content") - question_titles = [f"{q['subtitle']}{q['title']}" for q in questions] - is_quick_question = inp.message in question_titles + try: + cached = await redis_client.get(inp.message) + if cached: + is_quick_question = True + # 缓存命中,直接返回缓存数据 + questions = cached + else: + is_quick_question = False + except Exception as e: + print(f"[Redis] 查询缓存失败: {e}") + is_quick_question = False + + if not is_quick_question: + # 消息分类 + cat = await classify(inp.message) + print(f"Message category: {cat}") + + spot = None + scenics = None + if cat == "游玩判断" or cat == "保定文旅": + spot = await extract_spot(all_messages) + elif cat == "多景区比较": + scenics = await extract_multi_scenic(all_messages) + + # 知识库查询准备 + knowledge_task = None + if spot: + loop = asyncio.get_event_loop() + with ThreadPoolExecutor() as executor: + knowledge_task = loop.run_in_executor(executor, fetch_and_parse_markdown, user_id, inp.message) + + # # 获取快捷问题 + # questions = await QuickQuestion.filter(status="0", ischat="0").order_by("order_num").limit(4).values("title", + # "subtitle", + # "content") + # question_titles = [f"{q['subtitle']}{q['title']}" for q in questions] + # is_quick_question = inp.message in question_titles + # 添加用户消息到历史 if not is_quick_question: @@ -193,16 +211,36 @@ async def _handle_chat_request( nonlocal conversation_history try: - # 记录热门问题 - await record_hot_question(inp.message) + try: + # 记录热门问题 + await record_hot_question(inp.message) + except Exception as e: + print(f"记录热门问题失败: {e},问题内容为:{inp.message}") if is_quick_question: + await asyncio.sleep(0.5) + full_response = "" # 处理快捷问题 - question_content = next( - q["content"] for q in questions if f"{q['subtitle']}{q['title']}" == inp.message) - async for chunk in handle_quick_question(inp, question_content): + question_content = questions + #每次输出随机1-10个字符 + chunk_size = random.randint(5, 15) + for i in range(0, len(question_content), chunk_size): + chunk = question_content[i:i + chunk_size] + full_response += chunk yield f"data: {chunk}\n\n" - await asyncio.sleep(0.01) + await asyncio.sleep(0.09) + if full_response: + recommended_questions = await generate_recommended_questions(inp.message, full_response) + if recommended_questions: + # 添加分隔符和标题 + yield "\n\n### 您可能还想了解:" + # 逐个返回推荐问题 + for i, question in enumerate(recommended_questions, 1): + yield f"\n{i}. {question}" + + # async for chunk in handle_quick_question(inp, question_content): + # yield f"data: {chunk}\n\n" + # await asyncio.sleep(0.01) else: # 处理不同分类的消息 if cat == "游玩判断": @@ -272,7 +310,7 @@ async def _handle_chat_request( chunk_size = 8 for i in range(0, len(error_msg), chunk_size): chunk = error_msg[i:i + chunk_size] - yield f"data: {json.dumps({'content': chunk})}\n\n" + yield f"data: {chunk})\n\n" await asyncio.sleep(0.03) yield "data: [DONE]\n\n" @@ -495,7 +533,7 @@ async def get_all_scenic_flow(request: Request, req: AllScenicFlowRequest, redis @router.post("/get_all_toilet_info") -async def get_all_toilet_info(request: Request, req: AllScenicFlowRequest): +async def get_all_toilet_info(request: Request, req: AllScenicFlowRequest, redis_client = Depends(get_redis_client)): """ 获取所有厕所信息 """ @@ -512,7 +550,7 @@ async def get_all_toilet_info(request: Request, req: AllScenicFlowRequest): raise HTTPException(status_code=401, detail="无效的签名") try: - data = await get_all_toilet_data(request) + data = await get_all_toilet_data(request, redis_client) if not data: return { @@ -524,7 +562,7 @@ async def get_all_toilet_info(request: Request, req: AllScenicFlowRequest): return { "code": 200, "message": "查询成功", - "data": data + "data": json.loads(data) } except Exception as e: print(f"查询所有厕所信息异常: {e}") diff --git a/app/api/chat/ai/chat_service.py b/app/api/chat/ai/chat_service.py index 665ba96..c8f51fe 100644 --- a/app/api/chat/ai/chat_service.py +++ b/app/api/chat/ai/chat_service.py @@ -36,7 +36,7 @@ EXTRACT_PROMPT = """你是一名景区名称精准匹配助手。用户的问题 白石山景区 阜平云花溪谷-玫瑰谷 保定军校纪念馆 -保定直隶总督署博物馆 +直隶总督署博物馆 冉庄地道战遗址 刘伶醉景区 曲阳北岳庙景区 @@ -51,7 +51,7 @@ EXTRACT_PROMPT = """你是一名景区名称精准匹配助手。用户的问题 满城汉墓景区 灵山聚龙洞旅游风景区 易县狼牙山风景区 -留法勤工俭学纪念馆 +留法勤工俭学运动纪念馆 白求恩柯棣华纪念馆 唐县秀水峪 腰山王氏庄园 @@ -548,6 +548,8 @@ async def get_all_scenic_flow_data(request: Request, redis_client = None) -> lis for row in rows: id,scenic_name, enter_num, leave_num, max_capacity = row in_park_num = abs(enter_num - leave_num) # 确保是正数 + if in_park_num > max_capacity: + in_park_num = max_capacity # 避免除以零的情况 if max_capacity > 0: @@ -625,7 +627,8 @@ async def get_scenic_detail_data(request: Request, id: int, redis_client = None) # 计算在园人数 in_park_num = abs(enter_num - leave_num) # 确保是正数 - + if in_park_num > max_capacity: + in_park_num = max_capacity # 计算承载率和舒适度 if max_capacity > 0: capacity_rate = in_park_num / max_capacity @@ -874,7 +877,7 @@ MULTI_SCENIC_EXTRACT_PROMPT = """你是一名景区名称提取助手。用户 白石山景区 阜平云花溪谷-玫瑰谷 保定军校纪念馆 -保定直隶总督署博物馆 +直隶总督署博物馆 冉庄地道战遗址 刘伶醉景区 曲阳北岳庙景区 @@ -889,7 +892,7 @@ MULTI_SCENIC_EXTRACT_PROMPT = """你是一名景区名称提取助手。用户 满城汉墓景区 灵山聚龙洞旅游风景区 易县狼牙山风景区 -留法勤工俭学纪念馆 +留法勤工俭学运动纪念馆 白求恩柯棣华纪念馆 唐县秀水峪 腰山王氏庄园 @@ -998,17 +1001,29 @@ async def query_multi_scenic_flow(request: Request, scenics: list, msg: str, red # 在文件末尾添加新函数用于获取所有厕所信息 -async def get_all_toilet_data(request: Request) -> list: +async def get_all_toilet_data(request: Request, redis_client = None) -> list: """ 查询所有厕所信息 """ try: + cache_key = "all_toilet_list" + if redis_client is None: + redis_client = request.app.state.redis_client + + try: + cached = await redis_client.get(cache_key) + if cached: + # 缓存命中,直接返回缓存数据 + return json.loads(cached) + except Exception as e: + print(f"[Redis] 查询缓存失败: {e}") + pool = request.app.state.mysql_pool async with pool.acquire() as conn: async with conn.cursor() as cur: # 查询所有厕所信息 query = """ - SELECT + SELECT id, banner, title, @@ -1026,14 +1041,14 @@ async def get_all_toilet_data(request: Request) -> list: is_aixin, createtime, updatetime - FROM + FROM cyjcpt_bd.ai_toilet_info - ORDER BY + ORDER BY id """ await cur.execute(query) rows = await cur.fetchall() - + # 处理结果 result = [] for row in rows: @@ -1056,7 +1071,7 @@ async def get_all_toilet_data(request: Request) -> list: createtime, updatetime ) = row - + result.append({ "id": id, "banner": banner, @@ -1076,9 +1091,9 @@ async def get_all_toilet_data(request: Request) -> list: "createtime": createtime, "updatetime": updatetime }) - + return result - + except Exception as e: print(f"[MySQL] 查询所有厕所数据失败: {e}") return []