parent
442b592560
commit
2e74d93f21
@ -0,0 +1,432 @@ |
|||||||
|
|
||||||
|
import asyncio |
||||||
|
import json |
||||||
|
import logging |
||||||
|
import time |
||||||
|
import uuid |
||||||
|
from typing import Optional, cast |
||||||
|
from datetime import datetime |
||||||
|
from fastapi import APIRouter, Body, Depends |
||||||
|
from fastapi.responses import StreamingResponse |
||||||
|
|
||||||
|
from dbgpt._private.config import Config |
||||||
|
from dbgpt.configs import TAG_KEY_KNOWLEDGE_CHAT_DOMAIN_TYPE |
||||||
|
from dbgpt.core import ModelOutput |
||||||
|
from dbgpt.core.awel import BaseOperator, CommonLLMHttpRequestBody |
||||||
|
from dbgpt.core.awel.dag.dag_manager import DAGManager |
||||||
|
from dbgpt.core.awel.util.chat_util import safe_chat_stream_with_dag_task |
||||||
|
from dbgpt.core.schema.api import ( |
||||||
|
ChatCompletionResponse, |
||||||
|
ChatCompletionResponseChoice, |
||||||
|
ChatCompletionResponseStreamChoice, |
||||||
|
ChatCompletionStreamResponse, |
||||||
|
ChatMessage, |
||||||
|
DeltaMessage, |
||||||
|
UsageInfo, |
||||||
|
) |
||||||
|
from dbgpt.util.tracer import SpanType, root_tracer |
||||||
|
from dbgpt_app.knowledge.request.request import KnowledgeSpaceRequest |
||||||
|
from dbgpt_app.knowledge.service import KnowledgeService |
||||||
|
from dbgpt_app.openapi.api_view_model import ( |
||||||
|
ConversationVo, |
||||||
|
Result, |
||||||
|
) |
||||||
|
from dbgpt_app.scene import BaseChat, ChatFactory, ChatParam, ChatScene |
||||||
|
from dbgpt_serve.agent.db.gpts_app import UserRecentAppsDao, adapt_native_app_model |
||||||
|
from dbgpt_serve.core import blocking_func_to_async |
||||||
|
from dbgpt_serve.flow.service.service import Service as FlowService |
||||||
|
from dbgpt_serve.utils.auth import UserRequest, get_user_from_headers |
||||||
|
|
||||||
|
from dbgpt_system.sys_agent_user_share.agent_user_share_db import AgentUserShareDao as AgentUserShareDao |
||||||
|
from dbgpt_serve.agent.db.gpts_app import GptsAppDao |
||||||
|
|
||||||
|
router = APIRouter() |
||||||
|
CFG = Config() |
||||||
|
logger = logging.getLogger(__name__) |
||||||
|
user_recent_app_dao = UserRecentAppsDao() |
||||||
|
gpts_dao = GptsAppDao() |
||||||
|
CHAT_FACTORY = ChatFactory() |
||||||
|
knowledge_service = KnowledgeService() |
||||||
|
|
||||||
|
agentUserShareDao = AgentUserShareDao() |
||||||
|
def get_chat_flow() -> FlowService: |
||||||
|
"""Get Chat Flow Service.""" |
||||||
|
return FlowService.get_instance(CFG.SYSTEM_APP) |
||||||
|
|
||||||
|
async def chat_app_api(api_key: str, |
||||||
|
app_code: str, |
||||||
|
user_input: str, |
||||||
|
flow_service: FlowService = Depends(get_chat_flow), |
||||||
|
): |
||||||
|
|
||||||
|
await chat_completions(dialogueConver,flow_service) |
||||||
|
|
||||||
|
@router.get("/user_share_chat_completions") |
||||||
|
async def chat_completions( |
||||||
|
api_key: str, |
||||||
|
app_code: str, |
||||||
|
user_input: str, |
||||||
|
flow_service: FlowService = Depends(get_chat_flow), |
||||||
|
): |
||||||
|
|
||||||
|
logger.info("外部接口调用") |
||||||
|
if api_key is None: |
||||||
|
return Result.failed(msg="api_key不能为空") |
||||||
|
# 查看api_key是否正确 |
||||||
|
user_share = agentUserShareDao.select_user_share_by_apiKey(api_key) |
||||||
|
if user_share is None: |
||||||
|
return Result.failed(msg="api_key不正确") |
||||||
|
# 验证时间校验方式 |
||||||
|
if user_share.validity_period == '1': |
||||||
|
if current_time < start_time or current_time > end_time: |
||||||
|
return Result.failed(msg="api_key已过期,请联系管理员") |
||||||
|
# 验证app_code是否为空 |
||||||
|
if app_code is None: |
||||||
|
return Result.failed(msg="app_code 不能为空") |
||||||
|
# app_code是否合规 |
||||||
|
if app_code not in user_share.app_ids: |
||||||
|
return Result.failed(msg="无此应用编码,请与管理员联系") |
||||||
|
# 根据app_code 查看详细信息 |
||||||
|
|
||||||
|
conv_uid = user_share.id |
||||||
|
team_context = "" |
||||||
|
gpt_app_detail = gpts_dao.app_detail(app_code) |
||||||
|
gpt_app = gpts_dao.get_app_by_code(app_code) |
||||||
|
if gpt_app.team_context is None: |
||||||
|
team_context= "chat_agent" |
||||||
|
else: |
||||||
|
team_context = json.loads(gpt_app.team_context).get("chat_scene") |
||||||
|
# 获取模型调用情况 |
||||||
|
llm_strategy = ""; |
||||||
|
if gpt_app_detail.details[0].llm_strategy == 'priority': |
||||||
|
#查询llm model |
||||||
|
llm_strategy = json.loads(gpt_app_detail.details[0].llm_strategy_value[0]) |
||||||
|
else: |
||||||
|
# 查询model |
||||||
|
|
||||||
|
llm_strategy = "deepseek-reasoner"; |
||||||
|
dialogues = { |
||||||
|
"app_code": app_code, |
||||||
|
"chat_mode": team_context, |
||||||
|
"conv_uid": str(conv_uid), |
||||||
|
"model_name": llm_strategy, |
||||||
|
"select_param": "", |
||||||
|
"user_input": user_input, |
||||||
|
"user_name": str(conv_uid) |
||||||
|
} |
||||||
|
dialogue = ConversationVo(** dialogues) |
||||||
|
logger.info( |
||||||
|
f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}," |
||||||
|
f"{dialogue.model_name}, timestamp={int(time.time() * 1000)}" |
||||||
|
) |
||||||
|
dialogue = adapt_native_app_model(dialogue) |
||||||
|
headers = { |
||||||
|
"Content-Type": "text/event-stream", |
||||||
|
"Cache-Control": "no-cache", |
||||||
|
"Connection": "keep-alive", |
||||||
|
"Transfer-Encoding": "chunked", |
||||||
|
} |
||||||
|
try: |
||||||
|
domain_type = _parse_domain_type(dialogue) |
||||||
|
if dialogue.chat_mode == ChatScene.ChatAgent.value(): |
||||||
|
from dbgpt_serve.agent.agents.controller import multi_agents |
||||||
|
|
||||||
|
dialogue.ext_info.update({"model_name": dialogue.model_name}) |
||||||
|
dialogue.ext_info.update({"incremental": dialogue.incremental}) |
||||||
|
dialogue.ext_info.update({"temperature": dialogue.temperature}) |
||||||
|
return StreamingResponse( |
||||||
|
multi_agents.app_agent_chat( |
||||||
|
conv_uid=dialogue.conv_uid, |
||||||
|
chat_mode=dialogue.chat_mode, |
||||||
|
gpts_name=dialogue.app_code, |
||||||
|
user_query=dialogue.user_input, |
||||||
|
user_code=dialogue.user_name, |
||||||
|
sys_code=dialogue.sys_code, |
||||||
|
**dialogue.ext_info, |
||||||
|
), |
||||||
|
headers=headers, |
||||||
|
media_type="text/event-stream", |
||||||
|
) |
||||||
|
elif dialogue.chat_mode == ChatScene.ChatFlow.value(): |
||||||
|
flow_req = CommonLLMHttpRequestBody( |
||||||
|
model=dialogue.model_name, |
||||||
|
messages=dialogue.user_input, |
||||||
|
stream=True, |
||||||
|
# context=flow_ctx, |
||||||
|
# temperature= |
||||||
|
# max_new_tokens= |
||||||
|
# enable_vis= |
||||||
|
conv_uid=dialogue.conv_uid, |
||||||
|
span_id=root_tracer.get_current_span_id(), |
||||||
|
chat_mode=dialogue.chat_mode, |
||||||
|
chat_param=dialogue.select_param, |
||||||
|
user_name=dialogue.user_name, |
||||||
|
sys_code=dialogue.sys_code, |
||||||
|
incremental=dialogue.incremental, |
||||||
|
) |
||||||
|
return StreamingResponse( |
||||||
|
flow_service.chat_stream_flow_str(dialogue.select_param, flow_req), |
||||||
|
headers=headers, |
||||||
|
media_type="text/event-stream", |
||||||
|
) |
||||||
|
elif domain_type is not None and domain_type != "Normal": |
||||||
|
return StreamingResponse( |
||||||
|
chat_with_domain_flow(dialogue, domain_type), |
||||||
|
headers=headers, |
||||||
|
media_type="text/event-stream", |
||||||
|
) |
||||||
|
|
||||||
|
else: |
||||||
|
with root_tracer.start_span( |
||||||
|
"get_chat_instance", span_type=SpanType.CHAT, metadata=dialogue.dict() |
||||||
|
): |
||||||
|
chat: BaseChat = await get_chat_instance(dialogue) |
||||||
|
|
||||||
|
if not chat.prompt_template.stream_out: |
||||||
|
return StreamingResponse( |
||||||
|
no_stream_generator(chat), |
||||||
|
headers=headers, |
||||||
|
media_type="text/event-stream", |
||||||
|
) |
||||||
|
else: |
||||||
|
return StreamingResponse( |
||||||
|
stream_generator( |
||||||
|
chat, |
||||||
|
dialogue.incremental, |
||||||
|
dialogue.model_name, |
||||||
|
openai_format=dialogue.incremental, |
||||||
|
), |
||||||
|
headers=headers, |
||||||
|
media_type="text/plain", |
||||||
|
) |
||||||
|
except Exception as e: |
||||||
|
logger.exception(f"Chat Exception!{dialogue}", e) |
||||||
|
|
||||||
|
async def error_text(err_msg): |
||||||
|
yield f"data:{err_msg}\n\n" |
||||||
|
|
||||||
|
return StreamingResponse( |
||||||
|
error_text(str(e)), |
||||||
|
headers=headers, |
||||||
|
media_type="text/plain", |
||||||
|
) |
||||||
|
finally: |
||||||
|
# write to recent usage app. |
||||||
|
if dialogue.user_name is not None and dialogue.app_code is not None: |
||||||
|
user_recent_app_dao.upsert( |
||||||
|
user_code=dialogue.user_name, |
||||||
|
sys_code=dialogue.sys_code, |
||||||
|
app_code=dialogue.app_code, |
||||||
|
) |
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_domain_type(dialogue: ConversationVo) -> Optional[str]: |
||||||
|
if dialogue.chat_mode == ChatScene.ChatKnowledge.value(): |
||||||
|
# Supported in the knowledge chat |
||||||
|
if dialogue.app_code == "" or dialogue.app_code == "chat_knowledge": |
||||||
|
spaces = knowledge_service.get_knowledge_space( |
||||||
|
KnowledgeSpaceRequest(name=dialogue.select_param) |
||||||
|
) |
||||||
|
else: |
||||||
|
spaces = knowledge_service.get_knowledge_space( |
||||||
|
KnowledgeSpaceRequest(name=dialogue.select_param) |
||||||
|
) |
||||||
|
if len(spaces) == 0: |
||||||
|
raise ValueError(f"Knowledge space {dialogue.select_param} not found") |
||||||
|
dialogue.select_param = spaces[0].name |
||||||
|
if spaces[0].domain_type: |
||||||
|
return spaces[0].domain_type |
||||||
|
else: |
||||||
|
return None |
||||||
|
|
||||||
|
async def chat_with_domain_flow(dialogue: ConversationVo, domain_type: str): |
||||||
|
"""Chat with domain flow""" |
||||||
|
dag_manager = get_dag_manager() |
||||||
|
dags = dag_manager.get_dags_by_tag(TAG_KEY_KNOWLEDGE_CHAT_DOMAIN_TYPE, domain_type) |
||||||
|
if not dags or not dags[0].leaf_nodes: |
||||||
|
raise ValueError(f"Cant find the DAG for domain type {domain_type}") |
||||||
|
|
||||||
|
end_task = cast(BaseOperator, dags[0].leaf_nodes[0]) |
||||||
|
space = dialogue.select_param |
||||||
|
connector_manager = CFG.local_db_manager |
||||||
|
# TODO: Some flow maybe not connector |
||||||
|
db_list = [item["db_name"] for item in connector_manager.get_db_list()] |
||||||
|
db_names = [item for item in db_list if space in item] |
||||||
|
if len(db_names) == 0: |
||||||
|
raise ValueError(f"fin repost dbname {space}_fin_report not found.") |
||||||
|
flow_ctx = {"space": space, "db_name": db_names[0]} |
||||||
|
request = CommonLLMHttpRequestBody( |
||||||
|
model=dialogue.model_name, |
||||||
|
messages=dialogue.user_input, |
||||||
|
stream=True, |
||||||
|
extra=flow_ctx, |
||||||
|
conv_uid=dialogue.conv_uid, |
||||||
|
span_id=root_tracer.get_current_span_id(), |
||||||
|
chat_mode=dialogue.chat_mode, |
||||||
|
chat_param=dialogue.select_param, |
||||||
|
user_name=dialogue.user_name, |
||||||
|
sys_code=dialogue.sys_code, |
||||||
|
incremental=dialogue.incremental, |
||||||
|
) |
||||||
|
async for output in safe_chat_stream_with_dag_task(end_task, request, False): |
||||||
|
text = output.gen_text_with_thinking() |
||||||
|
if text: |
||||||
|
text = text.replace("\n", "\\n") |
||||||
|
if output.error_code != 0: |
||||||
|
yield f"data:[SERVER_ERROR]{text}\n\n" |
||||||
|
break |
||||||
|
else: |
||||||
|
yield f"data:{text}\n\n" |
||||||
|
|
||||||
|
async def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat: |
||||||
|
logger.info(f"get_chat_instance:{dialogue}") |
||||||
|
if not dialogue.chat_mode: |
||||||
|
dialogue.chat_mode = ChatScene.ChatNormal.value() |
||||||
|
if not dialogue.conv_uid: |
||||||
|
conv_vo = __new_conversation( |
||||||
|
dialogue.chat_mode, dialogue.user_name, dialogue.sys_code |
||||||
|
) |
||||||
|
dialogue.conv_uid = conv_vo.conv_uid |
||||||
|
|
||||||
|
if not ChatScene.is_valid_mode(dialogue.chat_mode): |
||||||
|
raise StopAsyncIteration( |
||||||
|
Result.failed("Unsupported Chat Mode," + dialogue.chat_mode + "!") |
||||||
|
) |
||||||
|
|
||||||
|
chat_param = ChatParam( |
||||||
|
chat_session_id=dialogue.conv_uid, |
||||||
|
user_name=dialogue.user_name, |
||||||
|
sys_code=dialogue.sys_code, |
||||||
|
current_user_input=dialogue.user_input, |
||||||
|
select_param=dialogue.select_param, |
||||||
|
model_name=dialogue.model_name, |
||||||
|
app_code=dialogue.app_code, |
||||||
|
ext_info=dialogue.ext_info, |
||||||
|
temperature=dialogue.temperature, |
||||||
|
max_new_tokens=dialogue.max_new_tokens, |
||||||
|
prompt_code=dialogue.prompt_code, |
||||||
|
chat_mode=ChatScene.of_mode(dialogue.chat_mode), |
||||||
|
) |
||||||
|
chat: BaseChat = await blocking_func_to_async( |
||||||
|
CFG.SYSTEM_APP, |
||||||
|
CHAT_FACTORY.get_implementation, |
||||||
|
dialogue.chat_mode, |
||||||
|
CFG.SYSTEM_APP, |
||||||
|
**{"chat_param": chat_param}, |
||||||
|
) |
||||||
|
return chat |
||||||
|
|
||||||
|
async def no_stream_generator(chat): |
||||||
|
with root_tracer.start_span("no_stream_generator"): |
||||||
|
msg = await chat.nostream_call() |
||||||
|
yield f"data: {msg}\n\n" |
||||||
|
|
||||||
|
async def stream_generator( |
||||||
|
chat, |
||||||
|
incremental: bool, |
||||||
|
model_name: str, |
||||||
|
text_output: bool = True, |
||||||
|
openai_format: bool = False, |
||||||
|
conv_uid: str = None, |
||||||
|
): |
||||||
|
"""Generate streaming responses |
||||||
|
|
||||||
|
Our goal is to generate an openai-compatible streaming responses. |
||||||
|
Currently, the incremental response is compatible, and the full response will be |
||||||
|
transformed in the future. |
||||||
|
|
||||||
|
Args: |
||||||
|
chat (BaseChat): Chat instance. |
||||||
|
incremental (bool): Used to control whether the content is returned |
||||||
|
incrementally or in full each time. |
||||||
|
model_name (str): The model name |
||||||
|
|
||||||
|
Yields: |
||||||
|
_type_: streaming responses |
||||||
|
""" |
||||||
|
span = root_tracer.start_span("stream_generator") |
||||||
|
msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong." |
||||||
|
|
||||||
|
stream_id = conv_uid or f"chatcmpl-{str(uuid.uuid1())}" |
||||||
|
try: |
||||||
|
if incremental and not openai_format: |
||||||
|
raise ValueError("Incremental response must be openai-compatible format.") |
||||||
|
async for chunk in chat.stream_call( |
||||||
|
text_output=text_output, incremental=incremental |
||||||
|
): |
||||||
|
if not chunk: |
||||||
|
await asyncio.sleep(0.02) |
||||||
|
continue |
||||||
|
|
||||||
|
if openai_format: |
||||||
|
# Must be ModelOutput |
||||||
|
output: ModelOutput = cast(ModelOutput, chunk) |
||||||
|
text = None |
||||||
|
think_text = None |
||||||
|
if output.has_text: |
||||||
|
text = output.text |
||||||
|
if output.has_thinking: |
||||||
|
think_text = output.thinking_text |
||||||
|
if incremental: |
||||||
|
choice_data = ChatCompletionResponseStreamChoice( |
||||||
|
index=0, |
||||||
|
delta=DeltaMessage( |
||||||
|
role="assistant", content=text, reasoning_content=think_text |
||||||
|
), |
||||||
|
) |
||||||
|
chunk = ChatCompletionStreamResponse( |
||||||
|
id=stream_id, choices=[choice_data], model=model_name |
||||||
|
) |
||||||
|
_content = json.dumps( |
||||||
|
chunk.dict(exclude_unset=True), ensure_ascii=False |
||||||
|
) |
||||||
|
yield f"data: {_content}\n\n" |
||||||
|
else: |
||||||
|
choice_data = ChatCompletionResponseChoice( |
||||||
|
index=0, |
||||||
|
message=ChatMessage( |
||||||
|
role="assistant", |
||||||
|
content=output.text, |
||||||
|
reasoning_content=output.thinking_text, |
||||||
|
), |
||||||
|
) |
||||||
|
if output.usage: |
||||||
|
usage = UsageInfo(**output.usage) |
||||||
|
else: |
||||||
|
usage = UsageInfo() |
||||||
|
_content = ChatCompletionResponse( |
||||||
|
id=stream_id, |
||||||
|
choices=[choice_data], |
||||||
|
model=model_name, |
||||||
|
usage=usage, |
||||||
|
) |
||||||
|
_content = json.dumps( |
||||||
|
chunk.dict(exclude_unset=True), ensure_ascii=False |
||||||
|
) |
||||||
|
yield f"data: {_content}\n\n" |
||||||
|
else: |
||||||
|
msg = chunk.replace("\ufffd", "") |
||||||
|
msg = msg.replace("\n", "\\n") |
||||||
|
yield f"data:{msg}\n\n" |
||||||
|
await asyncio.sleep(0.02) |
||||||
|
if incremental: |
||||||
|
yield "data: [DONE]\n\n" |
||||||
|
span.end() |
||||||
|
except Exception as e: |
||||||
|
logger.exception("stream_generator error") |
||||||
|
yield f"data: [SERVER_ERROR]{str(e)}\n\n" |
||||||
|
if incremental: |
||||||
|
yield "data: [DONE]\n\n" |
||||||
|
def get_dag_manager() -> DAGManager: |
||||||
|
"""Get the global default DAGManager""" |
||||||
|
return DAGManager.get_instance(CFG.SYSTEM_APP) |
||||||
|
|
||||||
|
def __new_conversation(chat_mode, user_name: str, sys_code: str) -> ConversationVo: |
||||||
|
unique_id = uuid.uuid1() |
||||||
|
return ConversationVo( |
||||||
|
conv_uid=str(unique_id), |
||||||
|
chat_mode=chat_mode, |
||||||
|
user_name=user_name, |
||||||
|
sys_code=sys_code, |
||||||
|
) |
||||||
Loading…
Reference in new issue