You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
459 lines
13 KiB
459 lines
13 KiB
/*
|
|
* Copyright 2025 coze-dev Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package dal
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"time"
|
|
|
|
"github.com/cloudwego/eino/schema"
|
|
"gorm.io/gorm"
|
|
|
|
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
|
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
|
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/internal/dal/model"
|
|
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/internal/dal/query"
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
|
"github.com/coze-dev/coze-studio/backend/types/errno"
|
|
)
|
|
|
|
type MessageDAO struct {
|
|
query *query.Query
|
|
idgen idgen.IDGenerator
|
|
}
|
|
|
|
func NewMessageDAO(db *gorm.DB, idgen idgen.IDGenerator) *MessageDAO {
|
|
return &MessageDAO{
|
|
query: query.Use(db),
|
|
idgen: idgen,
|
|
}
|
|
}
|
|
|
|
func (dao *MessageDAO) PreCreate(ctx context.Context, msg *entity.Message) (*entity.Message, error) {
|
|
poData, err := dao.messageDO2PO(ctx, msg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return dao.messagePO2DO(poData), nil
|
|
}
|
|
|
|
func (dao *MessageDAO) Create(ctx context.Context, msg *entity.Message) (*entity.Message, error) {
|
|
poData, err := dao.messageDO2PO(ctx, msg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
do := dao.query.Message.WithContext(ctx).Debug()
|
|
cErr := do.Create(poData)
|
|
if cErr != nil {
|
|
return nil, cErr
|
|
}
|
|
|
|
return dao.messagePO2DO(poData), nil
|
|
}
|
|
|
|
func (dao *MessageDAO) BatchCreate(ctx context.Context, msg []*entity.Message) ([]*entity.Message, error) {
|
|
poList := make([]*model.Message, 0, len(msg))
|
|
for _, m := range msg {
|
|
po, err := dao.messageDO2PO(ctx, m)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
poList = append(poList, po)
|
|
}
|
|
|
|
do := dao.query.Message.WithContext(ctx).Debug()
|
|
cErr := do.CreateInBatches(poList, len(poList))
|
|
if cErr != nil {
|
|
return nil, cErr
|
|
}
|
|
|
|
return dao.batchMessagePO2DO(poList), nil
|
|
}
|
|
|
|
func (dao *MessageDAO) List(ctx context.Context, listMeta *entity.ListMeta) ([]*entity.Message, bool, error) {
|
|
m := dao.query.Message
|
|
do := m.WithContext(ctx).Debug().Where(m.ConversationID.Eq(listMeta.ConversationID)).Where(m.Status.Eq(int32(entity.MessageStatusAvailable)))
|
|
|
|
if len(listMeta.RunID) > 0 {
|
|
do = do.Where(m.RunID.In(slices.Transform(listMeta.RunID, func(t *int64) int64 {
|
|
return *t
|
|
})...))
|
|
}
|
|
if len(listMeta.MessageType) > 0 {
|
|
do = do.Where(m.MessageType.In(slices.Transform(listMeta.MessageType, func(t *message.MessageType) string {
|
|
return string(*t)
|
|
})...))
|
|
}
|
|
|
|
if listMeta.Limit > 0 {
|
|
do = do.Limit(int(listMeta.Limit) + 1)
|
|
}
|
|
|
|
if listMeta.Cursor > 0 {
|
|
msg, err := m.Where(m.ID.Eq(listMeta.Cursor)).First()
|
|
if err != nil {
|
|
return nil, false, err
|
|
}
|
|
if listMeta.Direction == entity.ScrollPageDirectionPrev {
|
|
do = do.Where(m.CreatedAt.Lt(msg.CreatedAt))
|
|
do = do.Order(m.CreatedAt.Desc())
|
|
} else {
|
|
do = do.Where(m.CreatedAt.Gt(msg.CreatedAt))
|
|
do = do.Order(m.CreatedAt.Asc())
|
|
}
|
|
} else {
|
|
do = do.Order(m.CreatedAt.Desc())
|
|
}
|
|
|
|
messageList, err := do.Find()
|
|
|
|
var hasMore bool
|
|
|
|
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, hasMore, nil
|
|
}
|
|
if err != nil {
|
|
return nil, false, err
|
|
}
|
|
|
|
if len(messageList) > int(listMeta.Limit) {
|
|
hasMore = true
|
|
messageList = messageList[:int(listMeta.Limit)]
|
|
}
|
|
|
|
return dao.batchMessagePO2DO(messageList), hasMore, nil
|
|
}
|
|
|
|
func (dao *MessageDAO) GetByRunIDs(ctx context.Context, runIDs []int64, orderBy string) ([]*entity.Message, error) {
|
|
m := dao.query.Message
|
|
do := m.WithContext(ctx).Debug().Where(m.RunID.In(runIDs...)).Where(m.Status.Eq(int32(entity.MessageStatusAvailable)))
|
|
|
|
if orderBy == "DESC" {
|
|
do = do.Order(m.CreatedAt.Desc())
|
|
} else {
|
|
do = do.Order(m.CreatedAt.Asc())
|
|
}
|
|
poList, err := do.Find()
|
|
|
|
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return dao.batchMessagePO2DO(poList), nil
|
|
}
|
|
|
|
func (dao *MessageDAO) Edit(ctx context.Context, msgID int64, msg *message.Message) (int64, error) {
|
|
m := dao.query.Message
|
|
|
|
originMsg, err := dao.GetByID(ctx, msgID)
|
|
if originMsg == nil {
|
|
return 0, errorx.New(errno.ErrRecordNotFound)
|
|
}
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
columns := dao.buildEditColumns(msg, originMsg)
|
|
do, err := m.WithContext(ctx).Where(m.ID.Eq(msgID)).UpdateColumns(columns)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
if do.RowsAffected == 0 {
|
|
return 0, errorx.New(errno.ErrRecordNotFound)
|
|
}
|
|
|
|
return do.RowsAffected, nil
|
|
}
|
|
|
|
func (dao *MessageDAO) buildEditColumns(msg *message.Message, originMsg *entity.Message) map[string]interface{} {
|
|
columns := make(map[string]interface{})
|
|
table := dao.query.Message
|
|
if msg.Content != "" {
|
|
msg.Role = originMsg.Role
|
|
columns[table.Content.ColumnName().String()] = msg.Content
|
|
modelContent, err := dao.buildModelContent(msg)
|
|
if err == nil {
|
|
columns[table.ModelContent.ColumnName().String()] = modelContent
|
|
}
|
|
}
|
|
if msg.MessageType != "" {
|
|
columns[table.MessageType.ColumnName().String()] = msg.MessageType
|
|
}
|
|
if msg.ContentType != "" {
|
|
columns[table.ContentType.ColumnName().String()] = msg.ContentType
|
|
}
|
|
if len(msg.ReasoningContent) > 0 {
|
|
columns[table.ReasoningContent.ColumnName().String()] = msg.ReasoningContent
|
|
}
|
|
|
|
if msg.Position > 0 {
|
|
columns[table.BrokenPosition.ColumnName().String()] = msg.Position
|
|
}
|
|
if msg.Status > 0 {
|
|
columns[table.Status.ColumnName().String()] = msg.Status
|
|
}
|
|
|
|
if len(msg.ModelContent) > 0 {
|
|
columns[table.ModelContent.ColumnName().String()] = msg.ModelContent
|
|
}
|
|
|
|
columns[table.UpdatedAt.ColumnName().String()] = time.Now().UnixMilli()
|
|
if msg.Ext != nil {
|
|
if originMsg.Ext != nil {
|
|
for k, v := range originMsg.Ext {
|
|
msg.Ext[k] = v
|
|
}
|
|
}
|
|
ext, err := sonic.MarshalString(msg.Ext)
|
|
if err == nil {
|
|
columns[table.Ext.ColumnName().String()] = ext
|
|
}
|
|
}
|
|
return columns
|
|
}
|
|
|
|
func (dao *MessageDAO) GetByID(ctx context.Context, msgID int64) (*entity.Message, error) {
|
|
m := dao.query.Message
|
|
do := m.WithContext(ctx).Where(m.ID.Eq(msgID))
|
|
po, err := do.First()
|
|
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return dao.messagePO2DO(po), nil
|
|
}
|
|
|
|
func (dao *MessageDAO) Delete(ctx context.Context, delMeta *entity.DeleteMeta) error {
|
|
if len(delMeta.MessageIDs) == 0 && len(delMeta.RunIDs) == 0 {
|
|
return nil
|
|
}
|
|
|
|
updateColumns := make(map[string]interface{})
|
|
updateColumns["status"] = int32(entity.MessageStatusDeleted)
|
|
m := dao.query.Message
|
|
do := m.WithContext(ctx)
|
|
|
|
if len(delMeta.RunIDs) > 0 {
|
|
do = do.Where(m.RunID.In(delMeta.RunIDs...))
|
|
}
|
|
if len(delMeta.MessageIDs) > 0 {
|
|
do = do.Where(m.ID.In(delMeta.MessageIDs...))
|
|
}
|
|
if delMeta.ConversationID != nil && ptr.From(delMeta.ConversationID) > 0 {
|
|
do = do.Where(m.ConversationID.Eq(*delMeta.ConversationID))
|
|
}
|
|
_, err := do.UpdateColumns(&updateColumns)
|
|
return err
|
|
}
|
|
|
|
func (dao *MessageDAO) messageDO2PO(ctx context.Context, msgDo *entity.Message) (*model.Message, error) {
|
|
var id int64
|
|
if msgDo.ID > 0 {
|
|
id = msgDo.ID
|
|
} else {
|
|
genID, gErr := dao.idgen.GenID(ctx)
|
|
if gErr != nil {
|
|
return nil, gErr
|
|
}
|
|
id = genID
|
|
}
|
|
msgPO := &model.Message{
|
|
ID: id,
|
|
ConversationID: msgDo.ConversationID,
|
|
RunID: msgDo.RunID,
|
|
AgentID: msgDo.AgentID,
|
|
SectionID: msgDo.SectionID,
|
|
UserID: msgDo.UserID,
|
|
Role: string(msgDo.Role),
|
|
ContentType: string(msgDo.ContentType),
|
|
MessageType: string(msgDo.MessageType),
|
|
DisplayContent: msgDo.DisplayContent,
|
|
Content: msgDo.Content,
|
|
BrokenPosition: msgDo.Position,
|
|
Status: int32(entity.MessageStatusAvailable),
|
|
CreatedAt: time.Now().UnixMilli(),
|
|
UpdatedAt: time.Now().UnixMilli(),
|
|
ReasoningContent: msgDo.ReasoningContent,
|
|
}
|
|
if msgDo.CreatedAt > 0 {
|
|
msgPO.CreatedAt = msgDo.CreatedAt
|
|
}
|
|
if msgDo.UpdatedAt > 0 {
|
|
msgPO.UpdatedAt = msgDo.UpdatedAt
|
|
}
|
|
|
|
if msgDo.ModelContent != "" {
|
|
msgPO.ModelContent = msgDo.ModelContent
|
|
} else {
|
|
mc, err := dao.buildModelContent(msgDo)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
msgPO.ModelContent = mc
|
|
}
|
|
|
|
ext, err := json.Marshal(msgDo.Ext)
|
|
if err != nil {
|
|
return nil, errorx.WrapByCode(err, errno.ErrConversationJsonMarshal)
|
|
}
|
|
msgPO.Ext = string(ext)
|
|
|
|
return msgPO, nil
|
|
}
|
|
|
|
func (dao *MessageDAO) buildModelContent(msgDO *entity.Message) (string, error) {
|
|
modelContent := msgDO.ModelContent
|
|
if modelContent != "" {
|
|
return modelContent, nil
|
|
}
|
|
|
|
modelContentObj := &schema.Message{
|
|
Role: msgDO.Role,
|
|
Name: msgDO.Name,
|
|
}
|
|
if msgDO.Content == "" && len(msgDO.MultiContent) == 0 {
|
|
return "", nil
|
|
}
|
|
|
|
var multiContent []schema.ChatMessagePart
|
|
for _, contentData := range msgDO.MultiContent {
|
|
if contentData.Type == message.InputTypeText {
|
|
if len(msgDO.Content) == 0 && len(contentData.Text) > 0 {
|
|
msgDO.Content = contentData.Text
|
|
}
|
|
continue
|
|
}
|
|
one := schema.ChatMessagePart{}
|
|
switch contentData.Type {
|
|
case message.InputTypeImage:
|
|
one.Type = schema.ChatMessagePartTypeImageURL
|
|
one.ImageURL = &schema.ChatMessageImageURL{
|
|
URL: contentData.FileData[0].Url,
|
|
URI: contentData.FileData[0].URI,
|
|
}
|
|
case message.InputTypeFile:
|
|
one.Type = schema.ChatMessagePartTypeFileURL
|
|
one.FileURL = &schema.ChatMessageFileURL{
|
|
URL: contentData.FileData[0].Url,
|
|
URI: contentData.FileData[0].URI,
|
|
}
|
|
case message.InputTypeVideo:
|
|
one.Type = schema.ChatMessagePartTypeVideoURL
|
|
one.VideoURL = &schema.ChatMessageVideoURL{
|
|
URL: contentData.FileData[0].Url,
|
|
}
|
|
case message.InputTypeAudio:
|
|
one.Type = schema.ChatMessagePartTypeAudioURL
|
|
one.AudioURL = &schema.ChatMessageAudioURL{
|
|
URL: contentData.FileData[0].Url,
|
|
URI: contentData.FileData[0].URI,
|
|
}
|
|
}
|
|
multiContent = append(multiContent, one)
|
|
}
|
|
if len(multiContent) > 0 {
|
|
if len(msgDO.Content) > 0 {
|
|
multiContent = append(multiContent, schema.ChatMessagePart{
|
|
Type: schema.ChatMessagePartTypeText,
|
|
Text: msgDO.Content,
|
|
})
|
|
}
|
|
} else {
|
|
modelContentObj.Content = msgDO.Content
|
|
}
|
|
|
|
modelContentObj.MultiContent = multiContent
|
|
|
|
mcObjByte, err := json.Marshal(modelContentObj)
|
|
if err != nil {
|
|
return "", errorx.WrapByCode(err, errno.ErrConversationJsonMarshal)
|
|
}
|
|
|
|
return string(mcObjByte), nil
|
|
}
|
|
|
|
func (dao *MessageDAO) batchMessagePO2DO(msgPOs []*model.Message) []*entity.Message {
|
|
return slices.Transform(msgPOs, func(msgPO *model.Message) *entity.Message {
|
|
msgDO := &entity.Message{
|
|
ID: msgPO.ID,
|
|
AgentID: msgPO.AgentID,
|
|
ConversationID: msgPO.ConversationID,
|
|
SectionID: msgPO.SectionID,
|
|
UserID: msgPO.UserID,
|
|
RunID: msgPO.RunID,
|
|
Role: schema.RoleType(msgPO.Role),
|
|
ContentType: message.ContentType(msgPO.ContentType),
|
|
MessageType: message.MessageType(msgPO.MessageType),
|
|
Position: msgPO.BrokenPosition,
|
|
ModelContent: msgPO.ModelContent,
|
|
Content: msgPO.Content,
|
|
Status: message.MessageStatus(msgPO.Status),
|
|
DisplayContent: msgPO.DisplayContent,
|
|
CreatedAt: msgPO.CreatedAt,
|
|
UpdatedAt: msgPO.UpdatedAt,
|
|
ReasoningContent: msgPO.ReasoningContent,
|
|
}
|
|
|
|
var ext map[string]string
|
|
err := json.Unmarshal([]byte(msgPO.Ext), &ext)
|
|
if err == nil {
|
|
msgDO.Ext = ext
|
|
}
|
|
|
|
return msgDO
|
|
})
|
|
}
|
|
|
|
func (dao *MessageDAO) messagePO2DO(msgPO *model.Message) *entity.Message {
|
|
msgDO := &entity.Message{
|
|
ID: msgPO.ID,
|
|
AgentID: msgPO.AgentID,
|
|
ConversationID: msgPO.ConversationID,
|
|
SectionID: msgPO.SectionID,
|
|
UserID: msgPO.UserID,
|
|
RunID: msgPO.RunID,
|
|
Role: schema.RoleType(msgPO.Role),
|
|
ContentType: message.ContentType(msgPO.ContentType),
|
|
MessageType: message.MessageType(msgPO.MessageType),
|
|
ModelContent: msgPO.ModelContent,
|
|
Content: msgPO.Content,
|
|
DisplayContent: msgPO.DisplayContent,
|
|
CreatedAt: msgPO.CreatedAt,
|
|
UpdatedAt: msgPO.UpdatedAt,
|
|
}
|
|
|
|
var ext map[string]string
|
|
err := json.Unmarshal([]byte(msgPO.Ext), &ext)
|
|
if err == nil {
|
|
msgDO.Ext = ext
|
|
}
|
|
|
|
return msgDO
|
|
}
|
|
|