From 8db69fd6a94045b344cde37fc2ab2549feaa43c2 Mon Sep 17 00:00:00 2001 From: shentongmartin Date: Wed, 3 Sep 2025 14:48:51 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20update=20eino=20to=20fix=20the=20issue?= =?UTF-8?q?=20when=20field=20mappings=20are=20resolved=20re=E2=80=A6=20(#1?= =?UTF-8?q?952)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/ci@backend.yml | 16 ++- .../api/handler/coze/workflow_service_test.go | 19 ++- backend/domain/workflow/entity/message.go | 1 - .../internal/compose/designate_option.go | 12 +- .../domain/workflow/internal/compose/state.go | 1 + .../workflow/internal/compose/workflow_run.go | 2 +- .../internal/compose/workflow_tool.go | 114 ++++++++++++++---- .../workflow/internal/execute/callback.go | 65 +++++++--- .../internal/execute/collect_token.go | 28 +++++ .../domain/workflow/internal/execute/event.go | 21 +--- .../workflow/internal/execute/event_handle.go | 70 +++++------ .../workflow/internal/execute/tool_option.go | 9 +- backend/go.mod | 9 +- backend/go.sum | 12 +- backend/internal/testutil/chat_model.go | 18 ++- 15 files changed, 280 insertions(+), 117 deletions(-) diff --git a/.github/workflows/ci@backend.yml b/.github/workflows/ci@backend.yml index 7115ed7a..fb44283a 100644 --- a/.github/workflows/ci@backend.yml +++ b/.github/workflows/ci@backend.yml @@ -31,7 +31,7 @@ jobs: env: COVERAGE_FILE: coverage.out BREAKDOWN_FILE: main.breakdown - + steps: - uses: actions/checkout@v4 - name: Set up Go @@ -52,7 +52,7 @@ jobs: mysql version: '8.4.5' mysql database: 'opencoze' mysql root password: 'root' - + - name: Verify MySQL Startup run: | echo "Waiting for MySQL to be ready..." @@ -70,8 +70,12 @@ jobs: run: sudo apt-get update && sudo apt-get install -y mysql-client - name: Initialize Database - run: mysql -h 127.0.0.1 -P 3306 -u root -proot opencoze < docker/volumes/mysql/schema.sql - + uses: nick-fields/retry@v3 + with: + timeout_minutes: 10 + max_attempts: 20 + command: mysql -h 127.0.0.1 -P 3306 -u root -proot opencoze < docker/volumes/mysql/schema.sql + - name: Run Go Test run: | modules=`find . -name "go.mod" -exec dirname {} \;` @@ -82,7 +86,7 @@ jobs: for module in $modules; do go work use $module; list=$module"/... "$list; coverpkg=$module"/...,"$coverpkg; done go work sync go test -race -v -coverprofile=${{ env.COVERAGE_FILE }} -gcflags="all=-l -N" -coverpkg=$coverpkg $list - + - name: Upload coverage to Codecov uses: codecov/codecov-action@v5 with: @@ -118,4 +122,4 @@ jobs: if [[ ! -f "go.work" ]];then go work init;fi for module in $modules; do go work use $module; list=$module"/... "$list; coverpkg=$module"/...,"$coverpkg; done go work sync - go test -race -v -bench=. -benchmem -run=none -gcflags="all=-l -N" $list \ No newline at end of file + go test -race -v -bench=. -benchmem -run=none -gcflags="all=-l -N" $list diff --git a/backend/api/handler/coze/workflow_service_test.go b/backend/api/handler/coze/workflow_service_test.go index bfbd621b..9091c014 100644 --- a/backend/api/handler/coze/workflow_service_test.go +++ b/backend/api/handler/coze/workflow_service_test.go @@ -43,14 +43,15 @@ import ( "github.com/cloudwego/hertz/pkg/common/ut" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/sse" - message0 "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message" - "github.com/coze-dev/coze-studio/backend/domain/workflow/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "gorm.io/driver/mysql" "gorm.io/gorm" + message0 "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message" + "github.com/coze-dev/coze-studio/backend/domain/workflow/config" + "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge" modelknowledge "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge" "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message" @@ -755,6 +756,7 @@ func (r *wfTestRunner) getProcess(id, exeID string, opts ...func(options *getPro var nodeType string var token *workflow.TokenAndCost var reason string + var count int for { if nodeEvent != nil { if options.previousInterruptEventID != "" { @@ -770,6 +772,10 @@ func (r *wfTestRunner) getProcess(id, exeID string, opts ...func(options *getPro break } + if count > 1000 { + r.t.Fatal("get process for too long") + } + getProcessResp := getProcess(r.t, r.h, id, exeID) if len(getProcessResp.Data.NodeResults) == 1 { output = getProcessResp.Data.NodeResults[0].Output @@ -803,6 +809,8 @@ func (r *wfTestRunner) getProcess(id, exeID string, opts ...func(options *getPro eventID = nodeEvent.ID } r.t.Logf("getProcess output= %s, status= %v, eventID= %s, nodeType= %s", output, workflowStatus, eventID, nodeType) + + count++ } return &exeResult{ @@ -1624,6 +1632,7 @@ func TestNestedSubWorkflowWithInterrupt(t *testing.T) { post[workflow.DeleteWorkflowResponse](r, &workflow.DeleteWorkflowRequest{ WorkflowID: topID, }) + time.Sleep(time.Second) }() midID := r.load("subworkflow/middle_workflow.json", withID(7494849202016272435)) @@ -1841,6 +1850,7 @@ func TestPublishWorkflow(t *testing.T) { WorkflowID: id, } _ = post[workflow.DeleteWorkflowResponse](r, deleteReq) + time.Sleep(time.Second) }) } @@ -1964,6 +1974,7 @@ func TestSimpleInvokableToolWithReturnVariables(t *testing.T) { post[workflow.DeleteWorkflowResponse](r, &workflow.DeleteWorkflowRequest{ WorkflowID: id, }) + time.Sleep(time.Second) }() exeID := r.testRun(id, map[string]string{ @@ -2628,6 +2639,7 @@ func TestListWorkflowAsToolData(t *testing.T) { WorkflowID: id, } _ = post[workflow.DeleteWorkflowResponse](r, deleteReq) + time.Sleep(time.Second) }) } @@ -2662,6 +2674,7 @@ func TestWorkflowDetailAndDetailInfo(t *testing.T) { WorkflowID: id, } _ = post[workflow.DeleteWorkflowResponse](r, deleteReq) + time.Sleep(time.Second) }) } @@ -4542,6 +4555,8 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) { assert.Equal(t, "v0.0.1", node.Data.Inputs.WorkflowVersion) } } + + time.Sleep(time.Second) }) }) diff --git a/backend/domain/workflow/entity/message.go b/backend/domain/workflow/entity/message.go index 2b6b9080..d24669ad 100644 --- a/backend/domain/workflow/entity/message.go +++ b/backend/domain/workflow/entity/message.go @@ -85,7 +85,6 @@ type ToolResponseInfo struct { FunctionInfo CallID string Response string - Complete bool } type ToolType = workflow.PluginType diff --git a/backend/domain/workflow/internal/compose/designate_option.go b/backend/domain/workflow/internal/compose/designate_option.go index 851b1630..1b350f1d 100644 --- a/backend/domain/workflow/internal/compose/designate_option.go +++ b/backend/domain/workflow/internal/compose/designate_option.go @@ -38,7 +38,7 @@ import ( "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" ) -func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context, []einoCompose.Option, error) { +func (r *WorkflowRunner) designateOptions(ctx context.Context) ([]einoCompose.Option, error) { var ( wb = r.basic exeCfg = r.config @@ -83,13 +83,13 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context, ns, string(key)) if err != nil { - return ctx, nil, err + return nil, err } opts = append(opts, subOpts...) } else if ns.Type == entity.NodeTypeLLM { llmNodeOpts, err := llmToolCallbackOptions(ctx, ns, eventChan, sw) if err != nil { - return ctx, nil, err + return nil, err } opts = append(opts, llmNodeOpts...) @@ -103,7 +103,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context, ns, string(key)) if err != nil { - return ctx, nil, err + return nil, err } for _, subO := range subOpts { opts = append(opts, WrapOpt(subO, parent.Key)) @@ -111,7 +111,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context, } else if ns.Type == entity.NodeTypeLLM { llmNodeOpts, err := llmToolCallbackOptions(ctx, ns, eventChan, sw) if err != nil { - return ctx, nil, err + return nil, err } for _, subO := range llmNodeOpts { opts = append(opts, WrapOpt(subO, parent.Key)) @@ -124,7 +124,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context, opts = append(opts, einoCompose.WithCheckPointID(strconv.FormatInt(executeID, 10))) } - return ctx, opts, nil + return opts, nil } func nodeCallbackOption(key vo.NodeKey, name string, eventChan chan *execute.Event, resumeEvent *entity.InterruptEvent, diff --git a/backend/domain/workflow/internal/compose/state.go b/backend/domain/workflow/internal/compose/state.go index 1acf90fc..042d6bdb 100644 --- a/backend/domain/workflow/internal/compose/state.go +++ b/backend/domain/workflow/internal/compose/state.go @@ -92,6 +92,7 @@ func init() { _ = compose.RegisterSerializableType[*schema.Message]("schema_message") _ = compose.RegisterSerializableType[*crossmessage.WfMessage]("history_messages") _ = compose.RegisterSerializableType[*crossmessage.Content]("content") + _ = compose.RegisterSerializableType[*model.PromptTokenDetails]("prompt_token_details") } diff --git a/backend/domain/workflow/internal/compose/workflow_run.go b/backend/domain/workflow/internal/compose/workflow_run.go index 082043fc..0268c2d6 100644 --- a/backend/domain/workflow/internal/compose/workflow_run.go +++ b/backend/domain/workflow/internal/compose/workflow_run.go @@ -167,7 +167,7 @@ func (r *WorkflowRunner) Prepare(ctx context.Context) ( }() } - ctx, composeOpts, err := r.designateOptions(ctx) + composeOpts, err := r.designateOptions(ctx) if err != nil { return ctx, 0, nil, nil, err } diff --git a/backend/domain/workflow/internal/compose/workflow_tool.go b/backend/domain/workflow/internal/compose/workflow_tool.go index 26cd0c0d..9fcc2d71 100644 --- a/backend/domain/workflow/internal/compose/workflow_tool.go +++ b/backend/domain/workflow/internal/compose/workflow_tool.go @@ -21,6 +21,7 @@ import ( "fmt" "strings" + "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/tool" einoCompose "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" @@ -84,8 +85,10 @@ func resumeOnce(rInfo *entity.ResumeRequest, callID string, allIEs map[string]*e } } -func (wt *workflowTool) prepare(ctx context.Context, rInfo *entity.ResumeRequest, argumentsInJSON string, opts ...tool.Option) ( - cancelCtx context.Context, executeID int64, input map[string]any, callOpts []einoCompose.Option, err error) { +func (wt *workflowTool) prepare(ctx context.Context, rInfo *entity.ResumeRequest, + argumentsInJSON string, opts ...tool.Option) ( + cancelCtx context.Context, executeID int64, input map[string]any, + lastEventChan <-chan *execute.Event, callOpts []einoCompose.Option, err error) { cfg := execute.GetExecuteConfig(opts...) var runOpts []WorkflowRunnerOption @@ -126,11 +129,12 @@ func (wt *workflowTool) prepare(ctx context.Context, rInfo *entity.ResumeRequest } } - cancelCtx, executeID, callOpts, _, err = NewWorkflowRunner(wt.wfEntity.GetBasic(), wt.sc, cfg, runOpts...).Prepare(ctx) + cancelCtx, executeID, callOpts, lastEventChan, err = NewWorkflowRunner(wt.wfEntity.GetBasic(), wt.sc, cfg, runOpts...).Prepare(ctx) return } -func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { +func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) ( + contentStr string, err error) { rInfo, allIEs := execute.GetResumeRequest(opts...) var ( previouslyInterrupted bool @@ -145,6 +149,18 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st } } + ctx = callbacks.OnStart(ctx, &tool.CallbackInput{ + ArgumentsInJSON: argumentsInJSON, + Extra: map[string]any{ + execute.ToolCallIDKey: callID, + }, + }) + defer func() { + if err != nil { + _ = callbacks.OnError(ctx, err) + } + }() + if previouslyInterrupted && rInfo.ExecuteID != previousExecuteID { logs.Infof("previous interrupted call ID: %s, previous execute ID: %d, current execute ID: %d. Not resuming, interrupt immediately", callID, previousExecuteID, rInfo.ExecuteID) return "", einoCompose.InterruptAndRerun @@ -152,7 +168,7 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st defer resumeOnce(rInfo, callID, allIEs) - cancelCtx, executeID, in, callOpts, err := i.prepare(ctx, rInfo, argumentsInJSON, opts...) + cancelCtx, executeID, in, _, callOpts, err := i.prepare(ctx, rInfo, argumentsInJSON, opts...) if err != nil { return "", err } @@ -179,7 +195,19 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st } if i.terminatePlan == vo.ReturnVariables { - return sonic.MarshalString(out) + contentStr, err = sonic.MarshalString(out) + if err != nil { + return "", err + } + + _ = callbacks.OnEnd(ctx, &tool.CallbackOutput{ + Response: contentStr, + Extra: map[string]any{ + execute.ToolCallIDKey: callID, + }, + }) + + return contentStr, nil } content, ok := out[answerKey] @@ -187,7 +215,7 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st return "", fmt.Errorf("no answer found when terminate plan is use answer content. out: %v", out) } - contentStr, ok := content.(string) + contentStr, ok = content.(string) if !ok { return "", fmt.Errorf("answer content is not string. content: %v", content) } @@ -196,6 +224,13 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st contentStr = strings.TrimSuffix(contentStr, nodes.KeyIsFinished) } + _ = callbacks.OnEnd(ctx, &tool.CallbackOutput{ + Response: contentStr, + Extra: map[string]any{ + execute.ToolCallIDKey: callID, + }, + }) + return contentStr, nil } @@ -207,6 +242,10 @@ func (i *invokableWorkflow) GetWorkflow() *entity.Workflow { return i.wfEntity } +func (i *invokableWorkflow) IsCallbacksEnabled() bool { + return true +} + type streamableWorkflow struct { workflowTool stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error) @@ -235,12 +274,14 @@ func (s *streamableWorkflow) Info(_ context.Context) (*schema.ToolInfo, error) { return s.info, nil } -func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) { +func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) ( + out *schema.StreamReader[string], err error) { rInfo, allIEs := execute.GetResumeRequest(opts...) var ( previouslyInterrupted bool callID = einoCompose.GetToolCallID(ctx) previousExecuteID int64 + toolFinishChan = make(chan struct{}) ) for interruptedCallID := range allIEs { if callID == interruptedCallID { @@ -250,6 +291,20 @@ func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON } } + ctx = callbacks.OnStart(ctx, &tool.CallbackInput{ + ArgumentsInJSON: argumentsInJSON, + Extra: map[string]any{ + execute.ToolCallIDKey: callID, + execute.ToolFinishChanKey: toolFinishChan, + }, + }) + defer func() { + if err != nil { + _ = callbacks.OnError(ctx, err) + close(toolFinishChan) + } + }() + if previouslyInterrupted && rInfo.ExecuteID != previousExecuteID { logs.Infof("previous interrupted call ID: %s, previous execute ID: %d, current execute ID: %d. Not resuming, interrupt immediately", callID, previousExecuteID, rInfo.ExecuteID) return nil, einoCompose.InterruptAndRerun @@ -257,7 +312,7 @@ func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON defer resumeOnce(rInfo, callID, allIEs) - cancelCtx, executeID, in, callOpts, err := s.prepare(ctx, rInfo, argumentsInJSON, opts...) + cancelCtx, executeID, in, lastEventChan, callOpts, err := s.prepare(ctx, rInfo, argumentsInJSON, opts...) if err != nil { return nil, err } @@ -283,22 +338,35 @@ func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON return nil, err } - return schema.StreamReaderWithConvert(outStream, func(in map[string]any) (string, error) { - content, ok := in["output"] - if !ok { - return "", fmt.Errorf("no output found when stream plan is use output content. out: %v", in) + go func() { + for range lastEventChan { } + close(toolFinishChan) + }() + + _, callbackStream := callbacks.OnEndWithStreamOutput(ctx, schema.StreamReaderWithConvert(outStream, + func(in map[string]any) (*tool.CallbackOutput, error) { + content, ok := in["output"] + if !ok { + return nil, fmt.Errorf("no output found when stream plan is use output content. out: %v", in) + } - contentStr, ok := content.(string) - if !ok { - return "", fmt.Errorf("output content is not string. content: %v", content) - } + contentStr, ok := content.(string) + if !ok { + return nil, fmt.Errorf("output content is not string. content: %v", content) + } - if strings.HasSuffix(contentStr, nodes.KeyIsFinished) { - contentStr = strings.TrimSuffix(contentStr, nodes.KeyIsFinished) - } + if strings.HasSuffix(contentStr, nodes.KeyIsFinished) { + contentStr = strings.TrimSuffix(contentStr, nodes.KeyIsFinished) + } - return contentStr, nil + return &tool.CallbackOutput{ + Response: contentStr, + }, nil + })) + + return schema.StreamReaderWithConvert(callbackStream, func(in *tool.CallbackOutput) (string, error) { + return in.Response, nil }), nil } @@ -309,3 +377,7 @@ func (s *streamableWorkflow) TerminatePlan() vo.TerminatePlan { func (s *streamableWorkflow) GetWorkflow() *entity.Workflow { return s.wfEntity } + +func (s *streamableWorkflow) IsCallbacksEnabled() bool { + return true +} diff --git a/backend/domain/workflow/internal/execute/callback.go b/backend/domain/workflow/internal/execute/callback.go index a5c3696d..81a66b90 100644 --- a/backend/domain/workflow/internal/execute/callback.go +++ b/backend/domain/workflow/internal/execute/callback.go @@ -370,10 +370,6 @@ func (w *WorkflowHandler) OnError(ctx context.Context, info *callbacks.RunInfo, interruptEvent.EventType, interruptEvent.NodeKey) } - if c.TokenCollector != nil { // wait until all streaming chunks are collected - _ = c.TokenCollector.wait() - } - done := make(chan struct{}) w.ch <- &Event{ @@ -1271,6 +1267,11 @@ func (n *NodeHandler) OnEndWithStreamOutput(ctx context.Context, info *callbacks return ctx } +const ( + ToolCallIDKey = "call_id" + ToolFinishChanKey = "tool_finish_chan" +) + func (t *ToolHandler) OnStart(ctx context.Context, info *callbacks.RunInfo, input *tool.CallbackInput, ) context.Context { @@ -1286,13 +1287,35 @@ func (t *ToolHandler) OnStart(ctx context.Context, info *callbacks.RunInfo, } } + var ( + callID string + toolFinishChan chan struct{} + ) + if input.Extra != nil { + callIDAny, ok := input.Extra[ToolCallIDKey] + if ok { + callID = callIDAny.(string) + } + toolFinishChanAny, ok := input.Extra[ToolFinishChanKey] + if ok { + toolFinishChan = toolFinishChanAny.(chan struct{}) + } + } + + if len(callID) == 0 { + callID = compose.GetToolCallID(ctx) + } + t.ch <- &Event{ Type: FunctionCall, Context: GetExeCtx(ctx), - functionCall: &entity.FunctionCallInfo{ - FunctionInfo: t.info, - CallID: compose.GetToolCallID(ctx), - Arguments: args, + functionCall: &FunctionCallInfo{ + FunctionCallInfo: &entity.FunctionCallInfo{ + FunctionInfo: t.info, + CallID: callID, + Arguments: args, + }, + toolFinishChan: toolFinishChan, }, } @@ -1306,14 +1329,25 @@ func (t *ToolHandler) OnEnd(ctx context.Context, info *callbacks.RunInfo, return ctx } + var callID string + if output.Extra != nil { + callIDAny, ok := output.Extra[ToolCallIDKey] + if ok { + callID = callIDAny.(string) + } + } + + if len(callID) == 0 { + callID = compose.GetToolCallID(ctx) + } + t.ch <- &Event{ Type: ToolResponse, Context: GetExeCtx(ctx), toolResponse: &entity.ToolResponseInfo{ FunctionInfo: t.info, - CallID: compose.GetToolCallID(ctx), + CallID: callID, Response: output.Response, - Complete: true, }, } @@ -1352,7 +1386,6 @@ func (t *ToolHandler) OnEndWithStreamOutput(ctx context.Context, info *callbacks toolResponse: &entity.ToolResponseInfo{ FunctionInfo: t.info, CallID: callID, - Complete: true, }, } } @@ -1374,7 +1407,7 @@ func (t *ToolHandler) OnEndWithStreamOutput(ctx context.Context, info *callbacks Context: c, toolResponse: &entity.ToolResponseInfo{ FunctionInfo: t.info, - CallID: compose.GetToolCallID(ctx), + CallID: callID, Response: chunk.Response, }, } @@ -1398,9 +1431,11 @@ func (t *ToolHandler) OnError(ctx context.Context, info *callbacks.RunInfo, err t.ch <- &Event{ Type: ToolError, Context: GetExeCtx(ctx), - functionCall: &entity.FunctionCallInfo{ - FunctionInfo: t.info, - CallID: compose.GetToolCallID(ctx), + functionCall: &FunctionCallInfo{ + FunctionCallInfo: &entity.FunctionCallInfo{ + FunctionInfo: t.info, + CallID: compose.GetToolCallID(ctx), + }, }, Err: err, } diff --git a/backend/domain/workflow/internal/execute/collect_token.go b/backend/domain/workflow/internal/execute/collect_token.go index 8fc7aed2..7a24788f 100644 --- a/backend/domain/workflow/internal/execute/collect_token.go +++ b/backend/domain/workflow/internal/execute/collect_token.go @@ -26,6 +26,7 @@ import ( callbacks2 "github.com/cloudwego/eino/utils/callbacks" "github.com/coze-dev/coze-studio/backend/pkg/safego" + "github.com/coze-dev/coze-studio/backend/pkg/sonic" ) type TokenCollector struct { @@ -90,6 +91,33 @@ func (t *TokenCollector) finishStreamCounting() { } } +type tokenCollector struct { + Key string + Usage *model.TokenUsage + Parent *TokenCollector +} + +func (t *TokenCollector) MarshalJSON() ([]byte, error) { + t.wait() + return sonic.Marshal(&tokenCollector{ + Key: t.Key, + Usage: t.Usage, + Parent: t.Parent, + }) +} + +func (t *TokenCollector) UnmarshalJSON(bytes []byte) error { + tc := &tokenCollector{} + if err := sonic.Unmarshal(bytes, tc); err != nil { + return err + } + + t.Key = tc.Key + t.Usage = tc.Usage + t.Parent = tc.Parent + return nil +} + func getTokenCollector(ctx context.Context) *TokenCollector { c := GetExeCtx(ctx) if c == nil { diff --git a/backend/domain/workflow/internal/execute/event.go b/backend/domain/workflow/internal/execute/event.go index 301e99eb..95326556 100644 --- a/backend/domain/workflow/internal/execute/event.go +++ b/backend/domain/workflow/internal/execute/event.go @@ -64,7 +64,7 @@ type Event struct { InterruptEvents []*entity.InterruptEvent - functionCall *entity.FunctionCallInfo + functionCall *FunctionCallInfo toolResponse *entity.ToolResponseInfo outputExtractor func(o map[string]any) string @@ -75,6 +75,11 @@ type Event struct { nodeCount int32 } +type FunctionCallInfo struct { + *entity.FunctionCallInfo + toolFinishChan chan struct{} +} + type TokenInfo struct { InputToken int64 OutputToken int64 @@ -104,17 +109,3 @@ func (e *Event) GetResumedEventID() int64 { } return e.Context.RootCtx.ResumeEvent.ID } - -func (e *Event) GetFunctionCallInfo() (*entity.FunctionCallInfo, bool) { - if e.functionCall == nil { - return nil, false - } - return e.functionCall, true -} - -func (e *Event) GetToolResponse() (*entity.ToolResponseInfo, bool) { - if e.toolResponse == nil { - return nil, false - } - return e.toolResponse, true -} diff --git a/backend/domain/workflow/internal/execute/event_handle.go b/backend/domain/workflow/internal/execute/event_handle.go index ec07112a..fde75ed9 100644 --- a/backend/domain/workflow/internal/execute/event_handle.go +++ b/backend/domain/workflow/internal/execute/event_handle.go @@ -672,7 +672,7 @@ func handleEvent(ctx context.Context, event *Event, repo workflow.Repository, ExecuteID: event.RootExecuteID, Role: schema.Assistant, Type: entity.FunctionCall, - FunctionCall: event.functionCall, + FunctionCall: event.functionCall.FunctionCallInfo, }, }, nil) case ToolResponse: @@ -704,8 +704,6 @@ func handleEvent(ctx context.Context, event *Event, repo workflow.Repository, }, }, nil) case ToolError: - // TODO: optimize this log - logs.CtxErrorf(ctx, "received tool error event: %v", event) default: panic("unimplemented event type: " + event.Type) } @@ -715,8 +713,9 @@ func handleEvent(ctx context.Context, event *Event, repo workflow.Repository, type fcCacheKey struct{} type fcInfo struct { - input *entity.FunctionCallInfo - output *entity.ToolResponseInfo + input *entity.FunctionCallInfo + output *entity.ToolResponseInfo + toolFinishChan chan struct{} } func HandleExecuteEvent(ctx context.Context, @@ -772,7 +771,8 @@ func HandleExecuteEvent(ctx context.Context, lastNodeIsDone = true if wfSuccessEvent != nil { if err = setRootWorkflowSuccess(ctx, wfSuccessEvent, repo, sw); err != nil { - logs.CtxErrorf(ctx, "failed to set root workflow success: %v", err) + logs.CtxErrorf(ctx, "failed to set root workflow success for workflow %d: %v", + wfSuccessEvent.RootWorkflowBasic.ID, err) } return wfSuccessEvent } @@ -786,10 +786,12 @@ func HandleExecuteEvent(ctx context.Context, // Add cancellation check timer cancelTicker := time.NewTicker(cancelCheckInterval) defer func() { - logs.CtxInfof(ctx, "[handleExecuteEvent] finish, returned event type: %v, workflow id: %d", + logs.CtxInfof(ctx, "[handleExecuteEvent] cancellable finish, returned event type: %v, workflow id: %d", event.Type, event.Context.RootWorkflowBasic.ID) - cancelTicker.Stop() // Clean up timer waitUntilToolFinish(ctx) + logs.CtxInfof(ctx, "[handleExecuteEvent] cancellable wait until tool finished done, workflow id: %d", + event.Context.RootWorkflowBasic.ID) + cancelTicker.Stop() // Clean up timer if timeoutFn != nil { timeoutFn() } @@ -825,6 +827,9 @@ func HandleExecuteEvent(ctx context.Context, defer func() { logs.CtxInfof(ctx, "[handleExecuteEvent] finish, returned event type: %v, workflow id: %d", event.Type, event.Context.RootWorkflowBasic.ID) + waitUntilToolFinish(ctx) + logs.CtxInfof(ctx, "[handleExecuteEvent] wait until tool finished done, workflow id: %d", + event.Context.RootWorkflowBasic.ID) if timeoutFn != nil { timeoutFn() } @@ -859,29 +864,26 @@ func cacheFunctionCall(ctx context.Context, event *Event) { c[event.NodeKey] = make(map[string]*fcInfo) } c[event.NodeKey][event.functionCall.CallID] = &fcInfo{ - input: event.functionCall, + input: event.functionCall.FunctionCallInfo, + toolFinishChan: event.functionCall.toolFinishChan, } } func cacheToolResponse(ctx context.Context, event *Event) { c := ctx.Value(fcCacheKey{}).(map[vo.NodeKey]map[string]*fcInfo) - if _, ok := c[event.NodeKey]; !ok { - c[event.NodeKey] = make(map[string]*fcInfo) - } - c[event.NodeKey][event.toolResponse.CallID].output = event.toolResponse } func cacheToolStreamingResponse(ctx context.Context, event *Event) { c := ctx.Value(fcCacheKey{}).(map[vo.NodeKey]map[string]*fcInfo) - if _, ok := c[event.NodeKey]; !ok { - c[event.NodeKey] = make(map[string]*fcInfo) - } if c[event.NodeKey][event.toolResponse.CallID].output == nil { c[event.NodeKey][event.toolResponse.CallID].output = event.toolResponse + } else { + c[event.NodeKey][event.toolResponse.CallID].output.Response += event.toolResponse.Response } - c[event.NodeKey][event.toolResponse.CallID].output.Response += event.toolResponse.Response - c[event.NodeKey][event.toolResponse.CallID].output.Complete = event.toolResponse.Complete + + logs.CtxInfof(ctx, "receive tool response: %s, callID: %s", + event.toolResponse.Response, event.toolResponse.CallID) } func getFCInfos(ctx context.Context, nodeKey vo.NodeKey) map[string]*fcInfo { @@ -890,29 +892,17 @@ func getFCInfos(ctx context.Context, nodeKey vo.NodeKey) map[string]*fcInfo { } func waitUntilToolFinish(ctx context.Context) { - var cnt int -outer: - for { - if cnt > 1000 { - return - } - - c := ctx.Value(fcCacheKey{}).(map[vo.NodeKey]map[string]*fcInfo) - if len(c) == 0 { - return - } - - for _, m := range c { - for _, info := range m { - if info.output == nil { - cnt++ - continue outer - } + c := ctx.Value(fcCacheKey{}).(map[vo.NodeKey]map[string]*fcInfo) + if len(c) == 0 { + return + } - if !info.output.Complete { - cnt++ - continue outer - } + for _, m := range c { + for _, info := range m { + if info.toolFinishChan != nil { + <-info.toolFinishChan + logs.CtxInfof(ctx, "tool finished, callID: %s, pluginID: %v", info.output.CallID, + info.input.PluginID) } } } diff --git a/backend/domain/workflow/internal/execute/tool_option.go b/backend/domain/workflow/internal/execute/tool_option.go index 05e4cebb..b1be4cd3 100644 --- a/backend/domain/workflow/internal/execute/tool_option.go +++ b/backend/domain/workflow/internal/execute/tool_option.go @@ -26,11 +26,10 @@ import ( ) type workflowToolOption struct { - resumeReq *entity.ResumeRequest - streamContainer *StreamContainer - exeCfg workflowModel.ExecuteConfig - allInterruptEvents map[string]*entity.ToolInterruptEvent - parentTokenCollector *TokenCollector + resumeReq *entity.ResumeRequest + streamContainer *StreamContainer + exeCfg workflowModel.ExecuteConfig + allInterruptEvents map[string]*entity.ToolInterruptEvent } func WithResume(req *entity.ResumeRequest, all map[string]*entity.ToolInterruptEvent) tool.Option { diff --git a/backend/go.mod b/backend/go.mod index b289a322..9dd719e0 100755 --- a/backend/go.mod +++ b/backend/go.mod @@ -17,7 +17,7 @@ require ( github.com/bytedance/gopkg v0.1.3 github.com/bytedance/mockey v1.2.14 github.com/bytedance/sonic v1.14.0 - github.com/cloudwego/eino v0.3.55 + github.com/cloudwego/eino v0.4.8 github.com/cloudwego/eino-ext/components/embedding/ark v0.1.0 github.com/cloudwego/eino-ext/components/embedding/gemini v0.0.0-20250814083140-54b99ff82f8e github.com/cloudwego/eino-ext/components/embedding/ollama v0.0.0-20250728060543-79ec300857b8 @@ -284,3 +284,10 @@ require ( sigs.k8s.io/yaml v1.3.0 // indirect stathat.com/c/consistent v1.0.0 // indirect ) + +require ( + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/eino-contrib/jsonschema v1.0.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect +) diff --git a/backend/go.sum b/backend/go.sum index 5974e599..2a12eaac 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -131,6 +131,8 @@ github.com/aws/smithy-go v1.8.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAm github.com/aws/smithy-go v1.22.4 h1:uqXzVZNuNexwc/xrh6Tb56u89WDlJY6HS+KC0S4QSjw= github.com/aws/smithy-go v1.22.4/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= @@ -146,6 +148,8 @@ github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8= github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE= github.com/bytedance/go-tagexpr/v2 v2.9.2/go.mod h1:5qsx05dYOiUXOUgnQ7w3Oz8BYs2qtM/bJokdLb79wRM= @@ -190,8 +194,8 @@ github.com/clbanning/mxj v1.8.4/go.mod h1:BVjHeAH+rl9rs6f+QIpeRl0tfu10SXn1pUSa5P github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4= github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= -github.com/cloudwego/eino v0.3.55 h1:lMZrGtEh0k3qykQTLNXSXuAa98OtF2tS43GMHyvN7nA= -github.com/cloudwego/eino v0.3.55/go.mod h1:wUjz990apdsaOraOXdh6CdhVXq8DJsOvLsVlxNTcNfY= +github.com/cloudwego/eino v0.4.8 h1:wptTU24tQad1mFCHw0+4zSzH+p8dLEBk6HtggPlcvP0= +github.com/cloudwego/eino v0.4.8/go.mod h1:1TDlOmwGSsbCJaWB92w9YLZi2FL0WRZoRcD4eMvqikg= github.com/cloudwego/eino-ext/components/embedding/ark v0.1.0 h1:AuJsMdaTXc+dGUDQp82MifLYK8oiJf4gLQPUETmKISM= github.com/cloudwego/eino-ext/components/embedding/ark v0.1.0/go.mod h1:0FZG/KRBl3hGWkNsm55UaXyVa6PDVIy5u+QvboAB+cY= github.com/cloudwego/eino-ext/components/embedding/gemini v0.0.0-20250814083140-54b99ff82f8e h1:46D2fFDbUysA7kUD5x/wK3huneMEvTQfuWcHqI3M6iQ= @@ -288,6 +292,8 @@ github.com/eapache/go-xerial-snappy v0.0.0-20230731223053-c322873962e3/go.mod h1 github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc= github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= github.com/edsrzf/mmap-go v1.0.0/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M= +github.com/eino-contrib/jsonschema v1.0.0 h1:dXxbhGNZuI3+xNi8x3JT8AGyoXz6Pff6mRvmpjVl5Ww= +github.com/eino-contrib/jsonschema v1.0.0/go.mod h1:cpnX4SyKjWjGC7iN2EbhxaTdLqGjCi0e9DxpLYxddD4= github.com/eknkc/amber v0.0.0-20171010120322-cdade1c07385/go.mod h1:0vRUJqYpeSZifjYj7uP3BG/gKcuzL9xWVV/Y+cK33KM= github.com/elastic/elastic-transport-go/v8 v8.7.0 h1:OgTneVuXP2uip4BA658Xi6Hfw+PeIOod2rY3GVMGoVE= github.com/elastic/elastic-transport-go/v8 v8.7.0/go.mod h1:YLHer5cj0csTzNFXoNQ8qhtGY1GTvSqPnKWKaqQE3Hk= @@ -1056,6 +1062,8 @@ github.com/volcengine/volc-sdk-golang v1.0.211 h1:FgwD+1phyy+un4Qk2YqooYtp6XpvND github.com/volcengine/volc-sdk-golang v1.0.211/go.mod h1:stZX+EPgv1vF4nZwOlEe8iGcriUPRBKX8zA19gXycOQ= github.com/volcengine/volcengine-go-sdk v1.1.20 h1:+ifZdF7IIIagqF8yVNfk9CmNUl5wgRfU/8orlH+JQhA= github.com/volcengine/volcengine-go-sdk v1.1.20/go.mod h1:EyKoi6t6eZxoPNGr2GdFCZti2Skd7MO3eUzx7TtSvNo= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg= github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7VPsoEPHyzalCE06qoARUCeBBE= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= diff --git a/backend/internal/testutil/chat_model.go b/backend/internal/testutil/chat_model.go index 49899f83..6f7fba4b 100644 --- a/backend/internal/testutil/chat_model.go +++ b/backend/internal/testutil/chat_model.go @@ -72,7 +72,14 @@ func (q *UTChatModel) Generate(ctx context.Context, in []*schema.Message, _ ...m } if msg.ResponseMeta != nil { - callbackOut.TokenUsage = (*model.TokenUsage)(msg.ResponseMeta.Usage) + callbackOut.TokenUsage = &model.TokenUsage{ + PromptTokens: msg.ResponseMeta.Usage.PromptTokens, + PromptTokenDetails: model.PromptTokenDetails{ + CachedTokens: msg.ResponseMeta.Usage.PromptTokenDetails.CachedTokens, + }, + CompletionTokens: msg.ResponseMeta.Usage.CompletionTokens, + TotalTokens: msg.ResponseMeta.Usage.TotalTokens, + } } _ = callbacks.OnEnd(ctx, callbackOut) @@ -112,7 +119,14 @@ func (q *UTChatModel) Stream(ctx context.Context, in []*schema.Message, _ ...mod } if t.ResponseMeta != nil { - callbackOut.TokenUsage = (*model.TokenUsage)(t.ResponseMeta.Usage) + callbackOut.TokenUsage = &model.TokenUsage{ + PromptTokens: t.ResponseMeta.Usage.PromptTokens, + PromptTokenDetails: model.PromptTokenDetails{ + CachedTokens: t.ResponseMeta.Usage.PromptTokenDetails.CachedTokens, + }, + CompletionTokens: t.ResponseMeta.Usage.CompletionTokens, + TotalTokens: t.ResponseMeta.Usage.TotalTokens, + } } return callbackOut, nil