From 61e8331a5496e4dd4f2d3071512ee120c43952bd Mon Sep 17 00:00:00 2001 From: junwen-lee Date: Mon, 15 Sep 2025 15:54:29 +0800 Subject: [PATCH] =?UTF-8?q?feat(singleagent):=20openapi=20v3/chat=20additi?= =?UTF-8?q?onal=20message=20support=20assista=E2=80=A6=20(#2067)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../conversation/openapi_agent_run.go | 106 +++++++++++++----- .../crossdomain/contract/message/message.go | 1 + .../message/messagemock/message_mock.go | 15 +++ backend/crossdomain/impl/message/message.go | 3 + .../agentrun/entity/run_record.go | 46 +++++--- .../agentrun/internal/chatflow_run.go | 4 +- .../agentrun/internal/message_builder.go | 45 ++++++++ .../agentrun/internal/message_event.go | 46 +++++--- .../conversation/agentrun/internal/run.go | 22 +++- .../agentrun/internal/singleagent_run.go | 2 +- .../message/internal/dal/message.go | 19 ++++ .../message/repository/repository.go | 1 + .../conversation/message/service/message.go | 1 + .../message/service/message_impl.go | 4 + .../message/service/message_test.go | 52 +++++++++ backend/go.mod | 1 + 16 files changed, 299 insertions(+), 69 deletions(-) diff --git a/backend/application/conversation/openapi_agent_run.go b/backend/application/conversation/openapi_agent_run.go index b7c7145e..ca993408 100644 --- a/backend/application/conversation/openapi_agent_run.go +++ b/backend/application/conversation/openapi_agent_run.go @@ -21,6 +21,7 @@ import ( "encoding/json" "errors" "io" + "slices" "strconv" "github.com/cloudwego/eino/schema" @@ -102,7 +103,7 @@ func (a *OpenapiAgentRunApplication) checkConversation(ctx context.Context, ar * return nil, err } if conData == nil { - return nil, errors.New("conversation data is nil") + return nil, errorx.New(errno.ErrConversationNotFound) } conversationData = conData @@ -110,7 +111,7 @@ func (a *OpenapiAgentRunApplication) checkConversation(ctx context.Context, ar * } if conversationData.CreatorID != userID { - return nil, errors.New("conversation data not match") + return nil, errorx.New(errno.ErrConversationPermissionCode, errorx.KV("msg","user not match")) } return conversationData, nil @@ -138,26 +139,31 @@ func (a *OpenapiAgentRunApplication) buildAgentRunRequest(ctx context.Context, a if err != nil { return nil, err } - multiContent, contentType, err := a.buildMultiContent(ctx, ar) + multiAdditionalMessages, err := a.parseAdditionalMessages(ctx, ar) + if err != nil { + return nil, err + } + filterMultiAdditionalMessages, multiContent, contentType, err := a.parseQueryContent(ctx, multiAdditionalMessages) if err != nil { return nil, err } displayContent := a.buildDisplayContent(ctx, ar) arm := &entity.AgentRunMeta{ - ConversationID: ptr.From(ar.ConversationID), - AgentID: ar.BotID, - Content: multiContent, - DisplayContent: displayContent, - SpaceID: spaceID, - UserID: ar.User, - SectionID: conversationData.SectionID, - PreRetrieveTools: shortcutCMDData, - IsDraft: false, - ConnectorID: connectorID, - ContentType: contentType, - Ext: ar.ExtraParams, - CustomVariables: ar.CustomVariables, - CozeUID: conversationData.CreatorID, + ConversationID: ptr.From(ar.ConversationID), + AgentID: ar.BotID, + Content: multiContent, + DisplayContent: displayContent, + SpaceID: spaceID, + UserID: ar.User, + SectionID: conversationData.SectionID, + PreRetrieveTools: shortcutCMDData, + IsDraft: false, + ConnectorID: connectorID, + ContentType: contentType, + Ext: ar.ExtraParams, + CustomVariables: ar.CustomVariables, + CozeUID: conversationData.CreatorID, + AdditionalMessages: filterMultiAdditionalMessages, } return arm, nil } @@ -200,29 +206,68 @@ func (a *OpenapiAgentRunApplication) buildDisplayContent(_ context.Context, ar * return "" } -func (a *OpenapiAgentRunApplication) buildMultiContent(ctx context.Context, ar *run.ChatV3Request) ([]*message.InputMetaData, message.ContentType, error) { - var multiContents []*message.InputMetaData - contentType := message.ContentTypeText +func (a *OpenapiAgentRunApplication) parseQueryContent(ctx context.Context, multiAdditionalMessages []*entity.AdditionalMessage) ([]*entity.AdditionalMessage, []*message.InputMetaData, message.ContentType, error) { + + var multiContent []*message.InputMetaData + var contentType message.ContentType + var filterMultiAdditionalMessages []*entity.AdditionalMessage + filterMultiAdditionalMessages = multiAdditionalMessages + + if len(multiAdditionalMessages) > 0 { + lastMessage := multiAdditionalMessages[len(multiAdditionalMessages)-1] + if lastMessage != nil && lastMessage.Role == schema.User { + multiContent = lastMessage.Content + contentType = lastMessage.ContentType + filterMultiAdditionalMessages = multiAdditionalMessages[:len(multiAdditionalMessages)-1] + } + } + + return filterMultiAdditionalMessages, multiContent, contentType, nil +} + +func (a *OpenapiAgentRunApplication) parseAdditionalMessages(ctx context.Context, ar *run.ChatV3Request) ([]*entity.AdditionalMessage, error) { + + additionalMessages := make([]*entity.AdditionalMessage, 0, len(ar.AdditionalMessages)) for _, item := range ar.AdditionalMessages { if item == nil { continue } - if item.Role != string(schema.User) { - return nil, contentType, errors.New("role not match") + if item.Role != string(schema.User) && item.Role != string(schema.Assistant) { + return nil, errors.New("additional message role only support user and assistant") + } + if item.Type != nil && !slices.Contains([]message.MessageType{message.MessageTypeQuestion, message.MessageTypeAnswer}, message.MessageType(*item.Type)) { + return nil, errors.New("additional message type only support question and answer now") } + + addOne := entity.AdditionalMessage{ + Role: schema.RoleType(item.Role), + } + if item.Type != nil { + addOne.Type = message.MessageType(*item.Type) + } else { + addOne.Type = message.MessageTypeQuestion + } + if item.ContentType == run.ContentTypeText { if item.Content == "" { continue } - multiContents = append(multiContents, &message.InputMetaData{ + + addOne.ContentType = message.ContentTypeText + addOne.Content = []*message.InputMetaData{{ Type: message.InputTypeText, Text: item.Content, - }) + }} } if item.ContentType == run.ContentTypeMixApi { - contentType = message.ContentTypeMix + + if ptr.From(item.Type) == string(message.MessageTypeAnswer) { + return nil, errors.New(" answer messages only support text content") + } + + addOne.ContentType = message.ContentTypeMix var inputs []*run.AdditionalContent err := json.Unmarshal([]byte(item.Content), &inputs) @@ -236,7 +281,8 @@ func (a *OpenapiAgentRunApplication) buildMultiContent(ctx context.Context, ar * } switch message.InputType(one.Type) { case message.InputTypeText: - multiContents = append(multiContents, &message.InputMetaData{ + + addOne.Content = append(addOne.Content, &message.InputMetaData{ Type: message.InputTypeText, Text: ptr.From(one.Text), }) @@ -250,12 +296,12 @@ func (a *OpenapiAgentRunApplication) buildMultiContent(ctx context.Context, ar * ID: one.GetFileID(), }) if err != nil { - return nil, contentType, err + return nil, err } fileUrl = fileInfo.File.Url fileURI = fileInfo.File.TosURI } - multiContents = append(multiContents, &message.InputMetaData{ + addOne.Content = append(addOne.Content, &message.InputMetaData{ Type: message.InputType(one.Type), FileData: []*message.FileData{ { @@ -269,10 +315,10 @@ func (a *OpenapiAgentRunApplication) buildMultiContent(ctx context.Context, ar * } } } - + additionalMessages = append(additionalMessages, &addOne) } - return multiContents, contentType, nil + return additionalMessages, nil } func (a *OpenapiAgentRunApplication) pullStream(ctx context.Context, sseSender *sseImpl.SSenderImpl, streamer *schema.StreamReader[*entity.AgentRunResponse]) { diff --git a/backend/crossdomain/contract/message/message.go b/backend/crossdomain/contract/message/message.go index 20ad431c..b6d425c1 100644 --- a/backend/crossdomain/contract/message/message.go +++ b/backend/crossdomain/contract/message/message.go @@ -30,6 +30,7 @@ type Message interface { GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*message.Message, error) PreCreate(ctx context.Context, msg *message.Message) (*message.Message, error) Create(ctx context.Context, msg *message.Message) (*message.Message, error) + BatchCreate(ctx context.Context, msg []*message.Message) ([]*message.Message, error) List(ctx context.Context, meta *entity.ListMeta) (*entity.ListResult, error) ListWithoutPair(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error) Edit(ctx context.Context, msg *message.Message) (*message.Message, error) diff --git a/backend/crossdomain/contract/message/messagemock/message_mock.go b/backend/crossdomain/contract/message/messagemock/message_mock.go index 959566cc..49570270 100644 --- a/backend/crossdomain/contract/message/messagemock/message_mock.go +++ b/backend/crossdomain/contract/message/messagemock/message_mock.go @@ -59,6 +59,21 @@ func (m *MockMessage) EXPECT() *MockMessageMockRecorder { return m.recorder } +// BatchCreate mocks base method. +func (m *MockMessage) BatchCreate(ctx context.Context, msg []*message.Message) ([]*message.Message, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BatchCreate", ctx, msg) + ret0, _ := ret[0].([]*message.Message) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BatchCreate indicates an expected call of BatchCreate. +func (mr *MockMessageMockRecorder) BatchCreate(ctx, msg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchCreate", reflect.TypeOf((*MockMessage)(nil).BatchCreate), ctx, msg) +} + // Create mocks base method. func (m *MockMessage) Create(ctx context.Context, msg *message.Message) (*message.Message, error) { m.ctrl.T.Helper() diff --git a/backend/crossdomain/impl/message/message.go b/backend/crossdomain/impl/message/message.go index 2bd0b217..eafe75a3 100644 --- a/backend/crossdomain/impl/message/message.go +++ b/backend/crossdomain/impl/message/message.go @@ -170,6 +170,9 @@ func (c *impl) GetMessageByID(ctx context.Context, id int64) (*entity.Message, e func (c *impl) ListWithoutPair(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error) { return c.DomainSVC.ListWithoutPair(ctx, req) } +func (c *impl) BatchCreate(ctx context.Context, msgs []*entity.Message) ([]*entity.Message, error) { + return c.DomainSVC.BatchCreate(ctx, msgs) +} func convertToConvAndSchemaMessage(ctx context.Context, msgs []*entity.Message) ([]*crossmessage.WfMessage, []*schema.Message, error) { messages := make([]*schema.Message, 0) diff --git a/backend/domain/conversation/agentrun/entity/run_record.go b/backend/domain/conversation/agentrun/entity/run_record.go index 445dfb95..077e0531 100644 --- a/backend/domain/conversation/agentrun/entity/run_record.go +++ b/backend/domain/conversation/agentrun/entity/run_record.go @@ -106,24 +106,34 @@ type MetaInfo struct { } type AgentRunMeta struct { - ConversationID int64 `json:"conversation_id"` - ConnectorID int64 `json:"connector_id"` - SpaceID int64 `json:"space_id"` - Scene common.Scene `json:"scene"` - SectionID int64 `json:"section_id"` - Name string `json:"name"` - UserID string `json:"user_id"` - CozeUID int64 `json:"coze_uid"` - AgentID int64 `json:"agent_id"` - ContentType message.ContentType `json:"content_type"` - Content []*message.InputMetaData `json:"content"` - PreRetrieveTools []*Tool `json:"tools"` - IsDraft bool `json:"is_draft"` - CustomerConfig *CustomerConfig `json:"customer_config"` - DisplayContent string `json:"display_content"` - CustomVariables map[string]string `json:"custom_variables"` - Version string `json:"version"` - Ext map[string]string `json:"ext"` + ConversationID int64 `json:"conversation_id"` + ConnectorID int64 `json:"connector_id"` + SpaceID int64 `json:"space_id"` + Scene common.Scene `json:"scene"` + SectionID int64 `json:"section_id"` + Name string `json:"name"` + UserID string `json:"user_id"` + CozeUID int64 `json:"coze_uid"` + AgentID int64 `json:"agent_id"` + ContentType message.ContentType `json:"content_type"` + Content []*message.InputMetaData `json:"content"` + PreRetrieveTools []*Tool `json:"tools"` + IsDraft bool `json:"is_draft"` + CustomerConfig *CustomerConfig `json:"customer_config"` + DisplayContent string `json:"display_content"` + CustomVariables map[string]string `json:"custom_variables"` + Version string `json:"version"` + Ext map[string]string `json:"ext"` + AdditionalMessages []*AdditionalMessage `json:"additional_messages"` +} + +type AdditionalMessage struct { + Role schema.RoleType `json:"role"` + Type message.MessageType `json:"type"` + Content []*message.InputMetaData `json:"content"` + ContentType message.ContentType `json:"content_type"` + Name *string `json:"name"` + Meta map[string]string `json:"meta"` } type UpdateMeta struct { diff --git a/backend/domain/conversation/agentrun/internal/chatflow_run.go b/backend/domain/conversation/agentrun/internal/chatflow_run.go index 87018a85..fbbf7a88 100644 --- a/backend/domain/conversation/agentrun/internal/chatflow_run.go +++ b/backend/domain/conversation/agentrun/internal/chatflow_run.go @@ -41,7 +41,7 @@ import ( func (art *AgentRuntime) ChatflowRun(ctx context.Context, imagex imagex.ImageX) (err error) { - mh := &MesssageEventHanlder{ + mh := &MessageEventHandler{ sw: art.SW, messageEvent: art.MessageEvent, } @@ -110,7 +110,7 @@ func concatWfInput(rtDependence *AgentRuntime) string { return strings.Trim(input, ",") } -func (art *AgentRuntime) pullWfStream(ctx context.Context, events *schema.StreamReader[*crossworkflow.WorkflowMessage], mh *MesssageEventHanlder) { +func (art *AgentRuntime) pullWfStream(ctx context.Context, events *schema.StreamReader[*crossworkflow.WorkflowMessage], mh *MessageEventHandler) { fullAnswerContent := bytes.NewBuffer([]byte{}) var usage *msgEntity.UsageExt diff --git a/backend/domain/conversation/agentrun/internal/message_builder.go b/backend/domain/conversation/agentrun/internal/message_builder.go index 6891f5a8..0eb1e789 100644 --- a/backend/domain/conversation/agentrun/internal/message_builder.go +++ b/backend/domain/conversation/agentrun/internal/message_builder.go @@ -221,6 +221,51 @@ func preCreateAnswer(ctx context.Context, rtDependence *AgentRuntime) (*msgEntit return crossmessage.DefaultSVC().PreCreate(ctx, msgMeta) } +func buildAdditionalMessage2Create(ctx context.Context, runRecord *entity.RunRecordMeta, additionalMessage *entity.AdditionalMessage, userID string) *message.Message { + + msg := &msgEntity.Message{ + ConversationID: runRecord.ConversationID, + RunID: runRecord.ID, + AgentID: runRecord.AgentID, + SectionID: runRecord.SectionID, + UserID: userID, + MessageType: additionalMessage.Type, + } + + switch additionalMessage.Type { + case message.MessageTypeQuestion: + msg.Role = schema.User + msg.ContentType = additionalMessage.ContentType + for _, content := range additionalMessage.Content { + if content.Type == message.InputTypeText { + msg.Content = content.Text + break + } + } + msg.MultiContent = additionalMessage.Content + + case message.MessageTypeAnswer: + msg.Role = schema.Assistant + msg.ContentType = message.ContentTypeText + for _, content := range additionalMessage.Content { + if content.Type == message.InputTypeText { + msg.Content = content.Text + break + } + } + modelContent := &schema.Message{ + Role: schema.Assistant, + Content: msg.Content, + } + + jsonContent, err := json.Marshal(modelContent) + if err == nil { + msg.ModelContent = string(jsonContent) + } + } + return msg +} + func buildAgentMessage2Create(ctx context.Context, chunk *entity.AgentRespEvent, messageType message.MessageType, rtDependence *AgentRuntime) *message.Message { arm := rtDependence.GetRunMeta() msg := &msgEntity.Message{ diff --git a/backend/domain/conversation/agentrun/internal/message_event.go b/backend/domain/conversation/agentrun/internal/message_event.go index 7a7f9528..bdf510ac 100644 --- a/backend/domain/conversation/agentrun/internal/message_event.go +++ b/backend/domain/conversation/agentrun/internal/message_event.go @@ -98,12 +98,12 @@ func (e *Event) SendStreamDoneEvent(sw *schema.StreamWriter[*entity.AgentRunResp sw.Send(resp, nil) } -type MesssageEventHanlder struct { +type MessageEventHandler struct { messageEvent *Event sw *schema.StreamWriter[*entity.AgentRunResponse] } -func (mh *MesssageEventHanlder) handlerErr(_ context.Context, err error) { +func (mh *MessageEventHandler) handlerErr(_ context.Context, err error) { var errMsg string var statusErr errorx.StatusError @@ -123,7 +123,7 @@ func (mh *MesssageEventHanlder) handlerErr(_ context.Context, err error) { }) } -func (mh *MesssageEventHanlder) handlerAckMessage(_ context.Context, input *msgEntity.Message) error { +func (mh *MessageEventHandler) handlerAckMessage(_ context.Context, input *msgEntity.Message) error { sendMsg := &entity.ChunkMessageItem{ ID: input.ID, ConversationID: input.ConversationID, @@ -142,7 +142,7 @@ func (mh *MesssageEventHanlder) handlerAckMessage(_ context.Context, input *msgE return nil } -func (mh *MesssageEventHanlder) handlerFunctionCall(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error { +func (mh *MessageEventHandler) handlerFunctionCall(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error { cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeFunctionCall, rtDependence) cmData, err := crossmessage.DefaultSVC().Create(ctx, cm) @@ -156,7 +156,7 @@ func (mh *MesssageEventHanlder) handlerFunctionCall(ctx context.Context, chunk * return nil } -func (mh *MesssageEventHanlder) handlerTooResponse(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime, preToolResponseMsg *msgEntity.Message, toolResponseMsgContent string) error { +func (mh *MessageEventHandler) handlerTooResponse(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime, preToolResponseMsg *msgEntity.Message, toolResponseMsgContent string) error { cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeToolResponse, rtDependence) @@ -184,7 +184,7 @@ func (mh *MesssageEventHanlder) handlerTooResponse(ctx context.Context, chunk *e return nil } -func (mh *MesssageEventHanlder) handlerSuggest(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error { +func (mh *MessageEventHandler) handlerSuggest(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error { cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeFlowUp, rtDependence) cmData, err := crossmessage.DefaultSVC().Create(ctx, cm) @@ -199,7 +199,7 @@ func (mh *MesssageEventHanlder) handlerSuggest(ctx context.Context, chunk *entit return nil } -func (mh *MesssageEventHanlder) handlerKnowledge(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error { +func (mh *MessageEventHandler) handlerKnowledge(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error { cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeKnowledge, rtDependence) cmData, err := crossmessage.DefaultSVC().Create(ctx, cm) if err != nil { @@ -212,7 +212,7 @@ func (mh *MesssageEventHanlder) handlerKnowledge(ctx context.Context, chunk *ent return nil } -func (mh *MesssageEventHanlder) handlerAnswer(ctx context.Context, msg *entity.ChunkMessageItem, usage *msgEntity.UsageExt, rtDependence *AgentRuntime, preAnswerMsg *msgEntity.Message) error { +func (mh *MessageEventHandler) handlerAnswer(ctx context.Context, msg *entity.ChunkMessageItem, usage *msgEntity.UsageExt, rtDependence *AgentRuntime, preAnswerMsg *msgEntity.Message) error { if len(msg.Content) == 0 && len(ptr.From(msg.ReasoningContent)) == 0 { return nil @@ -265,7 +265,7 @@ func (mh *MesssageEventHanlder) handlerAnswer(ctx context.Context, msg *entity.C return nil } -func (mh *MesssageEventHanlder) handlerFinalAnswerFinish(ctx context.Context, rtDependence *AgentRuntime) error { +func (mh *MessageEventHandler) handlerFinalAnswerFinish(ctx context.Context, rtDependence *AgentRuntime) error { cm := buildAgentMessage2Create(ctx, nil, message.MessageTypeVerbose, rtDependence) cmData, err := crossmessage.DefaultSVC().Create(ctx, cm) if err != nil { @@ -278,7 +278,7 @@ func (mh *MesssageEventHanlder) handlerFinalAnswerFinish(ctx context.Context, rt return nil } -func (mh *MesssageEventHanlder) handlerInterruptVerbose(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error { +func (mh *MessageEventHandler) handlerInterruptVerbose(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error { cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeInterrupt, rtDependence) cmData, err := crossmessage.DefaultSVC().Create(ctx, cm) if err != nil { @@ -291,7 +291,7 @@ func (mh *MesssageEventHanlder) handlerInterruptVerbose(ctx context.Context, chu return nil } -func (mh *MesssageEventHanlder) handlerWfUsage(ctx context.Context, msg *entity.ChunkMessageItem, usage *msgEntity.UsageExt) error { +func (mh *MessageEventHandler) handlerWfUsage(ctx context.Context, msg *entity.ChunkMessageItem, usage *msgEntity.UsageExt) error { if msg.Ext == nil { msg.Ext = map[string]string{} @@ -314,7 +314,7 @@ func (mh *MesssageEventHanlder) handlerWfUsage(ctx context.Context, msg *entity. return nil } -func (mh *MesssageEventHanlder) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime, firstAnswerMsg *msgEntity.Message, reasoningContent string) error { +func (mh *MessageEventHandler) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime, firstAnswerMsg *msgEntity.Message, reasoningContent string) error { interruptData, cType, err := parseInterruptData(ctx, chunk.Interrupt) if err != nil { return err @@ -366,7 +366,7 @@ func (mh *MesssageEventHanlder) handlerInterrupt(ctx context.Context, chunk *ent return nil } -func (mh *MesssageEventHanlder) handlerWfInterruptMsg(ctx context.Context, stateMsg *crossworkflow.StateMessage, rtDependence *AgentRuntime) { +func (mh *MessageEventHandler) handlerWfInterruptMsg(ctx context.Context, stateMsg *crossworkflow.StateMessage, rtDependence *AgentRuntime) { interruptData, cType, err := handlerWfInterruptEvent(ctx, stateMsg.InterruptEvent) if err != nil { return @@ -412,7 +412,7 @@ func (mh *MesssageEventHanlder) handlerWfInterruptMsg(ctx context.Context, state } } -func (mh *MesssageEventHanlder) HandlerInput(ctx context.Context, rtDependence *AgentRuntime) (*msgEntity.Message, error) { +func (mh *MessageEventHandler) HandlerInput(ctx context.Context, rtDependence *AgentRuntime) (*msgEntity.Message, error) { msgMeta := buildAgentMessage2Create(ctx, nil, message.MessageTypeQuestion, rtDependence) cm, err := crossmessage.DefaultSVC().Create(ctx, msgMeta) @@ -426,3 +426,21 @@ func (mh *MesssageEventHanlder) HandlerInput(ctx context.Context, rtDependence * } return cm, nil } + +func (mh *MessageEventHandler) ParseAdditionalMessages(ctx context.Context, rtDependence *AgentRuntime, runRecord *entity.RunRecordMeta) error { + + if len(rtDependence.GetRunMeta().AdditionalMessages) == 0 { + return nil + } + + additionalMessages := make([]*message.Message, 0, len(rtDependence.GetRunMeta().AdditionalMessages)) + + for _, msg := range rtDependence.GetRunMeta().AdditionalMessages { + cm := buildAdditionalMessage2Create(ctx, runRecord, msg, rtDependence.GetRunMeta().UserID) + additionalMessages = append(additionalMessages, cm) + } + + _, err := crossmessage.DefaultSVC().BatchCreate(ctx, additionalMessages) + + return err +} diff --git a/backend/domain/conversation/agentrun/internal/run.go b/backend/domain/conversation/agentrun/internal/run.go index a315f78e..839f0e18 100644 --- a/backend/domain/conversation/agentrun/internal/run.go +++ b/backend/domain/conversation/agentrun/internal/run.go @@ -107,6 +107,11 @@ func (rd *AgentRuntime) GetHistory() []*msgEntity.Message { func (art *AgentRuntime) Run(ctx context.Context) (err error) { + mh := &MessageEventHandler{ + messageEvent: art.MessageEvent, + sw: art.SW, + } + agentInfo, err := getAgentInfo(ctx, art.GetRunMeta().AgentID, art.GetRunMeta().IsDraft, art.GetRunMeta().ConnectorID) if err != nil { return @@ -114,6 +119,18 @@ func (art *AgentRuntime) Run(ctx context.Context) (err error) { art.SetAgentInfo(agentInfo) + if len(art.GetRunMeta().AdditionalMessages) > 0 { + var additionalRunRecord *entity.RunRecordMeta + additionalRunRecord, err = art.RunRecordRepo.Create(ctx, art.GetRunMeta()) + if err != nil { + return + } + err = mh.ParseAdditionalMessages(ctx, art, additionalRunRecord) + if err != nil { + return + } + } + history, err := art.getHistory(ctx) if err != nil { return @@ -140,10 +157,7 @@ func (art *AgentRuntime) Run(ctx context.Context) (err error) { } art.RunProcess.StepToComplete(ctx, srRecord, art.SW, art.GetUsage()) }() - mh := &MesssageEventHanlder{ - messageEvent: art.MessageEvent, - sw: art.SW, - } + input, err := mh.HandlerInput(ctx, art) if err != nil { return diff --git a/backend/domain/conversation/agentrun/internal/singleagent_run.go b/backend/domain/conversation/agentrun/internal/singleagent_run.go index 147ef050..fe598676 100644 --- a/backend/domain/conversation/agentrun/internal/singleagent_run.go +++ b/backend/domain/conversation/agentrun/internal/singleagent_run.go @@ -80,7 +80,7 @@ func (art *AgentRuntime) AgentStreamExecute(ctx context.Context, imagex imagex.I func (art *AgentRuntime) push(ctx context.Context, mainChan chan *entity.AgentRespEvent) { - mh := &MesssageEventHanlder{ + mh := &MessageEventHandler{ sw: art.SW, messageEvent: art.MessageEvent, } diff --git a/backend/domain/conversation/message/internal/dal/message.go b/backend/domain/conversation/message/internal/dal/message.go index 739978d7..3d230842 100644 --- a/backend/domain/conversation/message/internal/dal/message.go +++ b/backend/domain/conversation/message/internal/dal/message.go @@ -72,6 +72,25 @@ func (dao *MessageDAO) Create(ctx context.Context, msg *entity.Message) (*entity return dao.messagePO2DO(poData), nil } +func (dao *MessageDAO) BatchCreate(ctx context.Context, msg []*entity.Message) ([]*entity.Message, error) { + poList := make([]*model.Message, 0, len(msg)) + for _, m := range msg { + po, err := dao.messageDO2PO(ctx, m) + if err != nil { + return nil, err + } + poList = append(poList, po) + } + + do := dao.query.Message.WithContext(ctx).Debug() + cErr := do.CreateInBatches(poList, len(poList)) + if cErr != nil { + return nil, cErr + } + + return dao.batchMessagePO2DO(poList), nil +} + func (dao *MessageDAO) List(ctx context.Context, listMeta *entity.ListMeta) ([]*entity.Message, bool, error) { m := dao.query.Message do := m.WithContext(ctx).Debug().Where(m.ConversationID.Eq(listMeta.ConversationID)).Where(m.Status.Eq(int32(entity.MessageStatusAvailable))) diff --git a/backend/domain/conversation/message/repository/repository.go b/backend/domain/conversation/message/repository/repository.go index 11796436..f375255c 100644 --- a/backend/domain/conversation/message/repository/repository.go +++ b/backend/domain/conversation/message/repository/repository.go @@ -34,6 +34,7 @@ func NewMessageRepo(db *gorm.DB, idGen idgen.IDGenerator) MessageRepo { type MessageRepo interface { PreCreate(ctx context.Context, msg *entity.Message) (*entity.Message, error) Create(ctx context.Context, msg *entity.Message) (*entity.Message, error) + BatchCreate(ctx context.Context, msg []*entity.Message) ([]*entity.Message, error) List(ctx context.Context, listMeta *entity.ListMeta) ([]*entity.Message, bool, error) GetByRunIDs(ctx context.Context, runIDs []int64, orderBy string) ([]*entity.Message, error) Edit(ctx context.Context, msgID int64, message *message.Message) (int64, error) diff --git a/backend/domain/conversation/message/service/message.go b/backend/domain/conversation/message/service/message.go index 8cab1258..1a367ae8 100644 --- a/backend/domain/conversation/message/service/message.go +++ b/backend/domain/conversation/message/service/message.go @@ -27,6 +27,7 @@ type Message interface { ListWithoutPair(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error) PreCreate(ctx context.Context, req *entity.Message) (*entity.Message, error) Create(ctx context.Context, req *entity.Message) (*entity.Message, error) + BatchCreate(ctx context.Context, req []*entity.Message) ([]*entity.Message, error) GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error) GetByID(ctx context.Context, id int64) (*entity.Message, error) Edit(ctx context.Context, req *entity.Message) (*entity.Message, error) diff --git a/backend/domain/conversation/message/service/message_impl.go b/backend/domain/conversation/message/service/message_impl.go index 833ac2d3..37615725 100644 --- a/backend/domain/conversation/message/service/message_impl.go +++ b/backend/domain/conversation/message/service/message_impl.go @@ -124,6 +124,10 @@ func (m *messageImpl) GetByID(ctx context.Context, id int64) (*entity.Message, e return m.MessageRepo.GetByID(ctx, id) } +func (m *messageImpl) BatchCreate(ctx context.Context, req []*entity.Message) ([]*entity.Message, error) { + return m.MessageRepo.BatchCreate(ctx, req) +} + func (m *messageImpl) Broken(ctx context.Context, req *entity.BrokenMeta) error { _, err := m.MessageRepo.Edit(ctx, req.ID, &message.Message{ diff --git a/backend/domain/conversation/message/service/message_test.go b/backend/domain/conversation/message/service/message_test.go index 721e1ef1..04437182 100644 --- a/backend/domain/conversation/message/service/message_test.go +++ b/backend/domain/conversation/message/service/message_test.go @@ -494,3 +494,55 @@ func TestListWithoutPair(t *testing.T) { assert.Equal(t, "Answer message", resp.Messages[0].Content) }) } + +func TestBatchCreate(t *testing.T) { + ctx := context.Background() + mockDBGen := orm.NewMockDB() + mockDBGen.AddTable(&model.Message{}) + mockDB, err := mockDBGen.DB() + assert.NoError(t, err) + + components := &Components{ + MessageRepo: repository.NewMessageRepo(mockDB, nil), + } + + + t.Run("success_single_message", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // 准备测试数据 + inputMsgs := []*entity.Message{ + { + ID: 1, + ConversationID: 100, + RunID: 200, + AgentID: 300, + UserID: "user123", + Content: "Hello World", + Role: schema.User, + ContentType: message.ContentTypeText, + MessageType: message.MessageTypeQuestion, + Status: message.MessageStatusAvailable, + }, + { + ID: 2, + ConversationID: 100, + RunID: 200, + AgentID: 300, + UserID: "user123", + Content: "Hello World", + Role: schema.Assistant, + ContentType: message.ContentTypeText, + MessageType: message.MessageTypeQuestion, + Status: message.MessageStatusAvailable, + }, + } + + result, err := NewService(components).BatchCreate(ctx, inputMsgs) + + assert.NoError(t, err) + assert.Len(t, result, 2) + assert.Equal(t, inputMsgs[1].ID, result[1].ID) + }) +} diff --git a/backend/go.mod b/backend/go.mod index bfa9075b..ee91b93d 100755 --- a/backend/go.mod +++ b/backend/go.mod @@ -289,5 +289,6 @@ require ( github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/eino-contrib/jsonschema v1.0.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect )