You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
406 lines
13 KiB
406 lines
13 KiB
/*
|
|
* Copyright 2025 coze-dev Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package conversation
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
"strconv"
|
|
|
|
"github.com/cloudwego/eino/schema"
|
|
|
|
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
|
|
"github.com/coze-dev/coze-studio/backend/api/model/conversation/run"
|
|
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
|
|
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
|
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
|
|
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
|
saEntity "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
|
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
|
convEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
|
|
cmdEntity "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity"
|
|
uploadService "github.com/coze-dev/coze-studio/backend/domain/upload/service"
|
|
sseImpl "github.com/coze-dev/coze-studio/backend/infra/impl/sse"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
|
"github.com/coze-dev/coze-studio/backend/types/consts"
|
|
"github.com/coze-dev/coze-studio/backend/types/errno"
|
|
)
|
|
|
|
func (a *OpenapiAgentRunApplication) OpenapiAgentRun(ctx context.Context, sseSender *sseImpl.SSenderImpl, ar *run.ChatV3Request) error {
|
|
|
|
apiKeyInfo := ctxutil.GetApiAuthFromCtx(ctx)
|
|
creatorID := apiKeyInfo.UserID
|
|
connectorID := apiKeyInfo.ConnectorID
|
|
|
|
if ptr.From(ar.ConnectorID) == consts.WebSDKConnectorID {
|
|
connectorID = ptr.From(ar.ConnectorID)
|
|
}
|
|
agentInfo, caErr := a.checkAgent(ctx, ar, connectorID)
|
|
if caErr != nil {
|
|
logs.CtxErrorf(ctx, "checkAgent err:%v", caErr)
|
|
return caErr
|
|
}
|
|
|
|
conversationData, ccErr := a.checkConversation(ctx, ar, creatorID, connectorID)
|
|
if ccErr != nil {
|
|
logs.CtxErrorf(ctx, "checkConversation err:%v", ccErr)
|
|
return ccErr
|
|
}
|
|
|
|
spaceID := agentInfo.SpaceID
|
|
arr, err := a.buildAgentRunRequest(ctx, ar, connectorID, spaceID, conversationData)
|
|
if err != nil {
|
|
logs.CtxErrorf(ctx, "buildAgentRunRequest err:%v", err)
|
|
return err
|
|
}
|
|
streamer, err := ConversationSVC.AgentRunDomainSVC.AgentRun(ctx, arr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
a.pullStream(ctx, sseSender, streamer)
|
|
return nil
|
|
}
|
|
|
|
func (a *OpenapiAgentRunApplication) checkConversation(ctx context.Context, ar *run.ChatV3Request, userID int64, connectorID int64) (*convEntity.Conversation, error) {
|
|
var conversationData *convEntity.Conversation
|
|
if ptr.From(ar.ConversationID) > 0 {
|
|
conData, err := ConversationSVC.ConversationDomainSVC.GetByID(ctx, ptr.From(ar.ConversationID))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
conversationData = conData
|
|
}
|
|
|
|
if ptr.From(ar.ConversationID) == 0 || conversationData == nil {
|
|
|
|
conData, err := ConversationSVC.ConversationDomainSVC.Create(ctx, &convEntity.CreateMeta{
|
|
AgentID: ar.BotID,
|
|
UserID: userID,
|
|
ConnectorID: connectorID,
|
|
Scene: common.Scene_SceneOpenApi,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if conData == nil {
|
|
return nil, errors.New("conversation data is nil")
|
|
}
|
|
conversationData = conData
|
|
|
|
ar.ConversationID = ptr.Of(conversationData.ID)
|
|
}
|
|
|
|
if conversationData.CreatorID != userID {
|
|
return nil, errors.New("conversation data not match")
|
|
}
|
|
|
|
return conversationData, nil
|
|
}
|
|
|
|
func (a *OpenapiAgentRunApplication) checkAgent(ctx context.Context, ar *run.ChatV3Request, connectorID int64) (*saEntity.SingleAgent, error) {
|
|
agentInfo, err := ConversationSVC.appContext.SingleAgentDomainSVC.ObtainAgentByIdentity(ctx, &singleagent.AgentIdentity{
|
|
AgentID: ar.BotID,
|
|
IsDraft: false,
|
|
ConnectorID: connectorID,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if agentInfo == nil {
|
|
return nil, errorx.New(errno.ErrAgentNotExists)
|
|
}
|
|
return agentInfo, nil
|
|
}
|
|
|
|
func (a *OpenapiAgentRunApplication) buildAgentRunRequest(ctx context.Context, ar *run.ChatV3Request, connectorID int64, spaceID int64, conversationData *convEntity.Conversation) (*entity.AgentRunMeta, error) {
|
|
|
|
shortcutCMDData, err := a.buildTools(ctx, ar.ShortcutCommand)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
multiContent, contentType, err := a.buildMultiContent(ctx, ar)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
displayContent := a.buildDisplayContent(ctx, ar)
|
|
arm := &entity.AgentRunMeta{
|
|
ConversationID: ptr.From(ar.ConversationID),
|
|
AgentID: ar.BotID,
|
|
Content: multiContent,
|
|
DisplayContent: displayContent,
|
|
SpaceID: spaceID,
|
|
UserID: ar.User,
|
|
SectionID: conversationData.SectionID,
|
|
PreRetrieveTools: shortcutCMDData,
|
|
IsDraft: false,
|
|
ConnectorID: connectorID,
|
|
ContentType: contentType,
|
|
Ext: ar.ExtraParams,
|
|
}
|
|
return arm, nil
|
|
}
|
|
|
|
func (a *OpenapiAgentRunApplication) buildTools(ctx context.Context, shortcmd *run.ShortcutCommandDetail) ([]*entity.Tool, error) {
|
|
var ts []*entity.Tool
|
|
|
|
if shortcmd == nil {
|
|
return ts, nil
|
|
}
|
|
|
|
var shortcutCMD *cmdEntity.ShortcutCmd
|
|
cmdMeta, err := a.ShortcutDomainSVC.GetByCmdID(ctx, shortcmd.CommandID, 0)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
shortcutCMD = cmdMeta
|
|
if shortcutCMD != nil {
|
|
argBytes, err := json.Marshal(shortcmd.Parameters)
|
|
if err == nil {
|
|
ts = append(ts, &entity.Tool{
|
|
PluginID: shortcutCMD.PluginID,
|
|
Arguments: string(argBytes),
|
|
ToolName: shortcutCMD.PluginToolName,
|
|
ToolID: shortcutCMD.PluginToolID,
|
|
Type: agentrun.ToolType(shortcutCMD.ToolType),
|
|
})
|
|
}
|
|
}
|
|
|
|
return ts, nil
|
|
}
|
|
|
|
func (a *OpenapiAgentRunApplication) buildDisplayContent(_ context.Context, ar *run.ChatV3Request) string {
|
|
for _, item := range ar.AdditionalMessages {
|
|
if item.ContentType == run.ContentTypeMixApi {
|
|
return item.Content
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (a *OpenapiAgentRunApplication) buildMultiContent(ctx context.Context, ar *run.ChatV3Request) ([]*message.InputMetaData, message.ContentType, error) {
|
|
var multiContents []*message.InputMetaData
|
|
contentType := message.ContentTypeText
|
|
|
|
for _, item := range ar.AdditionalMessages {
|
|
if item == nil {
|
|
continue
|
|
}
|
|
if item.Role != string(schema.User) {
|
|
return nil, contentType, errors.New("role not match")
|
|
}
|
|
if item.ContentType == run.ContentTypeText {
|
|
if item.Content == "" {
|
|
continue
|
|
}
|
|
multiContents = append(multiContents, &message.InputMetaData{
|
|
Type: message.InputTypeText,
|
|
Text: item.Content,
|
|
})
|
|
}
|
|
|
|
if item.ContentType == run.ContentTypeMixApi {
|
|
contentType = message.ContentTypeMix
|
|
var inputs []*run.AdditionalContent
|
|
err := json.Unmarshal([]byte(item.Content), &inputs)
|
|
|
|
logs.CtxInfof(ctx, "inputs:%v, err:%v", conv.DebugJsonToStr(inputs), err)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
for _, one := range inputs {
|
|
if one == nil {
|
|
continue
|
|
}
|
|
switch message.InputType(one.Type) {
|
|
case message.InputTypeText:
|
|
multiContents = append(multiContents, &message.InputMetaData{
|
|
Type: message.InputTypeText,
|
|
Text: ptr.From(one.Text),
|
|
})
|
|
case message.InputTypeImage, message.InputTypeFile:
|
|
|
|
var fileUrl, fileURI string
|
|
if one.GetFileURL() != "" {
|
|
fileUrl = one.GetFileURL()
|
|
} else if one.GetFileID() != 0 {
|
|
fileInfo, err := a.UploaodDomainSVC.GetFile(ctx, &uploadService.GetFileRequest{
|
|
ID: one.GetFileID(),
|
|
})
|
|
if err != nil {
|
|
return nil, contentType, err
|
|
}
|
|
fileUrl = fileInfo.File.Url
|
|
fileURI = fileInfo.File.TosURI
|
|
}
|
|
multiContents = append(multiContents, &message.InputMetaData{
|
|
Type: message.InputType(one.Type),
|
|
FileData: []*message.FileData{
|
|
{
|
|
Url: fileUrl,
|
|
URI: fileURI,
|
|
},
|
|
},
|
|
})
|
|
default:
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
return multiContents, contentType, nil
|
|
}
|
|
|
|
func (a *OpenapiAgentRunApplication) pullStream(ctx context.Context, sseSender *sseImpl.SSenderImpl, streamer *schema.StreamReader[*entity.AgentRunResponse]) {
|
|
for {
|
|
chunk, recvErr := streamer.Recv()
|
|
logs.CtxInfof(ctx, "chunk :%v, err:%v", conv.DebugJsonToStr(chunk), recvErr)
|
|
if recvErr != nil {
|
|
if errors.Is(recvErr, io.EOF) {
|
|
return
|
|
}
|
|
sseSender.Send(ctx, buildErrorEvent(errno.ErrConversationAgentRunError, recvErr.Error()))
|
|
return
|
|
}
|
|
|
|
switch chunk.Event {
|
|
|
|
case entity.RunEventError:
|
|
sseSender.Send(ctx, buildErrorEvent(chunk.Error.Code, chunk.Error.Msg))
|
|
case entity.RunEventStreamDone:
|
|
sseSender.Send(ctx, buildDoneEvent(string(entity.RunEventStreamDone)))
|
|
case entity.RunEventAck:
|
|
case entity.RunEventCreated, entity.RunEventCancelled, entity.RunEventInProgress, entity.RunEventFailed, entity.RunEventCompleted:
|
|
sseSender.Send(ctx, buildMessageChunkEvent(string(chunk.Event), buildARSM2ApiChatMessage(chunk)))
|
|
case entity.RunEventMessageDelta, entity.RunEventMessageCompleted:
|
|
sseSender.Send(ctx, buildMessageChunkEvent(string(chunk.Event), buildARSM2ApiMessage(chunk)))
|
|
|
|
default:
|
|
logs.CtxErrorf(ctx, "unknow handler event:%v", chunk.Event)
|
|
}
|
|
}
|
|
}
|
|
|
|
func buildARSM2ApiMessage(chunk *entity.AgentRunResponse) []byte {
|
|
chunkMessageItem := chunk.ChunkMessageItem
|
|
chunkMessage := &run.ChatV3MessageDetail{
|
|
ID: strconv.FormatInt(chunkMessageItem.ID, 10),
|
|
ConversationID: strconv.FormatInt(chunkMessageItem.ConversationID, 10),
|
|
BotID: strconv.FormatInt(chunkMessageItem.AgentID, 10),
|
|
Role: string(chunkMessageItem.Role),
|
|
Type: string(chunkMessageItem.MessageType),
|
|
Content: chunkMessageItem.Content,
|
|
ContentType: string(chunkMessageItem.ContentType),
|
|
MetaData: chunkMessageItem.Ext,
|
|
ChatID: strconv.FormatInt(chunkMessageItem.RunID, 10),
|
|
ReasoningContent: chunkMessageItem.ReasoningContent,
|
|
CreatedAt: ptr.Of(chunkMessageItem.CreatedAt / 1000),
|
|
}
|
|
|
|
mCM, _ := json.Marshal(chunkMessage)
|
|
return mCM
|
|
}
|
|
|
|
func buildARSM2ApiChatMessage(chunk *entity.AgentRunResponse) []byte {
|
|
chunkRunItem := chunk.ChunkRunItem
|
|
chunkMessage := &run.ChatV3ChatDetail{
|
|
ID: chunkRunItem.ID,
|
|
ConversationID: chunkRunItem.ConversationID,
|
|
BotID: chunkRunItem.AgentID,
|
|
Status: string(chunkRunItem.Status),
|
|
SectionID: ptr.Of(chunkRunItem.SectionID),
|
|
CreatedAt: ptr.Of(int32(chunkRunItem.CreatedAt / 1000)),
|
|
CompletedAt: ptr.Of(int32(chunkRunItem.CompletedAt / 1000)),
|
|
FailedAt: ptr.Of(int32(chunkRunItem.FailedAt / 1000)),
|
|
}
|
|
if chunkRunItem.Usage != nil {
|
|
chunkMessage.Usage = &run.Usage{
|
|
TokenCount: ptr.Of(int32(chunkRunItem.Usage.LlmTotalTokens)),
|
|
InputTokens: ptr.Of(int32(chunkRunItem.Usage.LlmPromptTokens)),
|
|
OutputTokens: ptr.Of(int32(chunkRunItem.Usage.LlmCompletionTokens)),
|
|
}
|
|
}
|
|
mCM, _ := json.Marshal(chunkMessage)
|
|
return mCM
|
|
}
|
|
|
|
func (a *OpenapiAgentRunApplication) CancelRun(ctx context.Context, req *run.CancelChatApiRequest) (*run.CancelChatApiResponse, error) {
|
|
resp := new(run.CancelChatApiResponse)
|
|
|
|
apiKeyInfo := ctxutil.GetApiAuthFromCtx(ctx)
|
|
userID := apiKeyInfo.UserID
|
|
|
|
runRecord, err := ConversationSVC.AgentRunDomainSVC.GetByID(ctx, req.ChatID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if runRecord == nil {
|
|
return nil, errorx.New(errno.ErrRecordNotFound)
|
|
}
|
|
|
|
conversationData, err := ConversationSVC.ConversationDomainSVC.GetByID(ctx, req.ConversationID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if userID != conversationData.CreatorID {
|
|
return nil, errorx.New(errno.ErrConversationPermissionCode, errorx.KV("msg", "user not match"))
|
|
}
|
|
|
|
if runRecord.Status != entity.RunStatusInProgress && runRecord.Status != entity.RunStatusCreated {
|
|
return nil, errorx.New(errno.ErrInProgressCanNotCancel)
|
|
}
|
|
|
|
runMeta, err := ConversationSVC.AgentRunDomainSVC.Cancel(ctx, &entity.CancelRunMeta{
|
|
RunID: req.ChatID,
|
|
ConversationID: req.ConversationID,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if runMeta == nil {
|
|
return nil, errorx.New(errno.ErrConversationAgentRunError)
|
|
}
|
|
|
|
resp.ChatV3ChatDetail = &run.ChatV3ChatDetail{
|
|
ID: runMeta.ID,
|
|
ConversationID: runMeta.ConversationID,
|
|
BotID: runMeta.AgentID,
|
|
Status: string(runMeta.Status),
|
|
SectionID: ptr.Of(runMeta.SectionID),
|
|
CreatedAt: ptr.Of(int32(runMeta.CreatedAt / 1000)),
|
|
CompletedAt: ptr.Of(int32(runMeta.CompletedAt / 1000)),
|
|
FailedAt: ptr.Of(int32(runMeta.FailedAt / 1000)),
|
|
}
|
|
if runMeta.Usage != nil {
|
|
resp.ChatV3ChatDetail.Usage = &run.Usage{
|
|
TokenCount: ptr.Of(int32(runMeta.Usage.LlmTotalTokens)),
|
|
InputTokens: ptr.Of(int32(runMeta.Usage.LlmPromptTokens)),
|
|
OutputTokens: ptr.Of(int32(runMeta.Usage.LlmCompletionTokens)),
|
|
}
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|