修改知识库请求路径

main
zc 1 month ago
parent 14858c483c
commit 1aad862731
  1. 80
      app/api/chat/ai/chat_router.py
  2. 235
      app/api/chat/ai/chat_service.py

@ -88,15 +88,17 @@ async def h5_chat_stream(request: Request, inp: ChatIn, redis_client = Depends(g
except asyncio.TimeoutError:
# 处理超时情况
async def timeout_stream() -> AsyncGenerator[str, None]:
error_msg = "亲爱的游客,目前系统暂时繁忙,可能是咨询过于火爆~ \n\n请您稍后重新发起咨询,我们会第一时间为您提供文旅信息服务,感谢您的耐心!"
chunk_size = 8
for i in range(0, len(error_msg), chunk_size):
chunk = error_msg[i:i + chunk_size]
yield f"data: {json.dumps({'content': chunk})}\n\n"
await asyncio.sleep(0.03)
try:
error_msg = "亲爱的游客,目前系统暂时繁忙,可能是咨询过于火爆~ \n\n请您稍后重新发起咨询,我们会第一时间为您提供文旅信息服务,感谢您的耐心!"
yield "data: [DONE]\n\n"
chunk_size = 8
for i in range(0, len(error_msg), chunk_size):
chunk = error_msg[i:i + chunk_size]
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.03)
finally:
# 确保在任何情况下都发送结束标记
yield "data: [DONE]\n\n"
return StreamingResponse(
timeout_stream(),
@ -112,15 +114,17 @@ async def h5_chat_stream(request: Request, inp: ChatIn, redis_client = Depends(g
print(f"聊天接口异常: {str(e)}")
async def error_stream() -> AsyncGenerator[str, None]:
error_msg = "亲爱的游客,目前系统暂时繁忙,如同景区高峰期需要稍作等候~ \n\n请您稍后再次尝试,我们会尽快为您提供贴心的文旅咨询服务,感谢您的理解!"
chunk_size = 8
for i in range(0, len(error_msg), chunk_size):
chunk = error_msg[i:i + chunk_size]
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.03)
try:
error_msg = "亲爱的游客,目前系统暂时繁忙,如同景区高峰期需要稍作等候~ \n\n请您稍后再次尝试,我们会尽快为您提供贴心的文旅咨询服务,感谢您的理解!"
yield "data: [DONE]\n\n"
chunk_size = 8
for i in range(0, len(error_msg), chunk_size):
chunk = error_msg[i:i + chunk_size]
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.03)
finally:
# 确保在任何情况下都发送结束标记
yield "data: [DONE]\n\n"
return StreamingResponse(
error_stream(),
@ -158,6 +162,11 @@ async def _handle_chat_request(
conversation_history_str = await redis_client.get(conversation_history_key)
conversation_history = json.loads(conversation_history_str) if conversation_history_str else []
# 限制对话历史长度为10条(5轮对话)
# 如果超过10条,删除最前面的2条(一轮对话)
if len(conversation_history) > 10:
conversation_history = conversation_history[2:]
# 提取用户历史消息
user_messages = [msg for msg in conversation_history if msg.get('role') == 'user']
user_messages_content = [msg.get('content', '') for msg in user_messages]
@ -178,7 +187,6 @@ async def _handle_chat_request(
if not is_quick_question:
# 消息分类
cat = await classify(inp.message)
print(f"Message category: {cat}")
spot = None
scenics = None
@ -253,8 +261,12 @@ async def _handle_chat_request(
else:
data = await query_flow(request, spot, redis_client)
if spot and knowledge_task:
knowledge = await knowledge_task
if knowledge and "无法" not in knowledge:
try:
knowledge = await asyncio.wait_for(knowledge_task, timeout=20)
except TimeoutError:
print("获取知识库信息超时")
knowledge = None
if knowledge:
data += "\n\n知识库查询到的景区内容:" + knowledge
async for chunk in gen_markdown_stream(inp.message, data, inp.language, conversation_history):
yield f"data: {chunk}\n\n"
@ -270,8 +282,12 @@ async def _handle_chat_request(
elif len(scenics) == 1:
data = await query_flow(request, scenics[0], redis_client)
if scenics[0] and knowledge_task:
knowledge = await knowledge_task
if knowledge and "无法" not in knowledge:
try:
knowledge = await asyncio.wait_for(knowledge_task, timeout=20)
except TimeoutError:
print("获取知识库信息超时")
knowledge = None
if knowledge:
data += "\n\n知识库查询到的景区内容:" + knowledge
async for chunk in gen_markdown_stream(inp.message, data, inp.language, conversation_history):
yield f"data: {chunk}\n\n"
@ -284,24 +300,29 @@ async def _handle_chat_request(
await asyncio.sleep(0.01)
else:
if spot and knowledge_task:
knowledge = await knowledge_task
if knowledge and "无法" not in knowledge:
try:
knowledge = await asyncio.wait_for(knowledge_task, timeout=20)
except TimeoutError:
print("获取知识库信息超时")
knowledge = None
if knowledge:
inp.message += "\n\n知识库查询到的景区内容:" + knowledge
async for chunk in ai_chat_stream(inp, conversation_history):
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.01)
# 保存对话历史
# 在保存前再次检查长度,确保不超过10条记录
if not is_quick_question:
# 确保对话历史不超过10条记录(5轮对话)
if len(conversation_history) > 10:
conversation_history = conversation_history[-10:]
await redis_client.setex(
conversation_history_key,
CONVERSATION_EXPIRE_TIME,
json.dumps(conversation_history)
)
# 发送结束标记
yield "data: [DONE]\n\n"
except Exception as e:
print(f"content_stream 异常: {str(e)}")
# 流式返回错误信息
@ -310,9 +331,10 @@ async def _handle_chat_request(
chunk_size = 8
for i in range(0, len(error_msg), chunk_size):
chunk = error_msg[i:i + chunk_size]
yield f"data: {chunk})\n\n"
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.03)
finally:
# 确保在任何情况下都发送结束标记
yield "data: [DONE]\n\n"
return StreamingResponse(

@ -6,17 +6,25 @@ from app.models.ChatIn import ChatIn
from fastapi import Request
from app.settings.config import settings
import json
import re
from typing import List
import requests
from chinese_calendar import is_holiday
from datetime import datetime
from typing import Optional
load_dotenv()
async_client = AsyncOpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url=settings.DEEPSEEK_API_URL)
# 知识库接口
KNOWLEDGE_URL = "http://172.21.11.20:8886/v3/chat"
# KNOWLEDGE_URL = "http://192.168.130.144:8888/v3/chat"
# 知识库应用id
BOT_ID = "7550848889243303936"
# 知识库token
KNOWL_TOKEN = "Bearer pat_5cb11052e7b4b517015467902cd7775742120fc88fe66b926c93fde3a39843c7"
#分类提示词
CATEGORY_PROMPT = """你是一个分类助手,请根据用户的问题判断属于以下哪一类:
如果用户的问题涉及保定市某个景区当前的人数客流量拥挤程度或是否适合前往例如某个保定市景区现在人多么某个保定市景区现在适不适合去"现在可以去吗""现在适合去吗"注意只有涉及实时现在等当前时间的如果是明天后天等未来时间的不包括在内请返回游玩判断
@ -206,7 +214,6 @@ async def ai_chat_stream(inp: ChatIn, conversation_history: list) -> AsyncGenera
messages = [{"role": "system", "content": chat_prompt}] + conversation_history
messages.append({"role": "user", "content": inp.message})
print(f"Starting AI chat stream with input: {inp.message}")
full_response = ""
try:
response = await async_client.chat.completions.create(
@ -241,6 +248,9 @@ async def ai_chat_stream(inp: ChatIn, conversation_history: list) -> AsyncGenera
if full_response:
conversation_history.append({"role": "assistant", "content": full_response})
# 限制对话历史长度为10条(5轮对话)
if len(conversation_history) > 10:
conversation_history = conversation_history[-10:]
print("AI chat stream finished.")
def get_formatted_prompt(user_language,msg,data):
@ -252,7 +262,6 @@ async def gen_markdown_stream(msg: str, data: str, language: str, conversation_h
messages = conversation_history + [{"role": "user", "content": prompt}]
print(f"Starting markdown stream with message: {msg} and data: {data}")
full_response = ""
try:
response = await async_client.chat.completions.create(
@ -286,6 +295,9 @@ async def gen_markdown_stream(msg: str, data: str, language: str, conversation_h
if full_response:
conversation_history.append({"role": "assistant", "content": full_response})
# 限制对话历史长度为10条(5轮对话)
if len(conversation_history) > 10:
conversation_history = conversation_history[-10:]
print("Markdown stream finished.")
async def extract_spot(msg) -> str:
# 如果msg是列表,则将其内容连接成字符串
@ -293,8 +305,7 @@ async def extract_spot(msg) -> str:
msg_content = '\n'.join(msg)
else:
msg_content = msg
print(f"Starting spot extraction for message: {msg_content}")
try:
response = await async_client.chat.completions.create(
model="deepseek-chat",
@ -326,11 +337,9 @@ async def query_flow(request: Request, spot: str, redis_client = None) -> str:
redis_client = request.app.state.redis_client
# Step 1: Redis 缓存查询
print(f"Querying Redis cache for key: {cache_key}")
try:
cached = await redis_client.get(cache_key)
if cached:
print(f"Found cached data for key: {cache_key}")
return cached
else:
return f"未找到景区【{spot}】的客流相关信息,在园人数和舒适度未知;停车场信息:暂无数据。"
@ -360,7 +369,7 @@ async def query_flow(request: Request, spot: str, redis_client = None) -> str:
await cur.execute(formatted_flow_query)
row = await cur.fetchone()
# 查询停车场信息
park_query = """SELECT t3.park_name AS park_name, IFNULL(t3.rate_info,'暂无收费标准信息') AS rate_info, t3.total_count AS total_count, t4.space AS space, t1.distance_value AS distance_value
FROM cyjcpt_bd.scenic_pack_distance t1
@ -387,7 +396,7 @@ async def query_flow(request: Request, spot: str, redis_client = None) -> str:
except Exception as e:
print(f"[MySQL] 查询失败: {e}")
return f"**未找到景区【{spot}】的信息,请检查名称是否正确。\n\n(内容仅供参考)"
result = ""
if row and all(v is not None for v in row):
# 使用变量名访问客流数据
@ -465,10 +474,6 @@ async def handle_quick_question(inp: ChatIn, question_content: str) -> AsyncGene
print(error_msg)
yield error_msg
# 不保存快捷问题的对话历史
print("Quick question handling finished.")
# 在chat_service.py中添加推荐问题生成函数
async def generate_recommended_questions(user_msg: str, ai_response: str) -> list:
"""基于用户问题和AI回答生成1-3个纵向延伸的推荐问题"""
@ -543,7 +548,7 @@ async def get_all_scenic_flow_data(request: Request, redis_client = None) -> lis
"""
await cur.execute(query)
rows = await cur.fetchall()
# 处理结果
result = []
for row in rows:
@ -551,13 +556,13 @@ async def get_all_scenic_flow_data(request: Request, redis_client = None) -> lis
in_park_num = abs(enter_num - leave_num) # 确保是正数
if in_park_num > max_capacity:
in_park_num = max_capacity
# 避免除以零的情况
if max_capacity > 0:
capacity_rate = in_park_num / max_capacity
else:
capacity_rate = 0
result.append({
"id": id,
"scenic_name": scenic_name,
@ -567,15 +572,15 @@ async def get_all_scenic_flow_data(request: Request, redis_client = None) -> lis
"max_capacity": max_capacity,
"capacity_rate": capacity_rate
})
# 将结果存入Redis缓存,过期时间1分钟
try:
await redis_client.setex(cache_key, 60, json.dumps(result))
except Exception as e:
print(f"[Redis] 写缓存失败: {e}")
return result
except Exception as e:
print(f"[MySQL] 查询所有景区客流数据失败: {e}")
return []
@ -620,12 +625,12 @@ async def get_scenic_detail_data(request: Request, id: int, redis_client = None)
"""
await cur.execute(query, (id,))
row = await cur.fetchone()
if not row:
return None
scenic_name, enter_num, leave_num, max_capacity = row
# 计算在园人数
in_park_num = abs(enter_num - leave_num) # 确保是正数
if in_park_num > max_capacity:
@ -650,7 +655,7 @@ async def get_scenic_detail_data(request: Request, id: int, redis_client = None)
else:
capacity_rate = 0.0
comfort_level = "舒适"
result = {
"scenic_name": scenic_name,
"enter_num": enter_num or 0,
@ -660,15 +665,15 @@ async def get_scenic_detail_data(request: Request, id: int, redis_client = None)
"capacity_rate": round(capacity_rate, 4),
"comfort_level": comfort_level
}
# 将结果存入Redis缓存,过期时间1分钟
try:
await redis_client.setex(cache_key, 60, json.dumps(result))
except Exception as e:
print(f"[Redis] 写缓存失败: {e}")
return result
except Exception as e:
print(f"[MySQL] 查询景区详情数据失败: {e}")
return None
@ -678,13 +683,13 @@ async def get_scenic_detail_data(request: Request, id: int, redis_client = None)
async def get_scenic_parking_data(request: Request, scenic_id: int, distance: int, redis_client = None) -> list:
"""
查询景区附近的停车场信息
Args:
request: FastAPI请求对象
scenic_id: 景区id
distance: 查询距离()>=1000时查询全部
redis_client: Redis客户端实例可选
Returns:
list: 停车场信息列表按距离排序
"""
@ -763,12 +768,12 @@ async def get_scenic_parking_data(request: Request, scenic_id: int, distance: in
await cur.execute(formatted_park_base)
rows = await cur.fetchall()
# 处理结果
result = []
for row in rows:
park_name, total_spaces, available_spaces, distance_meters, lon, lat, park_type = row
result.append({
"park_name": park_name,
"total_parking_spaces": total_spaces or 0,
@ -778,15 +783,15 @@ async def get_scenic_parking_data(request: Request, scenic_id: int, distance: in
"lat": lat or 0,
"park_type": park_type
})
# 将结果存入Redis缓存,过期时间1分钟
try:
await redis_client.setex(cache_key, 60, json.dumps(result))
except Exception as e:
print(f"[Redis] 写缓存失败: {e}")
return result
except Exception as e:
print(f"[MySQL] 查询景区停车场数据失败: {e}")
return []
@ -794,82 +799,108 @@ async def get_scenic_parking_data(request: Request, scenic_id: int, distance: in
# 添加用于获取完整响应数据的新函数
def fetch_and_parse_markdown(user_id: int, question: str) -> str:
"""
只提取最终完整的markdown内容过滤流式中间片段
功能发送请求解析SSE流提取完整知识库内容处理乱码
返回纯净的知识库markdown内容与原fetch_and_parse_markdown返回格式一致
"""
encoded_question = requests.utils.quote(question)
# url = f"http://cjy.aitto.net:45678/api/v3/user_share_chat_completions?random={user_id}&api_key=cjy-626e50140e934936b8c82a3be5f6dea3&app_code=f5b3d4ba-7e7a-11f0-9de7-00e04f309c26&user_input={encoded_question}"
url = f"http://127.0.0.1:5679/api/v3/user_share_chat_completions?random={user_id}&api_key=cjy-626e50140e934936b8c82a3be5f6dea3&app_code=f5b3d4ba-7e7a-11f0-9de7-00e04f309c26&user_input={encoded_question}"
all_markdowns: List[str] = []
final_content = "" # 存储最终完整内容
# 1. 新接口基础配置
HEADERS = {
"Authorization": KNOWL_TOKEN,
"Content-Type": "application/json",
"Accept": "text/event-stream",
"Accept-Charset": "utf-8"
}
# 请求体(user_id和question动态传入,其他参数固定)
PAYLOAD = {
"bot_id": BOT_ID, # 接口固定bot_id
"user_id": str(user_id), # 转为字符串适配接口
"additional_messages": [
{
"role": "user",
"type": "question",
"content": question, # 用户查询问题
"content_type": "text"
}
],
"stream": False,
"auto_save_history": True,
"enable_card": True
}
full_answer = "" # 核心:拼接你需要的「最终完整回答」
current_event_type: Optional[str] = None
try:
with requests.get(url, stream=True, timeout=60) as response:
response.encoding = "utf-8"
response.raise_for_status()
for line in response.iter_lines():
if not line:
continue
line_str = line.decode("utf-8", errors="replace")
if not line_str.startswith("data:"):
# 2. 发送SSE请求并流式接收
with requests.post(
url=KNOWLEDGE_URL,
headers=HEADERS,
data=json.dumps(PAYLOAD, ensure_ascii=False), # 请求体UTF-8编码
stream=True,
timeout=60
) as response:
response.raise_for_status() # 检查请求是否成功
# 3. 逐行解析SSE流
for line_bytes in response.iter_lines():
if not line_bytes: # 跳过空行(事件分隔符)
continue
data_str = line_str[5:].strip()
# 4. 处理乱码(重点解决双重编码问题)
try:
data_json = json.loads(data_str)
vis_content = data_json.get("vis", "")
# 提取所有markdown内容
code_blocks = re.findall(r'```(.*?)```', vis_content, re.DOTALL)
for block in code_blocks:
block_parts = block.split('\n', 1)
if len(block_parts) < 2:
continue
block_type, block_content = block_parts
block_content = block_content.strip()
# 优先UTF-8解码(正常情况)
line = line_bytes.decode("utf-8", errors="strict")
except UnicodeDecodeError:
# 修复"UTF-8→ISO-8859-1"双重编码(常见中文乱码原因)
line = line_bytes.decode("iso-8859-1").encode("iso-8859-1").decode("utf-8")
# 5. 提取事件类型(如 conversation.message.delta)
if line.startswith("event:"):
current_event_type = line.split(":", 1)[1].strip()
continue
# 6. 提取事件数据(只关注回答片段)
if line.startswith("data:"):
data_str = line.split(":", 1)[1].strip()
if not data_str:
continue
try:
# 解析JSON数据(处理可能的编码问题)
try:
items = json.loads(block_content)
if isinstance(items, list):
for item in items:
if isinstance(item, dict) and "markdown" in item:
md_content = item["markdown"].strip()
all_markdowns.append(md_content)
# 处理嵌套的markdown
nested_blocks = re.findall(r'```(.*?)```', md_content, re.DOTALL)
for nested in nested_blocks:
nested_parts = nested.split('\n', 1)
if len(nested_parts) >= 2:
nested_content = nested_parts[1].strip()
try:
nested_items = json.loads(nested_content)
if isinstance(nested_items, list):
for ni in nested_items:
if isinstance(ni, dict) and "markdown" in ni:
nested_md = ni["markdown"].strip()
all_markdowns.append(nested_md)
except json.JSONDecodeError:
continue
except json.JSONDecodeError:
continue
except json.JSONDecodeError:
continue
event_data = json.loads(data_str)
except UnicodeDecodeError:
data_str_fixed = data_str.encode("iso-8859-1").decode("utf-8")
event_data = json.loads(data_str_fixed)
# 7. 核心:拼接回答片段(只取 "conversation.message.delta" 事件的 answer 内容)
if (current_event_type == "conversation.message.delta"
and event_data.get("type") == "answer"):
# 提取当前片段(如"直"、"隶"、"总督署")
answer_chunk = event_data.get("content", "").strip()
# 修复片段中的乱码(兜底)
try:
answer_chunk = answer_chunk.encode("iso-8859-1").decode("utf-8")
except:
pass
full_answer += answer_chunk # 拼接成完整回答
except json.JSONDecodeError:
continue # 跳过无效JSON,不影响整体
# 8. 清理最终回答(去除多余空格/空行)
full_answer = full_answer.strip()
print("【调试】,知识库内容:", full_answer)
return full_answer
except requests.exceptions.RequestException as e:
print(f"请求错误: {e}")
error_msg = f"请求错误: {e}"
print(error_msg)
return "" # 错误时返回空字符串
except Exception as e:
error_msg = f"解析错误: {e}"
print(error_msg)
return ""
# 核心逻辑:筛选出最长且完整的内容(流式响应中最后完成的内容通常最长)
if all_markdowns:
# 按长度倒序排序,取最长的非空内容
all_markdowns = [md for md in all_markdowns if md] # 过滤空字符串
if all_markdowns:
final_content = max(all_markdowns, key=len)
return final_content
# 添加用于多景区比较的新提示词
MULTI_SCENIC_EXTRACT_PROMPT = """你是一名景区名称提取助手。用户的问题中可能包含多个景区名称,请根据下面的完整景区名称列表,准确提取用户提到的所有景区名称并返回,每个景区名称占一行。如果用户没有提到任何景区,返回空字符串。
完整景区名称列表
@ -946,7 +977,7 @@ async def extract_multi_scenic(msg) -> list:
msg_content = '\n'.join(msg)
else:
msg_content = msg
print(f"Starting multi scenic extraction for message: {msg_content}")
try:
response = await async_client.chat.completions.create(
@ -966,7 +997,7 @@ async def query_multi_scenic_flow(request: Request, scenics: list, msg: str, red
if not scenics:
print("No scenics found, returning default message.")
return "**未找到景区信息,请检查名称是否正确。**\n\n(内容由AI生成,仅供参考)"
# 查询多个景区的客流数据
results = []
for scenic in scenics:
@ -975,7 +1006,7 @@ async def query_multi_scenic_flow(request: Request, scenics: list, msg: str, red
"scenic": scenic,
"data": data
})
# 生成比较结果
if len(results) == 1:
return results[0]["data"]
@ -997,7 +1028,7 @@ async def query_multi_scenic_flow(request: Request, scenics: list, msg: str, red
# 如果AI比较失败,返回原始数据
result_str = "\n\n".join([f"**{r['scenic']}**:\n{r['data']}" for r in results])
return result_str
return "**未找到景区信息,请检查名称是否正确。**\n\n(内容由AI生成,仅供参考)"

Loading…
Cancel
Save