|
|
|
@ -1,3 +1,6 @@ |
|
|
|
|
import random |
|
|
|
|
|
|
|
|
|
from anyio import sleep |
|
|
|
|
from fastapi import APIRouter, HTTPException, Depends,FastAPI |
|
|
|
|
|
|
|
|
|
from starlette.requests import Request |
|
|
|
@ -23,7 +26,8 @@ from app.api.chat.ai.chat_service import ( |
|
|
|
|
get_scenic_parking_data, |
|
|
|
|
extract_multi_scenic, |
|
|
|
|
query_multi_scenic_flow, |
|
|
|
|
get_all_toilet_data |
|
|
|
|
get_all_toilet_data, |
|
|
|
|
generate_recommended_questions |
|
|
|
|
) |
|
|
|
|
# 导入用于异步执行同步函数的模块 |
|
|
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
@ -54,7 +58,7 @@ async def universal_exception_handler(request: Request, exc: Exception): |
|
|
|
|
chunk_size = 8 |
|
|
|
|
for i in range(0, len(error_msg), chunk_size): |
|
|
|
|
chunk = error_msg[i:i + chunk_size] |
|
|
|
|
yield f"data: {json.dumps({'content': chunk})}\n\n" |
|
|
|
|
yield f"data: {chunk}\n\n" |
|
|
|
|
await asyncio.sleep(0.03) # 模拟自然打字速度 |
|
|
|
|
|
|
|
|
|
# 发送结束标记 |
|
|
|
@ -113,7 +117,7 @@ async def h5_chat_stream(request: Request, inp: ChatIn, redis_client = Depends(g |
|
|
|
|
chunk_size = 8 |
|
|
|
|
for i in range(0, len(error_msg), chunk_size): |
|
|
|
|
chunk = error_msg[i:i + chunk_size] |
|
|
|
|
yield f"data: {json.dumps({'content': chunk})}\n\n" |
|
|
|
|
yield f"data: {chunk}\n\n" |
|
|
|
|
await asyncio.sleep(0.03) |
|
|
|
|
|
|
|
|
|
yield "data: [DONE]\n\n" |
|
|
|
@ -159,30 +163,44 @@ async def _handle_chat_request( |
|
|
|
|
user_messages_content = [msg.get('content', '') for msg in user_messages] |
|
|
|
|
all_messages = user_messages_content + [inp.message] |
|
|
|
|
|
|
|
|
|
# 消息分类 |
|
|
|
|
cat = await classify(inp.message) |
|
|
|
|
print(f"Message category: {cat}") |
|
|
|
|
|
|
|
|
|
spot = None |
|
|
|
|
scenics = None |
|
|
|
|
if cat == "游玩判断" or cat == "保定文旅": |
|
|
|
|
spot = await extract_spot(all_messages) |
|
|
|
|
elif cat == "多景区比较": |
|
|
|
|
scenics = await extract_multi_scenic(all_messages) |
|
|
|
|
|
|
|
|
|
# 知识库查询准备 |
|
|
|
|
knowledge_task = None |
|
|
|
|
if spot: |
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
|
|
with ThreadPoolExecutor() as executor: |
|
|
|
|
knowledge_task = loop.run_in_executor(executor, fetch_and_parse_markdown, user_id, inp.message) |
|
|
|
|
|
|
|
|
|
# 获取快捷问题 |
|
|
|
|
questions = await QuickQuestion.filter(status="0", ischat="0").order_by("order_num").limit(4).values("title", |
|
|
|
|
"subtitle", |
|
|
|
|
"content") |
|
|
|
|
question_titles = [f"{q['subtitle']}{q['title']}" for q in questions] |
|
|
|
|
is_quick_question = inp.message in question_titles |
|
|
|
|
try: |
|
|
|
|
cached = await redis_client.get(inp.message) |
|
|
|
|
if cached: |
|
|
|
|
is_quick_question = True |
|
|
|
|
# 缓存命中,直接返回缓存数据 |
|
|
|
|
questions = cached |
|
|
|
|
else: |
|
|
|
|
is_quick_question = False |
|
|
|
|
except Exception as e: |
|
|
|
|
print(f"[Redis] 查询缓存失败: {e}") |
|
|
|
|
is_quick_question = False |
|
|
|
|
|
|
|
|
|
if not is_quick_question: |
|
|
|
|
# 消息分类 |
|
|
|
|
cat = await classify(inp.message) |
|
|
|
|
print(f"Message category: {cat}") |
|
|
|
|
|
|
|
|
|
spot = None |
|
|
|
|
scenics = None |
|
|
|
|
if cat == "游玩判断" or cat == "保定文旅": |
|
|
|
|
spot = await extract_spot(all_messages) |
|
|
|
|
elif cat == "多景区比较": |
|
|
|
|
scenics = await extract_multi_scenic(all_messages) |
|
|
|
|
|
|
|
|
|
# 知识库查询准备 |
|
|
|
|
knowledge_task = None |
|
|
|
|
if spot: |
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
|
|
with ThreadPoolExecutor() as executor: |
|
|
|
|
knowledge_task = loop.run_in_executor(executor, fetch_and_parse_markdown, user_id, inp.message) |
|
|
|
|
|
|
|
|
|
# # 获取快捷问题 |
|
|
|
|
# questions = await QuickQuestion.filter(status="0", ischat="0").order_by("order_num").limit(4).values("title", |
|
|
|
|
# "subtitle", |
|
|
|
|
# "content") |
|
|
|
|
# question_titles = [f"{q['subtitle']}{q['title']}" for q in questions] |
|
|
|
|
# is_quick_question = inp.message in question_titles |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 添加用户消息到历史 |
|
|
|
|
if not is_quick_question: |
|
|
|
@ -193,16 +211,36 @@ async def _handle_chat_request( |
|
|
|
|
nonlocal conversation_history |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
# 记录热门问题 |
|
|
|
|
await record_hot_question(inp.message) |
|
|
|
|
try: |
|
|
|
|
# 记录热门问题 |
|
|
|
|
await record_hot_question(inp.message) |
|
|
|
|
except Exception as e: |
|
|
|
|
print(f"记录热门问题失败: {e},问题内容为:{inp.message}") |
|
|
|
|
|
|
|
|
|
if is_quick_question: |
|
|
|
|
await asyncio.sleep(0.5) |
|
|
|
|
full_response = "" |
|
|
|
|
# 处理快捷问题 |
|
|
|
|
question_content = next( |
|
|
|
|
q["content"] for q in questions if f"{q['subtitle']}{q['title']}" == inp.message) |
|
|
|
|
async for chunk in handle_quick_question(inp, question_content): |
|
|
|
|
question_content = questions |
|
|
|
|
#每次输出随机1-10个字符 |
|
|
|
|
chunk_size = random.randint(5, 15) |
|
|
|
|
for i in range(0, len(question_content), chunk_size): |
|
|
|
|
chunk = question_content[i:i + chunk_size] |
|
|
|
|
full_response += chunk |
|
|
|
|
yield f"data: {chunk}\n\n" |
|
|
|
|
await asyncio.sleep(0.01) |
|
|
|
|
await asyncio.sleep(0.09) |
|
|
|
|
if full_response: |
|
|
|
|
recommended_questions = await generate_recommended_questions(inp.message, full_response) |
|
|
|
|
if recommended_questions: |
|
|
|
|
# 添加分隔符和标题 |
|
|
|
|
yield "\n\n### 您可能还想了解:" |
|
|
|
|
# 逐个返回推荐问题 |
|
|
|
|
for i, question in enumerate(recommended_questions, 1): |
|
|
|
|
yield f"\n{i}. {question}" |
|
|
|
|
|
|
|
|
|
# async for chunk in handle_quick_question(inp, question_content): |
|
|
|
|
# yield f"data: {chunk}\n\n" |
|
|
|
|
# await asyncio.sleep(0.01) |
|
|
|
|
else: |
|
|
|
|
# 处理不同分类的消息 |
|
|
|
|
if cat == "游玩判断": |
|
|
|
@ -272,7 +310,7 @@ async def _handle_chat_request( |
|
|
|
|
chunk_size = 8 |
|
|
|
|
for i in range(0, len(error_msg), chunk_size): |
|
|
|
|
chunk = error_msg[i:i + chunk_size] |
|
|
|
|
yield f"data: {json.dumps({'content': chunk})}\n\n" |
|
|
|
|
yield f"data: {chunk})\n\n" |
|
|
|
|
await asyncio.sleep(0.03) |
|
|
|
|
|
|
|
|
|
yield "data: [DONE]\n\n" |
|
|
|
@ -495,7 +533,7 @@ async def get_all_scenic_flow(request: Request, req: AllScenicFlowRequest, redis |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/get_all_toilet_info") |
|
|
|
|
async def get_all_toilet_info(request: Request, req: AllScenicFlowRequest): |
|
|
|
|
async def get_all_toilet_info(request: Request, req: AllScenicFlowRequest, redis_client = Depends(get_redis_client)): |
|
|
|
|
""" |
|
|
|
|
获取所有厕所信息 |
|
|
|
|
""" |
|
|
|
@ -512,7 +550,7 @@ async def get_all_toilet_info(request: Request, req: AllScenicFlowRequest): |
|
|
|
|
raise HTTPException(status_code=401, detail="无效的签名") |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
data = await get_all_toilet_data(request) |
|
|
|
|
data = await get_all_toilet_data(request, redis_client) |
|
|
|
|
|
|
|
|
|
if not data: |
|
|
|
|
return { |
|
|
|
@ -524,7 +562,7 @@ async def get_all_toilet_info(request: Request, req: AllScenicFlowRequest): |
|
|
|
|
return { |
|
|
|
|
"code": 200, |
|
|
|
|
"message": "查询成功", |
|
|
|
|
"data": data |
|
|
|
|
"data": json.loads(data) |
|
|
|
|
} |
|
|
|
|
except Exception as e: |
|
|
|
|
print(f"查询所有厕所信息异常: {e}") |
|
|
|
|