refactor(workflow): Calculate chat history rounds during schema convertion (#1990)

Co-authored-by: zhuangjie.1125 <zhuangjie.1125@bytedance.com>
main
lvxinyu-1117 1 month ago committed by GitHub
parent 4416127d47
commit 4bfce5a8cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 267
      backend/api/handler/coze/workflow_service_test.go
  2. 68
      backend/application/workflow/chatflow.go
  3. 604
      backend/application/workflow/chatflow_test.go
  4. 4
      backend/crossdomain/contract/message/message.go
  5. 1
      backend/crossdomain/contract/upload/upload.go
  6. 73
      backend/crossdomain/contract/upload/uploadmock/upload_mock.go
  7. 4
      backend/crossdomain/impl/message/message.go
  8. 8
      backend/domain/workflow/component_interface.go
  9. 5
      backend/domain/workflow/entity/vo/chatflow.go
  10. 4
      backend/domain/workflow/entity/vo/conversation.go
  11. 1
      backend/domain/workflow/internal/canvas/adaptor/from_node.go
  12. 2
      backend/domain/workflow/internal/canvas/adaptor/to_schema.go
  13. 275
      backend/domain/workflow/internal/canvas/examples/chatflow/chat_run_with_interrupt.json
  14. 397
      backend/domain/workflow/internal/canvas/examples/chatflow/llm_chat.json
  15. 397
      backend/domain/workflow/internal/canvas/examples/chatflow/llm_chat_with_history.json
  16. 2
      backend/domain/workflow/internal/nodes/conversation/conversationhistory.go
  17. 4
      backend/domain/workflow/internal/nodes/conversation/createconversation.go
  18. 16
      backend/domain/workflow/internal/nodes/conversation/createmessage.go
  19. 8
      backend/domain/workflow/internal/nodes/conversation/messagelist.go
  20. 8
      backend/domain/workflow/internal/nodes/llm/llm.go
  21. 5
      backend/domain/workflow/internal/nodes/node.go
  22. 28
      backend/domain/workflow/internal/repo/conversation_repository.go
  23. 5
      backend/domain/workflow/internal/schema/node_schema.go
  24. 15
      backend/domain/workflow/internal/schema/workflow_schema.go
  25. 22
      backend/domain/workflow/service/conversation_impl.go
  26. 141
      backend/domain/workflow/service/executable_impl.go
  27. 286
      backend/domain/workflow/service/executable_impl_test.go
  28. 11
      backend/domain/workflow/service/service_impl.go
  29. 215
      backend/domain/workflow/service/utils.go

@ -228,6 +228,7 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
h.POST("/api/workflow_api/chat_flow_role/delete", DeleteChatFlowRole)
h.POST("/api/workflow_api/chat_flow_role/create", CreateChatFlowRole)
h.GET("/api/workflow_api/chat_flow_role/get", GetChatFlowRole)
h.POST("/v1/workflows/chat", OpenAPIChatFlowRun)
ctrl := gomock.NewController(t, gomock.WithOverridableExpectations())
mockIDGen := mock.NewMockIDGenerator(ctrl)
@ -1082,6 +1083,46 @@ func (r *wfTestRunner) openapiResume(id string, eventID string, resumeData strin
return re
}
func (r *wfTestRunner) openapiChatFlowRun(wfID string, cID, appID, botID *string, input any, additionalMessage []*workflow.EnterMessage) *sse.Reader {
inputStr, _ := sonic.MarshalString(input)
req := &workflow.ChatFlowRunRequest{
WorkflowID: wfID,
Parameters: ptr.Of(inputStr),
AdditionalMessages: additionalMessage,
}
if cID != nil {
req.ConversationID = cID
}
if appID != nil {
req.AppID = appID
}
if botID != nil {
req.BotID = botID
}
m, err := sonic.Marshal(req)
assert.NoError(r.t, err)
c, _ := client.NewClient()
hReq, hResp := protocol.AcquireRequest(), protocol.AcquireResponse()
hReq.SetRequestURI("http://localhost:8888" + "/v1/workflows/chat")
hReq.SetMethod("POST")
hReq.SetBody(m)
hReq.SetHeader("Content-Type", "application/json")
err = c.Do(context.Background(), hReq, hResp)
assert.NoError(r.t, err)
if hResp.StatusCode() != http.StatusOK {
r.t.Errorf("unexpected status code: %d, body: %s", hResp.StatusCode(), string(hResp.Body()))
}
re, err := sse.NewReader(hResp)
assert.NoError(r.t, err)
return re
}
func (r *wfTestRunner) runServer() func() {
go func() {
_ = r.h.Run()
@ -5491,7 +5532,7 @@ func TestConversationOfChatFlow(t *testing.T) {
if err != nil {
return err
}
if v.Name == "CONVERSATION_NAME" {
if v.Name == vo.ConversationNameKey {
v.DefaultValue = cName
}
startNode.Data.Outputs[idx] = v
@ -5522,7 +5563,7 @@ func TestConversationOfChatFlow(t *testing.T) {
for _, vAny := range node.Data.Outputs {
v, err := vo.ParseVariable(vAny)
assert.NoError(t, err)
if v.Name == "CONVERSATION_NAME" {
if v.Name == vo.ConversationNameKey {
assert.Equal(t, v.DefaultValue, updateName)
}
}
@ -5569,7 +5610,7 @@ func TestConversationOfChatFlow(t *testing.T) {
for _, vAny := range node.Data.Outputs {
v, err := vo.ParseVariable(vAny)
assert.NoError(t, err)
if v.Name == "CONVERSATION_NAME" {
if v.Name == vo.ConversationNameKey {
assert.Equal(t, v.DefaultValue, cName+"copy")
}
}
@ -5988,3 +6029,223 @@ func TestConversationHistoryNodes(t *testing.T) {
assert.Equal(t, []any{}, outputMap["history_list"])
})
}
func TestChatFlowRun(t *testing.T) {
mockey.PatchConvey("chat flow run", t, func() {
r := newWfTestRunner(t)
appworkflow.SVC.IDGenerator = r.idGen
defer r.closeFn()
defer r.runServer()()
chatModel1 := &testutil.UTChatModel{
StreamResultProvider: func(_ int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) {
sr := schema.StreamReaderFromArray([]*schema.Message{
{
Role: schema.Assistant,
Content: "I ",
},
{
Role: schema.Assistant,
Content: "don't know.",
},
})
return sr, nil
},
}
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel1, nil, nil).AnyTimes()
id := r.load("chatflow/llm_chat.json", withMode(workflow.WorkflowMode_ChatFlow))
r.publish(id, "v0.0.1", true)
cID := time.Now().UnixNano()
cIDStr := strconv.FormatInt(cID, 10)
appID := time.Now().UnixNano()
appIDStr := strconv.FormatInt(appID, 10)
// Create conversation first
r.conversation.EXPECT().CreateConversation(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{
ID: cID,
}, nil).AnyTimes()
idStr := r.load("conversation_manager/update_dynamic_conversation.json")
r.publish(idStr, "v0.0.1", true)
ret, _ := r.openapiSyncRun(idStr, map[string]string{
"input": "v1",
"new_name": "v2",
}, withRunProjectID(appID))
assert.Equal(t, map[string]any{"conversationId": strconv.FormatInt(cID, 10), "isExisted": false, "isSuccess": true}, ret["obj"])
msg := []*workflow.EnterMessage{
{
Role: "user",
ContentType: "text",
Content: "你好",
},
}
sID := time.Now().UnixNano()
r.conversation.EXPECT().GetByID(gomock.Any(), gomock.Any()).Return(&conventity.Conversation{
ID: cID,
SectionID: sID,
}, nil).AnyTimes()
rID := time.Now().UnixNano()
r.agentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{
ID: rID,
}, nil).AnyTimes()
mID := time.Now().Unix()
r.message.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&message.Message{
ID: mID,
}, nil).AnyTimes()
t.Run("chat flow run in app", func(t *testing.T) {
sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
vo.ConversationNameKey: "Default",
}, msg)
err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
return nil
})
assert.NoError(t, err)
})
t.Run("chat flow run in bot", func(t *testing.T) {
botID := time.Now().UnixNano()
botIDStr := strconv.FormatInt(botID, 10)
sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), nil, ptr.Of(botIDStr), map[string]any{
vo.ConversationNameKey: "Default",
}, msg)
err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
return nil
})
assert.NoError(t, err)
})
t.Run("chat flow run without cID", func(t *testing.T) {
sseReader := r.openapiChatFlowRun(id, nil, ptr.Of(appIDStr), nil, map[string]any{
vo.ConversationNameKey: "Default",
}, msg)
err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
return nil
})
assert.NoError(t, err)
})
t.Run("chat flow run with additional messages", func(t *testing.T) {
additionalMsg := []*workflow.EnterMessage{
{
Role: "user",
ContentType: "text",
Content: "你好, 我叫小明",
},
{
Role: "assistant",
ContentType: "text",
Content: "你好小明, 很高兴认识你",
},
{
Role: "user",
ContentType: "text",
Content: "你好",
},
}
sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
vo.ConversationNameKey: "Default",
}, additionalMsg)
err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
return nil
})
assert.NoError(t, err)
})
t.Run("chat flow run with history messages", func(t *testing.T) {
id := r.load("chatflow/llm_chat_with_history.json", withMode(workflow.WorkflowMode_ChatFlow))
r.publish(id, "v0.0.1", true)
r.message.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{rID}, nil).AnyTimes()
r.message.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(&message0.GetMessagesByRunIDsResponse{
Messages: []*message0.WfMessage{
{
ID: mID,
Role: schema.User,
Text: ptr.Of("你好"),
},
},
}, nil).AnyTimes()
sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
vo.ConversationNameKey: "Default",
}, msg)
err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
return nil
})
assert.NoError(t, err)
})
t.Run("chat flow run with interrupt nodes ", func(t *testing.T) {
// 生成一个携带 input, 问答文本 问答选项的三个中断节点 做测试
id := r.load("chatflow/chat_run_with_interrupt.json", withMode(workflow.WorkflowMode_ChatFlow))
r.publish(id, "v0.0.1", true)
sseReader := r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
vo.ConversationNameKey: "Default",
}, msg)
err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
if e.ID == "3" {
assert.Equal(t, e.Type, "conversation.message.completed")
assert.Contains(t, string(e.Data), "7383997384420262000")
}
return nil
})
assert.NoError(t, err)
sseReader = r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
vo.ConversationNameKey: "Default",
}, []*workflow.EnterMessage{
{Role: string(schema.User), Content: "input:1", ContentType: "text"},
})
err = sseReader.ForEach(t.Context(), func(e *sse.Event) error {
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
if e.ID == "4" {
assert.Equal(t, e.Type, "conversation.message.completed")
assert.Contains(t, string(e.Data), "你好")
}
return nil
})
assert.NoError(t, err)
sseReader = r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
vo.ConversationNameKey: "Default",
}, []*workflow.EnterMessage{
{Role: string(schema.User), Content: "hello", ContentType: "text"},
})
err = sseReader.ForEach(t.Context(), func(e *sse.Event) error {
if e.ID == "3" {
assert.Equal(t, e.Type, "conversation.message.completed")
assert.Contains(t, string(e.Data), "question_card_data", "请选择")
}
return nil
})
assert.NoError(t, err)
sseReader = r.openapiChatFlowRun(id, ptr.Of(cIDStr), ptr.Of(appIDStr), nil, map[string]any{
vo.ConversationNameKey: "Default",
}, []*workflow.EnterMessage{
{Role: string(schema.User), Content: "A", ContentType: "text"},
})
err = sseReader.ForEach(t.Context(), func(e *sse.Event) error {
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
if e.ID == "4" {
assert.Equal(t, e.Type, "conversation.message.completed")
assert.Contains(t, string(e.Data), "answer", "A")
}
return nil
})
assert.NoError(t, err)
})
})
}

@ -221,7 +221,7 @@ const (
"id": "5fJt3qKpSz",
"name": "list",
"defaultValue": [
]
}
},
@ -504,7 +504,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
workflowID = mustParseInt64(req.GetWorkflowID())
isDebug = req.GetExecuteMode() == "DEBUG"
appID, agentID *int64
resolveAppID int64
bizID int64
conversationID int64
sectionID int64
version string
@ -521,11 +521,11 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
if req.IsSetAppID() {
appID = ptr.Of(mustParseInt64(req.GetAppID()))
resolveAppID = mustParseInt64(req.GetAppID())
bizID = mustParseInt64(req.GetAppID())
}
if req.IsSetBotID() {
agentID = ptr.Of(mustParseInt64(req.GetBotID()))
resolveAppID = mustParseInt64(req.GetBotID())
bizID = mustParseInt64(req.GetBotID())
}
if appID != nil && agentID != nil {
@ -564,16 +564,16 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
sectionID = cInfo.SectionID
// only trust the conversation name under the app
conversationName, existed, err := GetWorkflowDomainSVC().GetConversationNameByID(ctx, ternary.IFElse(isDebug, vo.Draft, vo.Online), resolveAppID, connectorID, conversationID)
conversationName, existed, err := GetWorkflowDomainSVC().GetConversationNameByID(ctx, ternary.IFElse(isDebug, vo.Draft, vo.Online), bizID, connectorID, conversationID)
if err != nil {
return nil, err
}
if !existed {
return nil, fmt.Errorf("conversation not found")
}
parameters["CONVERSATION_NAME"] = conversationName
parameters[vo.ConversationNameKey] = conversationName
} else if req.IsSetConversationID() && req.IsSetBotID() {
parameters["CONVERSATION_NAME"] = "Default"
parameters[vo.ConversationNameKey] = "Default"
conversationID = mustParseInt64(req.GetConversationID())
cInfo, err := crossconversation.DefaultSVC().GetByID(ctx, conversationID)
if err != nil {
@ -581,11 +581,11 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
}
sectionID = cInfo.SectionID
} else {
conversationName, ok := parameters["CONVERSATION_NAME"].(string)
conversationName, ok := parameters[vo.ConversationNameKey].(string)
if !ok {
return nil, fmt.Errorf("conversation name is requried")
}
cID, sID, err := GetWorkflowDomainSVC().GetOrCreateConversation(ctx, ternary.IFElse(isDebug, vo.Draft, vo.Online), resolveAppID, connectorID, userID, conversationName)
cID, sID, err := GetWorkflowDomainSVC().GetOrCreateConversation(ctx, ternary.IFElse(isDebug, vo.Draft, vo.Online), bizID, connectorID, userID, conversationName)
if err != nil {
return nil, err
}
@ -594,7 +594,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
}
runRecord, err := crossagentrun.DefaultSVC().Create(ctx, &agententity.AgentRunMeta{
AgentID: resolveAppID,
AgentID: bizID,
ConversationID: conversationID,
UserID: strconv.FormatInt(userID, 10),
ConnectorID: connectorID,
@ -606,7 +606,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
roundID := runRecord.ID
userMessage, err := toConversationMessage(ctx, resolveAppID, conversationID, userID, roundID, sectionID, message.MessageTypeQuestion, lastUserMessage)
userMessage, err := toConversationMessage(ctx, bizID, conversationID, userID, roundID, sectionID, message.MessageTypeQuestion, lastUserMessage)
if err != nil {
return nil, err
}
@ -648,7 +648,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
return nil, err
}
return schema.StreamReaderWithConvert(sr, w.convertToChatFlowRunResponseList(ctx, convertToChatFlowInfo{
appID: resolveAppID,
bizID: bizID,
conversationID: conversationID,
roundID: roundID,
workflowID: workflowID,
@ -684,7 +684,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
Cancellable: isDebug,
}
historyMessages, err := makeChatFlowHistoryMessages(ctx, resolveAppID, conversationID, userID, sectionID, connectorID, messages[:len(req.GetAdditionalMessages())-1])
historyMessages, err := makeChatFlowHistoryMessages(ctx, bizID, conversationID, userID, sectionID, connectorID, messages[:len(req.GetAdditionalMessages())-1])
if err != nil {
return nil, err
}
@ -706,7 +706,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
logs.CtxWarnf(ctx, "create history message failed, err=%v", err)
}
}
parameters["USER_INPUT"], err = w.makeChatFlowUserInput(ctx, lastUserMessage)
parameters[vo.UserInputKey], err = w.makeChatFlowUserInput(ctx, lastUserMessage)
if err != nil {
return nil, err
}
@ -717,7 +717,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
}
return schema.StreamReaderWithConvert(sr, w.convertToChatFlowRunResponseList(ctx, convertToChatFlowInfo{
appID: resolveAppID,
bizID: bizID,
conversationID: conversationID,
roundID: roundID,
workflowID: workflowID,
@ -731,7 +731,7 @@ func (w *ApplicationService) OpenAPIChatFlowRun(ctx context.Context, req *workfl
func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Context, info convertToChatFlowInfo) func(msg *entity.Message) (responses []*workflow.ChatFlowRunResponse, err error) {
var (
appID = info.appID
bizID = info.bizID
conversationID = info.conversationID
roundID = info.roundID
workflowID = info.workflowID
@ -798,7 +798,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
ChatID: strconv.FormatInt(roundID, 10),
ConversationID: strconv.FormatInt(conversationID, 10),
SectionID: strconv.FormatInt(sectionID, 10),
BotID: strconv.FormatInt(appID, 10),
BotID: strconv.FormatInt(bizID, 10),
Role: string(schema.Assistant),
Type: "follow_up",
ContentType: "text",
@ -815,7 +815,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
ID: strconv.FormatInt(roundID, 10),
ConversationID: strconv.FormatInt(conversationID, 10),
SectionID: strconv.FormatInt(sectionID, 10),
BotID: strconv.FormatInt(appID, 10),
BotID: strconv.FormatInt(bizID, 10),
Status: vo.Completed,
ExecuteID: strconv.FormatInt(executeID, 10),
Usage: &vo.Usage{
@ -929,7 +929,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
}
_, err = crossmessage.DefaultSVC().Create(ctx, &message.Message{
AgentID: appID,
AgentID: bizID,
RunID: roundID,
SectionID: sectionID,
Content: msgContent,
@ -947,7 +947,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
ChatID: strconv.FormatInt(roundID, 10),
ConversationID: strconv.FormatInt(conversationID, 10),
SectionID: strconv.FormatInt(sectionID, 10),
BotID: strconv.FormatInt(appID, 10),
BotID: strconv.FormatInt(bizID, 10),
Role: string(schema.Assistant),
Type: string(entity.Answer),
ContentType: string(contentType),
@ -1046,7 +1046,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
}
intermediateMessage = &message.Message{
ID: id,
AgentID: appID,
AgentID: bizID,
RunID: roundID,
SectionID: sectionID,
ConversationID: conversationID,
@ -1066,7 +1066,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
ChatID: strconv.FormatInt(roundID, 10),
ConversationID: strconv.FormatInt(conversationID, 10),
SectionID: strconv.FormatInt(sectionID, 10),
BotID: strconv.FormatInt(appID, 10),
BotID: strconv.FormatInt(bizID, 10),
Role: string(dataMessage.Role),
Type: string(dataMessage.Type),
ContentType: string(message.ContentTypeText),
@ -1092,7 +1092,7 @@ func (w *ApplicationService) convertToChatFlowRunResponseList(ctx context.Contex
ChatID: strconv.FormatInt(roundID, 10),
ConversationID: strconv.FormatInt(conversationID, 10),
SectionID: strconv.FormatInt(sectionID, 10),
BotID: strconv.FormatInt(appID, 10),
BotID: strconv.FormatInt(bizID, 10),
Role: string(dataMessage.Role),
Type: string(dataMessage.Type),
ContentType: string(message.ContentTypeText),
@ -1155,9 +1155,9 @@ func (w *ApplicationService) makeChatFlowUserInput(ctx context.Context, message
} else {
return "", fmt.Errorf("invalid message ccontent type %v", message.ContentType)
}
}
func makeChatFlowHistoryMessages(ctx context.Context, appID, conversationID, userID, sectionID, connectorID int64, messages []*workflow.EnterMessage) ([]*message.Message, error) {
func makeChatFlowHistoryMessages(ctx context.Context, bizID, conversationID, userID, sectionID, connectorID int64, messages []*workflow.EnterMessage) ([]*message.Message, error) {
var (
rID int64
@ -1170,7 +1170,7 @@ func makeChatFlowHistoryMessages(ctx context.Context, appID, conversationID, use
for _, msg := range messages {
if msg.Role == userRole {
runRecord, err = crossagentrun.DefaultSVC().Create(ctx, &agententity.AgentRunMeta{
AgentID: appID,
AgentID: bizID,
ConversationID: conversationID,
UserID: strconv.FormatInt(userID, 10),
ConnectorID: connectorID,
@ -1180,13 +1180,15 @@ func makeChatFlowHistoryMessages(ctx context.Context, appID, conversationID, use
return nil, err
}
rID = runRecord.ID
} else if msg.Role == assistantRole && rID == 0 {
continue
} else if msg.Role == assistantRole {
if rID == 0 {
continue
}
} else {
return nil, fmt.Errorf("invalid role type %v", msg.Role)
}
m, err := toConversationMessage(ctx, appID, conversationID, userID, rID, sectionID, ternary.IFElse(msg.Role == userRole, message.MessageTypeQuestion, message.MessageTypeAnswer), msg)
m, err := toConversationMessage(ctx, bizID, conversationID, userID, rID, sectionID, ternary.IFElse(msg.Role == userRole, message.MessageTypeQuestion, message.MessageTypeAnswer), msg)
if err != nil {
return nil, err
}
@ -1274,7 +1276,7 @@ func (w *ApplicationService) OpenAPICreateConversation(ctx context.Context, req
}, nil
}
func toConversationMessage(ctx context.Context, appID, cid, userID, roundID, sectionID int64, messageType message.MessageType, msg *workflow.EnterMessage) (*message.Message, error) {
func toConversationMessage(ctx context.Context, bizID, cid, userID, roundID, sectionID int64, messageType message.MessageType, msg *workflow.EnterMessage) (*message.Message, error) {
type content struct {
Type string `json:"type"`
FileID *string `json:"file_id"`
@ -1284,7 +1286,7 @@ func toConversationMessage(ctx context.Context, appID, cid, userID, roundID, sec
return &message.Message{
Role: schema.User,
ConversationID: cid,
AgentID: appID,
AgentID: bizID,
RunID: roundID,
Content: msg.Content,
ContentType: message.ContentTypeText,
@ -1304,7 +1306,7 @@ func toConversationMessage(ctx context.Context, appID, cid, userID, roundID, sec
Role: schema.User,
MessageType: messageType,
ConversationID: cid,
AgentID: appID,
AgentID: bizID,
UserID: strconv.FormatInt(userID, 10),
RunID: roundID,
ContentType: message.ContentTypeMix,
@ -1432,7 +1434,7 @@ func toSchemaMessage(ctx context.Context, msg *workflow.EnterMessage) (*schema.M
type convertToChatFlowInfo struct {
userMessage *schema.Message
appID int64
bizID int64
conversationID int64
roundID int64
workflowID int64

@ -0,0 +1,604 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package workflow
import (
"context"
"errors"
"strconv"
"testing"
"github.com/cloudwego/eino/schema"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
messageentity "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/api/model/workflow"
crossagentrun "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agentrun"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/agentrun/agentrunmock"
crossupload "github.com/coze-dev/coze-studio/backend/crossdomain/contract/upload"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/upload/uploadmock"
agententity "github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
uploadentity "github.com/coze-dev/coze-studio/backend/domain/upload/entity"
"github.com/coze-dev/coze-studio/backend/domain/upload/service"
)
func TestApplicationService_makeChatFlowUserInput(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockUpload := uploadmock.NewMockUploader(ctrl)
crossupload.SetDefaultSVC(mockUpload)
tests := []struct {
name string
message *workflow.EnterMessage
setupMock func()
expected string
expectErr bool
}{
{
name: "content type text",
message: &workflow.EnterMessage{
ContentType: "text",
Content: "hello",
},
setupMock: func() {},
expected: "hello",
expectErr: false,
},
{
name: "content type object_string with text",
message: &workflow.EnterMessage{
ContentType: "object_string",
Content: `[{"type": "text", "text": "hello world"}]`,
},
setupMock: func() {},
expected: "hello world",
expectErr: false,
},
{
name: "content type object_string with file",
message: &workflow.EnterMessage{
ContentType: "object_string",
Content: `[{"type": "file", "file_id": "123"}]`,
},
setupMock: func() {
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{
File: &uploadentity.File{Url: "https://example.com/file"},
}, nil)
},
expected: "https://example.com/file",
expectErr: false,
},
{
name: "content type object_string with text and file",
message: &workflow.EnterMessage{
ContentType: "object_string",
Content: `[{"type": "text", "text": "see this file"}, {"type": "file", "file_id": "123"}]`,
},
setupMock: func() {
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{
File: &uploadentity.File{Url: "https://example.com/file"},
}, nil)
},
expected: "see this file,https://example.com/file",
expectErr: false,
},
{
name: "get file error",
message: &workflow.EnterMessage{
ContentType: "object_string",
Content: `[{"type": "file", "file_id": "123"}]`,
},
setupMock: func() {
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(nil, errors.New("get file error"))
},
expectErr: true,
},
{
name: "file not found",
message: &workflow.EnterMessage{
ContentType: "object_string",
Content: `[{"type": "file", "file_id": "123"}]`,
},
setupMock: func() {
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{
File: nil,
}, nil)
},
expectErr: true,
},
{
name: "invalid content type",
message: &workflow.EnterMessage{
ContentType: "invalid",
},
setupMock: func() {},
expectErr: true,
},
{
name: "invalid json",
message: &workflow.EnterMessage{
ContentType: "object_string",
Content: `invalid-json`,
},
setupMock: func() {},
expectErr: true,
},
}
w := &ApplicationService{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setupMock()
result, err := w.makeChatFlowUserInput(ctx, tt.message)
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
}
}
func Test_toConversationMessage(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockUpload := uploadmock.NewMockUploader(ctrl)
crossupload.SetDefaultSVC(mockUpload)
bizID, cid, userID, roundID, sectionID := int64(2), int64(1), int64(4), int64(3), int64(5)
tests := []struct {
name string
msg *workflow.EnterMessage
messageType messageentity.MessageType
setupMock func()
expected *messageentity.Message
expectErr bool
}{
{
name: "content type text",
msg: &workflow.EnterMessage{
ContentType: "text",
Content: "hello",
},
messageType: messageentity.MessageTypeQuestion,
setupMock: func() {},
expected: &messageentity.Message{
Role: schema.User,
ConversationID: cid,
AgentID: bizID,
RunID: roundID,
Content: "hello",
ContentType: messageentity.ContentTypeText,
MessageType: messageentity.MessageTypeQuestion,
UserID: strconv.FormatInt(userID, 10),
SectionID: sectionID,
},
expectErr: false,
},
{
name: "content type object_string with text",
msg: &workflow.EnterMessage{
ContentType: "object_string",
Content: `[{"type": "text", "text": "hello"}]`,
},
messageType: messageentity.MessageTypeQuestion,
setupMock: func() {},
expected: &messageentity.Message{
Role: schema.User,
MessageType: messageentity.MessageTypeQuestion,
ConversationID: cid,
AgentID: bizID,
UserID: strconv.FormatInt(userID, 10),
RunID: roundID,
ContentType: messageentity.ContentTypeMix,
MultiContent: []*messageentity.InputMetaData{
{Type: messageentity.InputTypeText, Text: "hello"},
},
SectionID: sectionID,
},
expectErr: false,
},
{
name: "content type object_string with file",
msg: &workflow.EnterMessage{
ContentType: "object_string",
Content: `[{"type": "file", "file_id": "123"}]`,
},
messageType: messageentity.MessageTypeQuestion,
setupMock: func() {
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{
File: &uploadentity.File{Url: "https://example.com/file", TosURI: "tos://uri", Name: "file.txt"},
}, nil)
},
expected: &messageentity.Message{
Role: schema.User,
MessageType: messageentity.MessageTypeQuestion,
ConversationID: cid,
AgentID: bizID,
UserID: strconv.FormatInt(userID, 10),
RunID: roundID,
ContentType: messageentity.ContentTypeMix,
MultiContent: []*messageentity.InputMetaData{
{
Type: "file",
FileData: []*messageentity.FileData{
{Url: "https://example.com/file", URI: "tos://uri", Name: "file.txt"},
},
},
},
SectionID: sectionID,
},
expectErr: false,
},
{
name: "get file error",
msg: &workflow.EnterMessage{
ContentType: "object_string",
Content: `[{"type": "file", "file_id": "123"}]`,
},
setupMock: func() {
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(nil, errors.New("get file error"))
},
expectErr: true,
},
{
name: "file not found",
msg: &workflow.EnterMessage{
ContentType: "object_string",
Content: `[{"type": "file", "file_id": "123"}]`,
},
setupMock: func() {
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{}, nil)
},
expectErr: true,
},
{
name: "invalid content type",
msg: &workflow.EnterMessage{
ContentType: "invalid",
},
setupMock: func() {},
expectErr: true,
},
{
name: "invalid json",
msg: &workflow.EnterMessage{
ContentType: "object_string",
Content: "invalid-json",
},
setupMock: func() {},
expectErr: true,
},
{
name: "invalid input type",
msg: &workflow.EnterMessage{
ContentType: "object_string",
Content: `[{"type": "invalid"}]`,
},
setupMock: func() {},
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setupMock()
result, err := toConversationMessage(ctx, bizID, cid, userID, roundID, sectionID, tt.messageType, tt.msg)
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
}
}
func Test_toSchemaMessage(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockUpload := uploadmock.NewMockUploader(ctrl)
crossupload.SetDefaultSVC(mockUpload)
tests := []struct {
name string
msg *workflow.EnterMessage
setupMock func()
expected *schema.Message
expectErr bool
}{
{
name: "content type text",
msg: &workflow.EnterMessage{
ContentType: "text",
Content: "hello",
},
setupMock: func() {},
expected: &schema.Message{
Role: schema.User,
Content: "hello",
},
expectErr: false,
},
{
name: "content type object_string with text",
msg: &workflow.EnterMessage{
ContentType: "object_string",
Content: `[{"type": "text", "text": "hello"}]`,
},
setupMock: func() {},
expected: &schema.Message{
Role: schema.User,
MultiContent: []schema.ChatMessagePart{
{Type: schema.ChatMessagePartTypeText, Text: "hello"},
},
},
expectErr: false,
},
{
name: "content type object_string with image",
msg: &workflow.EnterMessage{
ContentType: "object_string",
Content: `[{"type": "image", "file_id": "123"}]`,
},
setupMock: func() {
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{
File: &uploadentity.File{Url: "https://example.com/image.png"},
}, nil)
},
expected: &schema.Message{
Role: schema.User,
MultiContent: []schema.ChatMessagePart{
{
Type: schema.ChatMessagePartTypeImageURL,
ImageURL: &schema.ChatMessageImageURL{URL: "https://example.com/image.png"},
},
},
},
expectErr: false,
},
{
name: "content type object_string with various file types",
msg: &workflow.EnterMessage{
ContentType: "object_string",
Content: `[{"type": "file", "file_id": "1"}, {"type": "audio", "file_id": "2"}, {"type": "video", "file_id": "3"}]`,
},
setupMock: func() {
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 1}).Return(&service.GetFileResponse{File: &uploadentity.File{Url: "https://example.com/file"}}, nil)
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 2}).Return(&service.GetFileResponse{File: &uploadentity.File{Url: "https://example.com/audio"}}, nil)
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 3}).Return(&service.GetFileResponse{File: &uploadentity.File{Url: "https://example.com/video"}}, nil)
},
expected: &schema.Message{
Role: schema.User,
MultiContent: []schema.ChatMessagePart{
{Type: schema.ChatMessagePartTypeFileURL, FileURL: &schema.ChatMessageFileURL{URL: "https://example.com/file"}},
{Type: schema.ChatMessagePartTypeAudioURL, AudioURL: &schema.ChatMessageAudioURL{URL: "https://example.com/audio"}},
{Type: schema.ChatMessagePartTypeVideoURL, VideoURL: &schema.ChatMessageVideoURL{URL: "https://example.com/video"}},
},
},
expectErr: false,
},
{
name: "get file error",
msg: &workflow.EnterMessage{
ContentType: "object_string",
Content: `[{"type": "file", "file_id": "123"}]`,
},
setupMock: func() {
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(nil, errors.New("get file error"))
},
expectErr: true,
},
{
name: "file not found",
msg: &workflow.EnterMessage{
ContentType: "object_string",
Content: `[{"type": "file", "file_id": "123"}]`,
},
setupMock: func() {
mockUpload.EXPECT().GetFile(gomock.Any(), &service.GetFileRequest{ID: 123}).Return(&service.GetFileResponse{}, nil)
},
expectErr: true,
},
{
name: "invalid content type",
msg: &workflow.EnterMessage{
ContentType: "invalid",
},
setupMock: func() {},
expectErr: true,
},
{
name: "invalid json",
msg: &workflow.EnterMessage{
ContentType: "object_string",
Content: "invalid-json",
},
setupMock: func() {},
expectErr: true,
},
{
name: "invalid input type",
msg: &workflow.EnterMessage{
ContentType: "object_string",
Content: `[{"type": "invalid"}]`,
},
setupMock: func() {},
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setupMock()
result, err := toSchemaMessage(ctx, tt.msg)
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
}
}
func Test_makeChatFlowHistoryMessages(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockAgentRun := agentrunmock.NewMockAgentRun(ctrl)
crossagentrun.SetDefaultSVC(mockAgentRun)
mockUpload := uploadmock.NewMockUploader(ctrl)
crossupload.SetDefaultSVC(mockUpload)
bizID, conversationID, userID, sectionID, connectorID := int64(2), int64(1), int64(3), int64(4), int64(5)
tests := []struct {
name string
messages []*workflow.EnterMessage
setupMock func()
expected []*messageentity.Message
expectErr bool
}{
{
name: "empty messages",
messages: []*workflow.EnterMessage{},
setupMock: func() {},
expected: []*messageentity.Message{},
expectErr: false,
},
{
name: "one user message",
messages: []*workflow.EnterMessage{
{Role: "user", ContentType: "text", Content: "hello"},
},
setupMock: func() {
mockAgentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{ID: 100}, nil).Times(1)
},
expected: []*messageentity.Message{
{
Role: schema.User,
ConversationID: conversationID,
AgentID: bizID,
RunID: 100,
Content: "hello",
ContentType: messageentity.ContentTypeText,
MessageType: messageentity.MessageTypeQuestion,
UserID: strconv.FormatInt(userID, 10),
SectionID: sectionID,
},
},
expectErr: false,
},
{
name: "user and assistant message",
messages: []*workflow.EnterMessage{
{Role: "user", ContentType: "text", Content: "hello"},
{Role: "assistant", ContentType: "text", Content: "hi"},
},
setupMock: func() {
mockAgentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{ID: 100}, nil).Times(1)
},
expected: []*messageentity.Message{
{
Role: schema.User,
ConversationID: conversationID,
AgentID: bizID,
RunID: 100,
Content: "hello",
ContentType: messageentity.ContentTypeText,
MessageType: messageentity.MessageTypeQuestion,
UserID: strconv.FormatInt(userID, 10),
SectionID: sectionID,
},
{
Role: schema.User,
ConversationID: conversationID,
AgentID: bizID,
RunID: 100,
Content: "hi",
ContentType: messageentity.ContentTypeText,
MessageType: messageentity.MessageTypeAnswer,
UserID: strconv.FormatInt(userID, 10),
SectionID: sectionID,
},
},
expectErr: false,
},
{
name: "only assistant message",
messages: []*workflow.EnterMessage{
{Role: "assistant", ContentType: "text", Content: "hi"},
},
setupMock: func() {},
expected: []*messageentity.Message{},
expectErr: false,
},
{
name: "create run record error",
messages: []*workflow.EnterMessage{
{Role: "user", ContentType: "text", Content: "hello"},
},
setupMock: func() {
mockAgentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(nil, errors.New("db error"))
},
expectErr: true,
},
{
name: "invalid role",
messages: []*workflow.EnterMessage{
{Role: "system", ContentType: "text", Content: "hello"},
},
setupMock: func() {},
expectErr: true,
},
{
name: "toConversationMessage error",
messages: []*workflow.EnterMessage{
{Role: "user", ContentType: "invalid", Content: "hello"},
},
setupMock: func() {
mockAgentRun.EXPECT().Create(gomock.Any(), gomock.Any()).Return(&agententity.RunRecordMeta{ID: 100}, nil)
},
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setupMock()
result, err := makeChatFlowHistoryMessages(ctx, bizID, conversationID, userID, sectionID, connectorID, tt.messages)
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
}
}

@ -58,7 +58,7 @@ type MessageListRequest struct {
BeforeID *string
AfterID *string
UserID int64
AppID int64
BizID int64
OrderBy *string
}
@ -88,7 +88,7 @@ type WfMessage struct {
type GetLatestRunIDsRequest struct {
ConversationID int64
UserID int64
AppID int64
BizID int64
Rounds int64
SectionID int64
InitRunID *int64

@ -24,6 +24,7 @@ import (
var defaultSVC Uploader
//go:generate mockgen -destination uploadmock/upload_mock.go --package uploadmock -source upload.go
type Uploader interface {
GetFile(ctx context.Context, req *service.GetFileRequest) (resp *service.GetFileResponse, err error)
}

@ -0,0 +1,73 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Code generated by MockGen. DO NOT EDIT.
// Source: upload.go
//
// Generated by this command:
//
// mockgen -destination uploadmock/upload_mock.go --package uploadmock -source upload.go
//
// Package uploadmock is a generated GoMock package.
package uploadmock
import (
context "context"
reflect "reflect"
service "github.com/coze-dev/coze-studio/backend/domain/upload/service"
gomock "go.uber.org/mock/gomock"
)
// MockUploader is a mock of Uploader interface.
type MockUploader struct {
ctrl *gomock.Controller
recorder *MockUploaderMockRecorder
isgomock struct{}
}
// MockUploaderMockRecorder is the mock recorder for MockUploader.
type MockUploaderMockRecorder struct {
mock *MockUploader
}
// NewMockUploader creates a new mock instance.
func NewMockUploader(ctrl *gomock.Controller) *MockUploader {
mock := &MockUploader{ctrl: ctrl}
mock.recorder = &MockUploaderMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockUploader) EXPECT() *MockUploaderMockRecorder {
return m.recorder
}
// GetFile mocks base method.
func (m *MockUploader) GetFile(ctx context.Context, req *service.GetFileRequest) (*service.GetFileResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetFile", ctx, req)
ret0, _ := ret[0].(*service.GetFileResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetFile indicates an expected call of GetFile.
func (mr *MockUploaderMockRecorder) GetFile(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFile", reflect.TypeOf((*MockUploader)(nil).GetFile), ctx, req)
}

@ -54,7 +54,7 @@ func (c *impl) MessageList(ctx context.Context, req *crossmessage.MessageListReq
ConversationID: req.ConversationID,
Limit: int(req.Limit), // Since the value of limit is checked inside the node, the type cast here is safe
UserID: strconv.FormatInt(req.UserID, 10),
AgentID: req.AppID,
AgentID: req.BizID,
OrderBy: req.OrderBy,
}
if req.BeforeID != nil {
@ -96,7 +96,7 @@ func (c *impl) MessageList(ctx context.Context, req *crossmessage.MessageListReq
func (c *impl) GetLatestRunIDs(ctx context.Context, req *crossmessage.GetLatestRunIDsRequest) ([]int64, error) {
listMeta := &agententity.ListRunRecordMeta{
ConversationID: req.ConversationID,
AgentID: req.AppID,
AgentID: req.BizID,
Limit: int32(req.Rounds),
SectionID: req.SectionID,
}

@ -75,11 +75,11 @@ type Conversation interface {
ListDynamicConversation(ctx context.Context, env vo.Env, policy *vo.ListConversationPolicy) ([]*entity.DynamicConversation, error)
ReleaseConversationTemplate(ctx context.Context, appID int64, version string) error
InitApplicationDefaultConversationTemplate(ctx context.Context, spaceID int64, appID int64, userID int64) error
GetOrCreateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, int64, error)
GetOrCreateConversation(ctx context.Context, env vo.Env, bizID, connectorID, userID int64, conversationName string) (int64, int64, error)
UpdateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, error)
GetTemplateByName(ctx context.Context, env vo.Env, appID int64, templateName string) (*entity.ConversationTemplate, bool, error)
GetDynamicConversationByName(ctx context.Context, env vo.Env, appID, connectorID, userID int64, name string) (*entity.DynamicConversation, bool, error)
GetConversationNameByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (string, bool, error)
GetConversationNameByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (string, bool, error)
}
type InterruptEventStore interface {
@ -143,8 +143,8 @@ type ConversationRepository interface {
UpdateStaticConversation(ctx context.Context, env vo.Env, templateID int64, connectorID int64, userID int64, newConversationID int64) error
UpdateDynamicConversation(ctx context.Context, env vo.Env, conversationID, newConversationID int64) error
CopyTemplateConversationByAppID(ctx context.Context, appID int64, toAppID int64) error
GetStaticConversationByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (string, bool, error)
GetDynamicConversationByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (*entity.DynamicConversation, bool, error)
GetStaticConversationByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (string, bool, error)
GetDynamicConversationByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (*entity.DynamicConversation, bool, error)
}
type WorkflowConfig interface {
GetNodeOfCodeConfig() *config.NodeOfCodeConfig

@ -32,6 +32,11 @@ const (
ChatFlowMessageCompleted ChatFlowEvent = "conversation.message.completed"
)
const (
ConversationNameKey = "CONVERSATION_NAME"
UserInputKey = "USER_INPUT"
)
type Usage struct {
TokenCount *int32 `form:"token_count" json:"token_count,omitempty"`
OutputTokens *int32 `form:"output_count" json:"output_count,omitempty"`

@ -59,14 +59,14 @@ type ListConversationPolicy struct {
}
type CreateStaticConversation struct {
AppID int64
BizID int64
UserID int64
ConnectorID int64
TemplateID int64
}
type CreateDynamicConversation struct {
AppID int64
BizID int64
UserID int64
ConnectorID int64

@ -235,6 +235,7 @@ func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) (
if enabled {
trimmedSC.GeneratedNodes = append(trimmedSC.GeneratedNodes, ns.Key)
}
trimmedSC.Init()
return trimmedSC, nil
}

@ -446,7 +446,7 @@ func PruneIsolatedNodes(nodes []*vo.Node, edges []*vo.Edge, parentNode *vo.Node)
func parseBatchMode(n *vo.Node) (
batchN *vo.Node, // the new batch node
enabled bool, // whether the node has enabled batch mode
enabled bool, // whether the node has enabled batch mode
err error) {
if n.Data == nil || n.Data.Inputs == nil {
return nil, false, nil

@ -0,0 +1,275 @@
{
"nodes": [
{
"id": "100001",
"type": "1",
"meta": {
"position": {
"x": 180,
"y": 79.2
}
},
"data": {
"outputs": [
{
"type": "string",
"name": "USER_INPUT",
"required": false
},
{
"type": "string",
"name": "CONVERSATION_NAME",
"required": false,
"description": "本次请求绑定的会话,会自动写入消息、会从该会话读对话历史。",
"defaultValue": "dhl"
}
],
"nodeMeta": {
"title": "开始",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg",
"description": "工作流的起始节点,用于设定启动工作流需要的信息",
"subTitle": ""
},
"trigger_parameters": []
}
},
{
"id": "900001",
"type": "2",
"meta": {
"position": {
"x": 2020,
"y": 66.2
}
},
"data": {
"nodeMeta": {
"title": "结束",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg",
"description": "工作流的最终节点,用于返回工作流运行后的结果信息",
"subTitle": ""
},
"inputs": {
"terminatePlan": "useAnswerContent",
"streamingOutput": true,
"inputParameters": [
{
"name": "output",
"input": {
"type": "string",
"value": {
"type": "ref",
"content": {
"source": "block-output",
"blockID": "142077",
"name": "optionContent"
},
"rawMeta": {
"type": 1
}
}
}
}
],
"content": {
"type": "string",
"value": {
"type": "literal",
"content": "{{output}}"
}
}
}
}
},
{
"id": "190196",
"type": "30",
"meta": {
"position": {
"x": 640,
"y": 78.5
}
},
"data": {
"outputs": [
{
"type": "string",
"name": "input",
"required": true
}
],
"nodeMeta": {
"title": "输入",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Input-v2.jpg",
"description": "支持中间过程的信息输入",
"mainColor": "#5C62FF",
"subTitle": "输入"
},
"inputs": {
"outputSchema": "[{\"type\":\"string\",\"name\":\"input\",\"required\":true}]"
}
}
},
{
"id": "133775",
"type": "18",
"meta": {
"position": {
"x": 1100,
"y": 39
}
},
"data": {
"inputs": {
"llmParam": {
"modelType": 1001,
"modelName": "Doubao-Seed-1.6",
"generationDiversity": "balance",
"temperature": 0.8,
"maxTokens": 4096,
"topP": 0.7,
"responseFormat": 2,
"systemPrompt": ""
},
"inputParameters": [],
"extra_output": false,
"answer_type": "text",
"option_type": "static",
"dynamic_option": {
"type": "string",
"value": {
"type": "ref",
"content": {
"source": "block-output",
"blockID": "",
"name": ""
}
}
},
"question": "你好",
"options": [
{
"name": ""
},
{
"name": ""
}
],
"limit": 3
},
"nodeMeta": {
"title": "问答",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Direct-Question-v2.jpg",
"description": "支持中间向用户提问问题,支持预置选项提问和开放式问题提问两种方式",
"mainColor": "#3071F2",
"subTitle": "问答"
},
"outputs": [
{
"type": "string",
"name": "USER_RESPONSE",
"required": true,
"description": "用户本轮对话输入内容"
}
]
}
},
{
"id": "142077",
"type": "18",
"meta": {
"position": {
"x": 1560,
"y": 0
}
},
"data": {
"inputs": {
"llmParam": {
"modelType": 1001,
"modelName": "Doubao-Seed-1.6",
"generationDiversity": "balance",
"temperature": 0.8,
"maxTokens": 4096,
"topP": 0.7,
"responseFormat": 2,
"systemPrompt": ""
},
"inputParameters": [],
"extra_output": false,
"answer_type": "option",
"option_type": "static",
"dynamic_option": {
"type": "string",
"value": {
"type": "ref",
"content": {
"source": "block-output",
"blockID": "",
"name": ""
}
}
},
"question": "请选择",
"options": [
{
"name": "A"
},
{
"name": "B"
}
],
"limit": 3
},
"nodeMeta": {
"title": "问答_1",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Direct-Question-v2.jpg",
"description": "支持中间向用户提问问题,支持预置选项提问和开放式问题提问两种方式",
"mainColor": "#3071F2",
"subTitle": "问答"
},
"outputs": [
{
"type": "string",
"name": "optionId",
"required": false
},
{
"type": "string",
"name": "optionContent",
"required": false
}
]
}
}
],
"edges": [
{
"sourceNodeID": "100001",
"targetNodeID": "190196"
},
{
"sourceNodeID": "142077",
"targetNodeID": "900001",
"sourcePortID": "branch_0"
},
{
"sourceNodeID": "142077",
"targetNodeID": "900001",
"sourcePortID": "branch_1"
},
{
"sourceNodeID": "142077",
"targetNodeID": "900001",
"sourcePortID": "default"
},
{
"sourceNodeID": "190196",
"targetNodeID": "133775"
},
{
"sourceNodeID": "133775",
"targetNodeID": "142077"
}
]
}

@ -0,0 +1,397 @@
{
"nodes": [
{
"blocks": [],
"data": {
"nodeMeta": {
"description": "工作流的起始节点,用于设定启动工作流需要的信息",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg",
"subTitle": "",
"title": "开始"
},
"outputs": [
{
"name": "USER_INPUT",
"required": false,
"type": "string"
},
{
"defaultValue": "Default",
"description": "本次请求绑定的会话,会自动写入消息、会从该会话读对话历史。",
"name": "CONVERSATION_NAME",
"required": false,
"type": "string"
}
],
"trigger_parameters": []
},
"edges": null,
"id": "100001",
"meta": {
"position": {
"x": 0,
"y": 0
}
},
"type": "1"
},
{
"blocks": [],
"data": {
"inputs": {
"content": {
"type": "string",
"value": {
"content": "{{output}}",
"type": "literal"
}
},
"inputParameters": [
{
"input": {
"type": "string",
"value": {
"content": {
"blockID": "123887",
"name": "output",
"source": "block-output"
},
"rawMeta": {
"type": 1
},
"type": "ref"
}
},
"name": "output"
}
],
"streamingOutput": true,
"terminatePlan": "useAnswerContent"
},
"nodeMeta": {
"description": "工作流的最终节点,用于返回工作流运行后的结果信息",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg",
"subTitle": "",
"title": "结束"
}
},
"edges": null,
"id": "900001",
"meta": {
"position": {
"x": 926,
"y": -13
}
},
"type": "2"
},
{
"blocks": [],
"data": {
"inputs": {
"fcParamVar": {
"knowledgeFCParam": {}
},
"inputParameters": [
{
"input": {
"type": "string",
"value": {
"content": {
"blockID": "100001",
"name": "USER_INPUT",
"source": "block-output"
},
"rawMeta": {
"type": 1
},
"type": "ref"
}
},
"name": "input"
}
],
"llmParam": [
{
"input": {
"type": "integer",
"value": {
"content": "1737521813",
"rawMeta": {
"type": 2
},
"type": "literal"
}
},
"name": "modelType"
},
{
"input": {
"type": "string",
"value": {
"content": "豆包·1.5·Pro·32k",
"rawMeta": {
"type": 1
},
"type": "literal"
}
},
"name": "modleName"
},
{
"input": {
"type": "string",
"value": {
"content": "balance",
"rawMeta": {
"type": 1
},
"type": "literal"
}
},
"name": "generationDiversity"
},
{
"input": {
"type": "float",
"value": {
"content": "0.8",
"rawMeta": {
"type": 4
},
"type": "literal"
}
},
"name": "temperature"
},
{
"input": {
"type": "integer",
"value": {
"content": "4096",
"rawMeta": {
"type": 2
},
"type": "literal"
}
},
"name": "maxTokens"
},
{
"input": {
"type": "boolean",
"value": {
"content": false,
"rawMeta": {
"type": 3
},
"type": "literal"
}
},
"name": "spCurrentTime"
},
{
"input": {
"type": "boolean",
"value": {
"content": false,
"rawMeta": {
"type": 3
},
"type": "literal"
}
},
"name": "spAntiLeak"
},
{
"input": {
"type": "boolean",
"value": {
"content": false,
"rawMeta": {
"type": 3
},
"type": "literal"
}
},
"name": "prefixCache"
},
{
"input": {
"type": "integer",
"value": {
"content": "2",
"rawMeta": {
"type": 2
},
"type": "literal"
}
},
"name": "responseFormat"
},
{
"input": {
"type": "string",
"value": {
"content": "{{input}}",
"rawMeta": {
"type": 1
},
"type": "literal"
}
},
"name": "prompt"
},
{
"input": {
"type": "boolean",
"value": {
"content": false,
"rawMeta": {
"type": 3
},
"type": "literal"
}
},
"name": "enableChatHistory"
},
{
"input": {
"type": "integer",
"value": {
"content": "3",
"rawMeta": {
"type": 2
},
"type": "literal"
}
},
"name": "chatHistoryRound"
},
{
"input": {
"type": "string",
"value": {
"content": "",
"rawMeta": {
"type": 1
},
"type": "literal"
}
},
"name": "systemPrompt"
},
{
"input": {
"type": "string",
"value": {
"content": "",
"rawMeta": {
"type": 1
},
"type": "literal"
}
},
"name": "stableSystemPrompt"
},
{
"input": {
"type": "boolean",
"value": {
"content": false,
"rawMeta": {
"type": 3
},
"type": "literal"
}
},
"name": "canContinue"
},
{
"input": {
"type": "string",
"value": {
"content": "",
"rawMeta": {
"type": 1
},
"type": "literal"
}
},
"name": "loopPromptVersion"
},
{
"input": {
"type": "string",
"value": {
"content": "",
"rawMeta": {
"type": 1
},
"type": "literal"
}
},
"name": "loopPromptName"
},
{
"input": {
"type": "string",
"value": {
"content": "",
"rawMeta": {
"type": 1
},
"type": "literal"
}
},
"name": "loopPromptId"
}
],
"settingOnError": {
"processType": 1,
"retryTimes": 0,
"timeoutMs": 180000
}
},
"nodeMeta": {
"description": "调用大语言模型,使用变量和提示词生成回复",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LLM-v2.jpg",
"mainColor": "#5C62FF",
"subTitle": "大模型",
"title": "大模型"
},
"outputs": [
{
"name": "output",
"type": "string"
}
],
"version": "3"
},
"edges": null,
"id": "123887",
"meta": {
"position": {
"x": 463,
"y": -39
}
},
"type": "3"
}
],
"edges": [
{
"sourceNodeID": "100001",
"targetNodeID": "123887",
"sourcePortID": ""
},
{
"sourceNodeID": "123887",
"targetNodeID": "900001",
"sourcePortID": ""
}
],
"versions": {
"loop": "v2"
}
}

@ -0,0 +1,397 @@
{
"nodes": [
{
"blocks": [],
"data": {
"nodeMeta": {
"description": "工作流的起始节点,用于设定启动工作流需要的信息",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg",
"subTitle": "",
"title": "开始"
},
"outputs": [
{
"name": "USER_INPUT",
"required": false,
"type": "string"
},
{
"defaultValue": "Default",
"description": "本次请求绑定的会话,会自动写入消息、会从该会话读对话历史。",
"name": "CONVERSATION_NAME",
"required": false,
"type": "string"
}
],
"trigger_parameters": []
},
"edges": null,
"id": "100001",
"meta": {
"position": {
"x": 0,
"y": 0
}
},
"type": "1"
},
{
"blocks": [],
"data": {
"inputs": {
"content": {
"type": "string",
"value": {
"content": "{{output}}",
"type": "literal"
}
},
"inputParameters": [
{
"input": {
"type": "string",
"value": {
"content": {
"blockID": "123887",
"name": "output",
"source": "block-output"
},
"rawMeta": {
"type": 1
},
"type": "ref"
}
},
"name": "output"
}
],
"streamingOutput": true,
"terminatePlan": "useAnswerContent"
},
"nodeMeta": {
"description": "工作流的最终节点,用于返回工作流运行后的结果信息",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg",
"subTitle": "",
"title": "结束"
}
},
"edges": null,
"id": "900001",
"meta": {
"position": {
"x": 926,
"y": -13
}
},
"type": "2"
},
{
"blocks": [],
"data": {
"inputs": {
"fcParamVar": {
"knowledgeFCParam": {}
},
"inputParameters": [
{
"input": {
"type": "string",
"value": {
"content": {
"blockID": "100001",
"name": "USER_INPUT",
"source": "block-output"
},
"rawMeta": {
"type": 1
},
"type": "ref"
}
},
"name": "input"
}
],
"llmParam": [
{
"input": {
"type": "integer",
"value": {
"content": "1737521813",
"rawMeta": {
"type": 2
},
"type": "literal"
}
},
"name": "modelType"
},
{
"input": {
"type": "string",
"value": {
"content": "豆包·1.5·Pro·32k",
"rawMeta": {
"type": 1
},
"type": "literal"
}
},
"name": "modleName"
},
{
"input": {
"type": "string",
"value": {
"content": "balance",
"rawMeta": {
"type": 1
},
"type": "literal"
}
},
"name": "generationDiversity"
},
{
"input": {
"type": "float",
"value": {
"content": "0.8",
"rawMeta": {
"type": 4
},
"type": "literal"
}
},
"name": "temperature"
},
{
"input": {
"type": "integer",
"value": {
"content": "4096",
"rawMeta": {
"type": 2
},
"type": "literal"
}
},
"name": "maxTokens"
},
{
"input": {
"type": "boolean",
"value": {
"content": false,
"rawMeta": {
"type": 3
},
"type": "literal"
}
},
"name": "spCurrentTime"
},
{
"input": {
"type": "boolean",
"value": {
"content": false,
"rawMeta": {
"type": 3
},
"type": "literal"
}
},
"name": "spAntiLeak"
},
{
"input": {
"type": "boolean",
"value": {
"content": false,
"rawMeta": {
"type": 3
},
"type": "literal"
}
},
"name": "prefixCache"
},
{
"input": {
"type": "integer",
"value": {
"content": "2",
"rawMeta": {
"type": 2
},
"type": "literal"
}
},
"name": "responseFormat"
},
{
"input": {
"type": "string",
"value": {
"content": "{{input}}",
"rawMeta": {
"type": 1
},
"type": "literal"
}
},
"name": "prompt"
},
{
"input": {
"type": "boolean",
"value": {
"content": true,
"rawMeta": {
"type": 3
},
"type": "literal"
}
},
"name": "enableChatHistory"
},
{
"input": {
"type": "integer",
"value": {
"content": "3",
"rawMeta": {
"type": 2
},
"type": "literal"
}
},
"name": "chatHistoryRound"
},
{
"input": {
"type": "string",
"value": {
"content": "",
"rawMeta": {
"type": 1
},
"type": "literal"
}
},
"name": "systemPrompt"
},
{
"input": {
"type": "string",
"value": {
"content": "",
"rawMeta": {
"type": 1
},
"type": "literal"
}
},
"name": "stableSystemPrompt"
},
{
"input": {
"type": "boolean",
"value": {
"content": false,
"rawMeta": {
"type": 3
},
"type": "literal"
}
},
"name": "canContinue"
},
{
"input": {
"type": "string",
"value": {
"content": "",
"rawMeta": {
"type": 1
},
"type": "literal"
}
},
"name": "loopPromptVersion"
},
{
"input": {
"type": "string",
"value": {
"content": "",
"rawMeta": {
"type": 1
},
"type": "literal"
}
},
"name": "loopPromptName"
},
{
"input": {
"type": "string",
"value": {
"content": "",
"rawMeta": {
"type": 1
},
"type": "literal"
}
},
"name": "loopPromptId"
}
],
"settingOnError": {
"processType": 1,
"retryTimes": 0,
"timeoutMs": 180000
}
},
"nodeMeta": {
"description": "调用大语言模型,使用变量和提示词生成回复",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LLM-v2.jpg",
"mainColor": "#5C62FF",
"subTitle": "大模型",
"title": "大模型"
},
"outputs": [
{
"name": "output",
"type": "string"
}
],
"version": "3"
},
"edges": null,
"id": "123887",
"meta": {
"position": {
"x": 463,
"y": -39
}
},
"type": "3"
}
],
"edges": [
{
"sourceNodeID": "100001",
"targetNodeID": "123887",
"sourcePortID": ""
},
{
"sourceNodeID": "123887",
"targetNodeID": "900001",
"sourcePortID": ""
}
],
"versions": {
"loop": "v2"
}
}

@ -148,7 +148,7 @@ func (ch *ConversationHistory) Invoke(ctx context.Context, input map[string]any)
runIDs, err := crossmessage.DefaultSVC().GetLatestRunIDs(ctx, &crossmessage.GetLatestRunIDsRequest{
ConversationID: conversationID,
UserID: userID,
AppID: *appID,
BizID: *appID,
Rounds: rounds,
InitRunID: initRunID,
SectionID: sectionID,

@ -109,7 +109,7 @@ func (c *CreateConversation) Invoke(ctx context.Context, input map[string]any) (
if existed {
cID, _, existed, err := workflow.GetRepository().GetOrCreateStaticConversation(ctx, env, conversationIDGenerator, &vo.CreateStaticConversation{
AppID: ptr.From(appID),
BizID: ptr.From(appID),
TemplateID: template.TemplateID,
UserID: userID,
ConnectorID: connectorID,
@ -125,7 +125,7 @@ func (c *CreateConversation) Invoke(ctx context.Context, input map[string]any) (
}
cID, _, existed, err := workflow.GetRepository().GetOrCreateDynamicConversation(ctx, env, conversationIDGenerator, &vo.CreateDynamicConversation{
AppID: ptr.From(appID),
BizID: ptr.From(appID),
UserID: userID,
ConnectorID: connectorID,
Name: conversationName,

@ -98,7 +98,7 @@ func (c *CreateMessage) getConversationIDByName(ctx context.Context, env vo.Env,
var conversationID int64
if isExist {
cID, _, _, err := workflow.GetRepository().GetOrCreateStaticConversation(ctx, env, conversationIDGenerator, &vo.CreateStaticConversation{
AppID: ptr.From(appID),
BizID: ptr.From(appID),
TemplateID: template.TemplateID,
UserID: userID,
ConnectorID: connectorID,
@ -150,7 +150,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
var conversationID int64
var err error
var resolvedAppID int64
var bizID int64
if appID == nil {
if conversationName != "Default" {
return nil, vo.WrapError(errno.ErrOnlyDefaultConversationAllowInAgentScenario, errors.New("conversation node only allow in application"))
@ -167,13 +167,13 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
}, nil
}
conversationID = *execCtx.ExeCfg.ConversationID
resolvedAppID = *agentID
bizID = *agentID
} else {
conversationID, err = c.getConversationIDByName(ctx, env, appID, version, conversationName, userID, connectorID)
if err != nil {
return nil, err
}
resolvedAppID = *appID
bizID = *appID
}
if conversationID == 0 {
@ -209,7 +209,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
if role == "user" {
// For user messages, always create a new run and store the ID in the context.
runRecord, err := crossagentrun.DefaultSVC().Create(ctx, &agententity.AgentRunMeta{
AgentID: resolvedAppID,
AgentID: bizID,
ConversationID: conversationID,
UserID: strconv.FormatInt(userID, 10),
ConnectorID: connectorID,
@ -244,7 +244,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
runIDs, err := crossmessage.DefaultSVC().GetLatestRunIDs(ctx, &crossmessage.GetLatestRunIDsRequest{
ConversationID: conversationID,
UserID: userID,
AppID: resolvedAppID,
BizID: bizID,
Rounds: 1,
})
if err != nil {
@ -254,7 +254,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
runID = runIDs[0]
} else {
runRecord, err := crossagentrun.DefaultSVC().Create(ctx, &agententity.AgentRunMeta{
AgentID: resolvedAppID,
AgentID: bizID,
ConversationID: conversationID,
UserID: strconv.FormatInt(userID, 10),
ConnectorID: connectorID,
@ -273,7 +273,7 @@ func (c *CreateMessage) Invoke(ctx context.Context, input map[string]any) (map[s
Content: content,
ContentType: model.ContentType("text"),
UserID: strconv.FormatInt(userID, 10),
AgentID: resolvedAppID,
AgentID: bizID,
RunID: runID,
SectionID: sectionID,
}

@ -115,7 +115,7 @@ func (m *MessageList) Invoke(ctx context.Context, input map[string]any) (map[str
var conversationID int64
var err error
var resolvedAppID int64
var bizID int64
if appID == nil {
if conversationName != "Default" {
return nil, vo.WrapError(errno.ErrOnlyDefaultConversationAllowInAgentScenario, errors.New("conversation node only allow in application"))
@ -129,18 +129,18 @@ func (m *MessageList) Invoke(ctx context.Context, input map[string]any) (map[str
}, nil
}
conversationID = *execCtx.ExeCfg.ConversationID
resolvedAppID = *agentID
bizID = *agentID
} else {
conversationID, err = m.getConversationIDByName(ctx, env, appID, version, conversationName, userID, connectorID)
if err != nil {
return nil, err
}
resolvedAppID = *appID
bizID = *appID
}
req := &crossmessage.MessageListRequest{
UserID: userID,
AppID: resolvedAppID,
BizID: bizID,
ConversationID: conversationID,
}

@ -100,9 +100,9 @@ const (
ReasoningOutputKey = "reasoning_content"
)
const knowledgeUserPromptTemplate = `根据引用的内容回答问题:
1.如果引用的内容里面包含 <img src=""> 的标签, 标签里的 src 字段表示图片地址, 需要在回答问题的时候展示出去, 输出格式为"![图片名称](图片地址)"
2.如果引用的内容不包含 <img src=""> 的标签, 你回答问题时不需要展示图片
const knowledgeUserPromptTemplate = `根据引用的内容回答问题:
1.如果引用的内容里面包含 <img src=""> 的标签, 标签里的 src 字段表示图片地址, 需要在回答问题的时候展示出去, 输出格式为"![图片名称](图片地址)"
2.如果引用的内容不包含 <img src=""> 的标签, 你回答问题时不需要展示图片
例如
如果内容为<img src="https://example.com/image.jpg">一只小猫你的输出应为![一只小猫](https://example.com/image.jpg)。
如果内容为<img src="https://example.com/image1.jpg">一只小猫 <img src="https://example.com/image2.jpg">一只小狗 <img src="https://example.com/image3.jpg">一只小牛你的输出应为![一只小猫](https://example.com/image1.jpg) 和 ![一只小狗](https://example.com/image2.jpg) 和 ![一只小牛](https://example.com/image3.jpg)
@ -290,7 +290,7 @@ func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*
c.AssociateStartNodeUserInputFields = make(map[string]struct{})
for _, info := range ns.InputSources {
if len(info.Path) == 1 && info.Source.Ref != nil && info.Source.Ref.FromNodeKey == entity.EntryNodeKey {
if compose.FromFieldPath(info.Source.Ref.FromPath).Equals(compose.FromField("USER_INPUT")) {
if compose.FromFieldPath(info.Source.Ref.FromPath).Equals(compose.FromField(vo.UserInputKey)) {
c.AssociateStartNodeUserInputFields[info.Path[0]] = struct{}{}
}
}

@ -192,8 +192,3 @@ type StreamGenerator interface {
FieldStreamType(path compose.FieldPath, ns *schema.NodeSchema,
sc *schema.WorkflowSchema) (schema.FieldStreamType, error)
}
type ChatHistoryAware interface {
ChatHistoryEnabled() bool
ChatHistoryRounds() int64
}

@ -432,7 +432,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
appDynamicConversationDraft := r.query.AppDynamicConversationDraft
ret, err := appDynamicConversationDraft.WithContext(ctx).Where(
appDynamicConversationDraft.AppID.Eq(meta.AppID),
appDynamicConversationDraft.AppID.Eq(meta.BizID),
appDynamicConversationDraft.ConnectorID.Eq(meta.ConnectorID),
appDynamicConversationDraft.UserID.Eq(meta.UserID),
appDynamicConversationDraft.Name.Eq(meta.Name),
@ -452,7 +452,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
return 0, 0, false, vo.WrapError(errno.ErrDatabaseError, err)
}
conv, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID)
conv, err := idGen(ctx, meta.BizID, meta.UserID, meta.ConnectorID)
if err != nil {
return 0, 0, false, err
}
@ -464,7 +464,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
err = r.query.AppDynamicConversationDraft.WithContext(ctx).Create(&model.AppDynamicConversationDraft{
ID: id,
AppID: meta.AppID,
AppID: meta.BizID,
Name: meta.Name,
UserID: meta.UserID,
ConnectorID: meta.ConnectorID,
@ -479,7 +479,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
} else if env == vo.Online {
appDynamicConversationOnline := r.query.AppDynamicConversationOnline
ret, err := appDynamicConversationOnline.WithContext(ctx).Where(
appDynamicConversationOnline.AppID.Eq(meta.AppID),
appDynamicConversationOnline.AppID.Eq(meta.BizID),
appDynamicConversationOnline.ConnectorID.Eq(meta.ConnectorID),
appDynamicConversationOnline.UserID.Eq(meta.UserID),
appDynamicConversationOnline.Name.Eq(meta.Name),
@ -498,7 +498,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
return 0, 0, false, vo.WrapError(errno.ErrDatabaseError, err)
}
conv, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID)
conv, err := idGen(ctx, meta.BizID, meta.UserID, meta.ConnectorID)
if err != nil {
return 0, 0, false, err
}
@ -509,7 +509,7 @@ func (r *RepositoryImpl) GetOrCreateDynamicConversation(ctx context.Context, env
err = r.query.AppDynamicConversationOnline.WithContext(ctx).Create(&model.AppDynamicConversationOnline{
ID: id,
AppID: meta.AppID,
AppID: meta.BizID,
Name: meta.Name,
UserID: meta.UserID,
ConnectorID: meta.ConnectorID,
@ -586,7 +586,7 @@ func (r *RepositoryImpl) getOrCreateDraftStaticConversation(ctx context.Context,
return cs[0].ConversationID, cInfo.SectionID, true, nil
}
conv, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID)
conv, err := idGen(ctx, meta.BizID, meta.UserID, meta.ConnectorID)
if err != nil {
return 0, 0, false, err
}
@ -627,7 +627,7 @@ func (r *RepositoryImpl) getOrCreateOnlineStaticConversation(ctx context.Context
return cs[0].ConversationID, cInfo.SectionID, true, nil
}
conv, err := idGen(ctx, meta.AppID, meta.UserID, meta.ConnectorID)
conv, err := idGen(ctx, meta.BizID, meta.UserID, meta.ConnectorID)
if err != nil {
return 0, 0, false, err
}
@ -841,7 +841,7 @@ func (r *RepositoryImpl) CopyTemplateConversationByAppID(ctx context.Context, ap
}
func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (string, bool, error) {
func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (string, bool, error) {
if env == vo.Draft {
appStaticConversationDraft := r.query.AppStaticConversationDraft
ret, err := appStaticConversationDraft.WithContext(ctx).Where(
@ -857,7 +857,7 @@ func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.E
appConversationTemplateDraft := r.query.AppConversationTemplateDraft
template, err := appConversationTemplateDraft.WithContext(ctx).Where(
appConversationTemplateDraft.TemplateID.Eq(ret.TemplateID),
appConversationTemplateDraft.AppID.Eq(appID),
appConversationTemplateDraft.AppID.Eq(bizID),
).First()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
@ -881,7 +881,7 @@ func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.E
appConversationTemplateOnline := r.query.AppConversationTemplateOnline
template, err := appConversationTemplateOnline.WithContext(ctx).Where(
appConversationTemplateOnline.TemplateID.Eq(ret.TemplateID),
appConversationTemplateOnline.AppID.Eq(appID),
appConversationTemplateOnline.AppID.Eq(bizID),
).First()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
@ -894,11 +894,11 @@ func (r *RepositoryImpl) GetStaticConversationByID(ctx context.Context, env vo.E
return "", false, fmt.Errorf("unknown env %v", env)
}
func (r *RepositoryImpl) GetDynamicConversationByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (*entity.DynamicConversation, bool, error) {
func (r *RepositoryImpl) GetDynamicConversationByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (*entity.DynamicConversation, bool, error) {
if env == vo.Draft {
appDynamicConversationDraft := r.query.AppDynamicConversationDraft
ret, err := appDynamicConversationDraft.WithContext(ctx).Where(
appDynamicConversationDraft.AppID.Eq(appID),
appDynamicConversationDraft.AppID.Eq(bizID),
appDynamicConversationDraft.ConnectorID.Eq(connectorID),
appDynamicConversationDraft.ConversationID.Eq(conversationID),
).First()
@ -918,7 +918,7 @@ func (r *RepositoryImpl) GetDynamicConversationByID(ctx context.Context, env vo.
} else if env == vo.Online {
appDynamicConversationOnline := r.query.AppDynamicConversationOnline
ret, err := appDynamicConversationOnline.WithContext(ctx).Where(
appDynamicConversationOnline.AppID.Eq(appID),
appDynamicConversationOnline.AppID.Eq(bizID),
appDynamicConversationOnline.ConnectorID.Eq(connectorID),
appDynamicConversationOnline.ConversationID.Eq(conversationID),
).First()

@ -129,3 +129,8 @@ func (s *NodeSchema) SetOutputType(key string, t *vo.TypeInfo) {
func (s *NodeSchema) AddOutputSource(info ...*vo.FieldInfo) {
s.OutputSources = append(s.OutputSources, info...)
}
type ChatHistoryAware interface {
ChatHistoryEnabled() bool
ChatHistoryRounds() int64
}

@ -38,6 +38,7 @@ type WorkflowSchema struct {
compositeNodes []*CompositeNode // won't serialize this
requireCheckPoint bool // won't serialize this
requireStreaming bool
historyRounds int64
once sync.Once
}
@ -69,15 +70,22 @@ func (w *WorkflowSchema) Init() {
w.doGetCompositeNodes()
historyRounds := int64(0)
for _, node := range w.Nodes {
if node.Type == entity.NodeTypeSubWorkflow {
node.SubWorkflowSchema.Init()
historyRounds = max(historyRounds, node.SubWorkflowSchema.HistoryRounds())
if node.SubWorkflowSchema.requireCheckPoint {
w.requireCheckPoint = true
break
}
}
chatHistoryAware, ok := node.Configs.(ChatHistoryAware)
if ok && chatHistoryAware.ChatHistoryEnabled() {
historyRounds = max(historyRounds, chatHistoryAware.ChatHistoryRounds())
}
if rc, ok := node.Configs.(RequireCheckpoint); ok {
if rc.RequireCheckpoint() {
w.requireCheckPoint = true
@ -86,6 +94,7 @@ func (w *WorkflowSchema) Init() {
}
}
w.historyRounds = historyRounds
w.requireStreaming = w.doRequireStreaming()
})
}
@ -122,6 +131,12 @@ func (w *WorkflowSchema) RequireStreaming() bool {
return w.requireStreaming
}
func (w *WorkflowSchema) HistoryRounds() int64 { return w.historyRounds }
func (w *WorkflowSchema) SetHistoryRounds(historyRounds int64) {
w.historyRounds = historyRounds
}
func (w *WorkflowSchema) doGetCompositeNodes() (cNodes []*CompositeNode) {
if w.Hierarchy == nil {
return nil

@ -248,7 +248,7 @@ func (c *conversationImpl) findReplaceWorkflowByConversationName(ctx context.Con
if err != nil {
return false, err
}
if v.Name == "CONVERSATION_NAME" && v.DefaultValue == name {
if v.Name == vo.ConversationNameKey && v.DefaultValue == name {
return true, nil
}
}
@ -296,7 +296,7 @@ func (c *conversationImpl) replaceWorkflowsConversationName(ctx context.Context,
if err != nil {
return err
}
if v.Name == "CONVERSATION_NAME" {
if v.Name == vo.ConversationNameKey {
v.DefaultValue = conversionName
}
startNode.Data.Outputs[idx] = v
@ -351,18 +351,18 @@ func (c *conversationImpl) DeleteDynamicConversation(ctx context.Context, env vo
return c.repo.DeleteDynamicConversation(ctx, env, templateID)
}
func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, int64, error) {
func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.Env, bizID, connectorID, userID int64, conversationName string) (int64, int64, error) {
t, existed, err := c.repo.GetConversationTemplate(ctx, env, vo.GetConversationTemplatePolicy{
AppID: ptr.Of(appID),
AppID: ptr.Of(bizID),
Name: ptr.Of(conversationName),
})
if err != nil {
return 0, 0, err
}
conversationIDGenerator := workflow.ConversationIDGenerator(func(ctx context.Context, appID int64, userID, connectorID int64) (*conventity.Conversation, error) {
conversationIDGenerator := workflow.ConversationIDGenerator(func(ctx context.Context, bizID int64, userID, connectorID int64) (*conventity.Conversation, error) {
return crossconversation.DefaultSVC().CreateConversation(ctx, &conventity.CreateMeta{
AgentID: appID,
AgentID: bizID,
UserID: userID,
ConnectorID: connectorID,
Scene: common.Scene_SceneWorkflow,
@ -371,7 +371,7 @@ func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.E
if existed {
conversationID, sectionID, _, err := c.repo.GetOrCreateStaticConversation(ctx, env, conversationIDGenerator, &vo.CreateStaticConversation{
AppID: appID,
BizID: bizID,
ConnectorID: connectorID,
UserID: userID,
TemplateID: t.TemplateID,
@ -383,7 +383,7 @@ func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.E
}
conversationID, sectionID, _, err := c.repo.GetOrCreateDynamicConversation(ctx, env, conversationIDGenerator, &vo.CreateDynamicConversation{
AppID: appID,
BizID: bizID,
ConnectorID: connectorID,
UserID: userID,
Name: conversationName,
@ -465,8 +465,8 @@ func (c *conversationImpl) GetDynamicConversationByName(ctx context.Context, env
return c.repo.GetDynamicConversationByName(ctx, env, appID, connectorID, userID, name)
}
func (c *conversationImpl) GetConversationNameByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (string, bool, error) {
sc, existed, err := c.repo.GetStaticConversationByID(ctx, env, appID, connectorID, conversationID)
func (c *conversationImpl) GetConversationNameByID(ctx context.Context, env vo.Env, bizID, connectorID, conversationID int64) (string, bool, error) {
sc, existed, err := c.repo.GetStaticConversationByID(ctx, env, bizID, connectorID, conversationID)
if err != nil {
return "", false, err
}
@ -474,7 +474,7 @@ func (c *conversationImpl) GetConversationNameByID(ctx context.Context, env vo.E
return sc, true, nil
}
dc, existed, err := c.repo.GetDynamicConversationByID(ctx, env, appID, connectorID, conversationID)
dc, existed, err := c.repo.GetDynamicConversationByID(ctx, env, bizID, connectorID, conversationID)
if err != nil {
return "", false, err
}

@ -22,6 +22,8 @@ import (
"fmt"
"time"
"github.com/coze-dev/coze-studio/backend/types/consts"
einoCompose "github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
@ -282,6 +284,50 @@ func (i *impl) AsyncExecute(ctx context.Context, config workflowModel.ExecuteCon
return executeID, nil
}
func (i *impl) handleHistory(ctx context.Context, config *workflowModel.ExecuteConfig, input map[string]any, historyRounds int64, shouldFetchConversationByName bool) error {
if historyRounds <= 0 {
return nil
}
if shouldFetchConversationByName {
var cID, sID, bizID int64
var err error
if config.AppID != nil {
bizID = *config.AppID
} else if config.AgentID != nil {
bizID = *config.AgentID
}
for k, v := range input {
if k == vo.ConversationNameKey {
cName, ok := v.(string)
if !ok {
return errors.New("CONVERSATION_NAME must be string")
}
cID, sID, err = i.GetOrCreateConversation(ctx, vo.Draft, bizID, consts.CozeConnectorID, config.Operator, cName)
if err != nil {
return err
}
config.ConversationID = ptr.Of(cID)
config.SectionID = ptr.Of(sID)
}
}
}
messages, scMessages, err := i.prefetchChatHistory(ctx, *config, historyRounds)
if err != nil {
logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err)
}
if len(messages) > 0 {
config.ConversationHistory = messages
}
if len(scMessages) > 0 {
config.ConversationHistorySchemaMessages = scMessages
}
return nil
}
func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workflowModel.ExecuteConfig, input map[string]any) (int64, error) {
var (
err error
@ -308,30 +354,6 @@ func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workf
}
}
historyRounds := int64(0)
if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
historyRounds, err = getHistoryRoundsFromNode(ctx, wfEntity, nodeID, i.repo)
if err != nil {
return 0, err
}
}
if historyRounds > 0 {
messages, scMessages, err := i.prefetchChatHistory(ctx, config, historyRounds)
if err != nil {
logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err)
}
if len(messages) > 0 {
config.ConversationHistory = messages
}
if len(scMessages) > 0 {
config.ConversationHistorySchemaMessages = scMessages
}
}
c := &vo.Canvas{}
if err = sonic.UnmarshalString(wfEntity.Canvas, c); err != nil {
return 0, fmt.Errorf("failed to unmarshal canvas: %w", err)
@ -342,6 +364,17 @@ func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workf
return 0, fmt.Errorf("failed to convert canvas to workflow schema: %w", err)
}
historyRounds := int64(0)
if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
historyRounds = workflowSC.HistoryRounds()
}
if historyRounds > 0 {
if err = i.handleHistory(ctx, &config, input, historyRounds, true); err != nil {
return 0, err
}
}
wf, err := compose.NewWorkflowFromNode(ctx, workflowSC, vo.NodeKey(nodeID), einoCompose.WithGraphName(fmt.Sprintf("%d", wfEntity.ID)))
if err != nil {
return 0, fmt.Errorf("failed to create workflow: %w", err)
@ -417,29 +450,6 @@ func (i *impl) StreamExecute(ctx context.Context, config workflowModel.ExecuteCo
}
}
historyRounds := int64(0)
if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
historyRounds, err = i.calculateMaxChatHistoryRounds(ctx, wfEntity, i.repo)
if err != nil {
return nil, err
}
}
if historyRounds > 0 {
messages, scMessages, err := i.prefetchChatHistory(ctx, config, historyRounds)
if err != nil {
logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err)
}
if len(messages) > 0 {
config.ConversationHistory = messages
}
if len(scMessages) > 0 {
config.ConversationHistorySchemaMessages = scMessages
}
}
c := &vo.Canvas{}
if err = sonic.UnmarshalString(wfEntity.Canvas, c); err != nil {
return nil, fmt.Errorf("failed to unmarshal canvas: %w", err)
@ -450,6 +460,17 @@ func (i *impl) StreamExecute(ctx context.Context, config workflowModel.ExecuteCo
return nil, fmt.Errorf("failed to convert canvas to workflow schema: %w", err)
}
historyRounds := int64(0)
if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
historyRounds = workflowSC.HistoryRounds()
}
if historyRounds > 0 {
if err = i.handleHistory(ctx, &config, input, historyRounds, false); err != nil {
return nil, err
}
}
var wfOpts []compose.WorkflowOption
wfOpts = append(wfOpts, compose.WithIDAsName(wfEntity.ID))
if s := execute.GetStaticConfig(); s != nil && s.MaxNodeCountPerWorkflow > 0 {
@ -997,20 +1018,6 @@ func (i *impl) checkApplicationWorkflowReleaseVersion(ctx context.Context, appID
return nil
}
const maxHistoryRounds int64 = 30
func (i *impl) calculateMaxChatHistoryRounds(ctx context.Context, wfEntity *entity.Workflow, repo workflow.Repository) (int64, error) {
if wfEntity == nil {
return 0, nil
}
maxRounds, err := getMaxHistoryRoundsRecursively(ctx, wfEntity, repo)
if err != nil {
return 0, err
}
return min(maxRounds, maxHistoryRounds), nil
}
func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.ExecuteConfig, historyRounds int64) ([]*crossmessage.WfMessage, []*schema.Message, error) {
convID := config.ConversationID
agentID := config.AgentID
@ -1027,11 +1034,11 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe
return nil, nil, nil
}
var resolvedAppID int64
var bizID int64
if appID != nil {
resolvedAppID = *appID
bizID = *appID
} else if agentID != nil {
resolvedAppID = *agentID
bizID = *agentID
} else {
logs.CtxWarnf(ctx, "AppID and AgentID are both nil, skipping chat history")
return nil, nil, nil
@ -1039,7 +1046,7 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe
runIdsReq := &crossmessage.GetLatestRunIDsRequest{
ConversationID: *convID,
AppID: resolvedAppID,
BizID: bizID,
UserID: userID,
Rounds: historyRounds + 1,
SectionID: *sectionID,
@ -1048,7 +1055,7 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe
runIds, err := crossmessage.DefaultSVC().GetLatestRunIDs(ctx, runIdsReq)
if err != nil {
logs.CtxErrorf(ctx, "failed to get latest run ids: %v", err)
return nil, nil, nil
return nil, nil, err
}
if len(runIds) <= 1 {
return []*crossmessage.WfMessage{}, []*schema.Message{}, nil
@ -1061,7 +1068,7 @@ func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.Exe
})
if err != nil {
logs.CtxErrorf(ctx, "failed to get messages by run ids: %v", err)
return nil, nil, nil
return nil, nil, err
}
return response.Messages, response.SchemaMessages, nil

@ -0,0 +1,286 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package service
import (
"context"
"errors"
"testing"
"github.com/cloudwego/eino/schema"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
messagemock "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message/messagemock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
mock_workflow "github.com/coze-dev/coze-studio/backend/internal/mock/domain/workflow"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func TestImpl_handleHistory(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t, gomock.WithOverridableExpectations())
defer ctrl.Finish()
// Setup for cross-domain service mock
mockMessage := messagemock.NewMockMessage(ctrl)
crossmessage.SetDefaultSVC(mockMessage)
tests := []struct {
name string
setupMock func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository)
config *workflowModel.ExecuteConfig
input map[string]any
historyRounds int64
shouldFetch bool
expectErr bool
expectedHistory []*crossmessage.WfMessage
expectedSchemaHistory []*schema.Message
}{
{
name: "historyRounds is zero",
historyRounds: 0,
shouldFetch: true,
config: &workflowModel.ExecuteConfig{},
setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
},
expectErr: false,
},
{
name: "shouldFetch is false",
historyRounds: 5,
shouldFetch: false,
config: &workflowModel.ExecuteConfig{
AppID: ptr.Of(int64(1)),
ConversationID: ptr.Of(int64(100)),
SectionID: ptr.Of(int64(101)),
},
setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{1, 2}, nil).AnyTimes()
msgSvc.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(&crossmessage.GetMessagesByRunIDsResponse{
Messages: []*crossmessage.WfMessage{{ID: 1}},
SchemaMessages: []*schema.Message{{
Role: schema.User,
Content: "123",
}},
}, nil).AnyTimes()
},
expectErr: false,
expectedHistory: []*crossmessage.WfMessage{{ID: 1}},
expectedSchemaHistory: []*schema.Message{{
Role: schema.User,
Content: "123",
}},
},
{
name: "fetch conversation by name - conversation exists",
historyRounds: 3,
shouldFetch: true,
config: &workflowModel.ExecuteConfig{AppID: ptr.Of(int64(1))},
input: map[string]any{"CONVERSATION_NAME": "test-conv"},
setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
service.EXPECT().GetOrCreateConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), "test-conv").Return(int64(200), int64(201), nil).AnyTimes()
msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{3, 4}, nil).AnyTimes()
msgSvc.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(&crossmessage.GetMessagesByRunIDsResponse{
Messages: []*crossmessage.WfMessage{{ID: 2}},
SchemaMessages: []*schema.Message{{
Role: schema.Assistant,
Content: "123",
}},
}, nil).AnyTimes()
repo.EXPECT().GetConversationTemplate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&entity.ConversationTemplate{
TemplateID: int64(202),
SpaceID: int64(203),
AppID: int64(204),
}, true, nil).AnyTimes()
repo.EXPECT().GetOrCreateStaticConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(205), int64(206), true, nil).AnyTimes()
},
expectErr: false,
expectedHistory: []*crossmessage.WfMessage{{ID: 2}},
expectedSchemaHistory: []*schema.Message{{
Role: schema.Assistant,
Content: "123",
}},
},
{
name: "fetch conversation by name - conversation not exists",
historyRounds: 3,
shouldFetch: true,
config: &workflowModel.ExecuteConfig{AgentID: ptr.Of(int64(2))},
input: map[string]any{"CONVERSATION_NAME": "new-conv"},
setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
service.EXPECT().GetOrCreateConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), "new-conv").Return(int64(300), int64(301), nil).AnyTimes()
msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{5, 6}, nil).AnyTimes()
msgSvc.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(&crossmessage.GetMessagesByRunIDsResponse{
Messages: []*crossmessage.WfMessage{{ID: 3}},
}, nil).AnyTimes()
repo.EXPECT().GetConversationTemplate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&entity.ConversationTemplate{
TemplateID: int64(202),
SpaceID: int64(203),
AppID: int64(204),
}, false, nil).AnyTimes()
repo.EXPECT().GetOrCreateDynamicConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(205), int64(206), true, nil).AnyTimes()
},
expectErr: false,
expectedHistory: []*crossmessage.WfMessage{{ID: 3}},
},
{
name: "input with wrong type for conversation name",
historyRounds: 5,
shouldFetch: true,
config: &workflowModel.ExecuteConfig{AppID: ptr.Of(int64(1))},
input: map[string]any{"CONVERSATION_NAME": 12345},
setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
},
expectErr: true,
},
{
name: "GetOrCreateConversation returns error",
historyRounds: 5,
shouldFetch: true,
config: &workflowModel.ExecuteConfig{AppID: ptr.Of(int64(1))},
input: map[string]any{"CONVERSATION_NAME": "fail-conv"},
setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
service.EXPECT().GetOrCreateConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), "fail-conv").Return(int64(0), int64(0), errors.New("db error")).AnyTimes()
repo.EXPECT().GetConversationTemplate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&entity.ConversationTemplate{
TemplateID: int64(202),
SpaceID: int64(203),
AppID: int64(204),
}, false, nil).AnyTimes()
repo.EXPECT().GetOrCreateDynamicConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(205), int64(206), true, errors.New("db error")).AnyTimes()
},
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockService := mock_workflow.NewMockService(ctrl)
mockRepo := mock_workflow.NewMockRepository(ctrl)
testImpl := &impl{repo: mockRepo, conversationImpl: &conversationImpl{repo: mockRepo}}
tt.setupMock(mockService, mockMessage, mockRepo)
err := testImpl.handleHistory(ctx, tt.config, tt.input, tt.historyRounds, tt.shouldFetch)
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
if tt.expectedHistory != nil {
assert.Equal(t, tt.expectedHistory, tt.config.ConversationHistory)
} else if tt.historyRounds == 0 {
assert.Nil(t, tt.config.ConversationHistory)
} else if tt.expectedSchemaHistory != nil {
assert.Equal(t, tt.expectedSchemaHistory, tt.config.ConversationHistorySchemaMessages)
}
}
})
}
}
func TestImpl_prefetchChatHistory(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t, gomock.WithOverridableExpectations())
defer ctrl.Finish()
mockMessage := messagemock.NewMockMessage(ctrl)
crossmessage.SetDefaultSVC(mockMessage)
tests := []struct {
name string
setupMock func(msgSvc *messagemock.MockMessage)
config workflowModel.ExecuteConfig
historyRounds int64
expectErr bool
}{
{
name: "SectionID is nil",
config: workflowModel.ExecuteConfig{
ConversationID: ptr.Of(int64(100)),
AppID: ptr.Of(int64(1)),
},
historyRounds: 5,
setupMock: func(msgSvc *messagemock.MockMessage) {},
expectErr: false,
},
{
name: "ConversationID is nil",
config: workflowModel.ExecuteConfig{
SectionID: ptr.Of(int64(101)),
AppID: ptr.Of(int64(1)),
},
historyRounds: 5,
setupMock: func(msgSvc *messagemock.MockMessage) {},
expectErr: false,
},
{
name: "AppID and AgentID are both nil",
config: workflowModel.ExecuteConfig{
ConversationID: ptr.Of(int64(100)),
SectionID: ptr.Of(int64(101)),
},
historyRounds: 5,
setupMock: func(msgSvc *messagemock.MockMessage) {},
expectErr: false,
},
{
name: "GetLatestRunIDs returns error",
config: workflowModel.ExecuteConfig{
AppID: ptr.Of(int64(1)),
ConversationID: ptr.Of(int64(100)),
SectionID: ptr.Of(int64(101)),
},
historyRounds: 5,
setupMock: func(msgSvc *messagemock.MockMessage) {
msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return(nil, errors.New("db error"))
},
expectErr: true,
},
{
name: "GetMessagesByRunIDs returns error",
config: workflowModel.ExecuteConfig{
AppID: ptr.Of(int64(1)),
ConversationID: ptr.Of(int64(100)),
SectionID: ptr.Of(int64(101)),
},
historyRounds: 5,
setupMock: func(msgSvc *messagemock.MockMessage) {
msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{1, 2, 3}, nil)
msgSvc.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(nil, errors.New("db error"))
},
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testImpl := &impl{}
tt.setupMock(mockMessage)
_, _, err := testImpl.prefetchChatHistory(ctx, tt.config, tt.historyRounds)
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

@ -39,7 +39,6 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
@ -522,7 +521,7 @@ func isEnableChatHistory(s *schema.NodeSchema) bool {
return false
}
chatHistoryAware, ok := s.Configs.(nodes.ChatHistoryAware)
chatHistoryAware, ok := s.Configs.(schema.ChatHistoryAware)
if !ok {
return false
}
@ -2171,15 +2170,15 @@ func (i *impl) adaptToChatFlow(ctx context.Context, wID int64) error {
vMap[v.Name] = true
}
if _, ok := vMap["USER_INPUT"]; !ok {
if _, ok := vMap[vo.UserInputKey]; !ok {
startNode.Data.Outputs = append(startNode.Data.Outputs, &vo.Variable{
Name: "USER_INPUT",
Name: vo.UserInputKey,
Type: vo.VariableTypeString,
})
}
if _, ok := vMap["CONVERSATION_NAME"]; !ok {
if _, ok := vMap[vo.ConversationNameKey]; !ok {
startNode.Data.Outputs = append(startNode.Data.Outputs, &vo.Variable{
Name: "CONVERSATION_NAME",
Name: vo.ConversationNameKey,
Type: vo.VariableTypeString,
DefaultValue: "Default",
})

@ -22,15 +22,11 @@ import (
"strconv"
"strings"
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
"github.com/coze-dev/coze-studio/backend/api/model/workflow"
wf "github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/validate"
"github.com/coze-dev/coze-studio/backend/domain/workflow/variable"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
@ -201,214 +197,3 @@ func isIncremental(prev version, next version) bool {
return next.Patch > prev.Patch
}
func getMaxHistoryRoundsRecursively(ctx context.Context, wfEntity *entity.Workflow, repo wf.Repository) (int64, error) {
visited := make(map[string]struct{})
maxRounds := int64(0)
err := getMaxHistoryRoundsRecursiveHelper(ctx, wfEntity, repo, visited, &maxRounds)
return maxRounds, err
}
func getMaxHistoryRoundsRecursiveHelper(ctx context.Context, wfEntity *entity.Workflow, repo wf.Repository, visited map[string]struct{}, maxRounds *int64) error {
visitedKey := fmt.Sprintf("%d:%s", wfEntity.ID, wfEntity.GetVersion())
if _, ok := visited[visitedKey]; ok {
return nil
}
visited[visitedKey] = struct{}{}
var canvas vo.Canvas
if err := sonic.UnmarshalString(wfEntity.Canvas, &canvas); err != nil {
return fmt.Errorf("failed to unmarshal canvas for workflow %d: %w", wfEntity.ID, err)
}
return collectMaxHistoryRounds(ctx, canvas.Nodes, repo, visited, maxRounds)
}
func collectMaxHistoryRounds(ctx context.Context, nodes []*vo.Node, repo wf.Repository, visited map[string]struct{}, maxRounds *int64) error {
for _, node := range nodes {
if node == nil {
continue
}
if node.Data != nil && node.Data.Inputs != nil && node.Data.Inputs.ChatHistorySetting != nil && node.Data.Inputs.ChatHistorySetting.EnableChatHistory {
if node.Data.Inputs.ChatHistorySetting.ChatHistoryRound > *maxRounds {
*maxRounds = node.Data.Inputs.ChatHistorySetting.ChatHistoryRound
}
} else if node.Type == entity.NodeTypeLLM.IDStr() && node.Data != nil && node.Data.Inputs != nil && node.Data.Inputs.LLMParam != nil {
param := node.Data.Inputs.LLMParam
bs, _ := sonic.Marshal(param)
llmParam := make(vo.LLMParam, 0)
if err := sonic.Unmarshal(bs, &llmParam); err != nil {
return err
}
var chatHistoryEnabled bool
var chatHistoryRound int64
for _, param := range llmParam {
switch param.Name {
case "enableChatHistory":
if val, ok := param.Input.Value.Content.(bool); ok {
b := val
chatHistoryEnabled = b
}
case "chatHistoryRound":
if strVal, ok := param.Input.Value.Content.(string); ok {
int64Val, err := strconv.ParseInt(strVal, 10, 64)
if err != nil {
return err
}
chatHistoryRound = int64Val
}
}
}
if chatHistoryEnabled {
if chatHistoryRound > *maxRounds {
*maxRounds = chatHistoryRound
}
}
}
isSubWorkflow := node.Type == entity.NodeTypeSubWorkflow.IDStr() && node.Data != nil && node.Data.Inputs != nil
if isSubWorkflow {
workflowIDStr := node.Data.Inputs.WorkflowID
if workflowIDStr == "" {
continue
}
workflowID, err := strconv.ParseInt(workflowIDStr, 10, 64)
if err != nil {
return fmt.Errorf("invalid workflow ID in sub-workflow node %s: %w", node.ID, err)
}
subWfEntity, err := repo.GetEntity(ctx, &vo.GetPolicy{
ID: workflowID,
QType: ternary.IFElse(len(node.Data.Inputs.WorkflowVersion) == 0, workflowModel.FromDraft, workflowModel.FromSpecificVersion),
Version: node.Data.Inputs.WorkflowVersion,
})
if err != nil {
return fmt.Errorf("failed to get sub-workflow entity %d: %w", workflowID, err)
}
if err := getMaxHistoryRoundsRecursiveHelper(ctx, subWfEntity, repo, visited, maxRounds); err != nil {
return err
}
}
if len(node.Blocks) > 0 {
if err := collectMaxHistoryRounds(ctx, node.Blocks, repo, visited, maxRounds); err != nil {
return err
}
}
}
return nil
}
func getHistoryRoundsFromNode(ctx context.Context, wfEntity *entity.Workflow, nodeID string, repo wf.Repository) (int64, error) {
if wfEntity == nil {
return 0, nil
}
visited := make(map[string]struct{})
visitedKey := fmt.Sprintf("%d:%s", wfEntity.ID, wfEntity.GetVersion())
if _, ok := visited[visitedKey]; ok {
return 0, nil
}
visited[visitedKey] = struct{}{}
maxRounds := int64(0)
c := &vo.Canvas{}
if err := sonic.UnmarshalString(wfEntity.Canvas, c); err != nil {
return 0, fmt.Errorf("failed to unmarshal canvas: %w", err)
}
var (
n *vo.Node
nodeFinder func(nodes []*vo.Node) *vo.Node
)
nodeFinder = func(nodes []*vo.Node) *vo.Node {
for i := range nodes {
if nodes[i].ID == nodeID {
return nodes[i]
}
if len(nodes[i].Blocks) > 0 {
if n := nodeFinder(nodes[i].Blocks); n != nil {
return n
}
}
}
return nil
}
n = nodeFinder(c.Nodes)
if n.Type == entity.NodeTypeLLM.IDStr() {
if n.Data == nil || n.Data.Inputs == nil {
return 0, nil
}
param := n.Data.Inputs.LLMParam
bs, _ := sonic.Marshal(param)
llmParam := make(vo.LLMParam, 0)
if err := sonic.Unmarshal(bs, &llmParam); err != nil {
return 0, err
}
var chatHistoryEnabled bool
var chatHistoryRound int64
for _, param := range llmParam {
switch param.Name {
case "enableChatHistory":
if val, ok := param.Input.Value.Content.(bool); ok {
b := val
chatHistoryEnabled = b
}
case "chatHistoryRound":
if strVal, ok := param.Input.Value.Content.(string); ok {
int64Val, err := strconv.ParseInt(strVal, 10, 64)
if err != nil {
return 0, err
}
chatHistoryRound = int64Val
}
}
}
if chatHistoryEnabled {
return chatHistoryRound, nil
}
return 0, nil
}
if n.Type == entity.NodeTypeIntentDetector.IDStr() || n.Type == entity.NodeTypeKnowledgeRetriever.IDStr() {
if n.Data != nil && n.Data.Inputs != nil && n.Data.Inputs.ChatHistorySetting != nil && n.Data.Inputs.ChatHistorySetting.EnableChatHistory {
return n.Data.Inputs.ChatHistorySetting.ChatHistoryRound, nil
}
return 0, nil
}
if n.Type == entity.NodeTypeSubWorkflow.IDStr() {
if n.Data != nil && n.Data.Inputs != nil {
workflowIDStr := n.Data.Inputs.WorkflowID
if workflowIDStr == "" {
return 0, nil
}
workflowID, err := strconv.ParseInt(workflowIDStr, 10, 64)
if err != nil {
return 0, fmt.Errorf("invalid workflow ID in sub-workflow node %s: %w", n.ID, err)
}
subWfEntity, err := repo.GetEntity(ctx, &vo.GetPolicy{
ID: workflowID,
QType: ternary.IFElse(len(n.Data.Inputs.WorkflowVersion) == 0, workflowModel.FromDraft, workflowModel.FromSpecificVersion),
Version: n.Data.Inputs.WorkflowVersion,
})
if err != nil {
return 0, fmt.Errorf("failed to get sub-workflow entity %d: %w", workflowID, err)
}
if err := getMaxHistoryRoundsRecursiveHelper(ctx, subWfEntity, repo, visited, &maxRounds); err != nil {
return 0, err
}
return maxRounds, nil
}
}
if len(n.Blocks) > 0 {
if err := collectMaxHistoryRounds(ctx, n.Blocks, repo, visited, &maxRounds); err != nil {
return 0, err
}
}
return maxRounds, nil
}

Loading…
Cancel
Save