from fastapi import APIRouter, HTTPException from starlette.requests import Request from fastapi.responses import StreamingResponse from typing import AsyncGenerator, List from .chat_service import classify, extract_spot, query_flow, gen_markdown_stream, ai_chat_stream, \ handle_quick_question, fetch_and_parse_markdown from app.models.ChatIn import ChatIn, AllScenicFlowRequest, ScenicDetailRequest, ScenicParkingRequest from app.models.quick_question import QuickQuestion from app.models.hot_question import HotQuestion from pydantic import BaseModel import hmac import hashlib import time import json from app.settings.config import settings # 更新导入语句,添加新函数 from app.api.chat.ai.chat_service import ( query_flow, handle_quick_question, get_all_scenic_flow_data, get_scenic_detail_data, get_scenic_parking_data ) router = APIRouter() SECRET_KEY = settings.SIGN_KEY # 约定的密钥 TIMESTAMP_TOLERANCE = 60 # 时间戳容忍度,单位:秒 CONVERSATION_EXPIRE_TIME = 600 # 对话历史过期时间,单位:秒 @router.post("/chat", summary="ai对话") async def h5_chat_stream(request: Request, inp: ChatIn): if inp.sign is None: raise HTTPException(status_code=401, detail="缺少签名参数") if not verify_signature(inp.dict(), inp.sign): raise HTTPException(status_code=401, detail="无效的签名") if not verify_timestamp(inp.timestamp): raise HTTPException(status_code=401, detail="时间戳无效") if inp.language is None: inp.language = "zh_cn" # 获取用户 ID user_id = inp.user_id if not user_id: raise HTTPException(status_code=400, detail="缺少用户 ID") # 从 Redis 中获取用户的对话历史 redis_client = request.app.state.redis_client conversation_history_key = f"conversation:{user_id}" conversation_history_str = await redis_client.get(conversation_history_key) conversation_history = json.loads(conversation_history_str) if conversation_history_str else [] # 获取开启的前4个问题(包含标题和内容) questions = await QuickQuestion.filter(status="0").order_by("order_num").limit(4).values("title", "content") question_titles = [q["title"] for q in questions] # 检查消息是否在问题列表中 is_quick_question = inp.message in question_titles # 分类阶段(如果不是快捷问题才执行) cat = None if not is_quick_question: cat = await classify(inp.message) async def content_stream() -> AsyncGenerator[str, None]: nonlocal conversation_history try: if is_quick_question: # 找到对应的问题内容 question_content = next(q["content"] for q in questions if q["title"] == inp.message) # 处理快捷问题,传递content async for chunk in handle_quick_question(inp, question_content): yield chunk else: # 原来的逻辑 if cat == "游玩判断": spot = await extract_spot(inp.message) data = await query_flow(request, spot) knowledge = await fetch_and_parse_markdown(user_id,spot) #如果知识库返回的内容不包含"知识库内未找到相应资源"则拼接字符串 if "知识库内未找到相应资源" not in knowledge: data += "知识库查询到的景区内容:"+ knowledge async for chunk in gen_markdown_stream(inp.message, data, inp.language, conversation_history): yield chunk else: spot = await extract_spot(inp.message) if spot: knowledge = await fetch_and_parse_markdown(user_id, spot) if "知识库内未找到相应资源" not in knowledge: inp.message += ";知识库查询到的景区内容:"+ knowledge async for chunk in ai_chat_stream(inp, conversation_history): yield chunk # 将更新后的对话历史存回 Redis,并设置过期时间 # 只有非快捷问题才保存对话历史 if not is_quick_question: await redis_client.setex(conversation_history_key, CONVERSATION_EXPIRE_TIME, json.dumps(conversation_history)) # 记录热门问题(包括快捷问题) await record_hot_question(inp.message) except Exception as e: print(f"Error in content_stream: {e}") raise try: return StreamingResponse( content_stream(), media_type="text/plain", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no" } ) except Exception as e: print(f"Error in StreamingResponse: {e}") raise HTTPException(status_code=500, detail="Internal Server Error") class ClearConversationRequest(BaseModel): user_id: int @router.post("/clear_conversation", summary="清除对话记录") async def clear_conversation(request: Request, body: ClearConversationRequest): user_id = body.user_id if not user_id: raise HTTPException(status_code=400, detail="缺少用户 ID") redis_client = request.app.state.redis_client conversation_history_key = f"conversation:{user_id}" await redis_client.delete(conversation_history_key) return {"message": "对话历史已清除"} def verify_signature(data: dict, sign: str) -> bool: sorted_keys = sorted(data.keys()) sign_str = '&'.join([f'{key}={data[key]}' for key in sorted_keys if key != 'sign']) sign_str += f'&secret={SECRET_KEY}' calculated_sign = hmac.new(SECRET_KEY.encode(), sign_str.encode(), hashlib.sha256).hexdigest() return calculated_sign == sign def verify_timestamp(timestamp: int) -> bool: current_timestamp = int(time.time() * 1000) return abs(current_timestamp - timestamp) <= TIMESTAMP_TOLERANCE * 1000 async def record_hot_question(question: str): """记录热门问题,如果存在则次数加1,不存在则新增""" try: # 查找是否存在该问题 hot_q = await HotQuestion.filter(title=question).first() if hot_q: # 已存在,次数加1 hot_q.num += 1 await hot_q.save() else: # 不存在,新增记录 await HotQuestion.create(title=question, num=1) except Exception as e: print(f"记录热门问题失败: {e}") # 定义获取问题的响应模型 class QuestionResponse(BaseModel): id: int title: str class Config: orm_mode = True @router.post("/get_question", summary="获取开启的前4个问题") async def get_question(request: Request, req: AllScenicFlowRequest): # 验签逻辑 if not req.sign: raise HTTPException(status_code=401, detail="缺少签名参数") if not verify_timestamp(req.timestamp): raise HTTPException(status_code=401, detail="时间戳无效") try: # 查询状态为正常(0)的问题,按order_num正序排序,取前4条 questions = await QuickQuestion.filter(status="0").order_by("order_num").limit(4).values("title","subtitle","logo","label") return { "code": 200, "message": "查询成功", "data": questions } except Exception as e: print(f"查询快捷问题失败: {e}") return { "code": 500, "message": f"查询失败: {str(e)}", "data": [] } class HotQuestionResponse(BaseModel): id: int title: str num: int update_time: str class Config: orm_mode = True @router.post("/get_hot_questions", summary="获取热门问题top10") async def get_hot_questions(request: Request, req: AllScenicFlowRequest): # 验签逻辑 if not req.sign: raise HTTPException(status_code=401, detail="缺少签名参数") if not verify_timestamp(req.timestamp): raise HTTPException(status_code=401, detail="时间戳无效") """ 获取热门问题top10,按次数倒序排列 """ try: # 查询热门问题,按次数倒序排序,取前10条 hot_questions = await HotQuestion \ .filter() \ .order_by("-num") \ .limit(10) \ .values("id", "title", "num", "update_time") # 格式化时间 for q in hot_questions: q["update_time"] = q["update_time"].strftime("%Y-%m-%d %H:%M:%S") return { "code": 200, "message": "查询成功", "data": hot_questions } except Exception as e: print(f"查询热门问题失败: {e}") return { "code": 500, "message": f"查询失败: {str(e)}", "data": [] } @router.post("/get_all_scenic_flow") async def get_all_scenic_flow(request: Request, req: AllScenicFlowRequest): """ 查询所有景区的进入人数、离开人数,计算承载率并按承载率倒序排列 """ # 验签逻辑 if not req.sign: raise HTTPException(status_code=401, detail="缺少签名参数") if not verify_timestamp(req.timestamp): raise HTTPException(status_code=401, detail="时间戳无效") # 构建验证数据(无其他参数,仅包含timestamp) data = {"timestamp": req.timestamp} if not verify_signature(data, req.sign): raise HTTPException(status_code=401, detail="无效的签名") try: data = await get_all_scenic_flow_data(request) if not data: return { "code": 404, "message": "未找到景区客流数据", "data": [] } return { "code": 200, "message": "查询成功", "data": data } except Exception as e: print(f"查询所有景区客流数据异常: {e}") return { "code": 500, "message": f"查询异常: {str(e)}", "data": [] } # 在现有路由下方添加新接口 @router.post("/get_scenic_detail") async def get_scenic_detail(request: Request, req: ScenicDetailRequest): """ 查询单个景区的详细信息,包含舒适度判断 """ # 验签逻辑 if not req.sign: raise HTTPException(status_code=401, detail="缺少签名参数") if not verify_timestamp(req.timestamp): raise HTTPException(status_code=401, detail="时间戳无效") # 构建验证数据 data = {"id": req.id, "timestamp": req.timestamp} if not verify_signature(data, req.sign): raise HTTPException(status_code=401, detail="无效的签名") if not req.id: return { "code": 400, "message": "景区id不能为空", "data": None } try: data = await get_scenic_detail_data(request, req.id) if not data: return { "code": 404, "message": f"未找到景区信息", "data": None } return { "code": 200, "message": "查询成功", "data": data } except Exception as e: print(f"查询景区详情异常: {e}") return { "code": 500, "message": f"查询异常: {str(e)}", "data": None } # 在现有路由下方添加新接口 - 景区停车场查询 @router.post("/get_scenic_parking") async def get_scenic_parking(request: Request, req: ScenicParkingRequest): """ 查询景区附近的停车场信息 """ # 验签逻辑 if not req.sign: raise HTTPException(status_code=401, detail="缺少签名参数") if not verify_timestamp(req.timestamp): raise HTTPException(status_code=401, detail="时间戳无效") # 构建验证数据 data = { "scenic_name": req.scenic_name, "distance": req.distance, "timestamp": req.timestamp } if not verify_signature(data, req.sign): raise HTTPException(status_code=401, detail="无效的签名") if not req.scenic_name: return { "code": 400, "message": "景区名称不能为空", "data": [] } if req.distance <= 0: return { "code": 400, "message": "查询距离必须大于0", "data": [] } try: data = await get_scenic_parking_data(request, req.scenic_name, req.distance) if not data: return { "code": 404, "message": f"未找到【{req.scenic_name}】附近的停车场信息", "data": [] } return { "code": 200, "message": "查询成功", "data": data } except Exception as e: print(f"查询景区停车场数据异常: {e}") return { "code": 500, "message": f"查询异常: {str(e)}", "data": [] }