扣子智能体
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 
coze_studio/backend/domain/workflow/service/executable_impl_test.go

286 lines
11 KiB

/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package service
import (
"context"
"errors"
"testing"
"github.com/cloudwego/eino/schema"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
messagemock "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message/messagemock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
mock_workflow "github.com/coze-dev/coze-studio/backend/internal/mock/domain/workflow"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func TestImpl_handleHistory(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t, gomock.WithOverridableExpectations())
defer ctrl.Finish()
// Setup for cross-domain service mock
mockMessage := messagemock.NewMockMessage(ctrl)
crossmessage.SetDefaultSVC(mockMessage)
tests := []struct {
name string
setupMock func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository)
config *workflowModel.ExecuteConfig
input map[string]any
historyRounds int64
shouldFetch bool
expectErr bool
expectedHistory []*crossmessage.WfMessage
expectedSchemaHistory []*schema.Message
}{
{
name: "historyRounds is zero",
historyRounds: 0,
shouldFetch: true,
config: &workflowModel.ExecuteConfig{},
setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
},
expectErr: false,
},
{
name: "shouldFetch is false",
historyRounds: 5,
shouldFetch: false,
config: &workflowModel.ExecuteConfig{
AppID: ptr.Of(int64(1)),
ConversationID: ptr.Of(int64(100)),
SectionID: ptr.Of(int64(101)),
},
setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{1, 2}, nil).AnyTimes()
msgSvc.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(&crossmessage.GetMessagesByRunIDsResponse{
Messages: []*crossmessage.WfMessage{{ID: 1}},
SchemaMessages: []*schema.Message{{
Role: schema.User,
Content: "123",
}},
}, nil).AnyTimes()
},
expectErr: false,
expectedHistory: []*crossmessage.WfMessage{{ID: 1}},
expectedSchemaHistory: []*schema.Message{{
Role: schema.User,
Content: "123",
}},
},
{
name: "fetch conversation by name - conversation exists",
historyRounds: 3,
shouldFetch: true,
config: &workflowModel.ExecuteConfig{AppID: ptr.Of(int64(1))},
input: map[string]any{"CONVERSATION_NAME": "test-conv"},
setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
service.EXPECT().GetOrCreateConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), "test-conv").Return(int64(200), int64(201), nil).AnyTimes()
msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{3, 4}, nil).AnyTimes()
msgSvc.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(&crossmessage.GetMessagesByRunIDsResponse{
Messages: []*crossmessage.WfMessage{{ID: 2}},
SchemaMessages: []*schema.Message{{
Role: schema.Assistant,
Content: "123",
}},
}, nil).AnyTimes()
repo.EXPECT().GetConversationTemplate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&entity.ConversationTemplate{
TemplateID: int64(202),
SpaceID: int64(203),
AppID: int64(204),
}, true, nil).AnyTimes()
repo.EXPECT().GetOrCreateStaticConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(205), int64(206), true, nil).AnyTimes()
},
expectErr: false,
expectedHistory: []*crossmessage.WfMessage{{ID: 2}},
expectedSchemaHistory: []*schema.Message{{
Role: schema.Assistant,
Content: "123",
}},
},
{
name: "fetch conversation by name - conversation not exists",
historyRounds: 3,
shouldFetch: true,
config: &workflowModel.ExecuteConfig{AgentID: ptr.Of(int64(2))},
input: map[string]any{"CONVERSATION_NAME": "new-conv"},
setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
service.EXPECT().GetOrCreateConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), "new-conv").Return(int64(300), int64(301), nil).AnyTimes()
msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{5, 6}, nil).AnyTimes()
msgSvc.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(&crossmessage.GetMessagesByRunIDsResponse{
Messages: []*crossmessage.WfMessage{{ID: 3}},
}, nil).AnyTimes()
repo.EXPECT().GetConversationTemplate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&entity.ConversationTemplate{
TemplateID: int64(202),
SpaceID: int64(203),
AppID: int64(204),
}, false, nil).AnyTimes()
repo.EXPECT().GetOrCreateDynamicConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(205), int64(206), true, nil).AnyTimes()
},
expectErr: false,
expectedHistory: []*crossmessage.WfMessage{{ID: 3}},
},
{
name: "input with wrong type for conversation name",
historyRounds: 5,
shouldFetch: true,
config: &workflowModel.ExecuteConfig{AppID: ptr.Of(int64(1))},
input: map[string]any{"CONVERSATION_NAME": 12345},
setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
},
expectErr: true,
},
{
name: "GetOrCreateConversation returns error",
historyRounds: 5,
shouldFetch: true,
config: &workflowModel.ExecuteConfig{AppID: ptr.Of(int64(1))},
input: map[string]any{"CONVERSATION_NAME": "fail-conv"},
setupMock: func(service *mock_workflow.MockService, msgSvc *messagemock.MockMessage, repo *mock_workflow.MockRepository) {
service.EXPECT().GetOrCreateConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), "fail-conv").Return(int64(0), int64(0), errors.New("db error")).AnyTimes()
repo.EXPECT().GetConversationTemplate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&entity.ConversationTemplate{
TemplateID: int64(202),
SpaceID: int64(203),
AppID: int64(204),
}, false, nil).AnyTimes()
repo.EXPECT().GetOrCreateDynamicConversation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(205), int64(206), true, errors.New("db error")).AnyTimes()
},
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockService := mock_workflow.NewMockService(ctrl)
mockRepo := mock_workflow.NewMockRepository(ctrl)
testImpl := &impl{repo: mockRepo, conversationImpl: &conversationImpl{repo: mockRepo}}
tt.setupMock(mockService, mockMessage, mockRepo)
err := testImpl.handleHistory(ctx, tt.config, tt.input, tt.historyRounds, tt.shouldFetch)
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
if tt.expectedHistory != nil {
assert.Equal(t, tt.expectedHistory, tt.config.ConversationHistory)
} else if tt.historyRounds == 0 {
assert.Nil(t, tt.config.ConversationHistory)
} else if tt.expectedSchemaHistory != nil {
assert.Equal(t, tt.expectedSchemaHistory, tt.config.ConversationHistorySchemaMessages)
}
}
})
}
}
func TestImpl_prefetchChatHistory(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t, gomock.WithOverridableExpectations())
defer ctrl.Finish()
mockMessage := messagemock.NewMockMessage(ctrl)
crossmessage.SetDefaultSVC(mockMessage)
tests := []struct {
name string
setupMock func(msgSvc *messagemock.MockMessage)
config workflowModel.ExecuteConfig
historyRounds int64
expectErr bool
}{
{
name: "SectionID is nil",
config: workflowModel.ExecuteConfig{
ConversationID: ptr.Of(int64(100)),
AppID: ptr.Of(int64(1)),
},
historyRounds: 5,
setupMock: func(msgSvc *messagemock.MockMessage) {},
expectErr: false,
},
{
name: "ConversationID is nil",
config: workflowModel.ExecuteConfig{
SectionID: ptr.Of(int64(101)),
AppID: ptr.Of(int64(1)),
},
historyRounds: 5,
setupMock: func(msgSvc *messagemock.MockMessage) {},
expectErr: false,
},
{
name: "AppID and AgentID are both nil",
config: workflowModel.ExecuteConfig{
ConversationID: ptr.Of(int64(100)),
SectionID: ptr.Of(int64(101)),
},
historyRounds: 5,
setupMock: func(msgSvc *messagemock.MockMessage) {},
expectErr: false,
},
{
name: "GetLatestRunIDs returns error",
config: workflowModel.ExecuteConfig{
AppID: ptr.Of(int64(1)),
ConversationID: ptr.Of(int64(100)),
SectionID: ptr.Of(int64(101)),
},
historyRounds: 5,
setupMock: func(msgSvc *messagemock.MockMessage) {
msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return(nil, errors.New("db error"))
},
expectErr: true,
},
{
name: "GetMessagesByRunIDs returns error",
config: workflowModel.ExecuteConfig{
AppID: ptr.Of(int64(1)),
ConversationID: ptr.Of(int64(100)),
SectionID: ptr.Of(int64(101)),
},
historyRounds: 5,
setupMock: func(msgSvc *messagemock.MockMessage) {
msgSvc.EXPECT().GetLatestRunIDs(gomock.Any(), gomock.Any()).Return([]int64{1, 2, 3}, nil)
msgSvc.EXPECT().GetMessagesByRunIDs(gomock.Any(), gomock.Any()).Return(nil, errors.New("db error"))
},
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testImpl := &impl{}
tt.setupMock(mockMessage)
_, _, err := testImpl.prefetchChatHistory(ctx, tt.config, tt.historyRounds)
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}