@ -6,17 +6,25 @@ from app.models.ChatIn import ChatIn
from fastapi import Request
from app . settings . config import settings
import json
import re
from typing import List
import requests
from chinese_calendar import is_holiday
from datetime import datetime
from typing import Optional
load_dotenv ( )
async_client = AsyncOpenAI ( api_key = settings . DEEPSEEK_API_KEY , base_url = settings . DEEPSEEK_API_URL )
# 知识库接口
KNOWLEDGE_URL = " http://172.21.11.20:8886/v3/chat "
# KNOWLEDGE_URL = "http://192.168.130.144:8888/v3/chat"
# 知识库应用id
BOT_ID = " 7550848889243303936 "
# 知识库token
KNOWL_TOKEN = " Bearer pat_5cb11052e7b4b517015467902cd7775742120fc88fe66b926c93fde3a39843c7 "
#分类提示词
CATEGORY_PROMPT = """ 你是一个分类助手,请根据用户的问题判断属于以下哪一类:
如果用户的问题涉及保定市某个景区当前的人数 、 客流量 、 拥挤程度或是否适合前往 ( 例如 : “ 某个保定市景区现在人多么 ” 、 “ 某个保定市景区现在适不适合去 ” 、 " 现在可以去吗 " 、 " 现在适合去吗 " , 注意 : 只有涉及实时 、 现在等当前时间的 , 如果是明天 、 后天等未来时间的不包括在内 ) , 请返回 : 游玩判断 。
@ -206,7 +214,6 @@ async def ai_chat_stream(inp: ChatIn, conversation_history: list) -> AsyncGenera
messages = [ { " role " : " system " , " content " : chat_prompt } ] + conversation_history
messages . append ( { " role " : " user " , " content " : inp . message } )
print ( f " Starting AI chat stream with input: { inp . message } " )
full_response = " "
try :
response = await async_client . chat . completions . create (
@ -241,6 +248,9 @@ async def ai_chat_stream(inp: ChatIn, conversation_history: list) -> AsyncGenera
if full_response :
conversation_history . append ( { " role " : " assistant " , " content " : full_response } )
# 限制对话历史长度为10条(5轮对话)
if len ( conversation_history ) > 10 :
conversation_history = conversation_history [ - 10 : ]
print ( " AI chat stream finished. " )
def get_formatted_prompt ( user_language , msg , data ) :
@ -252,7 +262,6 @@ async def gen_markdown_stream(msg: str, data: str, language: str, conversation_h
messages = conversation_history + [ { " role " : " user " , " content " : prompt } ]
print ( f " Starting markdown stream with message: { msg } and data: { data } " )
full_response = " "
try :
response = await async_client . chat . completions . create (
@ -286,6 +295,9 @@ async def gen_markdown_stream(msg: str, data: str, language: str, conversation_h
if full_response :
conversation_history . append ( { " role " : " assistant " , " content " : full_response } )
# 限制对话历史长度为10条(5轮对话)
if len ( conversation_history ) > 10 :
conversation_history = conversation_history [ - 10 : ]
print ( " Markdown stream finished. " )
async def extract_spot ( msg ) - > str :
# 如果msg是列表,则将其内容连接成字符串
@ -293,8 +305,7 @@ async def extract_spot(msg) -> str:
msg_content = ' \n ' . join ( msg )
else :
msg_content = msg
print ( f " Starting spot extraction for message: { msg_content } " )
try :
response = await async_client . chat . completions . create (
model = " deepseek-chat " ,
@ -326,11 +337,9 @@ async def query_flow(request: Request, spot: str, redis_client = None) -> str:
redis_client = request . app . state . redis_client
# Step 1: Redis 缓存查询
print ( f " Querying Redis cache for key: { cache_key } " )
try :
cached = await redis_client . get ( cache_key )
if cached :
print ( f " Found cached data for key: { cache_key } " )
return cached
else :
return f " 未找到景区【 { spot } 】的客流相关信息,在园人数和舒适度未知;停车场信息:暂无数据。 "
@ -360,7 +369,7 @@ async def query_flow(request: Request, spot: str, redis_client = None) -> str:
await cur . execute ( formatted_flow_query )
row = await cur . fetchone ( )
# 查询停车场信息
park_query = """ SELECT t3.park_name AS park_name, IFNULL(t3.rate_info, ' 暂无收费标准信息 ' ) AS rate_info, t3.total_count AS total_count, t4.space AS space, t1.distance_value AS distance_value
FROM cyjcpt_bd . scenic_pack_distance t1
@ -387,7 +396,7 @@ async def query_flow(request: Request, spot: str, redis_client = None) -> str:
except Exception as e :
print ( f " [MySQL] 查询失败: { e } " )
return f " **未找到景区【 { spot } 】的信息,请检查名称是否正确。 \n \n (内容仅供参考) "
result = " "
if row and all ( v is not None for v in row ) :
# 使用变量名访问客流数据
@ -465,10 +474,6 @@ async def handle_quick_question(inp: ChatIn, question_content: str) -> AsyncGene
print ( error_msg )
yield error_msg
# 不保存快捷问题的对话历史
print ( " Quick question handling finished. " )
# 在chat_service.py中添加推荐问题生成函数
async def generate_recommended_questions ( user_msg : str , ai_response : str ) - > list :
""" 基于用户问题和AI回答生成1-3个纵向延伸的推荐问题 """
@ -543,7 +548,7 @@ async def get_all_scenic_flow_data(request: Request, redis_client = None) -> lis
"""
await cur . execute ( query )
rows = await cur . fetchall ( )
# 处理结果
result = [ ]
for row in rows :
@ -551,13 +556,13 @@ async def get_all_scenic_flow_data(request: Request, redis_client = None) -> lis
in_park_num = abs ( enter_num - leave_num ) # 确保是正数
if in_park_num > max_capacity :
in_park_num = max_capacity
# 避免除以零的情况
if max_capacity > 0 :
capacity_rate = in_park_num / max_capacity
else :
capacity_rate = 0
result . append ( {
" id " : id ,
" scenic_name " : scenic_name ,
@ -567,15 +572,15 @@ async def get_all_scenic_flow_data(request: Request, redis_client = None) -> lis
" max_capacity " : max_capacity ,
" capacity_rate " : capacity_rate
} )
# 将结果存入Redis缓存,过期时间1分钟
try :
await redis_client . setex ( cache_key , 60 , json . dumps ( result ) )
except Exception as e :
print ( f " [Redis] 写缓存失败: { e } " )
return result
except Exception as e :
print ( f " [MySQL] 查询所有景区客流数据失败: { e } " )
return [ ]
@ -620,12 +625,12 @@ async def get_scenic_detail_data(request: Request, id: int, redis_client = None)
"""
await cur . execute ( query , ( id , ) )
row = await cur . fetchone ( )
if not row :
return None
scenic_name , enter_num , leave_num , max_capacity = row
# 计算在园人数
in_park_num = abs ( enter_num - leave_num ) # 确保是正数
if in_park_num > max_capacity :
@ -650,7 +655,7 @@ async def get_scenic_detail_data(request: Request, id: int, redis_client = None)
else :
capacity_rate = 0.0
comfort_level = " 舒适 "
result = {
" scenic_name " : scenic_name ,
" enter_num " : enter_num or 0 ,
@ -660,15 +665,15 @@ async def get_scenic_detail_data(request: Request, id: int, redis_client = None)
" capacity_rate " : round ( capacity_rate , 4 ) ,
" comfort_level " : comfort_level
}
# 将结果存入Redis缓存,过期时间1分钟
try :
await redis_client . setex ( cache_key , 60 , json . dumps ( result ) )
except Exception as e :
print ( f " [Redis] 写缓存失败: { e } " )
return result
except Exception as e :
print ( f " [MySQL] 查询景区详情数据失败: { e } " )
return None
@ -678,13 +683,13 @@ async def get_scenic_detail_data(request: Request, id: int, redis_client = None)
async def get_scenic_parking_data ( request : Request , scenic_id : int , distance : int , redis_client = None ) - > list :
"""
查询景区附近的停车场信息
Args :
request : FastAPI请求对象
scenic_id : 景区id
distance : 查询距离 ( 米 ) , > = 1000 时查询全部
redis_client : Redis客户端实例 ( 可选 )
Returns :
list : 停车场信息列表 , 按距离排序
"""
@ -763,12 +768,12 @@ async def get_scenic_parking_data(request: Request, scenic_id: int, distance: in
await cur . execute ( formatted_park_base )
rows = await cur . fetchall ( )
# 处理结果
result = [ ]
for row in rows :
park_name , total_spaces , available_spaces , distance_meters , lon , lat , park_type = row
result . append ( {
" park_name " : park_name ,
" total_parking_spaces " : total_spaces or 0 ,
@ -778,15 +783,15 @@ async def get_scenic_parking_data(request: Request, scenic_id: int, distance: in
" lat " : lat or 0 ,
" park_type " : park_type
} )
# 将结果存入Redis缓存,过期时间1分钟
try :
await redis_client . setex ( cache_key , 60 , json . dumps ( result ) )
except Exception as e :
print ( f " [Redis] 写缓存失败: { e } " )
return result
except Exception as e :
print ( f " [MySQL] 查询景区停车场数据失败: { e } " )
return [ ]
@ -794,82 +799,108 @@ async def get_scenic_parking_data(request: Request, scenic_id: int, distance: in
# 添加用于获取完整响应数据的新函数
def fetch_and_parse_markdown ( user_id : int , question : str ) - > str :
"""
只提取最终完整的markdown内容 ( 过滤流式中间片段 )
功能 : 发送请求 、 解析SSE流 、 提取完整知识库内容 、 处理乱码
返回 : 纯净的知识库markdown内容 ( 与原fetch_and_parse_markdown返回格式一致 )
"""
encoded_question = requests . utils . quote ( question )
# url = f"http://cjy.aitto.net:45678/api/v3/user_share_chat_completions?random={user_id}&api_key=cjy-626e50140e934936b8c82a3be5f6dea3&app_code=f5b3d4ba-7e7a-11f0-9de7-00e04f309c26&user_input={encoded_question}"
url = f " http://127.0.0.1:5679/api/v3/user_share_chat_completions?random= { user_id } &api_key=cjy-626e50140e934936b8c82a3be5f6dea3&app_code=f5b3d4ba-7e7a-11f0-9de7-00e04f309c26&user_input= { encoded_question } "
all_markdowns : List [ str ] = [ ]
final_content = " " # 存储最终完整内容
# 1. 新接口基础配置
HEADERS = {
" Authorization " : KNOWL_TOKEN ,
" Content-Type " : " application/json " ,
" Accept " : " text/event-stream " ,
" Accept-Charset " : " utf-8 "
}
# 请求体(user_id和question动态传入,其他参数固定)
PAYLOAD = {
" bot_id " : BOT_ID , # 接口固定bot_id
" user_id " : str ( user_id ) , # 转为字符串适配接口
" additional_messages " : [
{
" role " : " user " ,
" type " : " question " ,
" content " : question , # 用户查询问题
" content_type " : " text "
}
] ,
" stream " : False ,
" auto_save_history " : True ,
" enable_card " : True
}
full_answer = " " # 核心:拼接你需要的「最终完整回答」
current_event_type : Optional [ str ] = None
try :
with requests . get ( url , stream = True , timeout = 60 ) as response :
response . encoding = " utf-8 "
response . raise_for_status ( )
for line in response . iter_lines ( ) :
if not line :
continue
line_str = line . decode ( " utf-8 " , errors = " replace " )
if not line_str . startswith ( " data: " ) :
# 2. 发送SSE请求并流式接收
with requests . post (
url = KNOWLEDGE_URL ,
headers = HEADERS ,
data = json . dumps ( PAYLOAD , ensure_ascii = False ) , # 请求体UTF-8编码
stream = True ,
timeout = 60
) as response :
response . raise_for_status ( ) # 检查请求是否成功
# 3. 逐行解析SSE流
for line_bytes in response . iter_lines ( ) :
if not line_bytes : # 跳过空行(事件分隔符)
continue
data_str = line_str [ 5 : ] . strip ( )
# 4. 处理乱码(重点解决双重编码问题)
try :
data_json = json . loads ( data_str )
vis_content = data_json . get ( " vis " , " " )
# 提取所有markdown内容
code_blocks = re . findall ( r ' ```(.*?)``` ' , vis_content , re . DOTALL )
for block in code_blocks :
block_parts = block . split ( ' \n ' , 1 )
if len ( block_parts ) < 2 :
continue
block_type , block_content = block_parts
block_content = block_content . strip ( )
# 优先UTF-8解码(正常情况)
line = line_bytes . decode ( " utf-8 " , errors = " strict " )
except UnicodeDecodeError :
# 修复"UTF-8→ISO-8859-1"双重编码(常见中文乱码原因)
line = line_bytes . decode ( " iso-8859-1 " ) . encode ( " iso-8859-1 " ) . decode ( " utf-8 " )
# 5. 提取事件类型(如 conversation.message.delta)
if line . startswith ( " event: " ) :
current_event_type = line . split ( " : " , 1 ) [ 1 ] . strip ( )
continue
# 6. 提取事件数据(只关注回答片段)
if line . startswith ( " data: " ) :
data_str = line . split ( " : " , 1 ) [ 1 ] . strip ( )
if not data_str :
continue
try :
# 解析JSON数据(处理可能的编码问题)
try :
items = json . loads ( block_content )
if isinstance ( items , list ) :
for item in items :
if isinstance ( item , dict ) and " markdown " in item :
md_content = item [ " markdown " ] . strip ( )
all_markdowns . append ( md_content )
# 处理嵌套的markdown
nested_blocks = re . findall ( r ' ```(.*?)``` ' , md_content , re . DOTALL )
for nested in nested_blocks :
nested_parts = nested . split ( ' \n ' , 1 )
if len ( nested_parts ) > = 2 :
nested_content = nested_parts [ 1 ] . strip ( )
try :
nested_items = json . loads ( nested_content )
if isinstance ( nested_items , list ) :
for ni in nested_items :
if isinstance ( ni , dict ) and " markdown " in ni :
nested_md = ni [ " markdown " ] . strip ( )
all_markdowns . append ( nested_md )
except json . JSONDecodeError :
continue
except json . JSONDecodeError :
continue
except json . JSONDecodeError :
continue
event_data = json . loads ( data_str )
except UnicodeDecodeError :
data_str_fixed = data_str . encode ( " iso-8859-1 " ) . decode ( " utf-8 " )
event_data = json . loads ( data_str_fixed )
# 7. 核心:拼接回答片段(只取 "conversation.message.delta" 事件的 answer 内容)
if ( current_event_type == " conversation.message.delta "
and event_data . get ( " type " ) == " answer " ) :
# 提取当前片段(如"直"、"隶"、"总督署")
answer_chunk = event_data . get ( " content " , " " ) . strip ( )
# 修复片段中的乱码(兜底)
try :
answer_chunk = answer_chunk . encode ( " iso-8859-1 " ) . decode ( " utf-8 " )
except :
pass
full_answer + = answer_chunk # 拼接成完整回答
except json . JSONDecodeError :
continue # 跳过无效JSON,不影响整体
# 8. 清理最终回答(去除多余空格/空行)
full_answer = full_answer . strip ( )
print ( " 【调试】,知识库内容: " , full_answer )
return full_answer
except requests . exceptions . RequestException as e :
print ( f " 请求错误: { e } " )
error_msg = f " 请求错误: { e } "
print ( error_msg )
return " " # 错误时返回空字符串
except Exception as e :
error_msg = f " 解析错误: { e } "
print ( error_msg )
return " "
# 核心逻辑:筛选出最长且完整的内容(流式响应中最后完成的内容通常最长)
if all_markdowns :
# 按长度倒序排序,取最长的非空内容
all_markdowns = [ md for md in all_markdowns if md ] # 过滤空字符串
if all_markdowns :
final_content = max ( all_markdowns , key = len )
return final_content
# 添加用于多景区比较的新提示词
MULTI_SCENIC_EXTRACT_PROMPT = """ 你是一名景区名称提取助手。用户的问题中可能包含多个景区名称,请根据下面的完整景区名称列表,准确提取用户提到的所有景区名称并返回,每个景区名称占一行。如果用户没有提到任何景区,返回空字符串。
完整景区名称列表 :
@ -946,7 +977,7 @@ async def extract_multi_scenic(msg) -> list:
msg_content = ' \n ' . join ( msg )
else :
msg_content = msg
print ( f " Starting multi scenic extraction for message: { msg_content } " )
try :
response = await async_client . chat . completions . create (
@ -966,7 +997,7 @@ async def query_multi_scenic_flow(request: Request, scenics: list, msg: str, red
if not scenics :
print ( " No scenics found, returning default message. " )
return " **未找到景区信息,请检查名称是否正确。** \n \n (内容由AI生成,仅供参考) "
# 查询多个景区的客流数据
results = [ ]
for scenic in scenics :
@ -975,7 +1006,7 @@ async def query_multi_scenic_flow(request: Request, scenics: list, msg: str, red
" scenic " : scenic ,
" data " : data
} )
# 生成比较结果
if len ( results ) == 1 :
return results [ 0 ] [ " data " ]
@ -997,7 +1028,7 @@ async def query_multi_scenic_flow(request: Request, scenics: list, msg: str, red
# 如果AI比较失败,返回原始数据
result_str = " \n \n " . join ( [ f " ** { r [ ' scenic ' ] } **: \n { r [ ' data ' ] } " for r in results ] )
return result_str
return " **未找到景区信息,请检查名称是否正确。** \n \n (内容由AI生成,仅供参考) "