|
|
|
@ -2,14 +2,15 @@ from fastapi import APIRouter, HTTPException |
|
|
|
|
|
|
|
|
|
from starlette.requests import Request |
|
|
|
|
from fastapi.responses import StreamingResponse |
|
|
|
|
from typing import AsyncGenerator |
|
|
|
|
from .chat_service import classify, extract_spot, query_flow, gen_markdown_stream, ai_chat_stream |
|
|
|
|
from typing import AsyncGenerator, List |
|
|
|
|
from .chat_service import classify, extract_spot, query_flow, gen_markdown_stream, ai_chat_stream, handle_quick_question |
|
|
|
|
from app.models.ChatIn import ChatIn |
|
|
|
|
from app.models.quick_question import QuickQuestion |
|
|
|
|
from pydantic import BaseModel |
|
|
|
|
import hmac |
|
|
|
|
import hashlib |
|
|
|
|
import time |
|
|
|
|
import json |
|
|
|
|
from pydantic import BaseModel |
|
|
|
|
from app.settings.config import settings |
|
|
|
|
|
|
|
|
|
router = APIRouter() |
|
|
|
@ -43,23 +44,42 @@ async def h5_chat_stream(request: Request, inp: ChatIn): |
|
|
|
|
conversation_history_str = await redis_client.get(conversation_history_key) |
|
|
|
|
conversation_history = json.loads(conversation_history_str) if conversation_history_str else [] |
|
|
|
|
|
|
|
|
|
# 分类阶段(保留同步调用) |
|
|
|
|
cat = await classify(inp.message) |
|
|
|
|
# 获取开启的前4个问题(包含标题和内容) |
|
|
|
|
questions = await QuickQuestion.filter(status="0").order_by("order_num").limit(4).values("title", "content") |
|
|
|
|
question_titles = [q["title"] for q in questions] |
|
|
|
|
|
|
|
|
|
# 检查消息是否在问题列表中 |
|
|
|
|
is_quick_question = inp.message in question_titles |
|
|
|
|
|
|
|
|
|
# 分类阶段(如果不是快捷问题才执行) |
|
|
|
|
cat = None |
|
|
|
|
if not is_quick_question: |
|
|
|
|
cat = await classify(inp.message) |
|
|
|
|
|
|
|
|
|
async def content_stream() -> AsyncGenerator[str, None]: |
|
|
|
|
nonlocal conversation_history |
|
|
|
|
try: |
|
|
|
|
if cat == "游玩判断": |
|
|
|
|
spot = await extract_spot(inp.message) |
|
|
|
|
data = await query_flow(request, spot) |
|
|
|
|
async for chunk in gen_markdown_stream(inp.message, data, inp.language, conversation_history): |
|
|
|
|
if is_quick_question: |
|
|
|
|
# 找到对应的问题内容 |
|
|
|
|
question_content = next(q["content"] for q in questions if q["title"] == inp.message) |
|
|
|
|
# 处理快捷问题,传递content |
|
|
|
|
async for chunk in handle_quick_question(inp, question_content): |
|
|
|
|
yield chunk |
|
|
|
|
else: |
|
|
|
|
async for chunk in ai_chat_stream(inp, conversation_history): |
|
|
|
|
yield chunk |
|
|
|
|
# 原来的逻辑 |
|
|
|
|
if cat == "游玩判断": |
|
|
|
|
spot = await extract_spot(inp.message) |
|
|
|
|
data = await query_flow(request, spot) |
|
|
|
|
async for chunk in gen_markdown_stream(inp.message, data, inp.language, conversation_history): |
|
|
|
|
yield chunk |
|
|
|
|
else: |
|
|
|
|
async for chunk in ai_chat_stream(inp, conversation_history): |
|
|
|
|
yield chunk |
|
|
|
|
|
|
|
|
|
# 将更新后的对话历史存回 Redis,并设置过期时间 |
|
|
|
|
await redis_client.setex(conversation_history_key, CONVERSATION_EXPIRE_TIME, json.dumps(conversation_history)) |
|
|
|
|
# 只有非快捷问题才保存对话历史 |
|
|
|
|
if not is_quick_question: |
|
|
|
|
await redis_client.setex(conversation_history_key, CONVERSATION_EXPIRE_TIME, json.dumps(conversation_history)) |
|
|
|
|
except Exception as e: |
|
|
|
|
print(f"Error in content_stream: {e}") |
|
|
|
|
raise |
|
|
|
@ -102,4 +122,20 @@ def verify_signature(data: dict, sign: str) -> bool: |
|
|
|
|
|
|
|
|
|
def verify_timestamp(timestamp: int) -> bool: |
|
|
|
|
current_timestamp = int(time.time() * 1000) |
|
|
|
|
return abs(current_timestamp - timestamp) <= TIMESTAMP_TOLERANCE * 1000 |
|
|
|
|
return abs(current_timestamp - timestamp) <= TIMESTAMP_TOLERANCE * 1000 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 定义获取问题的响应模型 |
|
|
|
|
class QuestionResponse(BaseModel): |
|
|
|
|
id: int |
|
|
|
|
title: str |
|
|
|
|
|
|
|
|
|
class Config: |
|
|
|
|
orm_mode = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/getQuestion", summary="获取开启的前4个问题") |
|
|
|
|
async def get_question(): |
|
|
|
|
# 查询状态为正常(0)的问题,按order_num正序排序,取前4条 |
|
|
|
|
questions = await QuickQuestion.filter(status="0").order_by("order_num").limit(4).values("id", "title") |
|
|
|
|
return questions |