parent
442b592560
commit
2e74d93f21
@ -0,0 +1,432 @@ |
||||
|
||||
import asyncio |
||||
import json |
||||
import logging |
||||
import time |
||||
import uuid |
||||
from typing import Optional, cast |
||||
from datetime import datetime |
||||
from fastapi import APIRouter, Body, Depends |
||||
from fastapi.responses import StreamingResponse |
||||
|
||||
from dbgpt._private.config import Config |
||||
from dbgpt.configs import TAG_KEY_KNOWLEDGE_CHAT_DOMAIN_TYPE |
||||
from dbgpt.core import ModelOutput |
||||
from dbgpt.core.awel import BaseOperator, CommonLLMHttpRequestBody |
||||
from dbgpt.core.awel.dag.dag_manager import DAGManager |
||||
from dbgpt.core.awel.util.chat_util import safe_chat_stream_with_dag_task |
||||
from dbgpt.core.schema.api import ( |
||||
ChatCompletionResponse, |
||||
ChatCompletionResponseChoice, |
||||
ChatCompletionResponseStreamChoice, |
||||
ChatCompletionStreamResponse, |
||||
ChatMessage, |
||||
DeltaMessage, |
||||
UsageInfo, |
||||
) |
||||
from dbgpt.util.tracer import SpanType, root_tracer |
||||
from dbgpt_app.knowledge.request.request import KnowledgeSpaceRequest |
||||
from dbgpt_app.knowledge.service import KnowledgeService |
||||
from dbgpt_app.openapi.api_view_model import ( |
||||
ConversationVo, |
||||
Result, |
||||
) |
||||
from dbgpt_app.scene import BaseChat, ChatFactory, ChatParam, ChatScene |
||||
from dbgpt_serve.agent.db.gpts_app import UserRecentAppsDao, adapt_native_app_model |
||||
from dbgpt_serve.core import blocking_func_to_async |
||||
from dbgpt_serve.flow.service.service import Service as FlowService |
||||
from dbgpt_serve.utils.auth import UserRequest, get_user_from_headers |
||||
|
||||
from dbgpt_system.sys_agent_user_share.agent_user_share_db import AgentUserShareDao as AgentUserShareDao |
||||
from dbgpt_serve.agent.db.gpts_app import GptsAppDao |
||||
|
||||
router = APIRouter() |
||||
CFG = Config() |
||||
logger = logging.getLogger(__name__) |
||||
user_recent_app_dao = UserRecentAppsDao() |
||||
gpts_dao = GptsAppDao() |
||||
CHAT_FACTORY = ChatFactory() |
||||
knowledge_service = KnowledgeService() |
||||
|
||||
agentUserShareDao = AgentUserShareDao() |
||||
def get_chat_flow() -> FlowService: |
||||
"""Get Chat Flow Service.""" |
||||
return FlowService.get_instance(CFG.SYSTEM_APP) |
||||
|
||||
async def chat_app_api(api_key: str, |
||||
app_code: str, |
||||
user_input: str, |
||||
flow_service: FlowService = Depends(get_chat_flow), |
||||
): |
||||
|
||||
await chat_completions(dialogueConver,flow_service) |
||||
|
||||
@router.get("/user_share_chat_completions") |
||||
async def chat_completions( |
||||
api_key: str, |
||||
app_code: str, |
||||
user_input: str, |
||||
flow_service: FlowService = Depends(get_chat_flow), |
||||
): |
||||
|
||||
logger.info("外部接口调用") |
||||
if api_key is None: |
||||
return Result.failed(msg="api_key不能为空") |
||||
# 查看api_key是否正确 |
||||
user_share = agentUserShareDao.select_user_share_by_apiKey(api_key) |
||||
if user_share is None: |
||||
return Result.failed(msg="api_key不正确") |
||||
# 验证时间校验方式 |
||||
if user_share.validity_period == '1': |
||||
if current_time < start_time or current_time > end_time: |
||||
return Result.failed(msg="api_key已过期,请联系管理员") |
||||
# 验证app_code是否为空 |
||||
if app_code is None: |
||||
return Result.failed(msg="app_code 不能为空") |
||||
# app_code是否合规 |
||||
if app_code not in user_share.app_ids: |
||||
return Result.failed(msg="无此应用编码,请与管理员联系") |
||||
# 根据app_code 查看详细信息 |
||||
|
||||
conv_uid = user_share.id |
||||
team_context = "" |
||||
gpt_app_detail = gpts_dao.app_detail(app_code) |
||||
gpt_app = gpts_dao.get_app_by_code(app_code) |
||||
if gpt_app.team_context is None: |
||||
team_context= "chat_agent" |
||||
else: |
||||
team_context = json.loads(gpt_app.team_context).get("chat_scene") |
||||
# 获取模型调用情况 |
||||
llm_strategy = ""; |
||||
if gpt_app_detail.details[0].llm_strategy == 'priority': |
||||
#查询llm model |
||||
llm_strategy = json.loads(gpt_app_detail.details[0].llm_strategy_value[0]) |
||||
else: |
||||
# 查询model |
||||
|
||||
llm_strategy = "deepseek-reasoner"; |
||||
dialogues = { |
||||
"app_code": app_code, |
||||
"chat_mode": team_context, |
||||
"conv_uid": str(conv_uid), |
||||
"model_name": llm_strategy, |
||||
"select_param": "", |
||||
"user_input": user_input, |
||||
"user_name": str(conv_uid) |
||||
} |
||||
dialogue = ConversationVo(** dialogues) |
||||
logger.info( |
||||
f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}," |
||||
f"{dialogue.model_name}, timestamp={int(time.time() * 1000)}" |
||||
) |
||||
dialogue = adapt_native_app_model(dialogue) |
||||
headers = { |
||||
"Content-Type": "text/event-stream", |
||||
"Cache-Control": "no-cache", |
||||
"Connection": "keep-alive", |
||||
"Transfer-Encoding": "chunked", |
||||
} |
||||
try: |
||||
domain_type = _parse_domain_type(dialogue) |
||||
if dialogue.chat_mode == ChatScene.ChatAgent.value(): |
||||
from dbgpt_serve.agent.agents.controller import multi_agents |
||||
|
||||
dialogue.ext_info.update({"model_name": dialogue.model_name}) |
||||
dialogue.ext_info.update({"incremental": dialogue.incremental}) |
||||
dialogue.ext_info.update({"temperature": dialogue.temperature}) |
||||
return StreamingResponse( |
||||
multi_agents.app_agent_chat( |
||||
conv_uid=dialogue.conv_uid, |
||||
chat_mode=dialogue.chat_mode, |
||||
gpts_name=dialogue.app_code, |
||||
user_query=dialogue.user_input, |
||||
user_code=dialogue.user_name, |
||||
sys_code=dialogue.sys_code, |
||||
**dialogue.ext_info, |
||||
), |
||||
headers=headers, |
||||
media_type="text/event-stream", |
||||
) |
||||
elif dialogue.chat_mode == ChatScene.ChatFlow.value(): |
||||
flow_req = CommonLLMHttpRequestBody( |
||||
model=dialogue.model_name, |
||||
messages=dialogue.user_input, |
||||
stream=True, |
||||
# context=flow_ctx, |
||||
# temperature= |
||||
# max_new_tokens= |
||||
# enable_vis= |
||||
conv_uid=dialogue.conv_uid, |
||||
span_id=root_tracer.get_current_span_id(), |
||||
chat_mode=dialogue.chat_mode, |
||||
chat_param=dialogue.select_param, |
||||
user_name=dialogue.user_name, |
||||
sys_code=dialogue.sys_code, |
||||
incremental=dialogue.incremental, |
||||
) |
||||
return StreamingResponse( |
||||
flow_service.chat_stream_flow_str(dialogue.select_param, flow_req), |
||||
headers=headers, |
||||
media_type="text/event-stream", |
||||
) |
||||
elif domain_type is not None and domain_type != "Normal": |
||||
return StreamingResponse( |
||||
chat_with_domain_flow(dialogue, domain_type), |
||||
headers=headers, |
||||
media_type="text/event-stream", |
||||
) |
||||
|
||||
else: |
||||
with root_tracer.start_span( |
||||
"get_chat_instance", span_type=SpanType.CHAT, metadata=dialogue.dict() |
||||
): |
||||
chat: BaseChat = await get_chat_instance(dialogue) |
||||
|
||||
if not chat.prompt_template.stream_out: |
||||
return StreamingResponse( |
||||
no_stream_generator(chat), |
||||
headers=headers, |
||||
media_type="text/event-stream", |
||||
) |
||||
else: |
||||
return StreamingResponse( |
||||
stream_generator( |
||||
chat, |
||||
dialogue.incremental, |
||||
dialogue.model_name, |
||||
openai_format=dialogue.incremental, |
||||
), |
||||
headers=headers, |
||||
media_type="text/plain", |
||||
) |
||||
except Exception as e: |
||||
logger.exception(f"Chat Exception!{dialogue}", e) |
||||
|
||||
async def error_text(err_msg): |
||||
yield f"data:{err_msg}\n\n" |
||||
|
||||
return StreamingResponse( |
||||
error_text(str(e)), |
||||
headers=headers, |
||||
media_type="text/plain", |
||||
) |
||||
finally: |
||||
# write to recent usage app. |
||||
if dialogue.user_name is not None and dialogue.app_code is not None: |
||||
user_recent_app_dao.upsert( |
||||
user_code=dialogue.user_name, |
||||
sys_code=dialogue.sys_code, |
||||
app_code=dialogue.app_code, |
||||
) |
||||
|
||||
|
||||
|
||||
def _parse_domain_type(dialogue: ConversationVo) -> Optional[str]: |
||||
if dialogue.chat_mode == ChatScene.ChatKnowledge.value(): |
||||
# Supported in the knowledge chat |
||||
if dialogue.app_code == "" or dialogue.app_code == "chat_knowledge": |
||||
spaces = knowledge_service.get_knowledge_space( |
||||
KnowledgeSpaceRequest(name=dialogue.select_param) |
||||
) |
||||
else: |
||||
spaces = knowledge_service.get_knowledge_space( |
||||
KnowledgeSpaceRequest(name=dialogue.select_param) |
||||
) |
||||
if len(spaces) == 0: |
||||
raise ValueError(f"Knowledge space {dialogue.select_param} not found") |
||||
dialogue.select_param = spaces[0].name |
||||
if spaces[0].domain_type: |
||||
return spaces[0].domain_type |
||||
else: |
||||
return None |
||||
|
||||
async def chat_with_domain_flow(dialogue: ConversationVo, domain_type: str): |
||||
"""Chat with domain flow""" |
||||
dag_manager = get_dag_manager() |
||||
dags = dag_manager.get_dags_by_tag(TAG_KEY_KNOWLEDGE_CHAT_DOMAIN_TYPE, domain_type) |
||||
if not dags or not dags[0].leaf_nodes: |
||||
raise ValueError(f"Cant find the DAG for domain type {domain_type}") |
||||
|
||||
end_task = cast(BaseOperator, dags[0].leaf_nodes[0]) |
||||
space = dialogue.select_param |
||||
connector_manager = CFG.local_db_manager |
||||
# TODO: Some flow maybe not connector |
||||
db_list = [item["db_name"] for item in connector_manager.get_db_list()] |
||||
db_names = [item for item in db_list if space in item] |
||||
if len(db_names) == 0: |
||||
raise ValueError(f"fin repost dbname {space}_fin_report not found.") |
||||
flow_ctx = {"space": space, "db_name": db_names[0]} |
||||
request = CommonLLMHttpRequestBody( |
||||
model=dialogue.model_name, |
||||
messages=dialogue.user_input, |
||||
stream=True, |
||||
extra=flow_ctx, |
||||
conv_uid=dialogue.conv_uid, |
||||
span_id=root_tracer.get_current_span_id(), |
||||
chat_mode=dialogue.chat_mode, |
||||
chat_param=dialogue.select_param, |
||||
user_name=dialogue.user_name, |
||||
sys_code=dialogue.sys_code, |
||||
incremental=dialogue.incremental, |
||||
) |
||||
async for output in safe_chat_stream_with_dag_task(end_task, request, False): |
||||
text = output.gen_text_with_thinking() |
||||
if text: |
||||
text = text.replace("\n", "\\n") |
||||
if output.error_code != 0: |
||||
yield f"data:[SERVER_ERROR]{text}\n\n" |
||||
break |
||||
else: |
||||
yield f"data:{text}\n\n" |
||||
|
||||
async def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat: |
||||
logger.info(f"get_chat_instance:{dialogue}") |
||||
if not dialogue.chat_mode: |
||||
dialogue.chat_mode = ChatScene.ChatNormal.value() |
||||
if not dialogue.conv_uid: |
||||
conv_vo = __new_conversation( |
||||
dialogue.chat_mode, dialogue.user_name, dialogue.sys_code |
||||
) |
||||
dialogue.conv_uid = conv_vo.conv_uid |
||||
|
||||
if not ChatScene.is_valid_mode(dialogue.chat_mode): |
||||
raise StopAsyncIteration( |
||||
Result.failed("Unsupported Chat Mode," + dialogue.chat_mode + "!") |
||||
) |
||||
|
||||
chat_param = ChatParam( |
||||
chat_session_id=dialogue.conv_uid, |
||||
user_name=dialogue.user_name, |
||||
sys_code=dialogue.sys_code, |
||||
current_user_input=dialogue.user_input, |
||||
select_param=dialogue.select_param, |
||||
model_name=dialogue.model_name, |
||||
app_code=dialogue.app_code, |
||||
ext_info=dialogue.ext_info, |
||||
temperature=dialogue.temperature, |
||||
max_new_tokens=dialogue.max_new_tokens, |
||||
prompt_code=dialogue.prompt_code, |
||||
chat_mode=ChatScene.of_mode(dialogue.chat_mode), |
||||
) |
||||
chat: BaseChat = await blocking_func_to_async( |
||||
CFG.SYSTEM_APP, |
||||
CHAT_FACTORY.get_implementation, |
||||
dialogue.chat_mode, |
||||
CFG.SYSTEM_APP, |
||||
**{"chat_param": chat_param}, |
||||
) |
||||
return chat |
||||
|
||||
async def no_stream_generator(chat): |
||||
with root_tracer.start_span("no_stream_generator"): |
||||
msg = await chat.nostream_call() |
||||
yield f"data: {msg}\n\n" |
||||
|
||||
async def stream_generator( |
||||
chat, |
||||
incremental: bool, |
||||
model_name: str, |
||||
text_output: bool = True, |
||||
openai_format: bool = False, |
||||
conv_uid: str = None, |
||||
): |
||||
"""Generate streaming responses |
||||
|
||||
Our goal is to generate an openai-compatible streaming responses. |
||||
Currently, the incremental response is compatible, and the full response will be |
||||
transformed in the future. |
||||
|
||||
Args: |
||||
chat (BaseChat): Chat instance. |
||||
incremental (bool): Used to control whether the content is returned |
||||
incrementally or in full each time. |
||||
model_name (str): The model name |
||||
|
||||
Yields: |
||||
_type_: streaming responses |
||||
""" |
||||
span = root_tracer.start_span("stream_generator") |
||||
msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong." |
||||
|
||||
stream_id = conv_uid or f"chatcmpl-{str(uuid.uuid1())}" |
||||
try: |
||||
if incremental and not openai_format: |
||||
raise ValueError("Incremental response must be openai-compatible format.") |
||||
async for chunk in chat.stream_call( |
||||
text_output=text_output, incremental=incremental |
||||
): |
||||
if not chunk: |
||||
await asyncio.sleep(0.02) |
||||
continue |
||||
|
||||
if openai_format: |
||||
# Must be ModelOutput |
||||
output: ModelOutput = cast(ModelOutput, chunk) |
||||
text = None |
||||
think_text = None |
||||
if output.has_text: |
||||
text = output.text |
||||
if output.has_thinking: |
||||
think_text = output.thinking_text |
||||
if incremental: |
||||
choice_data = ChatCompletionResponseStreamChoice( |
||||
index=0, |
||||
delta=DeltaMessage( |
||||
role="assistant", content=text, reasoning_content=think_text |
||||
), |
||||
) |
||||
chunk = ChatCompletionStreamResponse( |
||||
id=stream_id, choices=[choice_data], model=model_name |
||||
) |
||||
_content = json.dumps( |
||||
chunk.dict(exclude_unset=True), ensure_ascii=False |
||||
) |
||||
yield f"data: {_content}\n\n" |
||||
else: |
||||
choice_data = ChatCompletionResponseChoice( |
||||
index=0, |
||||
message=ChatMessage( |
||||
role="assistant", |
||||
content=output.text, |
||||
reasoning_content=output.thinking_text, |
||||
), |
||||
) |
||||
if output.usage: |
||||
usage = UsageInfo(**output.usage) |
||||
else: |
||||
usage = UsageInfo() |
||||
_content = ChatCompletionResponse( |
||||
id=stream_id, |
||||
choices=[choice_data], |
||||
model=model_name, |
||||
usage=usage, |
||||
) |
||||
_content = json.dumps( |
||||
chunk.dict(exclude_unset=True), ensure_ascii=False |
||||
) |
||||
yield f"data: {_content}\n\n" |
||||
else: |
||||
msg = chunk.replace("\ufffd", "") |
||||
msg = msg.replace("\n", "\\n") |
||||
yield f"data:{msg}\n\n" |
||||
await asyncio.sleep(0.02) |
||||
if incremental: |
||||
yield "data: [DONE]\n\n" |
||||
span.end() |
||||
except Exception as e: |
||||
logger.exception("stream_generator error") |
||||
yield f"data: [SERVER_ERROR]{str(e)}\n\n" |
||||
if incremental: |
||||
yield "data: [DONE]\n\n" |
||||
def get_dag_manager() -> DAGManager: |
||||
"""Get the global default DAGManager""" |
||||
return DAGManager.get_instance(CFG.SYSTEM_APP) |
||||
|
||||
def __new_conversation(chat_mode, user_name: str, sys_code: str) -> ConversationVo: |
||||
unique_id = uuid.uuid1() |
||||
return ConversationVo( |
||||
conv_uid=str(unique_id), |
||||
chat_mode=chat_mode, |
||||
user_name=user_name, |
||||
sys_code=sys_code, |
||||
) |
Loading…
Reference in new issue