fix: update eino to fix the issue when field mappings are resolved re… (#1952)

main
shentongmartin 2 months ago committed by GitHub
parent 6c8dd3a44c
commit 8db69fd6a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 16
      .github/workflows/ci@backend.yml
  2. 19
      backend/api/handler/coze/workflow_service_test.go
  3. 1
      backend/domain/workflow/entity/message.go
  4. 12
      backend/domain/workflow/internal/compose/designate_option.go
  5. 1
      backend/domain/workflow/internal/compose/state.go
  6. 2
      backend/domain/workflow/internal/compose/workflow_run.go
  7. 114
      backend/domain/workflow/internal/compose/workflow_tool.go
  8. 65
      backend/domain/workflow/internal/execute/callback.go
  9. 28
      backend/domain/workflow/internal/execute/collect_token.go
  10. 21
      backend/domain/workflow/internal/execute/event.go
  11. 70
      backend/domain/workflow/internal/execute/event_handle.go
  12. 9
      backend/domain/workflow/internal/execute/tool_option.go
  13. 9
      backend/go.mod
  14. 12
      backend/go.sum
  15. 18
      backend/internal/testutil/chat_model.go

@ -31,7 +31,7 @@ jobs:
env:
COVERAGE_FILE: coverage.out
BREAKDOWN_FILE: main.breakdown
steps:
- uses: actions/checkout@v4
- name: Set up Go
@ -52,7 +52,7 @@ jobs:
mysql version: '8.4.5'
mysql database: 'opencoze'
mysql root password: 'root'
- name: Verify MySQL Startup
run: |
echo "Waiting for MySQL to be ready..."
@ -70,8 +70,12 @@ jobs:
run: sudo apt-get update && sudo apt-get install -y mysql-client
- name: Initialize Database
run: mysql -h 127.0.0.1 -P 3306 -u root -proot opencoze < docker/volumes/mysql/schema.sql
uses: nick-fields/retry@v3
with:
timeout_minutes: 10
max_attempts: 20
command: mysql -h 127.0.0.1 -P 3306 -u root -proot opencoze < docker/volumes/mysql/schema.sql
- name: Run Go Test
run: |
modules=`find . -name "go.mod" -exec dirname {} \;`
@ -82,7 +86,7 @@ jobs:
for module in $modules; do go work use $module; list=$module"/... "$list; coverpkg=$module"/...,"$coverpkg; done
go work sync
go test -race -v -coverprofile=${{ env.COVERAGE_FILE }} -gcflags="all=-l -N" -coverpkg=$coverpkg $list
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
@ -118,4 +122,4 @@ jobs:
if [[ ! -f "go.work" ]];then go work init;fi
for module in $modules; do go work use $module; list=$module"/... "$list; coverpkg=$module"/...,"$coverpkg; done
go work sync
go test -race -v -bench=. -benchmem -run=none -gcflags="all=-l -N" $list
go test -race -v -bench=. -benchmem -run=none -gcflags="all=-l -N" $list

@ -43,14 +43,15 @@ import (
"github.com/cloudwego/hertz/pkg/common/ut"
"github.com/cloudwego/hertz/pkg/protocol"
"github.com/cloudwego/hertz/pkg/protocol/sse"
message0 "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
"github.com/coze-dev/coze-studio/backend/domain/workflow/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"gorm.io/driver/mysql"
"gorm.io/gorm"
message0 "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
"github.com/coze-dev/coze-studio/backend/domain/workflow/config"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
modelknowledge "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
@ -755,6 +756,7 @@ func (r *wfTestRunner) getProcess(id, exeID string, opts ...func(options *getPro
var nodeType string
var token *workflow.TokenAndCost
var reason string
var count int
for {
if nodeEvent != nil {
if options.previousInterruptEventID != "" {
@ -770,6 +772,10 @@ func (r *wfTestRunner) getProcess(id, exeID string, opts ...func(options *getPro
break
}
if count > 1000 {
r.t.Fatal("get process for too long")
}
getProcessResp := getProcess(r.t, r.h, id, exeID)
if len(getProcessResp.Data.NodeResults) == 1 {
output = getProcessResp.Data.NodeResults[0].Output
@ -803,6 +809,8 @@ func (r *wfTestRunner) getProcess(id, exeID string, opts ...func(options *getPro
eventID = nodeEvent.ID
}
r.t.Logf("getProcess output= %s, status= %v, eventID= %s, nodeType= %s", output, workflowStatus, eventID, nodeType)
count++
}
return &exeResult{
@ -1624,6 +1632,7 @@ func TestNestedSubWorkflowWithInterrupt(t *testing.T) {
post[workflow.DeleteWorkflowResponse](r, &workflow.DeleteWorkflowRequest{
WorkflowID: topID,
})
time.Sleep(time.Second)
}()
midID := r.load("subworkflow/middle_workflow.json", withID(7494849202016272435))
@ -1841,6 +1850,7 @@ func TestPublishWorkflow(t *testing.T) {
WorkflowID: id,
}
_ = post[workflow.DeleteWorkflowResponse](r, deleteReq)
time.Sleep(time.Second)
})
}
@ -1964,6 +1974,7 @@ func TestSimpleInvokableToolWithReturnVariables(t *testing.T) {
post[workflow.DeleteWorkflowResponse](r, &workflow.DeleteWorkflowRequest{
WorkflowID: id,
})
time.Sleep(time.Second)
}()
exeID := r.testRun(id, map[string]string{
@ -2628,6 +2639,7 @@ func TestListWorkflowAsToolData(t *testing.T) {
WorkflowID: id,
}
_ = post[workflow.DeleteWorkflowResponse](r, deleteReq)
time.Sleep(time.Second)
})
}
@ -2662,6 +2674,7 @@ func TestWorkflowDetailAndDetailInfo(t *testing.T) {
WorkflowID: id,
}
_ = post[workflow.DeleteWorkflowResponse](r, deleteReq)
time.Sleep(time.Second)
})
}
@ -4542,6 +4555,8 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
assert.Equal(t, "v0.0.1", node.Data.Inputs.WorkflowVersion)
}
}
time.Sleep(time.Second)
})
})

@ -85,7 +85,6 @@ type ToolResponseInfo struct {
FunctionInfo
CallID string
Response string
Complete bool
}
type ToolType = workflow.PluginType

@ -38,7 +38,7 @@ import (
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context, []einoCompose.Option, error) {
func (r *WorkflowRunner) designateOptions(ctx context.Context) ([]einoCompose.Option, error) {
var (
wb = r.basic
exeCfg = r.config
@ -83,13 +83,13 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
ns,
string(key))
if err != nil {
return ctx, nil, err
return nil, err
}
opts = append(opts, subOpts...)
} else if ns.Type == entity.NodeTypeLLM {
llmNodeOpts, err := llmToolCallbackOptions(ctx, ns, eventChan, sw)
if err != nil {
return ctx, nil, err
return nil, err
}
opts = append(opts, llmNodeOpts...)
@ -103,7 +103,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
ns,
string(key))
if err != nil {
return ctx, nil, err
return nil, err
}
for _, subO := range subOpts {
opts = append(opts, WrapOpt(subO, parent.Key))
@ -111,7 +111,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
} else if ns.Type == entity.NodeTypeLLM {
llmNodeOpts, err := llmToolCallbackOptions(ctx, ns, eventChan, sw)
if err != nil {
return ctx, nil, err
return nil, err
}
for _, subO := range llmNodeOpts {
opts = append(opts, WrapOpt(subO, parent.Key))
@ -124,7 +124,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
opts = append(opts, einoCompose.WithCheckPointID(strconv.FormatInt(executeID, 10)))
}
return ctx, opts, nil
return opts, nil
}
func nodeCallbackOption(key vo.NodeKey, name string, eventChan chan *execute.Event, resumeEvent *entity.InterruptEvent,

@ -92,6 +92,7 @@ func init() {
_ = compose.RegisterSerializableType[*schema.Message]("schema_message")
_ = compose.RegisterSerializableType[*crossmessage.WfMessage]("history_messages")
_ = compose.RegisterSerializableType[*crossmessage.Content]("content")
_ = compose.RegisterSerializableType[*model.PromptTokenDetails]("prompt_token_details")
}

@ -167,7 +167,7 @@ func (r *WorkflowRunner) Prepare(ctx context.Context) (
}()
}
ctx, composeOpts, err := r.designateOptions(ctx)
composeOpts, err := r.designateOptions(ctx)
if err != nil {
return ctx, 0, nil, nil, err
}

@ -21,6 +21,7 @@ import (
"fmt"
"strings"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components/tool"
einoCompose "github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
@ -84,8 +85,10 @@ func resumeOnce(rInfo *entity.ResumeRequest, callID string, allIEs map[string]*e
}
}
func (wt *workflowTool) prepare(ctx context.Context, rInfo *entity.ResumeRequest, argumentsInJSON string, opts ...tool.Option) (
cancelCtx context.Context, executeID int64, input map[string]any, callOpts []einoCompose.Option, err error) {
func (wt *workflowTool) prepare(ctx context.Context, rInfo *entity.ResumeRequest,
argumentsInJSON string, opts ...tool.Option) (
cancelCtx context.Context, executeID int64, input map[string]any,
lastEventChan <-chan *execute.Event, callOpts []einoCompose.Option, err error) {
cfg := execute.GetExecuteConfig(opts...)
var runOpts []WorkflowRunnerOption
@ -126,11 +129,12 @@ func (wt *workflowTool) prepare(ctx context.Context, rInfo *entity.ResumeRequest
}
}
cancelCtx, executeID, callOpts, _, err = NewWorkflowRunner(wt.wfEntity.GetBasic(), wt.sc, cfg, runOpts...).Prepare(ctx)
cancelCtx, executeID, callOpts, lastEventChan, err = NewWorkflowRunner(wt.wfEntity.GetBasic(), wt.sc, cfg, runOpts...).Prepare(ctx)
return
}
func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (
contentStr string, err error) {
rInfo, allIEs := execute.GetResumeRequest(opts...)
var (
previouslyInterrupted bool
@ -145,6 +149,18 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st
}
}
ctx = callbacks.OnStart(ctx, &tool.CallbackInput{
ArgumentsInJSON: argumentsInJSON,
Extra: map[string]any{
execute.ToolCallIDKey: callID,
},
})
defer func() {
if err != nil {
_ = callbacks.OnError(ctx, err)
}
}()
if previouslyInterrupted && rInfo.ExecuteID != previousExecuteID {
logs.Infof("previous interrupted call ID: %s, previous execute ID: %d, current execute ID: %d. Not resuming, interrupt immediately", callID, previousExecuteID, rInfo.ExecuteID)
return "", einoCompose.InterruptAndRerun
@ -152,7 +168,7 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st
defer resumeOnce(rInfo, callID, allIEs)
cancelCtx, executeID, in, callOpts, err := i.prepare(ctx, rInfo, argumentsInJSON, opts...)
cancelCtx, executeID, in, _, callOpts, err := i.prepare(ctx, rInfo, argumentsInJSON, opts...)
if err != nil {
return "", err
}
@ -179,7 +195,19 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st
}
if i.terminatePlan == vo.ReturnVariables {
return sonic.MarshalString(out)
contentStr, err = sonic.MarshalString(out)
if err != nil {
return "", err
}
_ = callbacks.OnEnd(ctx, &tool.CallbackOutput{
Response: contentStr,
Extra: map[string]any{
execute.ToolCallIDKey: callID,
},
})
return contentStr, nil
}
content, ok := out[answerKey]
@ -187,7 +215,7 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st
return "", fmt.Errorf("no answer found when terminate plan is use answer content. out: %v", out)
}
contentStr, ok := content.(string)
contentStr, ok = content.(string)
if !ok {
return "", fmt.Errorf("answer content is not string. content: %v", content)
}
@ -196,6 +224,13 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st
contentStr = strings.TrimSuffix(contentStr, nodes.KeyIsFinished)
}
_ = callbacks.OnEnd(ctx, &tool.CallbackOutput{
Response: contentStr,
Extra: map[string]any{
execute.ToolCallIDKey: callID,
},
})
return contentStr, nil
}
@ -207,6 +242,10 @@ func (i *invokableWorkflow) GetWorkflow() *entity.Workflow {
return i.wfEntity
}
func (i *invokableWorkflow) IsCallbacksEnabled() bool {
return true
}
type streamableWorkflow struct {
workflowTool
stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error)
@ -235,12 +274,14 @@ func (s *streamableWorkflow) Info(_ context.Context) (*schema.ToolInfo, error) {
return s.info, nil
}
func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) {
func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (
out *schema.StreamReader[string], err error) {
rInfo, allIEs := execute.GetResumeRequest(opts...)
var (
previouslyInterrupted bool
callID = einoCompose.GetToolCallID(ctx)
previousExecuteID int64
toolFinishChan = make(chan struct{})
)
for interruptedCallID := range allIEs {
if callID == interruptedCallID {
@ -250,6 +291,20 @@ func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON
}
}
ctx = callbacks.OnStart(ctx, &tool.CallbackInput{
ArgumentsInJSON: argumentsInJSON,
Extra: map[string]any{
execute.ToolCallIDKey: callID,
execute.ToolFinishChanKey: toolFinishChan,
},
})
defer func() {
if err != nil {
_ = callbacks.OnError(ctx, err)
close(toolFinishChan)
}
}()
if previouslyInterrupted && rInfo.ExecuteID != previousExecuteID {
logs.Infof("previous interrupted call ID: %s, previous execute ID: %d, current execute ID: %d. Not resuming, interrupt immediately", callID, previousExecuteID, rInfo.ExecuteID)
return nil, einoCompose.InterruptAndRerun
@ -257,7 +312,7 @@ func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON
defer resumeOnce(rInfo, callID, allIEs)
cancelCtx, executeID, in, callOpts, err := s.prepare(ctx, rInfo, argumentsInJSON, opts...)
cancelCtx, executeID, in, lastEventChan, callOpts, err := s.prepare(ctx, rInfo, argumentsInJSON, opts...)
if err != nil {
return nil, err
}
@ -283,22 +338,35 @@ func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON
return nil, err
}
return schema.StreamReaderWithConvert(outStream, func(in map[string]any) (string, error) {
content, ok := in["output"]
if !ok {
return "", fmt.Errorf("no output found when stream plan is use output content. out: %v", in)
go func() {
for range lastEventChan {
}
close(toolFinishChan)
}()
_, callbackStream := callbacks.OnEndWithStreamOutput(ctx, schema.StreamReaderWithConvert(outStream,
func(in map[string]any) (*tool.CallbackOutput, error) {
content, ok := in["output"]
if !ok {
return nil, fmt.Errorf("no output found when stream plan is use output content. out: %v", in)
}
contentStr, ok := content.(string)
if !ok {
return "", fmt.Errorf("output content is not string. content: %v", content)
}
contentStr, ok := content.(string)
if !ok {
return nil, fmt.Errorf("output content is not string. content: %v", content)
}
if strings.HasSuffix(contentStr, nodes.KeyIsFinished) {
contentStr = strings.TrimSuffix(contentStr, nodes.KeyIsFinished)
}
if strings.HasSuffix(contentStr, nodes.KeyIsFinished) {
contentStr = strings.TrimSuffix(contentStr, nodes.KeyIsFinished)
}
return contentStr, nil
return &tool.CallbackOutput{
Response: contentStr,
}, nil
}))
return schema.StreamReaderWithConvert(callbackStream, func(in *tool.CallbackOutput) (string, error) {
return in.Response, nil
}), nil
}
@ -309,3 +377,7 @@ func (s *streamableWorkflow) TerminatePlan() vo.TerminatePlan {
func (s *streamableWorkflow) GetWorkflow() *entity.Workflow {
return s.wfEntity
}
func (s *streamableWorkflow) IsCallbacksEnabled() bool {
return true
}

@ -370,10 +370,6 @@ func (w *WorkflowHandler) OnError(ctx context.Context, info *callbacks.RunInfo,
interruptEvent.EventType, interruptEvent.NodeKey)
}
if c.TokenCollector != nil { // wait until all streaming chunks are collected
_ = c.TokenCollector.wait()
}
done := make(chan struct{})
w.ch <- &Event{
@ -1271,6 +1267,11 @@ func (n *NodeHandler) OnEndWithStreamOutput(ctx context.Context, info *callbacks
return ctx
}
const (
ToolCallIDKey = "call_id"
ToolFinishChanKey = "tool_finish_chan"
)
func (t *ToolHandler) OnStart(ctx context.Context, info *callbacks.RunInfo,
input *tool.CallbackInput,
) context.Context {
@ -1286,13 +1287,35 @@ func (t *ToolHandler) OnStart(ctx context.Context, info *callbacks.RunInfo,
}
}
var (
callID string
toolFinishChan chan struct{}
)
if input.Extra != nil {
callIDAny, ok := input.Extra[ToolCallIDKey]
if ok {
callID = callIDAny.(string)
}
toolFinishChanAny, ok := input.Extra[ToolFinishChanKey]
if ok {
toolFinishChan = toolFinishChanAny.(chan struct{})
}
}
if len(callID) == 0 {
callID = compose.GetToolCallID(ctx)
}
t.ch <- &Event{
Type: FunctionCall,
Context: GetExeCtx(ctx),
functionCall: &entity.FunctionCallInfo{
FunctionInfo: t.info,
CallID: compose.GetToolCallID(ctx),
Arguments: args,
functionCall: &FunctionCallInfo{
FunctionCallInfo: &entity.FunctionCallInfo{
FunctionInfo: t.info,
CallID: callID,
Arguments: args,
},
toolFinishChan: toolFinishChan,
},
}
@ -1306,14 +1329,25 @@ func (t *ToolHandler) OnEnd(ctx context.Context, info *callbacks.RunInfo,
return ctx
}
var callID string
if output.Extra != nil {
callIDAny, ok := output.Extra[ToolCallIDKey]
if ok {
callID = callIDAny.(string)
}
}
if len(callID) == 0 {
callID = compose.GetToolCallID(ctx)
}
t.ch <- &Event{
Type: ToolResponse,
Context: GetExeCtx(ctx),
toolResponse: &entity.ToolResponseInfo{
FunctionInfo: t.info,
CallID: compose.GetToolCallID(ctx),
CallID: callID,
Response: output.Response,
Complete: true,
},
}
@ -1352,7 +1386,6 @@ func (t *ToolHandler) OnEndWithStreamOutput(ctx context.Context, info *callbacks
toolResponse: &entity.ToolResponseInfo{
FunctionInfo: t.info,
CallID: callID,
Complete: true,
},
}
}
@ -1374,7 +1407,7 @@ func (t *ToolHandler) OnEndWithStreamOutput(ctx context.Context, info *callbacks
Context: c,
toolResponse: &entity.ToolResponseInfo{
FunctionInfo: t.info,
CallID: compose.GetToolCallID(ctx),
CallID: callID,
Response: chunk.Response,
},
}
@ -1398,9 +1431,11 @@ func (t *ToolHandler) OnError(ctx context.Context, info *callbacks.RunInfo, err
t.ch <- &Event{
Type: ToolError,
Context: GetExeCtx(ctx),
functionCall: &entity.FunctionCallInfo{
FunctionInfo: t.info,
CallID: compose.GetToolCallID(ctx),
functionCall: &FunctionCallInfo{
FunctionCallInfo: &entity.FunctionCallInfo{
FunctionInfo: t.info,
CallID: compose.GetToolCallID(ctx),
},
},
Err: err,
}

@ -26,6 +26,7 @@ import (
callbacks2 "github.com/cloudwego/eino/utils/callbacks"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
type TokenCollector struct {
@ -90,6 +91,33 @@ func (t *TokenCollector) finishStreamCounting() {
}
}
type tokenCollector struct {
Key string
Usage *model.TokenUsage
Parent *TokenCollector
}
func (t *TokenCollector) MarshalJSON() ([]byte, error) {
t.wait()
return sonic.Marshal(&tokenCollector{
Key: t.Key,
Usage: t.Usage,
Parent: t.Parent,
})
}
func (t *TokenCollector) UnmarshalJSON(bytes []byte) error {
tc := &tokenCollector{}
if err := sonic.Unmarshal(bytes, tc); err != nil {
return err
}
t.Key = tc.Key
t.Usage = tc.Usage
t.Parent = tc.Parent
return nil
}
func getTokenCollector(ctx context.Context) *TokenCollector {
c := GetExeCtx(ctx)
if c == nil {

@ -64,7 +64,7 @@ type Event struct {
InterruptEvents []*entity.InterruptEvent
functionCall *entity.FunctionCallInfo
functionCall *FunctionCallInfo
toolResponse *entity.ToolResponseInfo
outputExtractor func(o map[string]any) string
@ -75,6 +75,11 @@ type Event struct {
nodeCount int32
}
type FunctionCallInfo struct {
*entity.FunctionCallInfo
toolFinishChan chan struct{}
}
type TokenInfo struct {
InputToken int64
OutputToken int64
@ -104,17 +109,3 @@ func (e *Event) GetResumedEventID() int64 {
}
return e.Context.RootCtx.ResumeEvent.ID
}
func (e *Event) GetFunctionCallInfo() (*entity.FunctionCallInfo, bool) {
if e.functionCall == nil {
return nil, false
}
return e.functionCall, true
}
func (e *Event) GetToolResponse() (*entity.ToolResponseInfo, bool) {
if e.toolResponse == nil {
return nil, false
}
return e.toolResponse, true
}

@ -672,7 +672,7 @@ func handleEvent(ctx context.Context, event *Event, repo workflow.Repository,
ExecuteID: event.RootExecuteID,
Role: schema.Assistant,
Type: entity.FunctionCall,
FunctionCall: event.functionCall,
FunctionCall: event.functionCall.FunctionCallInfo,
},
}, nil)
case ToolResponse:
@ -704,8 +704,6 @@ func handleEvent(ctx context.Context, event *Event, repo workflow.Repository,
},
}, nil)
case ToolError:
// TODO: optimize this log
logs.CtxErrorf(ctx, "received tool error event: %v", event)
default:
panic("unimplemented event type: " + event.Type)
}
@ -715,8 +713,9 @@ func handleEvent(ctx context.Context, event *Event, repo workflow.Repository,
type fcCacheKey struct{}
type fcInfo struct {
input *entity.FunctionCallInfo
output *entity.ToolResponseInfo
input *entity.FunctionCallInfo
output *entity.ToolResponseInfo
toolFinishChan chan struct{}
}
func HandleExecuteEvent(ctx context.Context,
@ -772,7 +771,8 @@ func HandleExecuteEvent(ctx context.Context,
lastNodeIsDone = true
if wfSuccessEvent != nil {
if err = setRootWorkflowSuccess(ctx, wfSuccessEvent, repo, sw); err != nil {
logs.CtxErrorf(ctx, "failed to set root workflow success: %v", err)
logs.CtxErrorf(ctx, "failed to set root workflow success for workflow %d: %v",
wfSuccessEvent.RootWorkflowBasic.ID, err)
}
return wfSuccessEvent
}
@ -786,10 +786,12 @@ func HandleExecuteEvent(ctx context.Context,
// Add cancellation check timer
cancelTicker := time.NewTicker(cancelCheckInterval)
defer func() {
logs.CtxInfof(ctx, "[handleExecuteEvent] finish, returned event type: %v, workflow id: %d",
logs.CtxInfof(ctx, "[handleExecuteEvent] cancellable finish, returned event type: %v, workflow id: %d",
event.Type, event.Context.RootWorkflowBasic.ID)
cancelTicker.Stop() // Clean up timer
waitUntilToolFinish(ctx)
logs.CtxInfof(ctx, "[handleExecuteEvent] cancellable wait until tool finished done, workflow id: %d",
event.Context.RootWorkflowBasic.ID)
cancelTicker.Stop() // Clean up timer
if timeoutFn != nil {
timeoutFn()
}
@ -825,6 +827,9 @@ func HandleExecuteEvent(ctx context.Context,
defer func() {
logs.CtxInfof(ctx, "[handleExecuteEvent] finish, returned event type: %v, workflow id: %d",
event.Type, event.Context.RootWorkflowBasic.ID)
waitUntilToolFinish(ctx)
logs.CtxInfof(ctx, "[handleExecuteEvent] wait until tool finished done, workflow id: %d",
event.Context.RootWorkflowBasic.ID)
if timeoutFn != nil {
timeoutFn()
}
@ -859,29 +864,26 @@ func cacheFunctionCall(ctx context.Context, event *Event) {
c[event.NodeKey] = make(map[string]*fcInfo)
}
c[event.NodeKey][event.functionCall.CallID] = &fcInfo{
input: event.functionCall,
input: event.functionCall.FunctionCallInfo,
toolFinishChan: event.functionCall.toolFinishChan,
}
}
func cacheToolResponse(ctx context.Context, event *Event) {
c := ctx.Value(fcCacheKey{}).(map[vo.NodeKey]map[string]*fcInfo)
if _, ok := c[event.NodeKey]; !ok {
c[event.NodeKey] = make(map[string]*fcInfo)
}
c[event.NodeKey][event.toolResponse.CallID].output = event.toolResponse
}
func cacheToolStreamingResponse(ctx context.Context, event *Event) {
c := ctx.Value(fcCacheKey{}).(map[vo.NodeKey]map[string]*fcInfo)
if _, ok := c[event.NodeKey]; !ok {
c[event.NodeKey] = make(map[string]*fcInfo)
}
if c[event.NodeKey][event.toolResponse.CallID].output == nil {
c[event.NodeKey][event.toolResponse.CallID].output = event.toolResponse
} else {
c[event.NodeKey][event.toolResponse.CallID].output.Response += event.toolResponse.Response
}
c[event.NodeKey][event.toolResponse.CallID].output.Response += event.toolResponse.Response
c[event.NodeKey][event.toolResponse.CallID].output.Complete = event.toolResponse.Complete
logs.CtxInfof(ctx, "receive tool response: %s, callID: %s",
event.toolResponse.Response, event.toolResponse.CallID)
}
func getFCInfos(ctx context.Context, nodeKey vo.NodeKey) map[string]*fcInfo {
@ -890,29 +892,17 @@ func getFCInfos(ctx context.Context, nodeKey vo.NodeKey) map[string]*fcInfo {
}
func waitUntilToolFinish(ctx context.Context) {
var cnt int
outer:
for {
if cnt > 1000 {
return
}
c := ctx.Value(fcCacheKey{}).(map[vo.NodeKey]map[string]*fcInfo)
if len(c) == 0 {
return
}
for _, m := range c {
for _, info := range m {
if info.output == nil {
cnt++
continue outer
}
c := ctx.Value(fcCacheKey{}).(map[vo.NodeKey]map[string]*fcInfo)
if len(c) == 0 {
return
}
if !info.output.Complete {
cnt++
continue outer
}
for _, m := range c {
for _, info := range m {
if info.toolFinishChan != nil {
<-info.toolFinishChan
logs.CtxInfof(ctx, "tool finished, callID: %s, pluginID: %v", info.output.CallID,
info.input.PluginID)
}
}
}

@ -26,11 +26,10 @@ import (
)
type workflowToolOption struct {
resumeReq *entity.ResumeRequest
streamContainer *StreamContainer
exeCfg workflowModel.ExecuteConfig
allInterruptEvents map[string]*entity.ToolInterruptEvent
parentTokenCollector *TokenCollector
resumeReq *entity.ResumeRequest
streamContainer *StreamContainer
exeCfg workflowModel.ExecuteConfig
allInterruptEvents map[string]*entity.ToolInterruptEvent
}
func WithResume(req *entity.ResumeRequest, all map[string]*entity.ToolInterruptEvent) tool.Option {

@ -17,7 +17,7 @@ require (
github.com/bytedance/gopkg v0.1.3
github.com/bytedance/mockey v1.2.14
github.com/bytedance/sonic v1.14.0
github.com/cloudwego/eino v0.3.55
github.com/cloudwego/eino v0.4.8
github.com/cloudwego/eino-ext/components/embedding/ark v0.1.0
github.com/cloudwego/eino-ext/components/embedding/gemini v0.0.0-20250814083140-54b99ff82f8e
github.com/cloudwego/eino-ext/components/embedding/ollama v0.0.0-20250728060543-79ec300857b8
@ -284,3 +284,10 @@ require (
sigs.k8s.io/yaml v1.3.0 // indirect
stathat.com/c/consistent v1.0.0 // indirect
)
require (
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/eino-contrib/jsonschema v1.0.0 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
)

@ -131,6 +131,8 @@ github.com/aws/smithy-go v1.8.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAm
github.com/aws/smithy-go v1.22.4 h1:uqXzVZNuNexwc/xrh6Tb56u89WDlJY6HS+KC0S4QSjw=
github.com/aws/smithy-go v1.22.4/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g=
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
@ -146,6 +148,8 @@ github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8=
github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE=
github.com/bytedance/go-tagexpr/v2 v2.9.2/go.mod h1:5qsx05dYOiUXOUgnQ7w3Oz8BYs2qtM/bJokdLb79wRM=
@ -190,8 +194,8 @@ github.com/clbanning/mxj v1.8.4/go.mod h1:BVjHeAH+rl9rs6f+QIpeRl0tfu10SXn1pUSa5P
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4=
github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/eino v0.3.55 h1:lMZrGtEh0k3qykQTLNXSXuAa98OtF2tS43GMHyvN7nA=
github.com/cloudwego/eino v0.3.55/go.mod h1:wUjz990apdsaOraOXdh6CdhVXq8DJsOvLsVlxNTcNfY=
github.com/cloudwego/eino v0.4.8 h1:wptTU24tQad1mFCHw0+4zSzH+p8dLEBk6HtggPlcvP0=
github.com/cloudwego/eino v0.4.8/go.mod h1:1TDlOmwGSsbCJaWB92w9YLZi2FL0WRZoRcD4eMvqikg=
github.com/cloudwego/eino-ext/components/embedding/ark v0.1.0 h1:AuJsMdaTXc+dGUDQp82MifLYK8oiJf4gLQPUETmKISM=
github.com/cloudwego/eino-ext/components/embedding/ark v0.1.0/go.mod h1:0FZG/KRBl3hGWkNsm55UaXyVa6PDVIy5u+QvboAB+cY=
github.com/cloudwego/eino-ext/components/embedding/gemini v0.0.0-20250814083140-54b99ff82f8e h1:46D2fFDbUysA7kUD5x/wK3huneMEvTQfuWcHqI3M6iQ=
@ -288,6 +292,8 @@ github.com/eapache/go-xerial-snappy v0.0.0-20230731223053-c322873962e3/go.mod h1
github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc=
github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I=
github.com/edsrzf/mmap-go v1.0.0/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M=
github.com/eino-contrib/jsonschema v1.0.0 h1:dXxbhGNZuI3+xNi8x3JT8AGyoXz6Pff6mRvmpjVl5Ww=
github.com/eino-contrib/jsonschema v1.0.0/go.mod h1:cpnX4SyKjWjGC7iN2EbhxaTdLqGjCi0e9DxpLYxddD4=
github.com/eknkc/amber v0.0.0-20171010120322-cdade1c07385/go.mod h1:0vRUJqYpeSZifjYj7uP3BG/gKcuzL9xWVV/Y+cK33KM=
github.com/elastic/elastic-transport-go/v8 v8.7.0 h1:OgTneVuXP2uip4BA658Xi6Hfw+PeIOod2rY3GVMGoVE=
github.com/elastic/elastic-transport-go/v8 v8.7.0/go.mod h1:YLHer5cj0csTzNFXoNQ8qhtGY1GTvSqPnKWKaqQE3Hk=
@ -1056,6 +1062,8 @@ github.com/volcengine/volc-sdk-golang v1.0.211 h1:FgwD+1phyy+un4Qk2YqooYtp6XpvND
github.com/volcengine/volc-sdk-golang v1.0.211/go.mod h1:stZX+EPgv1vF4nZwOlEe8iGcriUPRBKX8zA19gXycOQ=
github.com/volcengine/volcengine-go-sdk v1.1.20 h1:+ifZdF7IIIagqF8yVNfk9CmNUl5wgRfU/8orlH+JQhA=
github.com/volcengine/volcengine-go-sdk v1.1.20/go.mod h1:EyKoi6t6eZxoPNGr2GdFCZti2Skd7MO3eUzx7TtSvNo=
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg=
github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7VPsoEPHyzalCE06qoARUCeBBE=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=

@ -72,7 +72,14 @@ func (q *UTChatModel) Generate(ctx context.Context, in []*schema.Message, _ ...m
}
if msg.ResponseMeta != nil {
callbackOut.TokenUsage = (*model.TokenUsage)(msg.ResponseMeta.Usage)
callbackOut.TokenUsage = &model.TokenUsage{
PromptTokens: msg.ResponseMeta.Usage.PromptTokens,
PromptTokenDetails: model.PromptTokenDetails{
CachedTokens: msg.ResponseMeta.Usage.PromptTokenDetails.CachedTokens,
},
CompletionTokens: msg.ResponseMeta.Usage.CompletionTokens,
TotalTokens: msg.ResponseMeta.Usage.TotalTokens,
}
}
_ = callbacks.OnEnd(ctx, callbackOut)
@ -112,7 +119,14 @@ func (q *UTChatModel) Stream(ctx context.Context, in []*schema.Message, _ ...mod
}
if t.ResponseMeta != nil {
callbackOut.TokenUsage = (*model.TokenUsage)(t.ResponseMeta.Usage)
callbackOut.TokenUsage = &model.TokenUsage{
PromptTokens: t.ResponseMeta.Usage.PromptTokens,
PromptTokenDetails: model.PromptTokenDetails{
CachedTokens: t.ResponseMeta.Usage.PromptTokenDetails.CachedTokens,
},
CompletionTokens: t.ResponseMeta.Usage.CompletionTokens,
TotalTokens: t.ResponseMeta.Usage.TotalTokens,
}
}
return callbackOut, nil

Loading…
Cancel
Save