From 4bfce5a8cb468ce4e09f5097afa2946ce23c4e6d Mon Sep 17 00:00:00 2001 From: lvxinyu-1117 Date: Wed, 10 Sep 2025 15:52:23 +0800 Subject: [PATCH] refactor(workflow): Calculate chat history rounds during schema convertion (#1990) Co-authored-by: zhuangjie.1125 --- .../api/handler/coze/workflow_service_test.go | 267 +++++++- backend/application/workflow/chatflow.go | 68 +- backend/application/workflow/chatflow_test.go | 604 ++++++++++++++++++ .../crossdomain/contract/message/message.go | 4 +- backend/crossdomain/contract/upload/upload.go | 1 + .../contract/upload/uploadmock/upload_mock.go | 73 +++ backend/crossdomain/impl/message/message.go | 4 +- .../domain/workflow/component_interface.go | 8 +- backend/domain/workflow/entity/vo/chatflow.go | 5 + .../domain/workflow/entity/vo/conversation.go | 4 +- .../internal/canvas/adaptor/from_node.go | 1 + .../internal/canvas/adaptor/to_schema.go | 2 +- .../chatflow/chat_run_with_interrupt.json | 275 ++++++++ .../canvas/examples/chatflow/llm_chat.json | 397 ++++++++++++ .../chatflow/llm_chat_with_history.json | 397 ++++++++++++ .../nodes/conversation/conversationhistory.go | 2 +- .../nodes/conversation/createconversation.go | 4 +- .../nodes/conversation/createmessage.go | 16 +- .../nodes/conversation/messagelist.go | 8 +- .../domain/workflow/internal/nodes/llm/llm.go | 8 +- .../domain/workflow/internal/nodes/node.go | 5 - .../internal/repo/conversation_repository.go | 28 +- .../workflow/internal/schema/node_schema.go | 5 + .../internal/schema/workflow_schema.go | 15 + .../workflow/service/conversation_impl.go | 22 +- .../workflow/service/executable_impl.go | 141 ++-- .../workflow/service/executable_impl_test.go | 286 +++++++++ .../domain/workflow/service/service_impl.go | 11 +- backend/domain/workflow/service/utils.go | 215 ------- 29 files changed, 2492 insertions(+), 384 deletions(-) create mode 100644 backend/application/workflow/chatflow_test.go create mode 100644 backend/crossdomain/contract/upload/uploadmock/upload_mock.go create mode 100644 backend/domain/workflow/internal/canvas/examples/chatflow/chat_run_with_interrupt.json create mode 100644 backend/domain/workflow/internal/canvas/examples/chatflow/llm_chat.json create mode 100644 backend/domain/workflow/internal/canvas/examples/chatflow/llm_chat_with_history.json create mode 100644 backend/domain/workflow/service/executable_impl_test.go diff --git a/backend/api/handler/coze/workflow_service_test.go b/backend/api/handler/coze/workflow_service_test.go index 3806995b..353975a3 100644 --- a/backend/api/handler/coze/workflow_service_test.go +++ b/backend/api/handler/coze/workflow_service_test.go @@ -228,6 +228,7 @@ func newWfTestRunner(t *testing.T) *wfTestRunner { h.POST("/api/workflow_api/chat_flow_role/delete", DeleteChatFlowRole) h.POST("/api/workflow_api/chat_flow_role/create", CreateChatFlowRole) h.GET("/api/workflow_api/chat_flow_role/get", GetChatFlowRole) + h.POST("/v1/workflows/chat", OpenAPIChatFlowRun) ctrl := gomock.NewController(t, gomock.WithOverridableExpectations()) mockIDGen := mock.NewMockIDGenerator(ctrl) @@ -1082,6 +1083,46 @@ func (r *wfTestRunner) openapiResume(id string, eventID string, resumeData strin return re } +func (r *wfTestRunner) openapiChatFlowRun(wfID string, cID, appID, botID *string, input any, additionalMessage []*workflow.EnterMessage) *sse.Reader { + inputStr, _ := sonic.MarshalString(input) + + req := &workflow.ChatFlowRunRequest{ + WorkflowID: wfID, + Parameters: ptr.Of(inputStr), + AdditionalMessages: additionalMessage, + } + if cID != nil { + req.ConversationID = cID + } + if appID != nil { + req.AppID = appID + } + if botID != nil { + req.BotID = botID + } + + m, err := sonic.Marshal(req) + assert.NoError(r.t, err) + + c, _ := client.NewClient() + hReq, hResp := protocol.AcquireRequest(), protocol.AcquireResponse() + hReq.SetRequestURI("http://localhost:8888" + "/v1/workflows/chat") + hReq.SetMethod("POST") + hReq.SetBody(m) + hReq.SetHeader("Content-Type", "application/json") + err = c.Do(context.Background(), hReq, hResp) + assert.NoError(r.t, err) + + if hResp.StatusCode() != http.StatusOK { + r.t.Errorf("unexpected status code: %d, body: %s", hResp.StatusCode(), string(hResp.Body())) + } + + re, err := sse.NewReader(hResp) + assert.NoError(r.t, err) + + return re +} + func (r *wfTestRunner) runServer() func() { go func() { _ = r.h.Run() @@ -5491,7 +5532,7 @@ func TestConversationOfChatFlow(t *testing.T) { if err != nil { return err } - if v.Name == "CONVERSATION_NAME" { + if v.Name == vo.ConversationNameKey { v.DefaultValue = cName } startNode.Data.Outputs[idx] = v @@ -5522,7 +5563,7 @@ func TestConversationOfChatFlow(t *testing.T) { for _, vAny := range node.Data.Outputs { v, err := vo.ParseVariable(vAny) assert.NoError(t, err) - if v.Name == "CONVERSATION_NAME" { + if v.Name == vo.ConversationNameKey { assert.Equal(t, v.DefaultValue, updateName) } } @@ -5569,7 +5610,7 @@ func TestConversationOfChatFlow(t *testing.T) { for _, vAny := range node.Data.Outputs { v, err := vo.ParseVariable(vAny) assert.NoError(t, err) - if v.Name == "CONVERSATION_NAME" { + if v.Name == vo.ConversationNameKey { assert.Equal(t, v.DefaultValue, cName+"copy") } } @@ -5988,3 +6029,223 @@ func TestConversationHistoryNodes(t *testing.T) { assert.Equal(t, []any{}, outputMap["history_list"]) }) } + +func TestChatFlowRun(t *testing.T) { + mockey.PatchConvey("chat flow run", t, func() { + r := newWfTestRunner(t) + appworkflow.SVC.IDGenerator = r.idGen + defer r.closeFn() + defer r.runServer()() + + chatModel1 := &testutil.UTChatModel{ + StreamResultProvider: func(_ int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) { + sr := schema.StreamReaderFromArray([]*schema.Message{ + { + Role: schema.Assistant, + Content: "I ", + }, + { + Role: schema.Assistant, + Content: "don't know.", + }, + }) + return sr, nil + }, + } + r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel1, nil, nil).AnyTimes() + + id := r.load("chatflow/llm_chat.json", withMode(workflow.WorkflowMode_ChatFlow)) + r.publish(id, "v0.0.1", true) + cID := time.Now().UnixNano() + cIDStr := strconv.FormatInt(cID, 10) + appID := time.Now().UnixNano() + appIDStr := strconv.FormatInt(appID, 10) + + // Create conversation first + r.conversation.EXPECT().CreateConversation(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{ + ID: cID, + }, nil).AnyTimes() + idStr := r.load("conversation_manager/update_dynamic_conversation.json") + r.publish(idStr, "v0.0.1", true) + ret, _ := r.openapiSyncRun(idStr, map[string]string{ + "input": "v1", + "new_name": "v2", + }, withRunProjectID(appID)) + assert.Equal(t, map[string]any{"conversationId": strconv.FormatInt(cID, 10), "isExisted": false, "isSuccess": true}, ret["obj"]) + + msg := []*workflow.EnterMessage{ + { + Role: "user", + ContentType: "text", + Content: "你好", + }, + } + sID := time.Now().UnixNano() + r.conversation.EXPECT().GetByID(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{ + ID: cID, + SectionID: sID, + }, nil).AnyTimes() + rID := time.Now().UnixNano() + r.agentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{ + ID: rID, + }, nil).AnyTimes() + mID := time.Now().Unix() + r.message.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&message.Message{ + ID: mID, + }, nil).AnyTimes() + + t.Run("chat flow run in app", func(t *testing.T) { + sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{ + vo.ConversationNameKey: "Default", + }, msg) + err := sseReader.ForEach(t.Context(), func(e *sse.Event) error { + t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data)) + return nil + }) + assert.NoError(t, err) + }) + + t.Run("chat flow run in bot", func(t *testing.T) { + botID := time.Now().UnixNano() + botIDStr := strconv.FormatInt(botID, 10) + sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), nil, ptr.Of(botIDStr), map[string]any{ + vo.ConversationNameKey: "Default", + }, msg) + err := sseReader.ForEach(t.Context(), func(e *sse.Event) error { + t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data)) + return nil + }) + assert.NoError(t, err) + }) + + t.Run("chat flow run without cID", func(t *testing.T) { + sseReader := r.openapiChatFlowRun(id, nil, ptr.Of(appIDStr), nil, map[string]any{ + vo.ConversationNameKey: "Default", + }, msg) + err := sseReader.ForEach(t.Context(), func(e *sse.Event) error { + t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data)) + return nil + }) + assert.NoError(t, err) + }) + + t.Run("chat flow run with additional messages", func(t *testing.T) { + additionalMsg := []*workflow.EnterMessage{ + { + Role: "user", + ContentType: "text", + Content: "你好, 我叫小明", + }, + { + Role: "assistant", + ContentType: "text", + Content: "你好小明, 很高兴认识你", + }, + { + Role: "user", + ContentType: "text", + Content: "你好", + }, + } + sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{ + vo.ConversationNameKey: "Default", + }, additionalMsg) + err := sseReader.ForEach(t.Context(), func(e *sse.Event) error { + t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data)) + return nil + }) + assert.NoError(t, err) + }) + + t.Run("chat flow run with history messages", func(t *testing.T) { + id := r.load("chatflow/llm_chat_with_history.json", withMode(workflow.WorkflowMode_ChatFlow)) + r.publish(id, "v0.0.1", true) + r.message.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{rID}, nil).AnyTimes() + r.message.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(&message0.GetMessagesByRunIDsResponse{ + Messages: []*message0.WfMessage{ + { + ID: mID, + Role: schema.User, + Text: ptr.Of("你好"), + }, + }, + }, nil).AnyTimes() + sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{ + vo.ConversationNameKey: "Default", + }, msg) + err := sseReader.ForEach(t.Context(), func(e *sse.Event) error { + t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data)) + return nil + }) + assert.NoError(t, err) + }) + + t.Run("chat flow run with interrupt nodes ", func(t *testing.T) { + // 生成一个携带 input, 问答文本 问答选项的三个中断节点 做测试 + id := r.load("chatflow/chat_run_with_interrupt.json", withMode(workflow.WorkflowMode_ChatFlow)) + r.publish(id, "v0.0.1", true) + sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{ + vo.ConversationNameKey: "Default", + }, msg) + + err := sseReader.ForEach(t.Context(), func(e *sse.Event) error { + t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data)) + if e.ID == "3" { + assert.Equal(t, e.Type, "conversation.message.completed") + assert.Contains(t, string(e.Data), "7383997384420262000") + } + return nil + }) + assert.NoError(t, err) + + sseReader = r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{ + vo.ConversationNameKey: "Default", + }, []*workflow.EnterMessage{ + {Role: string(schema.User), Content: "input:1", ContentType: "text"}, + }) + + err = sseReader.ForEach(t.Context(), func(e *sse.Event) error { + t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data)) + if e.ID == "4" { + assert.Equal(t, e.Type, "conversation.message.completed") + assert.Contains(t, string(e.Data), "你好") + } + return nil + }) + assert.NoError(t, err) + + sseReader = r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{ + vo.ConversationNameKey: "Default", + }, []*workflow.EnterMessage{ + {Role: string(schema.User), Content: "hello", ContentType: "text"}, + }) + + err = sseReader.ForEach(t.Context(), func(e *sse.Event) error { + if e.ID == "3" { + assert.Equal(t, e.Type, "conversation.message.completed") + assert.Contains(t, string(e.Data), "question_card_data", "请选择") + } + return nil + }) + assert.NoError(t, err) + + sseReader = r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{ + vo.ConversationNameKey: "Default", + }, []*workflow.EnterMessage{ + {Role: string(schema.User), Content: "A", ContentType: "text"}, + }) + err = sseReader.ForEach(t.Context(), func(e *sse.Event) error { + t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data)) + + if e.ID == "4" { + assert.Equal(t, e.Type, "conversation.message.completed") + assert.Contains(t, string(e.Data), "answer", "A") + } + return nil + }) + assert.NoError(t, err) + + }) + + }) +} diff --git a/backend/application/workflow/chatflow.go b/backend/application/workflow/chatflow.go index ef9d4325..13728249 100644 --- a/backend/application/workflow/chatflow.go +++ b/backend/application/workflow/chatflow.go @@ -221,7 +221,7 @@ const ( "id": "5fJt3qKpSz", "name": "list", "defaultValue": [ - + ] } }, @@ -504,7 +504,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl workflowID = mustParseInt64(req.GetWorkflowID()) isDebug = req.GetExecuteMode() == "DEBUG" appID, agentID *int64 - resolveAppID int64 + bizID int64 conversationID int64 sectionID int64 version string @@ -521,11 +521,11 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl if req.IsSetAppID() { appID = ptr.Of(mustParseInt64(req.GetAppID())) - resolveAppID = mustParseInt64(req.GetAppID()) + bizID = mustParseInt64(req.GetAppID()) } if req.IsSetBotID() { agentID = ptr.Of(mustParseInt64(req.GetBotID())) - resolveAppID = mustParseInt64(req.GetBotID()) + bizID = mustParseInt64(req.GetBotID()) } if appID != nil && agentID != nil { @@ -564,16 +564,16 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl sectionID = cInfo.SectionID // only trust the conversation name under the app - conversationName, existed, err := GetWorkflowDomainSVC().GetConversationNameByID(ctx, ternary.IFElse(isDebug, vo.Draft, vo.Online), resolveAppID, connectorID, conversationID) + conversationName, existed, err := GetWorkflowDomainSVC().GetConversationNameByID(ctx, ternary.IFElse(isDebug, vo.Draft, vo.Online), bizID, connectorID, conversationID) if err != nil { return nil, err } if !existed { return nil, fmt.Errorf("conversation not found") } - parameters["CONVERSATION_NAME"] = conversationName + parameters[vo.ConversationNameKey] = conversationName } else if req.IsSetConversationID() && req.IsSetBotID() { - parameters["CONVERSATION_NAME"] = "Default" + parameters[vo.ConversationNameKey] = "Default" conversationID = mustParseInt64(req.GetConversationID()) cInfo, err := crossconversation.DefaultSVC().GetByID(ctx, conversationID) if err != nil { @@ -581,11 +581,11 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl } sectionID = cInfo.SectionID } else { - conversationName, ok := parameters["CONVERSATION_NAME"].(string) + conversationName, ok := parameters[vo.ConversationNameKey].(string) if !ok { return nil, fmt.Errorf("conversation name is requried") } - cID, sID, err := GetWorkflowDomainSVC().GetOrCreateConversation(ctx, ternary.IFElse(isDebug, vo.Draft, vo.Online), resolveAppID, connectorID, userID, conversationName) + cID, sID, err := GetWorkflowDomainSVC().GetOrCreateConversation(ctx, ternary.IFElse(isDebug, vo.Draft, vo.Online), bizID, connectorID, userID, conversationName) if err != nil { return nil, err } @@ -594,7 +594,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl } runRecord, err := crossagentrun.DefaultSVC().Create(ctx, &agententity.AgentRunMeta{ - AgentID: resolveAppID, + AgentID: bizID, ConversationID: conversationID, UserID: strconv.FormatInt(userID, 10), ConnectorID: connectorID, @@ -606,7 +606,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl roundID := runRecord.ID - userMessage, err := toConversationMessage(ctx, resolveAppID, conversationID, userID, roundID, sectionID, message.MessageTypeQuestion, lastUserMessage) + userMessage, err := toConversationMessage(ctx, bizID, conversationID, userID, roundID, sectionID, message.MessageTypeQuestion, lastUserMessage) if err != nil { return nil, err } @@ -648,7 +648,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl return nil, err } return schema.StreamReaderWithConvert(sr, w.convertToChatFlowRunResponseList(ctx, convertToChatFlowInfo{ - appID: resolveAppID, + bizID: bizID, conversationID: conversationID, roundID: roundID, workflowID: workflowID, @@ -684,7 +684,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl Cancellable: isDebug, } - historyMessages, err := makeChatFlowHistoryMessages(ctx, resolveAppID, conversationID, userID, sectionID, connectorID, messages[:len(req.GetAdditionalMessages())-1]) + historyMessages, err := makeChatFlowHistoryMessages(ctx, bizID, conversationID, userID, sectionID, connectorID, messages[:len(req.GetAdditionalMessages())-1]) if err != nil { return nil, err } @@ -706,7 +706,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl logs.CtxWarnf(ctx, "create history message failed, err=%v", err) } } - parameters["USER_INPUT"], err = w.makeChatFlowUserInput(ctx, lastUserMessage) + parameters[vo.UserInputKey], err = w.makeChatFlowUserInput(ctx, lastUserMessage) if err != nil { return nil, err } @@ -717,7 +717,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl } return schema.StreamReaderWithConvert(sr, w.convertToChatFlowRunResponseList(ctx, convertToChatFlowInfo{ - appID: resolveAppID, + bizID: bizID, conversationID: conversationID, roundID: roundID, workflowID: workflowID, @@ -731,7 +731,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Context, info convertToChatFlowInfo) func(msg *entity.Message) (responses []*workflow.ChatFlowRunResponse, err error) { var ( - appID = info.appID + bizID = info.bizID conversationID = info.conversationID roundID = info.roundID workflowID = info.workflowID @@ -798,7 +798,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex ChatID: strconv.FormatInt(roundID, 10), ConversationID: strconv.FormatInt(conversationID, 10), SectionID: strconv.FormatInt(sectionID, 10), - BotID: strconv.FormatInt(appID, 10), + BotID: strconv.FormatInt(bizID, 10), Role: string(schema.Assistant), Type: "follow_up", ContentType: "text", @@ -815,7 +815,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex ID: strconv.FormatInt(roundID, 10), ConversationID: strconv.FormatInt(conversationID, 10), SectionID: strconv.FormatInt(sectionID, 10), - BotID: strconv.FormatInt(appID, 10), + BotID: strconv.FormatInt(bizID, 10), Status: vo.Completed, ExecuteID: strconv.FormatInt(executeID, 10), Usage: &vo.Usage{ @@ -929,7 +929,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex } _, err = crossmessage.DefaultSVC().Create(ctx, &message.Message{ - AgentID: appID, + AgentID: bizID, RunID: roundID, SectionID: sectionID, Content: msgContent, @@ -947,7 +947,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex ChatID: strconv.FormatInt(roundID, 10), ConversationID: strconv.FormatInt(conversationID, 10), SectionID: strconv.FormatInt(sectionID, 10), - BotID: strconv.FormatInt(appID, 10), + BotID: strconv.FormatInt(bizID, 10), Role: string(schema.Assistant), Type: string(entity.Answer), ContentType: string(contentType), @@ -1046,7 +1046,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex } intermediateMessage = &message.Message{ ID: id, - AgentID: appID, + AgentID: bizID, RunID: roundID, SectionID: sectionID, ConversationID: conversationID, @@ -1066,7 +1066,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex ChatID: strconv.FormatInt(roundID, 10), ConversationID: strconv.FormatInt(conversationID, 10), SectionID: strconv.FormatInt(sectionID, 10), - BotID: strconv.FormatInt(appID, 10), + BotID: strconv.FormatInt(bizID, 10), Role: string(dataMessage.Role), Type: string(dataMessage.Type), ContentType: string(message.ContentTypeText), @@ -1092,7 +1092,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex ChatID: strconv.FormatInt(roundID, 10), ConversationID: strconv.FormatInt(conversationID, 10), SectionID: strconv.FormatInt(sectionID, 10), - BotID: strconv.FormatInt(appID, 10), + BotID: strconv.FormatInt(bizID, 10), Role: string(dataMessage.Role), Type: string(dataMessage.Type), ContentType: string(message.ContentTypeText), @@ -1155,9 +1155,9 @@ func (w *ApplicationService) makeChatFlowUserInput(ctx context.Context, message } else { return "", fmt.Errorf("invalid message ccontent type %v", message.ContentType) } - } -func makeChatFlowHistoryMessages(ctx context.Context, appID, conversationID, userID, sectionID, connectorID int64, messages []*workflow.EnterMessage) ([]*message.Message, error) { + +func makeChatFlowHistoryMessages(ctx context.Context, bizID, conversationID, userID, sectionID, connectorID int64, messages []*workflow.EnterMessage) ([]*message.Message, error) { var ( rID int64 @@ -1170,7 +1170,7 @@ func makeChatFlowHistoryMessages(ctx context.Context, appID, conversationID, use for _, msg := range messages { if msg.Role == userRole { runRecord, err = crossagentrun.DefaultSVC().Create(ctx, &agententity.AgentRunMeta{ - AgentID: appID, + AgentID: bizID, ConversationID: conversationID, UserID: strconv.FormatInt(userID, 10), ConnectorID: connectorID, @@ -1180,13 +1180,15 @@ func makeChatFlowHistoryMessages(ctx context.Context, appID, conversationID, use return nil, err } rID = runRecord.ID - } else if msg.Role == assistantRole && rID == 0 { - continue + } else if msg.Role == assistantRole { + if rID == 0 { + continue + } } else { return nil, fmt.Errorf("invalid role type %v", msg.Role) } - m, err := toConversationMessage(ctx, appID, conversationID, userID, rID, sectionID, ternary.IFElse(msg.Role == userRole, message.MessageTypeQuestion, message.MessageTypeAnswer), msg) + m, err := toConversationMessage(ctx, bizID, conversationID, userID, rID, sectionID, ternary.IFElse(msg.Role == userRole, message.MessageTypeQuestion, message.MessageTypeAnswer), msg) if err != nil { return nil, err } @@ -1274,7 +1276,7 @@ func (w *ApplicationService) OpenAPICreateConversation(ctx context.Context, req }, nil } -func toConversationMessage(ctx context.Context, appID, cid, userID, roundID, sectionID int64, messageType message.MessageType, msg *workflow.EnterMessage) (*message.Message, error) { +func toConversationMessage(ctx context.Context, bizID, cid, userID, roundID, sectionID int64, messageType message.MessageType, msg *workflow.EnterMessage) (*message.Message, error) { type content struct { Type string `json:"type"` FileID *string `json:"file_id"` @@ -1284,7 +1286,7 @@ func toConversationMessage(ctx context.Context, appID, cid, userID, roundID, sec return &message.Message{ Role: schema.User, ConversationID: cid, - AgentID: appID, + AgentID: bizID, RunID: roundID, Content: msg.Content, ContentType: message.ContentTypeText, @@ -1304,7 +1306,7 @@ func toConversationMessage(ctx context.Context, appID, cid, userID, roundID, sec Role: schema.User, MessageType: messageType, ConversationID: cid, - AgentID: appID, + AgentID: bizID, UserID: strconv.FormatInt(userID, 10), RunID: roundID, ContentType: message.ContentTypeMix, @@ -1432,7 +1434,7 @@ func toSchemaMessage(ctx context.Context, msg *workflow.EnterMessage) (*schema.M type convertToChatFlowInfo struct { userMessage *schema.Message - appID int64 + bizID int64 conversationID int64 roundID int64 workflowID int64 diff --git a/backend/application/workflow/chatflow_test.go b/backend/application/workflow/chatflow_test.go new file mode 100644 index 00000000..1bcd4a5b --- /dev/null +++ b/backend/application/workflow/chatflow_test.go @@ -0,0 +1,604 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package workflow + +import ( + "context" + "errors" + "strconv" + "testing" + + "github.com/cloudwego/eino/schema" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + messageentity "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message" + "github.com/coze-dev/coze-studio/backend/api/model/workflow" + crossagentrun "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agentrun" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agentrun/agentrunmock" + crossupload "github.com/coze-dev/coze-studio/backend/crossdomain/contract/upload" + "github.com/coze-dev/coze-studio/backend/crossdomain/contract/upload/uploadmock" + agententity "github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity" + uploadentity "github.com/coze-dev/coze-studio/backend/domain/upload/entity" + "github.com/coze-dev/coze-studio/backend/domain/upload/service" +) + +func TestApplicationService_makeChatFlowUserInput(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockUpload := uploadmock.NewMockUploader(ctrl) + crossupload.SetDefaultSVC(mockUpload) + + tests := []struct { + name string + message *workflow.EnterMessage + setupMock func() + expected string + expectErr bool + }{ + { + name: "content type text", + message: &workflow.EnterMessage{ + ContentType: "text", + Content: "hello", + }, + setupMock: func() {}, + expected: "hello", + expectErr: false, + }, + { + name: "content type object_string with text", + message: &workflow.EnterMessage{ + ContentType: "object_string", + Content: `[{"type": "text", "text": "hello world"}]`, + }, + setupMock: func() {}, + expected: "hello world", + expectErr: false, + }, + { + name: "content type object_string with file", + message: &workflow.EnterMessage{ + ContentType: "object_string", + Content: `[{"type": "file", "file_id": "123"}]`, + }, + setupMock: func() { + mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{ + File: &uploadentity.File{Url: "https://example.com/file"}, + }, nil) + }, + expected: "https://example.com/file", + expectErr: false, + }, + { + name: "content type object_string with text and file", + message: &workflow.EnterMessage{ + ContentType: "object_string", + Content: `[{"type": "text", "text": "see this file"}, {"type": "file", "file_id": "123"}]`, + }, + setupMock: func() { + mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{ + File: &uploadentity.File{Url: "https://example.com/file"}, + }, nil) + }, + expected: "see this file,https://example.com/file", + expectErr: false, + }, + { + name: "get file error", + message: &workflow.EnterMessage{ + ContentType: "object_string", + Content: `[{"type": "file", "file_id": "123"}]`, + }, + setupMock: func() { + mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(nil, errors.New("get file error")) + }, + expectErr: true, + }, + { + name: "file not found", + message: &workflow.EnterMessage{ + ContentType: "object_string", + Content: `[{"type": "file", "file_id": "123"}]`, + }, + setupMock: func() { + mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{ + File: nil, + }, nil) + }, + expectErr: true, + }, + { + name: "invalid content type", + message: &workflow.EnterMessage{ + ContentType: "invalid", + }, + setupMock: func() {}, + expectErr: true, + }, + { + name: "invalid json", + message: &workflow.EnterMessage{ + ContentType: "object_string", + Content: `invalid-json`, + }, + setupMock: func() {}, + expectErr: true, + }, + } + + w := &ApplicationService{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupMock() + result, err := w.makeChatFlowUserInput(ctx, tt.message) + if tt.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func Test_toConversationMessage(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockUpload := uploadmock.NewMockUploader(ctrl) + crossupload.SetDefaultSVC(mockUpload) + + bizID, cid, userID, roundID, sectionID := int64(2), int64(1), int64(4), int64(3), int64(5) + + tests := []struct { + name string + msg *workflow.EnterMessage + messageType messageentity.MessageType + setupMock func() + expected *messageentity.Message + expectErr bool + }{ + { + name: "content type text", + msg: &workflow.EnterMessage{ + ContentType: "text", + Content: "hello", + }, + messageType: messageentity.MessageTypeQuestion, + setupMock: func() {}, + expected: &messageentity.Message{ + Role: schema.User, + ConversationID: cid, + AgentID: bizID, + RunID: roundID, + Content: "hello", + ContentType: messageentity.ContentTypeText, + MessageType: messageentity.MessageTypeQuestion, + UserID: strconv.FormatInt(userID, 10), + SectionID: sectionID, + }, + expectErr: false, + }, + { + name: "content type object_string with text", + msg: &workflow.EnterMessage{ + ContentType: "object_string", + Content: `[{"type": "text", "text": "hello"}]`, + }, + messageType: messageentity.MessageTypeQuestion, + setupMock: func() {}, + expected: &messageentity.Message{ + Role: schema.User, + MessageType: messageentity.MessageTypeQuestion, + ConversationID: cid, + AgentID: bizID, + UserID: strconv.FormatInt(userID, 10), + RunID: roundID, + ContentType: messageentity.ContentTypeMix, + MultiContent: []*messageentity.InputMetaData{ + {Type: messageentity.InputTypeText, Text: "hello"}, + }, + SectionID: sectionID, + }, + expectErr: false, + }, + { + name: "content type object_string with file", + msg: &workflow.EnterMessage{ + ContentType: "object_string", + Content: `[{"type": "file", "file_id": "123"}]`, + }, + messageType: messageentity.MessageTypeQuestion, + setupMock: func() { + mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{ + File: &uploadentity.File{Url: "https://example.com/file", TosURI: "tos://uri", Name: "file.txt"}, + }, nil) + }, + expected: &messageentity.Message{ + Role: schema.User, + MessageType: messageentity.MessageTypeQuestion, + ConversationID: cid, + AgentID: bizID, + UserID: strconv.FormatInt(userID, 10), + RunID: roundID, + ContentType: messageentity.ContentTypeMix, + MultiContent: []*messageentity.InputMetaData{ + { + Type: "file", + FileData: []*messageentity.FileData{ + {Url: "https://example.com/file", URI: "tos://uri", Name: "file.txt"}, + }, + }, + }, + SectionID: sectionID, + }, + expectErr: false, + }, + { + name: "get file error", + msg: &workflow.EnterMessage{ + ContentType: "object_string", + Content: `[{"type": "file", "file_id": "123"}]`, + }, + setupMock: func() { + mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(nil, errors.New("get file error")) + }, + expectErr: true, + }, + { + name: "file not found", + msg: &workflow.EnterMessage{ + ContentType: "object_string", + Content: `[{"type": "file", "file_id": "123"}]`, + }, + setupMock: func() { + mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{}, nil) + }, + expectErr: true, + }, + { + name: "invalid content type", + msg: &workflow.EnterMessage{ + ContentType: "invalid", + }, + setupMock: func() {}, + expectErr: true, + }, + { + name: "invalid json", + msg: &workflow.EnterMessage{ + ContentType: "object_string", + Content: "invalid-json", + }, + setupMock: func() {}, + expectErr: true, + }, + { + name: "invalid input type", + msg: &workflow.EnterMessage{ + ContentType: "object_string", + Content: `[{"type": "invalid"}]`, + }, + setupMock: func() {}, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupMock() + result, err := toConversationMessage(ctx, bizID, cid, userID, roundID, sectionID, tt.messageType, tt.msg) + if tt.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func Test_toSchemaMessage(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockUpload := uploadmock.NewMockUploader(ctrl) + crossupload.SetDefaultSVC(mockUpload) + + tests := []struct { + name string + msg *workflow.EnterMessage + setupMock func() + expected *schema.Message + expectErr bool + }{ + { + name: "content type text", + msg: &workflow.EnterMessage{ + ContentType: "text", + Content: "hello", + }, + setupMock: func() {}, + expected: &schema.Message{ + Role: schema.User, + Content: "hello", + }, + expectErr: false, + }, + { + name: "content type object_string with text", + msg: &workflow.EnterMessage{ + ContentType: "object_string", + Content: `[{"type": "text", "text": "hello"}]`, + }, + setupMock: func() {}, + expected: &schema.Message{ + Role: schema.User, + MultiContent: []schema.ChatMessagePart{ + {Type: schema.ChatMessagePartTypeText, Text: "hello"}, + }, + }, + expectErr: false, + }, + { + name: "content type object_string with image", + msg: &workflow.EnterMessage{ + ContentType: "object_string", + Content: `[{"type": "image", "file_id": "123"}]`, + }, + setupMock: func() { + mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{ + File: &uploadentity.File{Url: "https://example.com/image.png"}, + }, nil) + }, + expected: &schema.Message{ + Role: schema.User, + MultiContent: []schema.ChatMessagePart{ + { + Type: schema.ChatMessagePartTypeImageURL, + ImageURL: &schema.ChatMessageImageURL{URL: "https://example.com/image.png"}, + }, + }, + }, + expectErr: false, + }, + { + name: "content type object_string with various file types", + msg: &workflow.EnterMessage{ + ContentType: "object_string", + Content: `[{"type": "file", "file_id": "1"}, {"type": "audio", "file_id": "2"}, {"type": "video", "file_id": "3"}]`, + }, + setupMock: func() { + mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 1}).Return(&service.GetFileResponse{File: &uploadentity.File{Url: "https://example.com/file"}}, nil) + mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 2}).Return(&service.GetFileResponse{File: &uploadentity.File{Url: "https://example.com/audio"}}, nil) + mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 3}).Return(&service.GetFileResponse{File: &uploadentity.File{Url: "https://example.com/video"}}, nil) + }, + expected: &schema.Message{ + Role: schema.User, + MultiContent: []schema.ChatMessagePart{ + {Type: schema.ChatMessagePartTypeFileURL, FileURL: &schema.ChatMessageFileURL{URL: "https://example.com/file"}}, + {Type: schema.ChatMessagePartTypeAudioURL, AudioURL: &schema.ChatMessageAudioURL{URL: "https://example.com/audio"}}, + {Type: schema.ChatMessagePartTypeVideoURL, VideoURL: &schema.ChatMessageVideoURL{URL: "https://example.com/video"}}, + }, + }, + expectErr: false, + }, + { + name: "get file error", + msg: &workflow.EnterMessage{ + ContentType: "object_string", + Content: `[{"type": "file", "file_id": "123"}]`, + }, + setupMock: func() { + mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(nil, errors.New("get file error")) + }, + expectErr: true, + }, + { + name: "file not found", + msg: &workflow.EnterMessage{ + ContentType: "object_string", + Content: `[{"type": "file", "file_id": "123"}]`, + }, + setupMock: func() { + mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{}, nil) + }, + expectErr: true, + }, + { + name: "invalid content type", + msg: &workflow.EnterMessage{ + ContentType: "invalid", + }, + setupMock: func() {}, + expectErr: true, + }, + { + name: "invalid json", + msg: &workflow.EnterMessage{ + ContentType: "object_string", + Content: "invalid-json", + }, + setupMock: func() {}, + expectErr: true, + }, + { + name: "invalid input type", + msg: &workflow.EnterMessage{ + ContentType: "object_string", + Content: `[{"type": "invalid"}]`, + }, + setupMock: func() {}, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupMock() + result, err := toSchemaMessage(ctx, tt.msg) + if tt.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func Test_makeChatFlowHistoryMessages(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAgentRun := agentrunmock.NewMockAgentRun(ctrl) + crossagentrun.SetDefaultSVC(mockAgentRun) + mockUpload := uploadmock.NewMockUploader(ctrl) + crossupload.SetDefaultSVC(mockUpload) + + bizID, conversationID, userID, sectionID, connectorID := int64(2), int64(1), int64(3), int64(4), int64(5) + + tests := []struct { + name string + messages []*workflow.EnterMessage + setupMock func() + expected []*messageentity.Message + expectErr bool + }{ + { + name: "empty messages", + messages: []*workflow.EnterMessage{}, + setupMock: func() {}, + expected: []*messageentity.Message{}, + expectErr: false, + }, + { + name: "one user message", + messages: []*workflow.EnterMessage{ + {Role: "user", ContentType: "text", Content: "hello"}, + }, + setupMock: func() { + mockAgentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{ID: 100}, nil).Times(1) + }, + expected: []*messageentity.Message{ + { + Role: schema.User, + ConversationID: conversationID, + AgentID: bizID, + RunID: 100, + Content: "hello", + ContentType: messageentity.ContentTypeText, + MessageType: messageentity.MessageTypeQuestion, + UserID: strconv.FormatInt(userID, 10), + SectionID: sectionID, + }, + }, + expectErr: false, + }, + { + name: "user and assistant message", + messages: []*workflow.EnterMessage{ + {Role: "user", ContentType: "text", Content: "hello"}, + {Role: "assistant", ContentType: "text", Content: "hi"}, + }, + setupMock: func() { + mockAgentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{ID: 100}, nil).Times(1) + }, + expected: []*messageentity.Message{ + { + Role: schema.User, + ConversationID: conversationID, + AgentID: bizID, + RunID: 100, + Content: "hello", + ContentType: messageentity.ContentTypeText, + MessageType: messageentity.MessageTypeQuestion, + UserID: strconv.FormatInt(userID, 10), + SectionID: sectionID, + }, + { + Role: schema.User, + ConversationID: conversationID, + AgentID: bizID, + RunID: 100, + Content: "hi", + ContentType: messageentity.ContentTypeText, + MessageType: messageentity.MessageTypeAnswer, + UserID: strconv.FormatInt(userID, 10), + SectionID: sectionID, + }, + }, + expectErr: false, + }, + { + name: "only assistant message", + messages: []*workflow.EnterMessage{ + {Role: "assistant", ContentType: "text", Content: "hi"}, + }, + setupMock: func() {}, + expected: []*messageentity.Message{}, + expectErr: false, + }, + { + name: "create run record error", + messages: []*workflow.EnterMessage{ + {Role: "user", ContentType: "text", Content: "hello"}, + }, + setupMock: func() { + mockAgentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(nil, errors.New("db error")) + }, + expectErr: true, + }, + { + name: "invalid role", + messages: []*workflow.EnterMessage{ + {Role: "system", ContentType: "text", Content: "hello"}, + }, + setupMock: func() {}, + expectErr: true, + }, + { + name: "toConversationMessage error", + messages: []*workflow.EnterMessage{ + {Role: "user", ContentType: "invalid", Content: "hello"}, + }, + setupMock: func() { + mockAgentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{ID: 100}, nil) + }, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupMock() + result, err := makeChatFlowHistoryMessages(ctx, bizID, conversationID, userID, sectionID, connectorID, tt.messages) + if tt.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} diff --git a/backend/crossdomain/contract/message/message.go b/backend/crossdomain/contract/message/message.go index 0a3d564e..20ad431c 100644 --- a/backend/crossdomain/contract/message/message.go +++ b/backend/crossdomain/contract/message/message.go @@ -58,7 +58,7 @@ type MessageListRequest struct { BeforeID *string AfterID *string UserID int64 - AppID int64 + BizID int64 OrderBy *string } @@ -88,7 +88,7 @@ type WfMessage struct { type GetLatestRunIDsRequest struct { ConversationID int64 UserID int64 - AppID int64 + BizID int64 Rounds int64 SectionID int64 InitRunID *int64 diff --git a/backend/crossdomain/contract/upload/upload.go b/backend/crossdomain/contract/upload/upload.go index 1c29e1e5..2be76435 100644 --- a/backend/crossdomain/contract/upload/upload.go +++ b/backend/crossdomain/contract/upload/upload.go @@ -24,6 +24,7 @@ import ( var defaultSVC Uploader +//go:generate mockgen -destination uploadmock/upload_mock.go --package uploadmock -source upload.go type Uploader interface { GetFile(ctx context.Context, req *service.GetFileRequest) (resp *service.GetFileResponse, err error) } diff --git a/backend/crossdomain/contract/upload/uploadmock/upload_mock.go b/backend/crossdomain/contract/upload/uploadmock/upload_mock.go new file mode 100644 index 00000000..f094a580 --- /dev/null +++ b/backend/crossdomain/contract/upload/uploadmock/upload_mock.go @@ -0,0 +1,73 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Code generated by MockGen. DO NOT EDIT. +// Source: upload.go +// +// Generated by this command: +// +// mockgen -destination uploadmock/upload_mock.go --package uploadmock -source upload.go +// + +// Package uploadmock is a generated GoMock package. +package uploadmock + +import ( + context "context" + reflect "reflect" + + service "github.com/coze-dev/coze-studio/backend/domain/upload/service" + gomock "go.uber.org/mock/gomock" +) + +// MockUploader is a mock of Uploader interface. +type MockUploader struct { + ctrl *gomock.Controller + recorder *MockUploaderMockRecorder + isgomock struct{} +} + +// MockUploaderMockRecorder is the mock recorder for MockUploader. +type MockUploaderMockRecorder struct { + mock *MockUploader +} + +// NewMockUploader creates a new mock instance. +func NewMockUploader(ctrl *gomock.Controller) *MockUploader { + mock := &MockUploader{ctrl: ctrl} + mock.recorder = &MockUploaderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockUploader) EXPECT() *MockUploaderMockRecorder { + return m.recorder +} + +// GetFile mocks base method. +func (m *MockUploader) GetFile(ctx context.Context, req *service.GetFileRequest) (*service.GetFileResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetFile", ctx, req) + ret0, _ := ret[0].(*service.GetFileResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetFile indicates an expected call of GetFile. +func (mr *MockUploaderMockRecorder) GetFile(ctx, req any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFile", reflect.TypeOf((*MockUploader)(nil).GetFile), ctx, req) +} diff --git a/backend/crossdomain/impl/message/message.go b/backend/crossdomain/impl/message/message.go index 9fd7fdf9..2bd0b217 100644 --- a/backend/crossdomain/impl/message/message.go +++ b/backend/crossdomain/impl/message/message.go @@ -54,7 +54,7 @@ func (c *impl) MessageList(ctx context.Context, req *crossmessage.MessageListReq ConversationID: req.ConversationID, Limit: int(req.Limit), // Since the value of limit is checked inside the node, the type cast here is safe UserID: strconv.FormatInt(req.UserID, 10), - AgentID: req.AppID, + AgentID: req.BizID, OrderBy: req.OrderBy, } if req.BeforeID != nil { @@ -96,7 +96,7 @@ func (c *impl) MessageList(ctx context.Context, req *crossmessage.MessageListReq func (c *impl) GetLatestRunIDs(ctx context.Context, req *crossmessage.GetLatestRunIDsRequest) ([]int64, error) { listMeta := &agententity.ListRunRecordMeta{ ConversationID: req.ConversationID, - AgentID: req.AppID, + AgentID: req.BizID, Limit: int32(req.Rounds), SectionID: req.SectionID, } diff --git a/backend/domain/workflow/component_interface.go b/backend/domain/workflow/component_interface.go index 9ebc924f..9e40bc3c 100644 --- a/backend/domain/workflow/component_interface.go +++ b/backend/domain/workflow/component_interface.go @@ -75,11 +75,11 @@ type Conversation interface { ListDynamicConversation(ctx context.Context, env vo.Env, policy *vo.ListConversationPolicy) ([]*entity.DynamicConversation, error) ReleaseConversationTemplate(ctx context.Context, appID int64, version string) error InitApplicationDefaultConversationTemplate(ctx context.Context, spaceID int64, appID int64, userID int64) error - GetOrCreateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, int64, error) + GetOrCreateConversation(ctx context.Context, env vo.Env, bizID, connectorID, userID int64, conversationName string) (int64, int64, error) UpdateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, error) GetTemplateByName(ctx context.Context, env vo.Env, appID int64, templateName string) (*entity.ConversationTemplate, bool, error) GetDynamicConversationByName(ctx context.Context, env vo.Env, appID, connectorID, userID int64, name string) (*entity.DynamicConversation, bool, error) - GetConversationNameByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (string, bool, error) + GetConversationNameByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (string, bool, error) } type InterruptEventStore interface { @@ -143,8 +143,8 @@ type ConversationRepository interface { UpdateStaticConversation(ctx context.Context, env vo.Env, templateID int64, connectorID int64, userID int64, newConversationID int64) error UpdateDynamicConversation(ctx context.Context, env vo.Env, conversationID, newConversationID int64) error CopyTemplateConversationByAppID(ctx context.Context, appID int64, toAppID int64) error - GetStaticConversationByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (string, bool, error) - GetDynamicConversationByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (*entity.DynamicConversation, bool, error) + GetStaticConversationByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (string, bool, error) + GetDynamicConversationByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (*entity.DynamicConversation, bool, error) } type WorkflowConfig interface { GetNodeOfCodeConfig() *config.NodeOfCodeConfig diff --git a/backend/domain/workflow/entity/vo/chatflow.go b/backend/domain/workflow/entity/vo/chatflow.go index 133b7a6f..4bb1c5e9 100644 --- a/backend/domain/workflow/entity/vo/chatflow.go +++ b/backend/domain/workflow/entity/vo/chatflow.go @@ -32,6 +32,11 @@ const ( ChatFlowMessageCompleted ChatFlowEvent = "conversation.message.completed" ) +const ( + ConversationNameKey = "CONVERSATION_NAME" + UserInputKey = "USER_INPUT" +) + type Usage struct { TokenCount *int32 `form:"token_count" json:"token_count,omitempty"` OutputTokens *int32 `form:"output_count" json:"output_count,omitempty"` diff --git a/backend/domain/workflow/entity/vo/conversation.go b/backend/domain/workflow/entity/vo/conversation.go index 1ed5c739..dbd238ad 100644 --- a/backend/domain/workflow/entity/vo/conversation.go +++ b/backend/domain/workflow/entity/vo/conversation.go @@ -59,14 +59,14 @@ type ListConversationPolicy struct { } type CreateStaticConversation struct { - AppID int64 + BizID int64 UserID int64 ConnectorID int64 TemplateID int64 } type CreateDynamicConversation struct { - AppID int64 + BizID int64 UserID int64 ConnectorID int64 diff --git a/backend/domain/workflow/internal/canvas/adaptor/from_node.go b/backend/domain/workflow/internal/canvas/adaptor/from_node.go index f93b45ed..3fdfe8e8 100644 --- a/backend/domain/workflow/internal/canvas/adaptor/from_node.go +++ b/backend/domain/workflow/internal/canvas/adaptor/from_node.go @@ -235,6 +235,7 @@ func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) ( if enabled { trimmedSC.GeneratedNodes = append(trimmedSC.GeneratedNodes, ns.Key) } + trimmedSC.Init() return trimmedSC, nil } diff --git a/backend/domain/workflow/internal/canvas/adaptor/to_schema.go b/backend/domain/workflow/internal/canvas/adaptor/to_schema.go index 63aa4790..d8d87561 100644 --- a/backend/domain/workflow/internal/canvas/adaptor/to_schema.go +++ b/backend/domain/workflow/internal/canvas/adaptor/to_schema.go @@ -446,7 +446,7 @@ func PruneIsolatedNodes(nodes []*vo.Node, edges []*vo.Edge, parentNode *vo.Node) func parseBatchMode(n *vo.Node) ( batchN *vo.Node, // the new batch node - enabled bool, // whether the node has enabled batch mode + enabled bool, // whether the node has enabled batch mode err error) { if n.Data == nil || n.Data.Inputs == nil { return nil, false, nil diff --git a/backend/domain/workflow/internal/canvas/examples/chatflow/chat_run_with_interrupt.json b/backend/domain/workflow/internal/canvas/examples/chatflow/chat_run_with_interrupt.json new file mode 100644 index 00000000..dcf277ac --- /dev/null +++ b/backend/domain/workflow/internal/canvas/examples/chatflow/chat_run_with_interrupt.json @@ -0,0 +1,275 @@ +{ + "nodes": [ + { + "id": "100001", + "type": "1", + "meta": { + "position": { + "x": 180, + "y": 79.2 + } + }, + "data": { + "outputs": [ + { + "type": "string", + "name": "USER_INPUT", + "required": false + }, + { + "type": "string", + "name": "CONVERSATION_NAME", + "required": false, + "description": "本次请求绑定的会话,会自动写入消息、会从该会话读对话历史。", + "defaultValue": "dhl" + } + ], + "nodeMeta": { + "title": "开始", + "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg", + "description": "工作流的起始节点,用于设定启动工作流需要的信息", + "subTitle": "" + }, + "trigger_parameters": [] + } + }, + { + "id": "900001", + "type": "2", + "meta": { + "position": { + "x": 2020, + "y": 66.2 + } + }, + "data": { + "nodeMeta": { + "title": "结束", + "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg", + "description": "工作流的最终节点,用于返回工作流运行后的结果信息", + "subTitle": "" + }, + "inputs": { + "terminatePlan": "useAnswerContent", + "streamingOutput": true, + "inputParameters": [ + { + "name": "output", + "input": { + "type": "string", + "value": { + "type": "ref", + "content": { + "source": "block-output", + "blockID": "142077", + "name": "optionContent" + }, + "rawMeta": { + "type": 1 + } + } + } + } + ], + "content": { + "type": "string", + "value": { + "type": "literal", + "content": "{{output}}" + } + } + } + } + }, + { + "id": "190196", + "type": "30", + "meta": { + "position": { + "x": 640, + "y": 78.5 + } + }, + "data": { + "outputs": [ + { + "type": "string", + "name": "input", + "required": true + } + ], + "nodeMeta": { + "title": "输入", + "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Input-v2.jpg", + "description": "支持中间过程的信息输入", + "mainColor": "#5C62FF", + "subTitle": "输入" + }, + "inputs": { + "outputSchema": "[{\"type\":\"string\",\"name\":\"input\",\"required\":true}]" + } + } + }, + { + "id": "133775", + "type": "18", + "meta": { + "position": { + "x": 1100, + "y": 39 + } + }, + "data": { + "inputs": { + "llmParam": { + "modelType": 1001, + "modelName": "Doubao-Seed-1.6", + "generationDiversity": "balance", + "temperature": 0.8, + "maxTokens": 4096, + "topP": 0.7, + "responseFormat": 2, + "systemPrompt": "" + }, + "inputParameters": [], + "extra_output": false, + "answer_type": "text", + "option_type": "static", + "dynamic_option": { + "type": "string", + "value": { + "type": "ref", + "content": { + "source": "block-output", + "blockID": "", + "name": "" + } + } + }, + "question": "你好", + "options": [ + { + "name": "" + }, + { + "name": "" + } + ], + "limit": 3 + }, + "nodeMeta": { + "title": "问答", + "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Direct-Question-v2.jpg", + "description": "支持中间向用户提问问题,支持预置选项提问和开放式问题提问两种方式", + "mainColor": "#3071F2", + "subTitle": "问答" + }, + "outputs": [ + { + "type": "string", + "name": "USER_RESPONSE", + "required": true, + "description": "用户本轮对话输入内容" + } + ] + } + }, + { + "id": "142077", + "type": "18", + "meta": { + "position": { + "x": 1560, + "y": 0 + } + }, + "data": { + "inputs": { + "llmParam": { + "modelType": 1001, + "modelName": "Doubao-Seed-1.6", + "generationDiversity": "balance", + "temperature": 0.8, + "maxTokens": 4096, + "topP": 0.7, + "responseFormat": 2, + "systemPrompt": "" + }, + "inputParameters": [], + "extra_output": false, + "answer_type": "option", + "option_type": "static", + "dynamic_option": { + "type": "string", + "value": { + "type": "ref", + "content": { + "source": "block-output", + "blockID": "", + "name": "" + } + } + }, + "question": "请选择", + "options": [ + { + "name": "A" + }, + { + "name": "B" + } + ], + "limit": 3 + }, + "nodeMeta": { + "title": "问答_1", + "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Direct-Question-v2.jpg", + "description": "支持中间向用户提问问题,支持预置选项提问和开放式问题提问两种方式", + "mainColor": "#3071F2", + "subTitle": "问答" + }, + "outputs": [ + { + "type": "string", + "name": "optionId", + "required": false + }, + { + "type": "string", + "name": "optionContent", + "required": false + } + ] + } + } + ], + "edges": [ + { + "sourceNodeID": "100001", + "targetNodeID": "190196" + }, + { + "sourceNodeID": "142077", + "targetNodeID": "900001", + "sourcePortID": "branch_0" + }, + { + "sourceNodeID": "142077", + "targetNodeID": "900001", + "sourcePortID": "branch_1" + }, + { + "sourceNodeID": "142077", + "targetNodeID": "900001", + "sourcePortID": "default" + }, + { + "sourceNodeID": "190196", + "targetNodeID": "133775" + }, + { + "sourceNodeID": "133775", + "targetNodeID": "142077" + } + ] +} diff --git a/backend/domain/workflow/internal/canvas/examples/chatflow/llm_chat.json b/backend/domain/workflow/internal/canvas/examples/chatflow/llm_chat.json new file mode 100644 index 00000000..7b28b91b --- /dev/null +++ b/backend/domain/workflow/internal/canvas/examples/chatflow/llm_chat.json @@ -0,0 +1,397 @@ +{ + "nodes": [ + { + "blocks": [], + "data": { + "nodeMeta": { + "description": "工作流的起始节点,用于设定启动工作流需要的信息", + "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg", + "subTitle": "", + "title": "开始" + }, + "outputs": [ + { + "name": "USER_INPUT", + "required": false, + "type": "string" + }, + { + "defaultValue": "Default", + "description": "本次请求绑定的会话,会自动写入消息、会从该会话读对话历史。", + "name": "CONVERSATION_NAME", + "required": false, + "type": "string" + } + ], + "trigger_parameters": [] + }, + "edges": null, + "id": "100001", + "meta": { + "position": { + "x": 0, + "y": 0 + } + }, + "type": "1" + }, + { + "blocks": [], + "data": { + "inputs": { + "content": { + "type": "string", + "value": { + "content": "{{output}}", + "type": "literal" + } + }, + "inputParameters": [ + { + "input": { + "type": "string", + "value": { + "content": { + "blockID": "123887", + "name": "output", + "source": "block-output" + }, + "rawMeta": { + "type": 1 + }, + "type": "ref" + } + }, + "name": "output" + } + ], + "streamingOutput": true, + "terminatePlan": "useAnswerContent" + }, + "nodeMeta": { + "description": "工作流的最终节点,用于返回工作流运行后的结果信息", + "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg", + "subTitle": "", + "title": "结束" + } + }, + "edges": null, + "id": "900001", + "meta": { + "position": { + "x": 926, + "y": -13 + } + }, + "type": "2" + }, + { + "blocks": [], + "data": { + "inputs": { + "fcParamVar": { + "knowledgeFCParam": {} + }, + "inputParameters": [ + { + "input": { + "type": "string", + "value": { + "content": { + "blockID": "100001", + "name": "USER_INPUT", + "source": "block-output" + }, + "rawMeta": { + "type": 1 + }, + "type": "ref" + } + }, + "name": "input" + } + ], + "llmParam": [ + { + "input": { + "type": "integer", + "value": { + "content": "1737521813", + "rawMeta": { + "type": 2 + }, + "type": "literal" + } + }, + "name": "modelType" + }, + { + "input": { + "type": "string", + "value": { + "content": "豆包·1.5·Pro·32k", + "rawMeta": { + "type": 1 + }, + "type": "literal" + } + }, + "name": "modleName" + }, + { + "input": { + "type": "string", + "value": { + "content": "balance", + "rawMeta": { + "type": 1 + }, + "type": "literal" + } + }, + "name": "generationDiversity" + }, + { + "input": { + "type": "float", + "value": { + "content": "0.8", + "rawMeta": { + "type": 4 + }, + "type": "literal" + } + }, + "name": "temperature" + }, + { + "input": { + "type": "integer", + "value": { + "content": "4096", + "rawMeta": { + "type": 2 + }, + "type": "literal" + } + }, + "name": "maxTokens" + }, + { + "input": { + "type": "boolean", + "value": { + "content": false, + "rawMeta": { + "type": 3 + }, + "type": "literal" + } + }, + "name": "spCurrentTime" + }, + { + "input": { + "type": "boolean", + "value": { + "content": false, + "rawMeta": { + "type": 3 + }, + "type": "literal" + } + }, + "name": "spAntiLeak" + }, + { + "input": { + "type": "boolean", + "value": { + "content": false, + "rawMeta": { + "type": 3 + }, + "type": "literal" + } + }, + "name": "prefixCache" + }, + { + "input": { + "type": "integer", + "value": { + "content": "2", + "rawMeta": { + "type": 2 + }, + "type": "literal" + } + }, + "name": "responseFormat" + }, + { + "input": { + "type": "string", + "value": { + "content": "{{input}}", + "rawMeta": { + "type": 1 + }, + "type": "literal" + } + }, + "name": "prompt" + }, + { + "input": { + "type": "boolean", + "value": { + "content": false, + "rawMeta": { + "type": 3 + }, + "type": "literal" + } + }, + "name": "enableChatHistory" + }, + { + "input": { + "type": "integer", + "value": { + "content": "3", + "rawMeta": { + "type": 2 + }, + "type": "literal" + } + }, + "name": "chatHistoryRound" + }, + { + "input": { + "type": "string", + "value": { + "content": "", + "rawMeta": { + "type": 1 + }, + "type": "literal" + } + }, + "name": "systemPrompt" + }, + { + "input": { + "type": "string", + "value": { + "content": "", + "rawMeta": { + "type": 1 + }, + "type": "literal" + } + }, + "name": "stableSystemPrompt" + }, + { + "input": { + "type": "boolean", + "value": { + "content": false, + "rawMeta": { + "type": 3 + }, + "type": "literal" + } + }, + "name": "canContinue" + }, + { + "input": { + "type": "string", + "value": { + "content": "", + "rawMeta": { + "type": 1 + }, + "type": "literal" + } + }, + "name": "loopPromptVersion" + }, + { + "input": { + "type": "string", + "value": { + "content": "", + "rawMeta": { + "type": 1 + }, + "type": "literal" + } + }, + "name": "loopPromptName" + }, + { + "input": { + "type": "string", + "value": { + "content": "", + "rawMeta": { + "type": 1 + }, + "type": "literal" + } + }, + "name": "loopPromptId" + } + ], + "settingOnError": { + "processType": 1, + "retryTimes": 0, + "timeoutMs": 180000 + } + }, + "nodeMeta": { + "description": "调用大语言模型,使用变量和提示词生成回复", + "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LLM-v2.jpg", + "mainColor": "#5C62FF", + "subTitle": "大模型", + "title": "大模型" + }, + "outputs": [ + { + "name": "output", + "type": "string" + } + ], + "version": "3" + }, + "edges": null, + "id": "123887", + "meta": { + "position": { + "x": 463, + "y": -39 + } + }, + "type": "3" + } + ], + "edges": [ + { + "sourceNodeID": "100001", + "targetNodeID": "123887", + "sourcePortID": "" + }, + { + "sourceNodeID": "123887", + "targetNodeID": "900001", + "sourcePortID": "" + } + ], + "versions": { + "loop": "v2" + } +} diff --git a/backend/domain/workflow/internal/canvas/examples/chatflow/llm_chat_with_history.json b/backend/domain/workflow/internal/canvas/examples/chatflow/llm_chat_with_history.json new file mode 100644 index 00000000..c8894503 --- /dev/null +++ b/backend/domain/workflow/internal/canvas/examples/chatflow/llm_chat_with_history.json @@ -0,0 +1,397 @@ +{ + "nodes": [ + { + "blocks": [], + "data": { + "nodeMeta": { + "description": "工作流的起始节点,用于设定启动工作流需要的信息", + "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg", + "subTitle": "", + "title": "开始" + }, + "outputs": [ + { + "name": "USER_INPUT", + "required": false, + "type": "string" + }, + { + "defaultValue": "Default", + "description": "本次请求绑定的会话,会自动写入消息、会从该会话读对话历史。", + "name": "CONVERSATION_NAME", + "required": false, + "type": "string" + } + ], + "trigger_parameters": [] + }, + "edges": null, + "id": "100001", + "meta": { + "position": { + "x": 0, + "y": 0 + } + }, + "type": "1" + }, + { + "blocks": [], + "data": { + "inputs": { + "content": { + "type": "string", + "value": { + "content": "{{output}}", + "type": "literal" + } + }, + "inputParameters": [ + { + "input": { + "type": "string", + "value": { + "content": { + "blockID": "123887", + "name": "output", + "source": "block-output" + }, + "rawMeta": { + "type": 1 + }, + "type": "ref" + } + }, + "name": "output" + } + ], + "streamingOutput": true, + "terminatePlan": "useAnswerContent" + }, + "nodeMeta": { + "description": "工作流的最终节点,用于返回工作流运行后的结果信息", + "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg", + "subTitle": "", + "title": "结束" + } + }, + "edges": null, + "id": "900001", + "meta": { + "position": { + "x": 926, + "y": -13 + } + }, + "type": "2" + }, + { + "blocks": [], + "data": { + "inputs": { + "fcParamVar": { + "knowledgeFCParam": {} + }, + "inputParameters": [ + { + "input": { + "type": "string", + "value": { + "content": { + "blockID": "100001", + "name": "USER_INPUT", + "source": "block-output" + }, + "rawMeta": { + "type": 1 + }, + "type": "ref" + } + }, + "name": "input" + } + ], + "llmParam": [ + { + "input": { + "type": "integer", + "value": { + "content": "1737521813", + "rawMeta": { + "type": 2 + }, + "type": "literal" + } + }, + "name": "modelType" + }, + { + "input": { + "type": "string", + "value": { + "content": "豆包·1.5·Pro·32k", + "rawMeta": { + "type": 1 + }, + "type": "literal" + } + }, + "name": "modleName" + }, + { + "input": { + "type": "string", + "value": { + "content": "balance", + "rawMeta": { + "type": 1 + }, + "type": "literal" + } + }, + "name": "generationDiversity" + }, + { + "input": { + "type": "float", + "value": { + "content": "0.8", + "rawMeta": { + "type": 4 + }, + "type": "literal" + } + }, + "name": "temperature" + }, + { + "input": { + "type": "integer", + "value": { + "content": "4096", + "rawMeta": { + "type": 2 + }, + "type": "literal" + } + }, + "name": "maxTokens" + }, + { + "input": { + "type": "boolean", + "value": { + "content": false, + "rawMeta": { + "type": 3 + }, + "type": "literal" + } + }, + "name": "spCurrentTime" + }, + { + "input": { + "type": "boolean", + "value": { + "content": false, + "rawMeta": { + "type": 3 + }, + "type": "literal" + } + }, + "name": "spAntiLeak" + }, + { + "input": { + "type": "boolean", + "value": { + "content": false, + "rawMeta": { + "type": 3 + }, + "type": "literal" + } + }, + "name": "prefixCache" + }, + { + "input": { + "type": "integer", + "value": { + "content": "2", + "rawMeta": { + "type": 2 + }, + "type": "literal" + } + }, + "name": "responseFormat" + }, + { + "input": { + "type": "string", + "value": { + "content": "{{input}}", + "rawMeta": { + "type": 1 + }, + "type": "literal" + } + }, + "name": "prompt" + }, + { + "input": { + "type": "boolean", + "value": { + "content": true, + "rawMeta": { + "type": 3 + }, + "type": "literal" + } + }, + "name": "enableChatHistory" + }, + { + "input": { + "type": "integer", + "value": { + "content": "3", + "rawMeta": { + "type": 2 + }, + "type": "literal" + } + }, + "name": "chatHistoryRound" + }, + { + "input": { + "type": "string", + "value": { + "content": "", + "rawMeta": { + "type": 1 + }, + "type": "literal" + } + }, + "name": "systemPrompt" + }, + { + "input": { + "type": "string", + "value": { + "content": "", + "rawMeta": { + "type": 1 + }, + "type": "literal" + } + }, + "name": "stableSystemPrompt" + }, + { + "input": { + "type": "boolean", + "value": { + "content": false, + "rawMeta": { + "type": 3 + }, + "type": "literal" + } + }, + "name": "canContinue" + }, + { + "input": { + "type": "string", + "value": { + "content": "", + "rawMeta": { + "type": 1 + }, + "type": "literal" + } + }, + "name": "loopPromptVersion" + }, + { + "input": { + "type": "string", + "value": { + "content": "", + "rawMeta": { + "type": 1 + }, + "type": "literal" + } + }, + "name": "loopPromptName" + }, + { + "input": { + "type": "string", + "value": { + "content": "", + "rawMeta": { + "type": 1 + }, + "type": "literal" + } + }, + "name": "loopPromptId" + } + ], + "settingOnError": { + "processType": 1, + "retryTimes": 0, + "timeoutMs": 180000 + } + }, + "nodeMeta": { + "description": "调用大语言模型,使用变量和提示词生成回复", + "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LLM-v2.jpg", + "mainColor": "#5C62FF", + "subTitle": "大模型", + "title": "大模型" + }, + "outputs": [ + { + "name": "output", + "type": "string" + } + ], + "version": "3" + }, + "edges": null, + "id": "123887", + "meta": { + "position": { + "x": 463, + "y": -39 + } + }, + "type": "3" + } + ], + "edges": [ + { + "sourceNodeID": "100001", + "targetNodeID": "123887", + "sourcePortID": "" + }, + { + "sourceNodeID": "123887", + "targetNodeID": "900001", + "sourcePortID": "" + } + ], + "versions": { + "loop": "v2" + } +} diff --git a/backend/domain/workflow/internal/nodes/conversation/conversationhistory.go b/backend/domain/workflow/internal/nodes/conversation/conversationhistory.go index 35ef1b2a..91d3d10a 100644 --- a/backend/domain/workflow/internal/nodes/conversation/conversationhistory.go +++ b/backend/domain/workflow/internal/nodes/conversation/conversationhistory.go @@ -148,7 +148,7 @@ func (ch *ConversationHistory) Invoke(ctx context.Context, input map[string]any) runIDs, err := crossmessage.DefaultSVC().GetLatestRunIDs(ctx, &crossmessage.GetLatestRunIDsRequest{ ConversationID: conversationID, UserID: userID, - AppID: *appID, + BizID: *appID, Rounds: rounds, InitRunID: initRunID, SectionID: sectionID, diff --git a/backend/domain/workflow/internal/nodes/conversation/createconversation.go b/backend/domain/workflow/internal/nodes/conversation/createconversation.go index 671dcb49..2b3c7129 100644 --- a/backend/domain/workflow/internal/nodes/conversation/createconversation.go +++ b/backend/domain/workflow/internal/nodes/conversation/createconversation.go @@ -109,7 +109,7 @@ func (c *CreateConversation) Invoke(ctx context.Context, input map[string]any) ( if existed { cID, _, existed, err := workflow.GetRepository().GetOrCreateStaticConversation(ctx, env, conversationIDGenerator, &vo.CreateStaticConversation{ - AppID: ptr.From(appID), + BizID: ptr.From(appID), TemplateID: template.TemplateID, UserID: userID, ConnectorID: connectorID, @@ -125,7 +125,7 @@ func (c *CreateConversation) Invoke(ctx context.Context, input map[string]any) ( } cID, _, existed, err := workflow.GetRepository().GetOrCreateDynamicConversation(ctx, env, conversationIDGenerator, &vo.CreateDynamicConversation{ - AppID: ptr.From(appID), + BizID: ptr.From(appID), UserID: userID, ConnectorID: connectorID, Name: conversationName, diff --git a/backend/domain/workflow/internal/nodes/conversation/createmessage.go b/backend/domain/workflow/internal/nodes/conversation/createmessage.go index 16b038e4..0579a740 100644 --- a/backend/domain/workflow/internal/nodes/conversation/createmessage.go +++ b/backend/domain/workflow/internal/nodes/conversation/createmessage.go @@ -98,7 +98,7 @@ func (c *CreateMessage) getConversationIDByName(ctx context.Context, env vo.Env, var conversationID int64 if isExist { cID, _, _, err := workflow.GetRepository().GetOrCreateStaticConversation(ctx, env, conversationIDGenerator, &vo.CreateStaticConversation{ - AppID: ptr.From(appID), + BizID: ptr.From(appID), TemplateID: template.TemplateID, UserID: userID, ConnectorID: connectorID, @@ -150,7 +150,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s var conversationID int64 var err error - var resolvedAppID int64 + var bizID int64 if appID == nil { if conversationName != "Default" { return nil, vo.WrapError(errno.ErrOnlyDefaultConversationAllowInAgentScenario, errors.New("conversation node only allow in application")) @@ -167,13 +167,13 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s }, nil } conversationID = *execCtx.ExeCfg.ConversationID - resolvedAppID = *agentID + bizID = *agentID } else { conversationID, err = c.getConversationIDByName(ctx, env, appID, version, conversationName, userID, connectorID) if err != nil { return nil, err } - resolvedAppID = *appID + bizID = *appID } if conversationID == 0 { @@ -209,7 +209,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s if role == "user" { // For user messages, always create a new run and store the ID in the context. runRecord, err := crossagentrun.DefaultSVC().Create(ctx, &agententity.AgentRunMeta{ - AgentID: resolvedAppID, + AgentID: bizID, ConversationID: conversationID, UserID: strconv.FormatInt(userID, 10), ConnectorID: connectorID, @@ -244,7 +244,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s runIDs, err := crossmessage.DefaultSVC().GetLatestRunIDs(ctx, &crossmessage.GetLatestRunIDsRequest{ ConversationID: conversationID, UserID: userID, - AppID: resolvedAppID, + BizID: bizID, Rounds: 1, }) if err != nil { @@ -254,7 +254,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s runID = runIDs[0] } else { runRecord, err := crossagentrun.DefaultSVC().Create(ctx, &agententity.AgentRunMeta{ - AgentID: resolvedAppID, + AgentID: bizID, ConversationID: conversationID, UserID: strconv.FormatInt(userID, 10), ConnectorID: connectorID, @@ -273,7 +273,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s Content: content, ContentType: model.ContentType("text"), UserID: strconv.FormatInt(userID, 10), - AgentID: resolvedAppID, + AgentID: bizID, RunID: runID, SectionID: sectionID, } diff --git a/backend/domain/workflow/internal/nodes/conversation/messagelist.go b/backend/domain/workflow/internal/nodes/conversation/messagelist.go index be50af48..db8611fc 100644 --- a/backend/domain/workflow/internal/nodes/conversation/messagelist.go +++ b/backend/domain/workflow/internal/nodes/conversation/messagelist.go @@ -115,7 +115,7 @@ func (m *MessageList) Invoke(ctx context.Context, input map[string]any) (map[str var conversationID int64 var err error - var resolvedAppID int64 + var bizID int64 if appID == nil { if conversationName != "Default" { return nil, vo.WrapError(errno.ErrOnlyDefaultConversationAllowInAgentScenario, errors.New("conversation node only allow in application")) @@ -129,18 +129,18 @@ func (m *MessageList) Invoke(ctx context.Context, input map[string]any) (map[str }, nil } conversationID = *execCtx.ExeCfg.ConversationID - resolvedAppID = *agentID + bizID = *agentID } else { conversationID, err = m.getConversationIDByName(ctx, env, appID, version, conversationName, userID, connectorID) if err != nil { return nil, err } - resolvedAppID = *appID + bizID = *appID } req := &crossmessage.MessageListRequest{ UserID: userID, - AppID: resolvedAppID, + BizID: bizID, ConversationID: conversationID, } diff --git a/backend/domain/workflow/internal/nodes/llm/llm.go b/backend/domain/workflow/internal/nodes/llm/llm.go index b1652eab..b45f29a6 100644 --- a/backend/domain/workflow/internal/nodes/llm/llm.go +++ b/backend/domain/workflow/internal/nodes/llm/llm.go @@ -100,9 +100,9 @@ const ( ReasoningOutputKey = "reasoning_content" ) -const knowledgeUserPromptTemplate = `根据引用的内容回答问题: - 1.如果引用的内容里面包含 的标签, 标签里的 src 字段表示图片地址, 需要在回答问题的时候展示出去, 输出格式为"![图片名称](图片地址)" 。 - 2.如果引用的内容不包含 的标签, 你回答问题时不需要展示图片 。 +const knowledgeUserPromptTemplate = `根据引用的内容回答问题: + 1.如果引用的内容里面包含 的标签, 标签里的 src 字段表示图片地址, 需要在回答问题的时候展示出去, 输出格式为"![图片名称](图片地址)" 。 + 2.如果引用的内容不包含 的标签, 你回答问题时不需要展示图片 。 例如: 如果内容为一只小猫,你的输出应为:![一只小猫](https://example.com/image.jpg)。 如果内容为一只小猫 和 一只小狗 和 一只小牛,你的输出应为:![一只小猫](https://example.com/image1.jpg) 和 ![一只小狗](https://example.com/image2.jpg) 和 ![一只小牛](https://example.com/image3.jpg) @@ -290,7 +290,7 @@ func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (* c.AssociateStartNodeUserInputFields = make(map[string]struct{}) for _, info := range ns.InputSources { if len(info.Path) == 1 && info.Source.Ref != nil && info.Source.Ref.FromNodeKey == entity.EntryNodeKey { - if compose.FromFieldPath(info.Source.Ref.FromPath).Equals(compose.FromField("USER_INPUT")) { + if compose.FromFieldPath(info.Source.Ref.FromPath).Equals(compose.FromField(vo.UserInputKey)) { c.AssociateStartNodeUserInputFields[info.Path[0]] = struct{}{} } } diff --git a/backend/domain/workflow/internal/nodes/node.go b/backend/domain/workflow/internal/nodes/node.go index 634da62f..575123fe 100644 --- a/backend/domain/workflow/internal/nodes/node.go +++ b/backend/domain/workflow/internal/nodes/node.go @@ -192,8 +192,3 @@ type StreamGenerator interface { FieldStreamType(path compose.FieldPath, ns *schema.NodeSchema, sc *schema.WorkflowSchema) (schema.FieldStreamType, error) } - -type ChatHistoryAware interface { - ChatHistoryEnabled() bool - ChatHistoryRounds() int64 -} diff --git a/backend/domain/workflow/internal/repo/conversation_repository.go b/backend/domain/workflow/internal/repo/conversation_repository.go index a3ae07be..85d75d94 100644 --- a/backend/domain/workflow/internal/repo/conversation_repository.go +++ b/backend/domain/workflow/internal/repo/conversation_repository.go @@ -432,7 +432,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env appDynamicConversationDraft := r.query.AppDynamicConversationDraft ret, err := appDynamicConversationDraft.WithContext(ctx).Where( - appDynamicConversationDraft.AppID.Eq(meta.AppID), + appDynamicConversationDraft.AppID.Eq(meta.BizID), appDynamicConversationDraft.ConnectorID.Eq(meta.ConnectorID), appDynamicConversationDraft.UserID.Eq(meta.UserID), appDynamicConversationDraft.Name.Eq(meta.Name), @@ -452,7 +452,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env return 0, 0, false, vo.WrapError(errno.ErrDatabaseError, err) } - conv, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID) + conv, err := idGen(ctx, meta.BizID, meta.UserID, meta.ConnectorID) if err != nil { return 0, 0, false, err } @@ -464,7 +464,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env err = r.query.AppDynamicConversationDraft.WithContext(ctx).Create(&model.AppDynamicConversationDraft{ ID: id, - AppID: meta.AppID, + AppID: meta.BizID, Name: meta.Name, UserID: meta.UserID, ConnectorID: meta.ConnectorID, @@ -479,7 +479,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env } else if env == vo.Online { appDynamicConversationOnline := r.query.AppDynamicConversationOnline ret, err := appDynamicConversationOnline.WithContext(ctx).Where( - appDynamicConversationOnline.AppID.Eq(meta.AppID), + appDynamicConversationOnline.AppID.Eq(meta.BizID), appDynamicConversationOnline.ConnectorID.Eq(meta.ConnectorID), appDynamicConversationOnline.UserID.Eq(meta.UserID), appDynamicConversationOnline.Name.Eq(meta.Name), @@ -498,7 +498,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env return 0, 0, false, vo.WrapError(errno.ErrDatabaseError, err) } - conv, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID) + conv, err := idGen(ctx, meta.BizID, meta.UserID, meta.ConnectorID) if err != nil { return 0, 0, false, err } @@ -509,7 +509,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env err = r.query.AppDynamicConversationOnline.WithContext(ctx).Create(&model.AppDynamicConversationOnline{ ID: id, - AppID: meta.AppID, + AppID: meta.BizID, Name: meta.Name, UserID: meta.UserID, ConnectorID: meta.ConnectorID, @@ -586,7 +586,7 @@ func (r *RepositoryImpl) getOrCreateDraftStaticConversation(ctx context.Context, return cs[0].ConversationID, cInfo.SectionID, true, nil } - conv, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID) + conv, err := idGen(ctx, meta.BizID, meta.UserID, meta.ConnectorID) if err != nil { return 0, 0, false, err } @@ -627,7 +627,7 @@ func (r *RepositoryImpl) getOrCreateOnlineStaticConversation(ctx context.Context return cs[0].ConversationID, cInfo.SectionID, true, nil } - conv, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID) + conv, err := idGen(ctx, meta.BizID, meta.UserID, meta.ConnectorID) if err != nil { return 0, 0, false, err } @@ -841,7 +841,7 @@ func (r *RepositoryImpl) CopyTemplateConversationByAppID(ctx context.Context, ap } -func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (string, bool, error) { +func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (string, bool, error) { if env == vo.Draft { appStaticConversationDraft := r.query.AppStaticConversationDraft ret, err := appStaticConversationDraft.WithContext(ctx).Where( @@ -857,7 +857,7 @@ func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.E appConversationTemplateDraft := r.query.AppConversationTemplateDraft template, err := appConversationTemplateDraft.WithContext(ctx).Where( appConversationTemplateDraft.TemplateID.Eq(ret.TemplateID), - appConversationTemplateDraft.AppID.Eq(appID), + appConversationTemplateDraft.AppID.Eq(bizID), ).First() if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -881,7 +881,7 @@ func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.E appConversationTemplateOnline := r.query.AppConversationTemplateOnline template, err := appConversationTemplateOnline.WithContext(ctx).Where( appConversationTemplateOnline.TemplateID.Eq(ret.TemplateID), - appConversationTemplateOnline.AppID.Eq(appID), + appConversationTemplateOnline.AppID.Eq(bizID), ).First() if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -894,11 +894,11 @@ func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.E return "", false, fmt.Errorf("unknown env %v", env) } -func (r *RepositoryImpl) GetDynamicConversationByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (*entity.DynamicConversation, bool, error) { +func (r *RepositoryImpl) GetDynamicConversationByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (*entity.DynamicConversation, bool, error) { if env == vo.Draft { appDynamicConversationDraft := r.query.AppDynamicConversationDraft ret, err := appDynamicConversationDraft.WithContext(ctx).Where( - appDynamicConversationDraft.AppID.Eq(appID), + appDynamicConversationDraft.AppID.Eq(bizID), appDynamicConversationDraft.ConnectorID.Eq(connectorID), appDynamicConversationDraft.ConversationID.Eq(conversationID), ).First() @@ -918,7 +918,7 @@ func (r *RepositoryImpl) GetDynamicConversationByID(ctx context.Context, env vo. } else if env == vo.Online { appDynamicConversationOnline := r.query.AppDynamicConversationOnline ret, err := appDynamicConversationOnline.WithContext(ctx).Where( - appDynamicConversationOnline.AppID.Eq(appID), + appDynamicConversationOnline.AppID.Eq(bizID), appDynamicConversationOnline.ConnectorID.Eq(connectorID), appDynamicConversationOnline.ConversationID.Eq(conversationID), ).First() diff --git a/backend/domain/workflow/internal/schema/node_schema.go b/backend/domain/workflow/internal/schema/node_schema.go index 61b85f0f..e8c07f19 100644 --- a/backend/domain/workflow/internal/schema/node_schema.go +++ b/backend/domain/workflow/internal/schema/node_schema.go @@ -129,3 +129,8 @@ func (s *NodeSchema) SetOutputType(key string, t *vo.TypeInfo) { func (s *NodeSchema) AddOutputSource(info ...*vo.FieldInfo) { s.OutputSources = append(s.OutputSources, info...) } + +type ChatHistoryAware interface { + ChatHistoryEnabled() bool + ChatHistoryRounds() int64 +} diff --git a/backend/domain/workflow/internal/schema/workflow_schema.go b/backend/domain/workflow/internal/schema/workflow_schema.go index 22400823..a2efc2cc 100644 --- a/backend/domain/workflow/internal/schema/workflow_schema.go +++ b/backend/domain/workflow/internal/schema/workflow_schema.go @@ -38,6 +38,7 @@ type WorkflowSchema struct { compositeNodes []*CompositeNode // won't serialize this requireCheckPoint bool // won't serialize this requireStreaming bool + historyRounds int64 once sync.Once } @@ -69,15 +70,22 @@ func (w *WorkflowSchema) Init() { w.doGetCompositeNodes() + historyRounds := int64(0) for _, node := range w.Nodes { if node.Type == entity.NodeTypeSubWorkflow { node.SubWorkflowSchema.Init() + historyRounds = max(historyRounds, node.SubWorkflowSchema.HistoryRounds()) if node.SubWorkflowSchema.requireCheckPoint { w.requireCheckPoint = true break } } + chatHistoryAware, ok := node.Configs.(ChatHistoryAware) + if ok && chatHistoryAware.ChatHistoryEnabled() { + historyRounds = max(historyRounds, chatHistoryAware.ChatHistoryRounds()) + } + if rc, ok := node.Configs.(RequireCheckpoint); ok { if rc.RequireCheckpoint() { w.requireCheckPoint = true @@ -86,6 +94,7 @@ func (w *WorkflowSchema) Init() { } } + w.historyRounds = historyRounds w.requireStreaming = w.doRequireStreaming() }) } @@ -122,6 +131,12 @@ func (w *WorkflowSchema) RequireStreaming() bool { return w.requireStreaming } +func (w *WorkflowSchema) HistoryRounds() int64 { return w.historyRounds } + +func (w *WorkflowSchema) SetHistoryRounds(historyRounds int64) { + w.historyRounds = historyRounds +} + func (w *WorkflowSchema) doGetCompositeNodes() (cNodes []*CompositeNode) { if w.Hierarchy == nil { return nil diff --git a/backend/domain/workflow/service/conversation_impl.go b/backend/domain/workflow/service/conversation_impl.go index ccec12ea..53e08600 100644 --- a/backend/domain/workflow/service/conversation_impl.go +++ b/backend/domain/workflow/service/conversation_impl.go @@ -248,7 +248,7 @@ func (c *conversationImpl) findReplaceWorkflowByConversationName(ctx context.Con if err != nil { return false, err } - if v.Name == "CONVERSATION_NAME" && v.DefaultValue == name { + if v.Name == vo.ConversationNameKey && v.DefaultValue == name { return true, nil } } @@ -296,7 +296,7 @@ func (c *conversationImpl) replaceWorkflowsConversationName(ctx context.Context, if err != nil { return err } - if v.Name == "CONVERSATION_NAME" { + if v.Name == vo.ConversationNameKey { v.DefaultValue = conversionName } startNode.Data.Outputs[idx] = v @@ -351,18 +351,18 @@ func (c *conversationImpl) DeleteDynamicConversation(ctx context.Context, env vo return c.repo.DeleteDynamicConversation(ctx, env, templateID) } -func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, int64, error) { +func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.Env, bizID, connectorID, userID int64, conversationName string) (int64, int64, error) { t, existed, err := c.repo.GetConversationTemplate(ctx, env, vo.GetConversationTemplatePolicy{ - AppID: ptr.Of(appID), + AppID: ptr.Of(bizID), Name: ptr.Of(conversationName), }) if err != nil { return 0, 0, err } - conversationIDGenerator := workflow.ConversationIDGenerator(func(ctx context.Context, appID int64, userID, connectorID int64) (*conventity.Conversation, error) { + conversationIDGenerator := workflow.ConversationIDGenerator(func(ctx context.Context, bizID int64, userID, connectorID int64) (*conventity.Conversation, error) { return crossconversation.DefaultSVC().CreateConversation(ctx, &conventity.CreateMeta{ - AgentID: appID, + AgentID: bizID, UserID: userID, ConnectorID: connectorID, Scene: common.Scene_SceneWorkflow, @@ -371,7 +371,7 @@ func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.E if existed { conversationID, sectionID, _, err := c.repo.GetOrCreateStaticConversation(ctx, env, conversationIDGenerator, &vo.CreateStaticConversation{ - AppID: appID, + BizID: bizID, ConnectorID: connectorID, UserID: userID, TemplateID: t.TemplateID, @@ -383,7 +383,7 @@ func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.E } conversationID, sectionID, _, err := c.repo.GetOrCreateDynamicConversation(ctx, env, conversationIDGenerator, &vo.CreateDynamicConversation{ - AppID: appID, + BizID: bizID, ConnectorID: connectorID, UserID: userID, Name: conversationName, @@ -465,8 +465,8 @@ func (c *conversationImpl) GetDynamicConversationByName(ctx context.Context, env return c.repo.GetDynamicConversationByName(ctx, env, appID, connectorID, userID, name) } -func (c *conversationImpl) GetConversationNameByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (string, bool, error) { - sc, existed, err := c.repo.GetStaticConversationByID(ctx, env, appID, connectorID, conversationID) +func (c *conversationImpl) GetConversationNameByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (string, bool, error) { + sc, existed, err := c.repo.GetStaticConversationByID(ctx, env, bizID, connectorID, conversationID) if err != nil { return "", false, err } @@ -474,7 +474,7 @@ func (c *conversationImpl) GetConversationNameByID(ctx context.Context, env vo.E return sc, true, nil } - dc, existed, err := c.repo.GetDynamicConversationByID(ctx, env, appID, connectorID, conversationID) + dc, existed, err := c.repo.GetDynamicConversationByID(ctx, env, bizID, connectorID, conversationID) if err != nil { return "", false, err } diff --git a/backend/domain/workflow/service/executable_impl.go b/backend/domain/workflow/service/executable_impl.go index d0734d40..1d854b63 100644 --- a/backend/domain/workflow/service/executable_impl.go +++ b/backend/domain/workflow/service/executable_impl.go @@ -22,6 +22,8 @@ import ( "fmt" "time" + "github.com/coze-dev/coze-studio/backend/types/consts" + einoCompose "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" @@ -282,6 +284,50 @@ func (i *impl) AsyncExecute(ctx context.Context, config workflowModel.ExecuteCon return executeID, nil } +func (i *impl) handleHistory(ctx context.Context, config *workflowModel.ExecuteConfig, input map[string]any, historyRounds int64, shouldFetchConversationByName bool) error { + if historyRounds <= 0 { + return nil + } + + if shouldFetchConversationByName { + var cID, sID, bizID int64 + var err error + if config.AppID != nil { + bizID = *config.AppID + } else if config.AgentID != nil { + bizID = *config.AgentID + } + for k, v := range input { + if k == vo.ConversationNameKey { + cName, ok := v.(string) + if !ok { + return errors.New("CONVERSATION_NAME must be string") + } + cID, sID, err = i.GetOrCreateConversation(ctx, vo.Draft, bizID, consts.CozeConnectorID, config.Operator, cName) + if err != nil { + return err + } + config.ConversationID = ptr.Of(cID) + config.SectionID = ptr.Of(sID) + } + } + } + + messages, scMessages, err := i.prefetchChatHistory(ctx, *config, historyRounds) + if err != nil { + logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err) + } + + if len(messages) > 0 { + config.ConversationHistory = messages + } + + if len(scMessages) > 0 { + config.ConversationHistorySchemaMessages = scMessages + } + return nil +} + func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workflowModel.ExecuteConfig, input map[string]any) (int64, error) { var ( err error @@ -308,30 +354,6 @@ func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workf } } - historyRounds := int64(0) - if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow { - - historyRounds, err = getHistoryRoundsFromNode(ctx, wfEntity, nodeID, i.repo) - if err != nil { - return 0, err - } - } - - if historyRounds > 0 { - messages, scMessages, err := i.prefetchChatHistory(ctx, config, historyRounds) - if err != nil { - logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err) - } - - if len(messages) > 0 { - config.ConversationHistory = messages - } - - if len(scMessages) > 0 { - config.ConversationHistorySchemaMessages = scMessages - } - - } c := &vo.Canvas{} if err = sonic.UnmarshalString(wfEntity.Canvas, c); err != nil { return 0, fmt.Errorf("failed to unmarshal canvas: %w", err) @@ -342,6 +364,17 @@ func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workf return 0, fmt.Errorf("failed to convert canvas to workflow schema: %w", err) } + historyRounds := int64(0) + if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow { + historyRounds = workflowSC.HistoryRounds() + } + + if historyRounds > 0 { + if err = i.handleHistory(ctx, &config, input, historyRounds, true); err != nil { + return 0, err + } + } + wf, err := compose.NewWorkflowFromNode(ctx, workflowSC, vo.NodeKey(nodeID), einoCompose.WithGraphName(fmt.Sprintf("%d", wfEntity.ID))) if err != nil { return 0, fmt.Errorf("failed to create workflow: %w", err) @@ -417,29 +450,6 @@ func (i *impl) StreamExecute(ctx context.Context, config workflowModel.ExecuteCo } } - historyRounds := int64(0) - if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow { - historyRounds, err = i.calculateMaxChatHistoryRounds(ctx, wfEntity, i.repo) - if err != nil { - return nil, err - } - } - - if historyRounds > 0 { - messages, scMessages, err := i.prefetchChatHistory(ctx, config, historyRounds) - if err != nil { - logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err) - } - - if len(messages) > 0 { - config.ConversationHistory = messages - } - - if len(scMessages) > 0 { - config.ConversationHistorySchemaMessages = scMessages - } - - } c := &vo.Canvas{} if err = sonic.UnmarshalString(wfEntity.Canvas, c); err != nil { return nil, fmt.Errorf("failed to unmarshal canvas: %w", err) @@ -450,6 +460,17 @@ func (i *impl) StreamExecute(ctx context.Context, config workflowModel.ExecuteCo return nil, fmt.Errorf("failed to convert canvas to workflow schema: %w", err) } + historyRounds := int64(0) + if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow { + historyRounds = workflowSC.HistoryRounds() + } + + if historyRounds > 0 { + if err = i.handleHistory(ctx, &config, input, historyRounds, false); err != nil { + return nil, err + } + } + var wfOpts []compose.WorkflowOption wfOpts = append(wfOpts, compose.WithIDAsName(wfEntity.ID)) if s := execute.GetStaticConfig(); s != nil && s.MaxNodeCountPerWorkflow > 0 { @@ -997,20 +1018,6 @@ func (i *impl) checkApplicationWorkflowReleaseVersion(ctx context.Context, appID return nil } -const maxHistoryRounds int64 = 30 - -func (i *impl) calculateMaxChatHistoryRounds(ctx context.Context, wfEntity *entity.Workflow, repo workflow.Repository) (int64, error) { - if wfEntity == nil { - return 0, nil - } - - maxRounds, err := getMaxHistoryRoundsRecursively(ctx, wfEntity, repo) - if err != nil { - return 0, err - } - return min(maxRounds, maxHistoryRounds), nil -} - func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.ExecuteConfig, historyRounds int64) ([]*crossmessage.WfMessage, []*schema.Message, error) { convID := config.ConversationID agentID := config.AgentID @@ -1027,11 +1034,11 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe return nil, nil, nil } - var resolvedAppID int64 + var bizID int64 if appID != nil { - resolvedAppID = *appID + bizID = *appID } else if agentID != nil { - resolvedAppID = *agentID + bizID = *agentID } else { logs.CtxWarnf(ctx, "AppID and AgentID are both nil, skipping chat history") return nil, nil, nil @@ -1039,7 +1046,7 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe runIdsReq := &crossmessage.GetLatestRunIDsRequest{ ConversationID: *convID, - AppID: resolvedAppID, + BizID: bizID, UserID: userID, Rounds: historyRounds + 1, SectionID: *sectionID, @@ -1048,7 +1055,7 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe runIds, err := crossmessage.DefaultSVC().GetLatestRunIDs(ctx, runIdsReq) if err != nil { logs.CtxErrorf(ctx, "failed to get latest run ids: %v", err) - return nil, nil, nil + return nil, nil, err } if len(runIds) <= 1 { return []*crossmessage.WfMessage{}, []*schema.Message{}, nil @@ -1061,7 +1068,7 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe }) if err != nil { logs.CtxErrorf(ctx, "failed to get messages by run ids: %v", err) - return nil, nil, nil + return nil, nil, err } return response.Messages, response.SchemaMessages, nil diff --git a/backend/domain/workflow/service/executable_impl_test.go b/backend/domain/workflow/service/executable_impl_test.go new file mode 100644 index 00000000..e3224d6e --- /dev/null +++ b/backend/domain/workflow/service/executable_impl_test.go @@ -0,0 +1,286 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/cloudwego/eino/schema" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow" + crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message" + messagemock "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message/messagemock" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" + mock_workflow "github.com/coze-dev/coze-studio/backend/internal/mock/domain/workflow" + "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" +) + +func TestImpl_handleHistory(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t, gomock.WithOverridableExpectations()) + defer ctrl.Finish() + + // Setup for cross-domain service mock + mockMessage := messagemock.NewMockMessage(ctrl) + crossmessage.SetDefaultSVC(mockMessage) + + tests := []struct { + name string + setupMock func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) + config *workflowModel.ExecuteConfig + input map[string]any + historyRounds int64 + shouldFetch bool + expectErr bool + expectedHistory []*crossmessage.WfMessage + expectedSchemaHistory []*schema.Message + }{ + { + name: "historyRounds is zero", + historyRounds: 0, + shouldFetch: true, + config: &workflowModel.ExecuteConfig{}, + setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) { + }, + expectErr: false, + }, + { + name: "shouldFetch is false", + historyRounds: 5, + shouldFetch: false, + config: &workflowModel.ExecuteConfig{ + AppID: ptr.Of(int64(1)), + ConversationID: ptr.Of(int64(100)), + SectionID: ptr.Of(int64(101)), + }, + setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) { + msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{1, 2}, nil).AnyTimes() + msgSvc.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(&crossmessage.GetMessagesByRunIDsResponse{ + Messages: []*crossmessage.WfMessage{{ID: 1}}, + SchemaMessages: []*schema.Message{{ + Role: schema.User, + Content: "123", + }}, + }, nil).AnyTimes() + }, + expectErr: false, + expectedHistory: []*crossmessage.WfMessage{{ID: 1}}, + expectedSchemaHistory: []*schema.Message{{ + Role: schema.User, + Content: "123", + }}, + }, + { + name: "fetch conversation by name - conversation exists", + historyRounds: 3, + shouldFetch: true, + config: &workflowModel.ExecuteConfig{AppID: ptr.Of(int64(1))}, + input: map[string]any{"CONVERSATION_NAME": "test-conv"}, + setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) { + service.EXPECT().GetOrCreateConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), "test-conv").Return(int64(200), int64(201), nil).AnyTimes() + msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{3, 4}, nil).AnyTimes() + msgSvc.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(&crossmessage.GetMessagesByRunIDsResponse{ + Messages: []*crossmessage.WfMessage{{ID: 2}}, + SchemaMessages: []*schema.Message{{ + Role: schema.Assistant, + Content: "123", + }}, + }, nil).AnyTimes() + repo.EXPECT().GetConversationTemplate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&entity.ConversationTemplate{ + TemplateID: int64(202), + SpaceID: int64(203), + AppID: int64(204), + }, true, nil).AnyTimes() + repo.EXPECT().GetOrCreateStaticConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(205), int64(206), true, nil).AnyTimes() + }, + expectErr: false, + expectedHistory: []*crossmessage.WfMessage{{ID: 2}}, + expectedSchemaHistory: []*schema.Message{{ + Role: schema.Assistant, + Content: "123", + }}, + }, + { + name: "fetch conversation by name - conversation not exists", + historyRounds: 3, + shouldFetch: true, + config: &workflowModel.ExecuteConfig{AgentID: ptr.Of(int64(2))}, + input: map[string]any{"CONVERSATION_NAME": "new-conv"}, + setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) { + service.EXPECT().GetOrCreateConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), "new-conv").Return(int64(300), int64(301), nil).AnyTimes() + msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{5, 6}, nil).AnyTimes() + msgSvc.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(&crossmessage.GetMessagesByRunIDsResponse{ + Messages: []*crossmessage.WfMessage{{ID: 3}}, + }, nil).AnyTimes() + repo.EXPECT().GetConversationTemplate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&entity.ConversationTemplate{ + TemplateID: int64(202), + SpaceID: int64(203), + AppID: int64(204), + }, false, nil).AnyTimes() + repo.EXPECT().GetOrCreateDynamicConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(205), int64(206), true, nil).AnyTimes() + }, + expectErr: false, + expectedHistory: []*crossmessage.WfMessage{{ID: 3}}, + }, + { + name: "input with wrong type for conversation name", + historyRounds: 5, + shouldFetch: true, + config: &workflowModel.ExecuteConfig{AppID: ptr.Of(int64(1))}, + input: map[string]any{"CONVERSATION_NAME": 12345}, + setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) { + }, + expectErr: true, + }, + { + name: "GetOrCreateConversation returns error", + historyRounds: 5, + shouldFetch: true, + config: &workflowModel.ExecuteConfig{AppID: ptr.Of(int64(1))}, + input: map[string]any{"CONVERSATION_NAME": "fail-conv"}, + setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) { + service.EXPECT().GetOrCreateConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), "fail-conv").Return(int64(0), int64(0), errors.New("db error")).AnyTimes() + repo.EXPECT().GetConversationTemplate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&entity.ConversationTemplate{ + TemplateID: int64(202), + SpaceID: int64(203), + AppID: int64(204), + }, false, nil).AnyTimes() + repo.EXPECT().GetOrCreateDynamicConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(205), int64(206), true, errors.New("db error")).AnyTimes() + }, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockService := mock_workflow.NewMockService(ctrl) + mockRepo := mock_workflow.NewMockRepository(ctrl) + testImpl := &impl{repo: mockRepo, conversationImpl: &conversationImpl{repo: mockRepo}} + + tt.setupMock(mockService, mockMessage, mockRepo) + + err := testImpl.handleHistory(ctx, tt.config, tt.input, tt.historyRounds, tt.shouldFetch) + + if tt.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + if tt.expectedHistory != nil { + assert.Equal(t, tt.expectedHistory, tt.config.ConversationHistory) + } else if tt.historyRounds == 0 { + assert.Nil(t, tt.config.ConversationHistory) + } else if tt.expectedSchemaHistory != nil { + assert.Equal(t, tt.expectedSchemaHistory, tt.config.ConversationHistorySchemaMessages) + } + } + }) + } +} + +func TestImpl_prefetchChatHistory(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t, gomock.WithOverridableExpectations()) + defer ctrl.Finish() + + mockMessage := messagemock.NewMockMessage(ctrl) + crossmessage.SetDefaultSVC(mockMessage) + + tests := []struct { + name string + setupMock func(msgSvc *messagemock.MockMessage) + config workflowModel.ExecuteConfig + historyRounds int64 + expectErr bool + }{ + { + name: "SectionID is nil", + config: workflowModel.ExecuteConfig{ + ConversationID: ptr.Of(int64(100)), + AppID: ptr.Of(int64(1)), + }, + historyRounds: 5, + setupMock: func(msgSvc *messagemock.MockMessage) {}, + expectErr: false, + }, + { + name: "ConversationID is nil", + config: workflowModel.ExecuteConfig{ + SectionID: ptr.Of(int64(101)), + AppID: ptr.Of(int64(1)), + }, + historyRounds: 5, + setupMock: func(msgSvc *messagemock.MockMessage) {}, + expectErr: false, + }, + { + name: "AppID and AgentID are both nil", + config: workflowModel.ExecuteConfig{ + ConversationID: ptr.Of(int64(100)), + SectionID: ptr.Of(int64(101)), + }, + historyRounds: 5, + setupMock: func(msgSvc *messagemock.MockMessage) {}, + expectErr: false, + }, + { + name: "GetLatestRunIDs returns error", + config: workflowModel.ExecuteConfig{ + AppID: ptr.Of(int64(1)), + ConversationID: ptr.Of(int64(100)), + SectionID: ptr.Of(int64(101)), + }, + historyRounds: 5, + setupMock: func(msgSvc *messagemock.MockMessage) { + msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return(nil, errors.New("db error")) + }, + expectErr: true, + }, + { + name: "GetMessagesByRunIDs returns error", + config: workflowModel.ExecuteConfig{ + AppID: ptr.Of(int64(1)), + ConversationID: ptr.Of(int64(100)), + SectionID: ptr.Of(int64(101)), + }, + historyRounds: 5, + setupMock: func(msgSvc *messagemock.MockMessage) { + msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{1, 2, 3}, nil) + msgSvc.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(nil, errors.New("db error")) + }, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testImpl := &impl{} + tt.setupMock(mockMessage) + + _, _, err := testImpl.prefetchChatHistory(ctx, tt.config, tt.historyRounds) + + if tt.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/backend/domain/workflow/service/service_impl.go b/backend/domain/workflow/service/service_impl.go index a007064a..50b3ab01 100644 --- a/backend/domain/workflow/service/service_impl.go +++ b/backend/domain/workflow/service/service_impl.go @@ -39,7 +39,6 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/infra/contract/cache" @@ -522,7 +521,7 @@ func isEnableChatHistory(s *schema.NodeSchema) bool { return false } - chatHistoryAware, ok := s.Configs.(nodes.ChatHistoryAware) + chatHistoryAware, ok := s.Configs.(schema.ChatHistoryAware) if !ok { return false } @@ -2171,15 +2170,15 @@ func (i *impl) adaptToChatFlow(ctx context.Context, wID int64) error { vMap[v.Name] = true } - if _, ok := vMap["USER_INPUT"]; !ok { + if _, ok := vMap[vo.UserInputKey]; !ok { startNode.Data.Outputs = append(startNode.Data.Outputs, &vo.Variable{ - Name: "USER_INPUT", + Name: vo.UserInputKey, Type: vo.VariableTypeString, }) } - if _, ok := vMap["CONVERSATION_NAME"]; !ok { + if _, ok := vMap[vo.ConversationNameKey]; !ok { startNode.Data.Outputs = append(startNode.Data.Outputs, &vo.Variable{ - Name: "CONVERSATION_NAME", + Name: vo.ConversationNameKey, Type: vo.VariableTypeString, DefaultValue: "Default", }) diff --git a/backend/domain/workflow/service/utils.go b/backend/domain/workflow/service/utils.go index 572f1320..dd5608f8 100644 --- a/backend/domain/workflow/service/utils.go +++ b/backend/domain/workflow/service/utils.go @@ -22,15 +22,11 @@ import ( "strconv" "strings" - workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow" "github.com/coze-dev/coze-studio/backend/api/model/workflow" - wf "github.com/coze-dev/coze-studio/backend/domain/workflow" - "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/validate" "github.com/coze-dev/coze-studio/backend/domain/workflow/variable" - "github.com/coze-dev/coze-studio/backend/pkg/lang/ternary" "github.com/coze-dev/coze-studio/backend/pkg/sonic" "github.com/coze-dev/coze-studio/backend/types/errno" ) @@ -201,214 +197,3 @@ func isIncremental(prev version, next version) bool { return next.Patch > prev.Patch } - -func getMaxHistoryRoundsRecursively(ctx context.Context, wfEntity *entity.Workflow, repo wf.Repository) (int64, error) { - visited := make(map[string]struct{}) - maxRounds := int64(0) - err := getMaxHistoryRoundsRecursiveHelper(ctx, wfEntity, repo, visited, &maxRounds) - return maxRounds, err -} - -func getMaxHistoryRoundsRecursiveHelper(ctx context.Context, wfEntity *entity.Workflow, repo wf.Repository, visited map[string]struct{}, maxRounds *int64) error { - visitedKey := fmt.Sprintf("%d:%s", wfEntity.ID, wfEntity.GetVersion()) - if _, ok := visited[visitedKey]; ok { - return nil - } - visited[visitedKey] = struct{}{} - - var canvas vo.Canvas - if err := sonic.UnmarshalString(wfEntity.Canvas, &canvas); err != nil { - return fmt.Errorf("failed to unmarshal canvas for workflow %d: %w", wfEntity.ID, err) - } - - return collectMaxHistoryRounds(ctx, canvas.Nodes, repo, visited, maxRounds) -} - -func collectMaxHistoryRounds(ctx context.Context, nodes []*vo.Node, repo wf.Repository, visited map[string]struct{}, maxRounds *int64) error { - for _, node := range nodes { - if node == nil { - continue - } - - if node.Data != nil && node.Data.Inputs != nil && node.Data.Inputs.ChatHistorySetting != nil && node.Data.Inputs.ChatHistorySetting.EnableChatHistory { - if node.Data.Inputs.ChatHistorySetting.ChatHistoryRound > *maxRounds { - *maxRounds = node.Data.Inputs.ChatHistorySetting.ChatHistoryRound - } - } else if node.Type == entity.NodeTypeLLM.IDStr() && node.Data != nil && node.Data.Inputs != nil && node.Data.Inputs.LLMParam != nil { - param := node.Data.Inputs.LLMParam - bs, _ := sonic.Marshal(param) - llmParam := make(vo.LLMParam, 0) - if err := sonic.Unmarshal(bs, &llmParam); err != nil { - return err - } - var chatHistoryEnabled bool - var chatHistoryRound int64 - for _, param := range llmParam { - switch param.Name { - case "enableChatHistory": - if val, ok := param.Input.Value.Content.(bool); ok { - b := val - chatHistoryEnabled = b - } - case "chatHistoryRound": - if strVal, ok := param.Input.Value.Content.(string); ok { - int64Val, err := strconv.ParseInt(strVal, 10, 64) - if err != nil { - return err - } - chatHistoryRound = int64Val - } - } - } - - if chatHistoryEnabled { - if chatHistoryRound > *maxRounds { - *maxRounds = chatHistoryRound - } - } - } - - isSubWorkflow := node.Type == entity.NodeTypeSubWorkflow.IDStr() && node.Data != nil && node.Data.Inputs != nil - if isSubWorkflow { - workflowIDStr := node.Data.Inputs.WorkflowID - if workflowIDStr == "" { - continue - } - - workflowID, err := strconv.ParseInt(workflowIDStr, 10, 64) - if err != nil { - return fmt.Errorf("invalid workflow ID in sub-workflow node %s: %w", node.ID, err) - } - - subWfEntity, err := repo.GetEntity(ctx, &vo.GetPolicy{ - ID: workflowID, - QType: ternary.IFElse(len(node.Data.Inputs.WorkflowVersion) == 0, workflowModel.FromDraft, workflowModel.FromSpecificVersion), - Version: node.Data.Inputs.WorkflowVersion, - }) - if err != nil { - return fmt.Errorf("failed to get sub-workflow entity %d: %w", workflowID, err) - } - - if err := getMaxHistoryRoundsRecursiveHelper(ctx, subWfEntity, repo, visited, maxRounds); err != nil { - return err - } - } - - if len(node.Blocks) > 0 { - if err := collectMaxHistoryRounds(ctx, node.Blocks, repo, visited, maxRounds); err != nil { - return err - } - } - } - - return nil -} - -func getHistoryRoundsFromNode(ctx context.Context, wfEntity *entity.Workflow, nodeID string, repo wf.Repository) (int64, error) { - if wfEntity == nil { - return 0, nil - } - visited := make(map[string]struct{}) - visitedKey := fmt.Sprintf("%d:%s", wfEntity.ID, wfEntity.GetVersion()) - if _, ok := visited[visitedKey]; ok { - return 0, nil - } - visited[visitedKey] = struct{}{} - maxRounds := int64(0) - c := &vo.Canvas{} - if err := sonic.UnmarshalString(wfEntity.Canvas, c); err != nil { - return 0, fmt.Errorf("failed to unmarshal canvas: %w", err) - } - var ( - n *vo.Node - nodeFinder func(nodes []*vo.Node) *vo.Node - ) - nodeFinder = func(nodes []*vo.Node) *vo.Node { - for i := range nodes { - if nodes[i].ID == nodeID { - return nodes[i] - } - if len(nodes[i].Blocks) > 0 { - if n := nodeFinder(nodes[i].Blocks); n != nil { - return n - } - } - } - return nil - } - - n = nodeFinder(c.Nodes) - if n.Type == entity.NodeTypeLLM.IDStr() { - if n.Data == nil || n.Data.Inputs == nil { - return 0, nil - } - param := n.Data.Inputs.LLMParam - bs, _ := sonic.Marshal(param) - llmParam := make(vo.LLMParam, 0) - if err := sonic.Unmarshal(bs, &llmParam); err != nil { - return 0, err - } - var chatHistoryEnabled bool - var chatHistoryRound int64 - for _, param := range llmParam { - switch param.Name { - case "enableChatHistory": - if val, ok := param.Input.Value.Content.(bool); ok { - b := val - chatHistoryEnabled = b - } - case "chatHistoryRound": - if strVal, ok := param.Input.Value.Content.(string); ok { - int64Val, err := strconv.ParseInt(strVal, 10, 64) - if err != nil { - return 0, err - } - chatHistoryRound = int64Val - } - } - } - if chatHistoryEnabled { - return chatHistoryRound, nil - } - return 0, nil - } - - if n.Type == entity.NodeTypeIntentDetector.IDStr() || n.Type == entity.NodeTypeKnowledgeRetriever.IDStr() { - if n.Data != nil && n.Data.Inputs != nil && n.Data.Inputs.ChatHistorySetting != nil && n.Data.Inputs.ChatHistorySetting.EnableChatHistory { - return n.Data.Inputs.ChatHistorySetting.ChatHistoryRound, nil - } - return 0, nil - } - - if n.Type == entity.NodeTypeSubWorkflow.IDStr() { - if n.Data != nil && n.Data.Inputs != nil { - workflowIDStr := n.Data.Inputs.WorkflowID - if workflowIDStr == "" { - return 0, nil - } - workflowID, err := strconv.ParseInt(workflowIDStr, 10, 64) - if err != nil { - return 0, fmt.Errorf("invalid workflow ID in sub-workflow node %s: %w", n.ID, err) - } - subWfEntity, err := repo.GetEntity(ctx, &vo.GetPolicy{ - ID: workflowID, - QType: ternary.IFElse(len(n.Data.Inputs.WorkflowVersion) == 0, workflowModel.FromDraft, workflowModel.FromSpecificVersion), - Version: n.Data.Inputs.WorkflowVersion, - }) - if err != nil { - return 0, fmt.Errorf("failed to get sub-workflow entity %d: %w", workflowID, err) - } - if err := getMaxHistoryRoundsRecursiveHelper(ctx, subWfEntity, repo, visited, &maxRounds); err != nil { - return 0, err - } - return maxRounds, nil - } - } - - if len(n.Blocks) > 0 { - if err := collectMaxHistoryRounds(ctx, n.Blocks, repo, visited, &maxRounds); err != nil { - return 0, err - } - } - return maxRounds, nil -}