|  |  |  | @ -2,14 +2,15 @@ from fastapi import APIRouter, HTTPException | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | from starlette.requests import Request | 
			
		
	
		
			
				
					|  |  |  |  | from fastapi.responses import StreamingResponse | 
			
		
	
		
			
				
					|  |  |  |  | from typing import AsyncGenerator | 
			
		
	
		
			
				
					|  |  |  |  | from .chat_service import classify, extract_spot, query_flow, gen_markdown_stream, ai_chat_stream | 
			
		
	
		
			
				
					|  |  |  |  | from typing import AsyncGenerator, List | 
			
		
	
		
			
				
					|  |  |  |  | from .chat_service import classify, extract_spot, query_flow, gen_markdown_stream, ai_chat_stream, handle_quick_question | 
			
		
	
		
			
				
					|  |  |  |  | from app.models.ChatIn import ChatIn | 
			
		
	
		
			
				
					|  |  |  |  | from app.models.quick_question import QuickQuestion | 
			
		
	
		
			
				
					|  |  |  |  | from pydantic import BaseModel | 
			
		
	
		
			
				
					|  |  |  |  | import hmac | 
			
		
	
		
			
				
					|  |  |  |  | import hashlib | 
			
		
	
		
			
				
					|  |  |  |  | import time | 
			
		
	
		
			
				
					|  |  |  |  | import json | 
			
		
	
		
			
				
					|  |  |  |  | from pydantic import BaseModel | 
			
		
	
		
			
				
					|  |  |  |  | from app.settings.config import settings | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | router = APIRouter() | 
			
		
	
	
		
			
				
					|  |  |  | @ -43,23 +44,42 @@ async def h5_chat_stream(request: Request, inp: ChatIn): | 
			
		
	
		
			
				
					|  |  |  |  |     conversation_history_str = await redis_client.get(conversation_history_key) | 
			
		
	
		
			
				
					|  |  |  |  |     conversation_history = json.loads(conversation_history_str) if conversation_history_str else [] | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     # 分类阶段(保留同步调用) | 
			
		
	
		
			
				
					|  |  |  |  |     cat = await classify(inp.message) | 
			
		
	
		
			
				
					|  |  |  |  |     # 获取开启的前4个问题(包含标题和内容) | 
			
		
	
		
			
				
					|  |  |  |  |     questions = await QuickQuestion.filter(status="0").order_by("order_num").limit(4).values("title", "content") | 
			
		
	
		
			
				
					|  |  |  |  |     question_titles = [q["title"] for q in questions] | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     # 检查消息是否在问题列表中 | 
			
		
	
		
			
				
					|  |  |  |  |     is_quick_question = inp.message in question_titles | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     # 分类阶段(如果不是快捷问题才执行) | 
			
		
	
		
			
				
					|  |  |  |  |     cat = None | 
			
		
	
		
			
				
					|  |  |  |  |     if not is_quick_question: | 
			
		
	
		
			
				
					|  |  |  |  |         cat = await classify(inp.message) | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     async def content_stream() -> AsyncGenerator[str, None]: | 
			
		
	
		
			
				
					|  |  |  |  |         nonlocal conversation_history | 
			
		
	
		
			
				
					|  |  |  |  |         try: | 
			
		
	
		
			
				
					|  |  |  |  |             if cat == "游玩判断": | 
			
		
	
		
			
				
					|  |  |  |  |                 spot = await extract_spot(inp.message) | 
			
		
	
		
			
				
					|  |  |  |  |                 data = await query_flow(request, spot) | 
			
		
	
		
			
				
					|  |  |  |  |                 async for chunk in gen_markdown_stream(inp.message, data, inp.language, conversation_history): | 
			
		
	
		
			
				
					|  |  |  |  |             if is_quick_question: | 
			
		
	
		
			
				
					|  |  |  |  |                 # 找到对应的问题内容 | 
			
		
	
		
			
				
					|  |  |  |  |                 question_content = next(q["content"] for q in questions if q["title"] == inp.message) | 
			
		
	
		
			
				
					|  |  |  |  |                 # 处理快捷问题,传递content | 
			
		
	
		
			
				
					|  |  |  |  |                 async for chunk in handle_quick_question(inp, question_content): | 
			
		
	
		
			
				
					|  |  |  |  |                     yield chunk | 
			
		
	
		
			
				
					|  |  |  |  |             else: | 
			
		
	
		
			
				
					|  |  |  |  |                 async for chunk in ai_chat_stream(inp, conversation_history): | 
			
		
	
		
			
				
					|  |  |  |  |                     yield chunk | 
			
		
	
		
			
				
					|  |  |  |  |                 # 原来的逻辑 | 
			
		
	
		
			
				
					|  |  |  |  |                 if cat == "游玩判断": | 
			
		
	
		
			
				
					|  |  |  |  |                     spot = await extract_spot(inp.message) | 
			
		
	
		
			
				
					|  |  |  |  |                     data = await query_flow(request, spot) | 
			
		
	
		
			
				
					|  |  |  |  |                     async for chunk in gen_markdown_stream(inp.message, data, inp.language, conversation_history): | 
			
		
	
		
			
				
					|  |  |  |  |                         yield chunk | 
			
		
	
		
			
				
					|  |  |  |  |                 else: | 
			
		
	
		
			
				
					|  |  |  |  |                     async for chunk in ai_chat_stream(inp, conversation_history): | 
			
		
	
		
			
				
					|  |  |  |  |                         yield chunk | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |             # 将更新后的对话历史存回 Redis,并设置过期时间 | 
			
		
	
		
			
				
					|  |  |  |  |             await redis_client.setex(conversation_history_key, CONVERSATION_EXPIRE_TIME, json.dumps(conversation_history)) | 
			
		
	
		
			
				
					|  |  |  |  |             # 只有非快捷问题才保存对话历史 | 
			
		
	
		
			
				
					|  |  |  |  |             if not is_quick_question: | 
			
		
	
		
			
				
					|  |  |  |  |                 await redis_client.setex(conversation_history_key, CONVERSATION_EXPIRE_TIME, json.dumps(conversation_history)) | 
			
		
	
		
			
				
					|  |  |  |  |         except Exception as e: | 
			
		
	
		
			
				
					|  |  |  |  |             print(f"Error in content_stream: {e}") | 
			
		
	
		
			
				
					|  |  |  |  |             raise | 
			
		
	
	
		
			
				
					|  |  |  | @ -103,3 +123,19 @@ def verify_signature(data: dict, sign: str) -> bool: | 
			
		
	
		
			
				
					|  |  |  |  | def verify_timestamp(timestamp: int) -> bool: | 
			
		
	
		
			
				
					|  |  |  |  |     current_timestamp = int(time.time() * 1000) | 
			
		
	
		
			
				
					|  |  |  |  |     return abs(current_timestamp - timestamp) <= TIMESTAMP_TOLERANCE * 1000 | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | # 定义获取问题的响应模型 | 
			
		
	
		
			
				
					|  |  |  |  | class QuestionResponse(BaseModel): | 
			
		
	
		
			
				
					|  |  |  |  |     id: int | 
			
		
	
		
			
				
					|  |  |  |  |     title: str | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     class Config: | 
			
		
	
		
			
				
					|  |  |  |  |         orm_mode = True | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | @router.get("/getQuestion", summary="获取开启的前4个问题") | 
			
		
	
		
			
				
					|  |  |  |  | async def get_question(): | 
			
		
	
		
			
				
					|  |  |  |  |     # 查询状态为正常(0)的问题,按order_num正序排序,取前4条 | 
			
		
	
		
			
				
					|  |  |  |  |     questions = await QuickQuestion.filter(status="0").order_by("order_num").limit(4).values("id", "title") | 
			
		
	
		
			
				
					|  |  |  |  |     return questions |