From 1aad862731f0ea941c52c7efad57f314db95da8f Mon Sep 17 00:00:00 2001 From: zc Date: Wed, 17 Sep 2025 08:55:16 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E7=9F=A5=E8=AF=86=E5=BA=93?= =?UTF-8?q?=E8=AF=B7=E6=B1=82=E8=B7=AF=E5=BE=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/chat/ai/chat_router.py | 80 +++++++---- app/api/chat/ai/chat_service.py | 235 ++++++++++++++++++-------------- 2 files changed, 184 insertions(+), 131 deletions(-) diff --git a/app/api/chat/ai/chat_router.py b/app/api/chat/ai/chat_router.py index 4001ae9..0c7a635 100644 --- a/app/api/chat/ai/chat_router.py +++ b/app/api/chat/ai/chat_router.py @@ -88,15 +88,17 @@ async def h5_chat_stream(request: Request, inp: ChatIn, redis_client = Depends(g except asyncio.TimeoutError: # 处理超时情况 async def timeout_stream() -> AsyncGenerator[str, None]: - error_msg = "亲爱的游客,目前系统暂时繁忙,可能是咨询过于火爆~ \n\n请您稍后重新发起咨询,我们会第一时间为您提供文旅信息服务,感谢您的耐心!" - - chunk_size = 8 - for i in range(0, len(error_msg), chunk_size): - chunk = error_msg[i:i + chunk_size] - yield f"data: {json.dumps({'content': chunk})}\n\n" - await asyncio.sleep(0.03) + try: + error_msg = "亲爱的游客,目前系统暂时繁忙,可能是咨询过于火爆~ \n\n请您稍后重新发起咨询,我们会第一时间为您提供文旅信息服务,感谢您的耐心!" - yield "data: [DONE]\n\n" + chunk_size = 8 + for i in range(0, len(error_msg), chunk_size): + chunk = error_msg[i:i + chunk_size] + yield f"data: {chunk}\n\n" + await asyncio.sleep(0.03) + finally: + # 确保在任何情况下都发送结束标记 + yield "data: [DONE]\n\n" return StreamingResponse( timeout_stream(), @@ -112,15 +114,17 @@ async def h5_chat_stream(request: Request, inp: ChatIn, redis_client = Depends(g print(f"聊天接口异常: {str(e)}") async def error_stream() -> AsyncGenerator[str, None]: - error_msg = "亲爱的游客,目前系统暂时繁忙,如同景区高峰期需要稍作等候~ \n\n请您稍后再次尝试,我们会尽快为您提供贴心的文旅咨询服务,感谢您的理解!" - - chunk_size = 8 - for i in range(0, len(error_msg), chunk_size): - chunk = error_msg[i:i + chunk_size] - yield f"data: {chunk}\n\n" - await asyncio.sleep(0.03) + try: + error_msg = "亲爱的游客,目前系统暂时繁忙,如同景区高峰期需要稍作等候~ \n\n请您稍后再次尝试,我们会尽快为您提供贴心的文旅咨询服务,感谢您的理解!" - yield "data: [DONE]\n\n" + chunk_size = 8 + for i in range(0, len(error_msg), chunk_size): + chunk = error_msg[i:i + chunk_size] + yield f"data: {chunk}\n\n" + await asyncio.sleep(0.03) + finally: + # 确保在任何情况下都发送结束标记 + yield "data: [DONE]\n\n" return StreamingResponse( error_stream(), @@ -158,6 +162,11 @@ async def _handle_chat_request( conversation_history_str = await redis_client.get(conversation_history_key) conversation_history = json.loads(conversation_history_str) if conversation_history_str else [] + # 限制对话历史长度为10条(5轮对话) + # 如果超过10条,删除最前面的2条(一轮对话) + if len(conversation_history) > 10: + conversation_history = conversation_history[2:] + # 提取用户历史消息 user_messages = [msg for msg in conversation_history if msg.get('role') == 'user'] user_messages_content = [msg.get('content', '') for msg in user_messages] @@ -178,7 +187,6 @@ async def _handle_chat_request( if not is_quick_question: # 消息分类 cat = await classify(inp.message) - print(f"Message category: {cat}") spot = None scenics = None @@ -253,8 +261,12 @@ async def _handle_chat_request( else: data = await query_flow(request, spot, redis_client) if spot and knowledge_task: - knowledge = await knowledge_task - if knowledge and "无法" not in knowledge: + try: + knowledge = await asyncio.wait_for(knowledge_task, timeout=20) + except TimeoutError: + print("获取知识库信息超时") + knowledge = None + if knowledge: data += "\n\n知识库查询到的景区内容:" + knowledge async for chunk in gen_markdown_stream(inp.message, data, inp.language, conversation_history): yield f"data: {chunk}\n\n" @@ -270,8 +282,12 @@ async def _handle_chat_request( elif len(scenics) == 1: data = await query_flow(request, scenics[0], redis_client) if scenics[0] and knowledge_task: - knowledge = await knowledge_task - if knowledge and "无法" not in knowledge: + try: + knowledge = await asyncio.wait_for(knowledge_task, timeout=20) + except TimeoutError: + print("获取知识库信息超时") + knowledge = None + if knowledge: data += "\n\n知识库查询到的景区内容:" + knowledge async for chunk in gen_markdown_stream(inp.message, data, inp.language, conversation_history): yield f"data: {chunk}\n\n" @@ -284,24 +300,29 @@ async def _handle_chat_request( await asyncio.sleep(0.01) else: if spot and knowledge_task: - knowledge = await knowledge_task - if knowledge and "无法" not in knowledge: + try: + knowledge = await asyncio.wait_for(knowledge_task, timeout=20) + except TimeoutError: + print("获取知识库信息超时") + knowledge = None + if knowledge: inp.message += "\n\n知识库查询到的景区内容:" + knowledge async for chunk in ai_chat_stream(inp, conversation_history): yield f"data: {chunk}\n\n" await asyncio.sleep(0.01) - # 保存对话历史 + # 在保存前再次检查长度,确保不超过10条记录 if not is_quick_question: + # 确保对话历史不超过10条记录(5轮对话) + if len(conversation_history) > 10: + conversation_history = conversation_history[-10:] + await redis_client.setex( conversation_history_key, CONVERSATION_EXPIRE_TIME, json.dumps(conversation_history) ) - # 发送结束标记 - yield "data: [DONE]\n\n" - except Exception as e: print(f"content_stream 异常: {str(e)}") # 流式返回错误信息 @@ -310,9 +331,10 @@ async def _handle_chat_request( chunk_size = 8 for i in range(0, len(error_msg), chunk_size): chunk = error_msg[i:i + chunk_size] - yield f"data: {chunk})\n\n" + yield f"data: {chunk}\n\n" await asyncio.sleep(0.03) - + finally: + # 确保在任何情况下都发送结束标记 yield "data: [DONE]\n\n" return StreamingResponse( diff --git a/app/api/chat/ai/chat_service.py b/app/api/chat/ai/chat_service.py index 6b21ddb..adefac6 100644 --- a/app/api/chat/ai/chat_service.py +++ b/app/api/chat/ai/chat_service.py @@ -6,17 +6,25 @@ from app.models.ChatIn import ChatIn from fastapi import Request from app.settings.config import settings import json -import re -from typing import List import requests from chinese_calendar import is_holiday from datetime import datetime +from typing import Optional + load_dotenv() async_client = AsyncOpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url=settings.DEEPSEEK_API_URL) +# 知识库接口 +KNOWLEDGE_URL = "http://172.21.11.20:8886/v3/chat" +# KNOWLEDGE_URL = "http://192.168.130.144:8888/v3/chat" +# 知识库应用id +BOT_ID = "7550848889243303936" +# 知识库token +KNOWL_TOKEN = "Bearer pat_5cb11052e7b4b517015467902cd7775742120fc88fe66b926c93fde3a39843c7" + #分类提示词 CATEGORY_PROMPT = """你是一个分类助手,请根据用户的问题判断属于以下哪一类: 如果用户的问题涉及保定市某个景区当前的人数、客流量、拥挤程度或是否适合前往(例如:“某个保定市景区现在人多么”、“某个保定市景区现在适不适合去”、"现在可以去吗"、"现在适合去吗",注意:只有涉及实时、现在等当前时间的,如果是明天、后天等未来时间的不包括在内),请返回:游玩判断。 @@ -206,7 +214,6 @@ async def ai_chat_stream(inp: ChatIn, conversation_history: list) -> AsyncGenera messages = [{"role": "system", "content": chat_prompt}] + conversation_history messages.append({"role": "user", "content": inp.message}) - print(f"Starting AI chat stream with input: {inp.message}") full_response = "" try: response = await async_client.chat.completions.create( @@ -241,6 +248,9 @@ async def ai_chat_stream(inp: ChatIn, conversation_history: list) -> AsyncGenera if full_response: conversation_history.append({"role": "assistant", "content": full_response}) + # 限制对话历史长度为10条(5轮对话) + if len(conversation_history) > 10: + conversation_history = conversation_history[-10:] print("AI chat stream finished.") def get_formatted_prompt(user_language,msg,data): @@ -252,7 +262,6 @@ async def gen_markdown_stream(msg: str, data: str, language: str, conversation_h messages = conversation_history + [{"role": "user", "content": prompt}] - print(f"Starting markdown stream with message: {msg} and data: {data}") full_response = "" try: response = await async_client.chat.completions.create( @@ -286,6 +295,9 @@ async def gen_markdown_stream(msg: str, data: str, language: str, conversation_h if full_response: conversation_history.append({"role": "assistant", "content": full_response}) + # 限制对话历史长度为10条(5轮对话) + if len(conversation_history) > 10: + conversation_history = conversation_history[-10:] print("Markdown stream finished.") async def extract_spot(msg) -> str: # 如果msg是列表,则将其内容连接成字符串 @@ -293,8 +305,7 @@ async def extract_spot(msg) -> str: msg_content = '\n'.join(msg) else: msg_content = msg - - print(f"Starting spot extraction for message: {msg_content}") + try: response = await async_client.chat.completions.create( model="deepseek-chat", @@ -326,11 +337,9 @@ async def query_flow(request: Request, spot: str, redis_client = None) -> str: redis_client = request.app.state.redis_client # Step 1: Redis 缓存查询 - print(f"Querying Redis cache for key: {cache_key}") try: cached = await redis_client.get(cache_key) if cached: - print(f"Found cached data for key: {cache_key}") return cached else: return f"未找到景区【{spot}】的客流相关信息,在园人数和舒适度未知;停车场信息:暂无数据。" @@ -360,7 +369,7 @@ async def query_flow(request: Request, spot: str, redis_client = None) -> str: await cur.execute(formatted_flow_query) row = await cur.fetchone() - + # 查询停车场信息 park_query = """SELECT t3.park_name AS park_name, IFNULL(t3.rate_info,'暂无收费标准信息') AS rate_info, t3.total_count AS total_count, t4.space AS space, t1.distance_value AS distance_value FROM cyjcpt_bd.scenic_pack_distance t1 @@ -387,7 +396,7 @@ async def query_flow(request: Request, spot: str, redis_client = None) -> str: except Exception as e: print(f"[MySQL] 查询失败: {e}") return f"**未找到景区【{spot}】的信息,请检查名称是否正确。\n\n(内容仅供参考)" - + result = "" if row and all(v is not None for v in row): # 使用变量名访问客流数据 @@ -465,10 +474,6 @@ async def handle_quick_question(inp: ChatIn, question_content: str) -> AsyncGene print(error_msg) yield error_msg - # 不保存快捷问题的对话历史 - print("Quick question handling finished.") - - # 在chat_service.py中添加推荐问题生成函数 async def generate_recommended_questions(user_msg: str, ai_response: str) -> list: """基于用户问题和AI回答生成1-3个纵向延伸的推荐问题""" @@ -543,7 +548,7 @@ async def get_all_scenic_flow_data(request: Request, redis_client = None) -> lis """ await cur.execute(query) rows = await cur.fetchall() - + # 处理结果 result = [] for row in rows: @@ -551,13 +556,13 @@ async def get_all_scenic_flow_data(request: Request, redis_client = None) -> lis in_park_num = abs(enter_num - leave_num) # 确保是正数 if in_park_num > max_capacity: in_park_num = max_capacity - + # 避免除以零的情况 if max_capacity > 0: capacity_rate = in_park_num / max_capacity else: capacity_rate = 0 - + result.append({ "id": id, "scenic_name": scenic_name, @@ -567,15 +572,15 @@ async def get_all_scenic_flow_data(request: Request, redis_client = None) -> lis "max_capacity": max_capacity, "capacity_rate": capacity_rate }) - + # 将结果存入Redis缓存,过期时间1分钟 try: await redis_client.setex(cache_key, 60, json.dumps(result)) except Exception as e: print(f"[Redis] 写缓存失败: {e}") - + return result - + except Exception as e: print(f"[MySQL] 查询所有景区客流数据失败: {e}") return [] @@ -620,12 +625,12 @@ async def get_scenic_detail_data(request: Request, id: int, redis_client = None) """ await cur.execute(query, (id,)) row = await cur.fetchone() - + if not row: return None - + scenic_name, enter_num, leave_num, max_capacity = row - + # 计算在园人数 in_park_num = abs(enter_num - leave_num) # 确保是正数 if in_park_num > max_capacity: @@ -650,7 +655,7 @@ async def get_scenic_detail_data(request: Request, id: int, redis_client = None) else: capacity_rate = 0.0 comfort_level = "舒适" - + result = { "scenic_name": scenic_name, "enter_num": enter_num or 0, @@ -660,15 +665,15 @@ async def get_scenic_detail_data(request: Request, id: int, redis_client = None) "capacity_rate": round(capacity_rate, 4), "comfort_level": comfort_level } - + # 将结果存入Redis缓存,过期时间1分钟 try: await redis_client.setex(cache_key, 60, json.dumps(result)) except Exception as e: print(f"[Redis] 写缓存失败: {e}") - + return result - + except Exception as e: print(f"[MySQL] 查询景区详情数据失败: {e}") return None @@ -678,13 +683,13 @@ async def get_scenic_detail_data(request: Request, id: int, redis_client = None) async def get_scenic_parking_data(request: Request, scenic_id: int, distance: int, redis_client = None) -> list: """ 查询景区附近的停车场信息 - + Args: request: FastAPI请求对象 scenic_id: 景区id distance: 查询距离(米),>=1000时查询全部 redis_client: Redis客户端实例(可选) - + Returns: list: 停车场信息列表,按距离排序 """ @@ -763,12 +768,12 @@ async def get_scenic_parking_data(request: Request, scenic_id: int, distance: in await cur.execute(formatted_park_base) rows = await cur.fetchall() - + # 处理结果 result = [] for row in rows: park_name, total_spaces, available_spaces, distance_meters, lon, lat, park_type = row - + result.append({ "park_name": park_name, "total_parking_spaces": total_spaces or 0, @@ -778,15 +783,15 @@ async def get_scenic_parking_data(request: Request, scenic_id: int, distance: in "lat": lat or 0, "park_type": park_type }) - + # 将结果存入Redis缓存,过期时间1分钟 try: await redis_client.setex(cache_key, 60, json.dumps(result)) except Exception as e: print(f"[Redis] 写缓存失败: {e}") - + return result - + except Exception as e: print(f"[MySQL] 查询景区停车场数据失败: {e}") return [] @@ -794,82 +799,108 @@ async def get_scenic_parking_data(request: Request, scenic_id: int, distance: in # 添加用于获取完整响应数据的新函数 def fetch_and_parse_markdown(user_id: int, question: str) -> str: """ - 只提取最终完整的markdown内容(过滤流式中间片段) + 功能:发送请求、解析SSE流、提取完整知识库内容、处理乱码 + 返回:纯净的知识库markdown内容(与原fetch_and_parse_markdown返回格式一致) """ - encoded_question = requests.utils.quote(question) - # url = f"http://cjy.aitto.net:45678/api/v3/user_share_chat_completions?random={user_id}&api_key=cjy-626e50140e934936b8c82a3be5f6dea3&app_code=f5b3d4ba-7e7a-11f0-9de7-00e04f309c26&user_input={encoded_question}" - url = f"http://127.0.0.1:5679/api/v3/user_share_chat_completions?random={user_id}&api_key=cjy-626e50140e934936b8c82a3be5f6dea3&app_code=f5b3d4ba-7e7a-11f0-9de7-00e04f309c26&user_input={encoded_question}" - - all_markdowns: List[str] = [] - final_content = "" # 存储最终完整内容 + # 1. 新接口基础配置 + HEADERS = { + "Authorization": KNOWL_TOKEN, + "Content-Type": "application/json", + "Accept": "text/event-stream", + "Accept-Charset": "utf-8" + } + # 请求体(user_id和question动态传入,其他参数固定) + PAYLOAD = { + "bot_id": BOT_ID, # 接口固定bot_id + "user_id": str(user_id), # 转为字符串适配接口 + "additional_messages": [ + { + "role": "user", + "type": "question", + "content": question, # 用户查询问题 + "content_type": "text" + } + ], + "stream": False, + "auto_save_history": True, + "enable_card": True + } + + full_answer = "" # 核心:拼接你需要的「最终完整回答」 + current_event_type: Optional[str] = None try: - with requests.get(url, stream=True, timeout=60) as response: - response.encoding = "utf-8" - response.raise_for_status() - - for line in response.iter_lines(): - if not line: - continue - line_str = line.decode("utf-8", errors="replace") - if not line_str.startswith("data:"): + # 2. 发送SSE请求并流式接收 + with requests.post( + url=KNOWLEDGE_URL, + headers=HEADERS, + data=json.dumps(PAYLOAD, ensure_ascii=False), # 请求体UTF-8编码 + stream=True, + timeout=60 + ) as response: + response.raise_for_status() # 检查请求是否成功 + + # 3. 逐行解析SSE流 + for line_bytes in response.iter_lines(): + if not line_bytes: # 跳过空行(事件分隔符) continue - data_str = line_str[5:].strip() + # 4. 处理乱码(重点解决双重编码问题) try: - data_json = json.loads(data_str) - vis_content = data_json.get("vis", "") - - # 提取所有markdown内容 - code_blocks = re.findall(r'```(.*?)```', vis_content, re.DOTALL) - for block in code_blocks: - block_parts = block.split('\n', 1) - if len(block_parts) < 2: - continue - block_type, block_content = block_parts - block_content = block_content.strip() + # 优先UTF-8解码(正常情况) + line = line_bytes.decode("utf-8", errors="strict") + except UnicodeDecodeError: + # 修复"UTF-8→ISO-8859-1"双重编码(常见中文乱码原因) + line = line_bytes.decode("iso-8859-1").encode("iso-8859-1").decode("utf-8") + + # 5. 提取事件类型(如 conversation.message.delta) + if line.startswith("event:"): + current_event_type = line.split(":", 1)[1].strip() + continue + + # 6. 提取事件数据(只关注回答片段) + if line.startswith("data:"): + data_str = line.split(":", 1)[1].strip() + if not data_str: + continue + try: + # 解析JSON数据(处理可能的编码问题) try: - items = json.loads(block_content) - if isinstance(items, list): - for item in items: - if isinstance(item, dict) and "markdown" in item: - md_content = item["markdown"].strip() - all_markdowns.append(md_content) - - # 处理嵌套的markdown - nested_blocks = re.findall(r'```(.*?)```', md_content, re.DOTALL) - for nested in nested_blocks: - nested_parts = nested.split('\n', 1) - if len(nested_parts) >= 2: - nested_content = nested_parts[1].strip() - try: - nested_items = json.loads(nested_content) - if isinstance(nested_items, list): - for ni in nested_items: - if isinstance(ni, dict) and "markdown" in ni: - nested_md = ni["markdown"].strip() - all_markdowns.append(nested_md) - except json.JSONDecodeError: - continue - except json.JSONDecodeError: - continue - except json.JSONDecodeError: - continue + event_data = json.loads(data_str) + except UnicodeDecodeError: + data_str_fixed = data_str.encode("iso-8859-1").decode("utf-8") + event_data = json.loads(data_str_fixed) + + # 7. 核心:拼接回答片段(只取 "conversation.message.delta" 事件的 answer 内容) + if (current_event_type == "conversation.message.delta" + and event_data.get("type") == "answer"): + # 提取当前片段(如"直"、"隶"、"总督署") + answer_chunk = event_data.get("content", "").strip() + # 修复片段中的乱码(兜底) + try: + answer_chunk = answer_chunk.encode("iso-8859-1").decode("utf-8") + except: + pass + full_answer += answer_chunk # 拼接成完整回答 + + except json.JSONDecodeError: + continue # 跳过无效JSON,不影响整体 + + # 8. 清理最终回答(去除多余空格/空行) + full_answer = full_answer.strip() + print("【调试】,知识库内容:", full_answer) + return full_answer except requests.exceptions.RequestException as e: - print(f"请求错误: {e}") + error_msg = f"请求错误: {e}" + print(error_msg) + return "" # 错误时返回空字符串 + except Exception as e: + error_msg = f"解析错误: {e}" + print(error_msg) return "" - # 核心逻辑:筛选出最长且完整的内容(流式响应中最后完成的内容通常最长) - if all_markdowns: - # 按长度倒序排序,取最长的非空内容 - all_markdowns = [md for md in all_markdowns if md] # 过滤空字符串 - if all_markdowns: - final_content = max(all_markdowns, key=len) - - return final_content - # 添加用于多景区比较的新提示词 MULTI_SCENIC_EXTRACT_PROMPT = """你是一名景区名称提取助手。用户的问题中可能包含多个景区名称,请根据下面的完整景区名称列表,准确提取用户提到的所有景区名称并返回,每个景区名称占一行。如果用户没有提到任何景区,返回空字符串。 完整景区名称列表: @@ -946,7 +977,7 @@ async def extract_multi_scenic(msg) -> list: msg_content = '\n'.join(msg) else: msg_content = msg - + print(f"Starting multi scenic extraction for message: {msg_content}") try: response = await async_client.chat.completions.create( @@ -966,7 +997,7 @@ async def query_multi_scenic_flow(request: Request, scenics: list, msg: str, red if not scenics: print("No scenics found, returning default message.") return "**未找到景区信息,请检查名称是否正确。**\n\n(内容由AI生成,仅供参考)" - + # 查询多个景区的客流数据 results = [] for scenic in scenics: @@ -975,7 +1006,7 @@ async def query_multi_scenic_flow(request: Request, scenics: list, msg: str, red "scenic": scenic, "data": data }) - + # 生成比较结果 if len(results) == 1: return results[0]["data"] @@ -997,7 +1028,7 @@ async def query_multi_scenic_flow(request: Request, scenics: list, msg: str, red # 如果AI比较失败,返回原始数据 result_str = "\n\n".join([f"**{r['scenic']}**:\n{r['data']}" for r in results]) return result_str - + return "**未找到景区信息,请检查名称是否正确。**\n\n(内容由AI生成,仅供参考)"