|
|
|
|
@ -6,17 +6,25 @@ from app.models.ChatIn import ChatIn |
|
|
|
|
from fastapi import Request |
|
|
|
|
from app.settings.config import settings |
|
|
|
|
import json |
|
|
|
|
import re |
|
|
|
|
from typing import List |
|
|
|
|
import requests |
|
|
|
|
from chinese_calendar import is_holiday |
|
|
|
|
from datetime import datetime |
|
|
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
|
async_client = AsyncOpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url=settings.DEEPSEEK_API_URL) |
|
|
|
|
|
|
|
|
|
# 知识库接口 |
|
|
|
|
KNOWLEDGE_URL = "http://172.21.11.20:8886/v3/chat" |
|
|
|
|
# KNOWLEDGE_URL = "http://192.168.130.144:8888/v3/chat" |
|
|
|
|
# 知识库应用id |
|
|
|
|
BOT_ID = "7550848889243303936" |
|
|
|
|
# 知识库token |
|
|
|
|
KNOWL_TOKEN = "Bearer pat_5cb11052e7b4b517015467902cd7775742120fc88fe66b926c93fde3a39843c7" |
|
|
|
|
|
|
|
|
|
#分类提示词 |
|
|
|
|
CATEGORY_PROMPT = """你是一个分类助手,请根据用户的问题判断属于以下哪一类: |
|
|
|
|
如果用户的问题涉及保定市某个景区当前的人数、客流量、拥挤程度或是否适合前往(例如:“某个保定市景区现在人多么”、“某个保定市景区现在适不适合去”、"现在可以去吗"、"现在适合去吗",注意:只有涉及实时、现在等当前时间的,如果是明天、后天等未来时间的不包括在内),请返回:游玩判断。 |
|
|
|
|
@ -206,7 +214,6 @@ async def ai_chat_stream(inp: ChatIn, conversation_history: list) -> AsyncGenera |
|
|
|
|
messages = [{"role": "system", "content": chat_prompt}] + conversation_history |
|
|
|
|
messages.append({"role": "user", "content": inp.message}) |
|
|
|
|
|
|
|
|
|
print(f"Starting AI chat stream with input: {inp.message}") |
|
|
|
|
full_response = "" |
|
|
|
|
try: |
|
|
|
|
response = await async_client.chat.completions.create( |
|
|
|
|
@ -241,6 +248,9 @@ async def ai_chat_stream(inp: ChatIn, conversation_history: list) -> AsyncGenera |
|
|
|
|
|
|
|
|
|
if full_response: |
|
|
|
|
conversation_history.append({"role": "assistant", "content": full_response}) |
|
|
|
|
# 限制对话历史长度为10条(5轮对话) |
|
|
|
|
if len(conversation_history) > 10: |
|
|
|
|
conversation_history = conversation_history[-10:] |
|
|
|
|
print("AI chat stream finished.") |
|
|
|
|
|
|
|
|
|
def get_formatted_prompt(user_language,msg,data): |
|
|
|
|
@ -252,7 +262,6 @@ async def gen_markdown_stream(msg: str, data: str, language: str, conversation_h |
|
|
|
|
|
|
|
|
|
messages = conversation_history + [{"role": "user", "content": prompt}] |
|
|
|
|
|
|
|
|
|
print(f"Starting markdown stream with message: {msg} and data: {data}") |
|
|
|
|
full_response = "" |
|
|
|
|
try: |
|
|
|
|
response = await async_client.chat.completions.create( |
|
|
|
|
@ -286,6 +295,9 @@ async def gen_markdown_stream(msg: str, data: str, language: str, conversation_h |
|
|
|
|
|
|
|
|
|
if full_response: |
|
|
|
|
conversation_history.append({"role": "assistant", "content": full_response}) |
|
|
|
|
# 限制对话历史长度为10条(5轮对话) |
|
|
|
|
if len(conversation_history) > 10: |
|
|
|
|
conversation_history = conversation_history[-10:] |
|
|
|
|
print("Markdown stream finished.") |
|
|
|
|
async def extract_spot(msg) -> str: |
|
|
|
|
# 如果msg是列表,则将其内容连接成字符串 |
|
|
|
|
@ -294,7 +306,6 @@ async def extract_spot(msg) -> str: |
|
|
|
|
else: |
|
|
|
|
msg_content = msg |
|
|
|
|
|
|
|
|
|
print(f"Starting spot extraction for message: {msg_content}") |
|
|
|
|
try: |
|
|
|
|
response = await async_client.chat.completions.create( |
|
|
|
|
model="deepseek-chat", |
|
|
|
|
@ -326,11 +337,9 @@ async def query_flow(request: Request, spot: str, redis_client = None) -> str: |
|
|
|
|
redis_client = request.app.state.redis_client |
|
|
|
|
|
|
|
|
|
# Step 1: Redis 缓存查询 |
|
|
|
|
print(f"Querying Redis cache for key: {cache_key}") |
|
|
|
|
try: |
|
|
|
|
cached = await redis_client.get(cache_key) |
|
|
|
|
if cached: |
|
|
|
|
print(f"Found cached data for key: {cache_key}") |
|
|
|
|
return cached |
|
|
|
|
else: |
|
|
|
|
return f"未找到景区【{spot}】的客流相关信息,在园人数和舒适度未知;停车场信息:暂无数据。" |
|
|
|
|
@ -465,10 +474,6 @@ async def handle_quick_question(inp: ChatIn, question_content: str) -> AsyncGene |
|
|
|
|
print(error_msg) |
|
|
|
|
yield error_msg |
|
|
|
|
|
|
|
|
|
# 不保存快捷问题的对话历史 |
|
|
|
|
print("Quick question handling finished.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 在chat_service.py中添加推荐问题生成函数 |
|
|
|
|
async def generate_recommended_questions(user_msg: str, ai_response: str) -> list: |
|
|
|
|
"""基于用户问题和AI回答生成1-3个纵向延伸的推荐问题""" |
|
|
|
|
@ -794,82 +799,108 @@ async def get_scenic_parking_data(request: Request, scenic_id: int, distance: in |
|
|
|
|
# 添加用于获取完整响应数据的新函数 |
|
|
|
|
def fetch_and_parse_markdown(user_id: int, question: str) -> str: |
|
|
|
|
""" |
|
|
|
|
只提取最终完整的markdown内容(过滤流式中间片段) |
|
|
|
|
功能:发送请求、解析SSE流、提取完整知识库内容、处理乱码 |
|
|
|
|
返回:纯净的知识库markdown内容(与原fetch_and_parse_markdown返回格式一致) |
|
|
|
|
""" |
|
|
|
|
encoded_question = requests.utils.quote(question) |
|
|
|
|
# url = f"http://cjy.aitto.net:45678/api/v3/user_share_chat_completions?random={user_id}&api_key=cjy-626e50140e934936b8c82a3be5f6dea3&app_code=f5b3d4ba-7e7a-11f0-9de7-00e04f309c26&user_input={encoded_question}" |
|
|
|
|
url = f"http://127.0.0.1:5679/api/v3/user_share_chat_completions?random={user_id}&api_key=cjy-626e50140e934936b8c82a3be5f6dea3&app_code=f5b3d4ba-7e7a-11f0-9de7-00e04f309c26&user_input={encoded_question}" |
|
|
|
|
|
|
|
|
|
all_markdowns: List[str] = [] |
|
|
|
|
final_content = "" # 存储最终完整内容 |
|
|
|
|
# 1. 新接口基础配置 |
|
|
|
|
HEADERS = { |
|
|
|
|
"Authorization": KNOWL_TOKEN, |
|
|
|
|
"Content-Type": "application/json", |
|
|
|
|
"Accept": "text/event-stream", |
|
|
|
|
"Accept-Charset": "utf-8" |
|
|
|
|
} |
|
|
|
|
# 请求体(user_id和question动态传入,其他参数固定) |
|
|
|
|
PAYLOAD = { |
|
|
|
|
"bot_id": BOT_ID, # 接口固定bot_id |
|
|
|
|
"user_id": str(user_id), # 转为字符串适配接口 |
|
|
|
|
"additional_messages": [ |
|
|
|
|
{ |
|
|
|
|
"role": "user", |
|
|
|
|
"type": "question", |
|
|
|
|
"content": question, # 用户查询问题 |
|
|
|
|
"content_type": "text" |
|
|
|
|
} |
|
|
|
|
], |
|
|
|
|
"stream": False, |
|
|
|
|
"auto_save_history": True, |
|
|
|
|
"enable_card": True |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
full_answer = "" # 核心:拼接你需要的「最终完整回答」 |
|
|
|
|
current_event_type: Optional[str] = None |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
with requests.get(url, stream=True, timeout=60) as response: |
|
|
|
|
response.encoding = "utf-8" |
|
|
|
|
response.raise_for_status() |
|
|
|
|
|
|
|
|
|
for line in response.iter_lines(): |
|
|
|
|
if not line: |
|
|
|
|
continue |
|
|
|
|
line_str = line.decode("utf-8", errors="replace") |
|
|
|
|
if not line_str.startswith("data:"): |
|
|
|
|
# 2. 发送SSE请求并流式接收 |
|
|
|
|
with requests.post( |
|
|
|
|
url=KNOWLEDGE_URL, |
|
|
|
|
headers=HEADERS, |
|
|
|
|
data=json.dumps(PAYLOAD, ensure_ascii=False), # 请求体UTF-8编码 |
|
|
|
|
stream=True, |
|
|
|
|
timeout=60 |
|
|
|
|
) as response: |
|
|
|
|
response.raise_for_status() # 检查请求是否成功 |
|
|
|
|
|
|
|
|
|
# 3. 逐行解析SSE流 |
|
|
|
|
for line_bytes in response.iter_lines(): |
|
|
|
|
if not line_bytes: # 跳过空行(事件分隔符) |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
data_str = line_str[5:].strip() |
|
|
|
|
# 4. 处理乱码(重点解决双重编码问题) |
|
|
|
|
try: |
|
|
|
|
data_json = json.loads(data_str) |
|
|
|
|
vis_content = data_json.get("vis", "") |
|
|
|
|
|
|
|
|
|
# 提取所有markdown内容 |
|
|
|
|
code_blocks = re.findall(r'```(.*?)```', vis_content, re.DOTALL) |
|
|
|
|
for block in code_blocks: |
|
|
|
|
block_parts = block.split('\n', 1) |
|
|
|
|
if len(block_parts) < 2: |
|
|
|
|
continue |
|
|
|
|
block_type, block_content = block_parts |
|
|
|
|
block_content = block_content.strip() |
|
|
|
|
# 优先UTF-8解码(正常情况) |
|
|
|
|
line = line_bytes.decode("utf-8", errors="strict") |
|
|
|
|
except UnicodeDecodeError: |
|
|
|
|
# 修复"UTF-8→ISO-8859-1"双重编码(常见中文乱码原因) |
|
|
|
|
line = line_bytes.decode("iso-8859-1").encode("iso-8859-1").decode("utf-8") |
|
|
|
|
|
|
|
|
|
# 5. 提取事件类型(如 conversation.message.delta) |
|
|
|
|
if line.startswith("event:"): |
|
|
|
|
current_event_type = line.split(":", 1)[1].strip() |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
# 6. 提取事件数据(只关注回答片段) |
|
|
|
|
if line.startswith("data:"): |
|
|
|
|
data_str = line.split(":", 1)[1].strip() |
|
|
|
|
if not data_str: |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
# 解析JSON数据(处理可能的编码问题) |
|
|
|
|
try: |
|
|
|
|
items = json.loads(block_content) |
|
|
|
|
if isinstance(items, list): |
|
|
|
|
for item in items: |
|
|
|
|
if isinstance(item, dict) and "markdown" in item: |
|
|
|
|
md_content = item["markdown"].strip() |
|
|
|
|
all_markdowns.append(md_content) |
|
|
|
|
|
|
|
|
|
# 处理嵌套的markdown |
|
|
|
|
nested_blocks = re.findall(r'```(.*?)```', md_content, re.DOTALL) |
|
|
|
|
for nested in nested_blocks: |
|
|
|
|
nested_parts = nested.split('\n', 1) |
|
|
|
|
if len(nested_parts) >= 2: |
|
|
|
|
nested_content = nested_parts[1].strip() |
|
|
|
|
try: |
|
|
|
|
nested_items = json.loads(nested_content) |
|
|
|
|
if isinstance(nested_items, list): |
|
|
|
|
for ni in nested_items: |
|
|
|
|
if isinstance(ni, dict) and "markdown" in ni: |
|
|
|
|
nested_md = ni["markdown"].strip() |
|
|
|
|
all_markdowns.append(nested_md) |
|
|
|
|
except json.JSONDecodeError: |
|
|
|
|
continue |
|
|
|
|
except json.JSONDecodeError: |
|
|
|
|
continue |
|
|
|
|
except json.JSONDecodeError: |
|
|
|
|
continue |
|
|
|
|
event_data = json.loads(data_str) |
|
|
|
|
except UnicodeDecodeError: |
|
|
|
|
data_str_fixed = data_str.encode("iso-8859-1").decode("utf-8") |
|
|
|
|
event_data = json.loads(data_str_fixed) |
|
|
|
|
|
|
|
|
|
# 7. 核心:拼接回答片段(只取 "conversation.message.delta" 事件的 answer 内容) |
|
|
|
|
if (current_event_type == "conversation.message.delta" |
|
|
|
|
and event_data.get("type") == "answer"): |
|
|
|
|
# 提取当前片段(如"直"、"隶"、"总督署") |
|
|
|
|
answer_chunk = event_data.get("content", "").strip() |
|
|
|
|
# 修复片段中的乱码(兜底) |
|
|
|
|
try: |
|
|
|
|
answer_chunk = answer_chunk.encode("iso-8859-1").decode("utf-8") |
|
|
|
|
except: |
|
|
|
|
pass |
|
|
|
|
full_answer += answer_chunk # 拼接成完整回答 |
|
|
|
|
|
|
|
|
|
except json.JSONDecodeError: |
|
|
|
|
continue # 跳过无效JSON,不影响整体 |
|
|
|
|
|
|
|
|
|
# 8. 清理最终回答(去除多余空格/空行) |
|
|
|
|
full_answer = full_answer.strip() |
|
|
|
|
print("【调试】,知识库内容:", full_answer) |
|
|
|
|
return full_answer |
|
|
|
|
|
|
|
|
|
except requests.exceptions.RequestException as e: |
|
|
|
|
print(f"请求错误: {e}") |
|
|
|
|
error_msg = f"请求错误: {e}" |
|
|
|
|
print(error_msg) |
|
|
|
|
return "" # 错误时返回空字符串 |
|
|
|
|
except Exception as e: |
|
|
|
|
error_msg = f"解析错误: {e}" |
|
|
|
|
print(error_msg) |
|
|
|
|
return "" |
|
|
|
|
|
|
|
|
|
# 核心逻辑:筛选出最长且完整的内容(流式响应中最后完成的内容通常最长) |
|
|
|
|
if all_markdowns: |
|
|
|
|
# 按长度倒序排序,取最长的非空内容 |
|
|
|
|
all_markdowns = [md for md in all_markdowns if md] # 过滤空字符串 |
|
|
|
|
if all_markdowns: |
|
|
|
|
final_content = max(all_markdowns, key=len) |
|
|
|
|
|
|
|
|
|
return final_content |
|
|
|
|
|
|
|
|
|
# 添加用于多景区比较的新提示词 |
|
|
|
|
MULTI_SCENIC_EXTRACT_PROMPT = """你是一名景区名称提取助手。用户的问题中可能包含多个景区名称,请根据下面的完整景区名称列表,准确提取用户提到的所有景区名称并返回,每个景区名称占一行。如果用户没有提到任何景区,返回空字符串。 |
|
|
|
|
完整景区名称列表: |
|
|
|
|
|