feat(infra): add object tagging support for PutObject and ListObjects (#1845)

main
Ryo 2 months ago committed by GitHub
parent aae865dafb
commit 16bd3b5628
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 12
      backend/infra/contract/storage/option.go
  2. 14
      backend/infra/contract/storage/storage.go
  3. 50
      backend/infra/impl/storage/minio/minio.go
  4. 106
      backend/infra/impl/storage/s3/s3.go
  5. 44
      backend/infra/impl/storage/tos/tos.go
  6. 12
      backend/internal/mock/infra/contract/storage/storage_mock.go

@ -38,11 +38,23 @@ type PutOption struct {
ContentDisposition *string ContentDisposition *string
ContentLanguage *string ContentLanguage *string
Expires *time.Time Expires *time.Time
Tagging map[string]string
ObjectSize int64 ObjectSize int64
} }
type PutOptFn func(option *PutOption) type PutOptFn func(option *PutOption)
func WithTagging(tag map[string]string) PutOptFn {
return func(o *PutOption) {
if len(tag) > 0 {
o.Tagging = make(map[string]string, len(tag))
for k, v := range tag {
o.Tagging[k] = v
}
}
}
}
func WithContentType(v string) PutOptFn { func WithContentType(v string) PutOptFn {
return func(o *PutOption) { return func(o *PutOption) {
o.ContentType = &v o.ContentType = &v

@ -24,15 +24,20 @@ import (
//go:generate mockgen -destination ../../../internal/mock/infra/contract/storage/storage_mock.go -package mock -source storage.go Factory //go:generate mockgen -destination ../../../internal/mock/infra/contract/storage/storage_mock.go -package mock -source storage.go Factory
type Storage interface { type Storage interface {
// PutObject puts the object with the specified key.
PutObject(ctx context.Context, objectKey string, content []byte, opts ...PutOptFn) error PutObject(ctx context.Context, objectKey string, content []byte, opts ...PutOptFn) error
// PutObjectWithReader puts the object with the specified key.
PutObjectWithReader(ctx context.Context, objectKey string, content io.Reader, opts ...PutOptFn) error PutObjectWithReader(ctx context.Context, objectKey string, content io.Reader, opts ...PutOptFn) error
// GetObject returns the object with the specified key.
GetObject(ctx context.Context, objectKey string) ([]byte, error) GetObject(ctx context.Context, objectKey string) ([]byte, error)
// DeleteObject deletes the object with the specified key.
DeleteObject(ctx context.Context, objectKey string) error DeleteObject(ctx context.Context, objectKey string) error
// GetObjectUrl returns a presigned URL for the object.
// The URL is valid for the specified duration.
GetObjectUrl(ctx context.Context, objectKey string, opts ...GetOptFn) (string, error) GetObjectUrl(ctx context.Context, objectKey string, opts ...GetOptFn) (string, error)
// ListObjects returns all objects with the specified prefix. // ListAllObjects returns all objects with the specified prefix.
// It may return a large number of objects, consider using ListObjectsPaginated for better performance. // It may return a large number of objects, consider using ListObjectsPaginated for better performance.
ListObjects(ctx context.Context, prefix string) ([]*FileInfo, error) ListAllObjects(ctx context.Context, prefix string, withTagging bool) ([]*FileInfo, error)
// ListObjectsPaginated returns objects with pagination support. // ListObjectsPaginated returns objects with pagination support.
// Use this method when dealing with large number of objects. // Use this method when dealing with large number of objects.
ListObjectsPaginated(ctx context.Context, input *ListObjectsPaginatedInput) (*ListObjectsPaginatedOutput, error) ListObjectsPaginated(ctx context.Context, input *ListObjectsPaginatedInput) (*ListObjectsPaginatedOutput, error)
@ -50,6 +55,8 @@ type ListObjectsPaginatedInput struct {
Prefix string Prefix string
PageSize int PageSize int
Cursor string Cursor string
// Include objects tagging in the listing
WithTagging bool
} }
type ListObjectsPaginatedOutput struct { type ListObjectsPaginatedOutput struct {
@ -64,4 +71,5 @@ type FileInfo struct {
LastModified time.Time LastModified time.Time
ETag string ETag string
Size int64 Size int64
Tagging map[string]string
} }

@ -21,7 +21,6 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"log"
"math/rand" "math/rand"
"net/url" "net/url"
"time" "time"
@ -99,34 +98,45 @@ func (m *minioClient) test() {
ctx := context.Background() ctx := context.Background()
objectName := fmt.Sprintf("test-file-%d.txt", rand.Int()) objectName := fmt.Sprintf("test-file-%d.txt", rand.Int())
m.ListObjects(ctx, "") err := m.PutObject(ctx, objectName, []byte("hello content"),
storage.WithContentType("text/plain"), storage.WithTagging(map[string]string{
"uid": "7543149965070155780",
"conversation_id": "7543149965070155781",
"type": "user",
}))
if err != nil {
logs.CtxErrorf(ctx, "upload file failed: %v", err)
}
err := m.PutObject(ctx, objectName, []byte("hello content"), storage.WithContentType("text/plain")) logs.CtxInfof(ctx, "upload file success")
files, err := m.ListAllObjects(ctx, "test-file-", true)
if err != nil { if err != nil {
log.Fatalf("upload file failed: %v", err) logs.CtxErrorf(ctx, "list objects failed: %v", err)
} }
log.Printf("upload file success")
logs.CtxInfof(ctx, "list objects success, files.len: %v", len(files))
url, err := m.GetObjectUrl(ctx, objectName) url, err := m.GetObjectUrl(ctx, objectName)
if err != nil { if err != nil {
log.Fatalf("get file url failed: %v", err) logs.CtxErrorf(ctx, "get file url failed: %v", err)
} }
log.Printf("get file url success, url: %s", url) logs.CtxInfof(ctx, "get file url success, url: %s", url)
content, err := m.GetObject(ctx, objectName) content, err := m.GetObject(ctx, objectName)
if err != nil { if err != nil {
log.Fatalf("download file failed: %v", err) logs.CtxErrorf(ctx, "download file failed: %v", err)
} }
log.Printf("download file success, content: %s", string(content)) logs.CtxInfof(ctx, "download file success, content: %s", string(content))
err = m.DeleteObject(ctx, objectName) err = m.DeleteObject(ctx, objectName)
if err != nil { if err != nil {
log.Fatalf("delete object failed: %v", err) logs.CtxErrorf(ctx, "delete object failed: %v", err)
} }
log.Printf("delete object success") logs.CtxInfof(ctx, "delete object success")
} }
func (m *minioClient) PutObject(ctx context.Context, objectKey string, content []byte, opts ...storage.PutOptFn) error { func (m *minioClient) PutObject(ctx context.Context, objectKey string, content []byte, opts ...storage.PutOptFn) error {
@ -161,6 +171,10 @@ func (m *minioClient) PutObjectWithReader(ctx context.Context, objectKey string,
minioOpts.Expires = *option.Expires minioOpts.Expires = *option.Expires
} }
if option.Tagging != nil {
minioOpts.UserTags = option.Tagging
}
_, err := m.client.PutObject(ctx, m.bucketName, objectKey, _, err := m.client.PutObject(ctx, m.bucketName, objectKey,
content, option.ObjectSize, minioOpts) content, option.ObjectSize, minioOpts)
if err != nil { if err != nil {
@ -223,7 +237,7 @@ func (m *minioClient) ListObjectsPaginated(ctx context.Context, input *storage.L
return nil, fmt.Errorf("page size must be positive") return nil, fmt.Errorf("page size must be positive")
} }
files, err := m.ListObjects(ctx, input.Prefix) files, err := m.ListAllObjects(ctx, input.Prefix, input.WithTagging)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -235,10 +249,11 @@ func (m *minioClient) ListObjectsPaginated(ctx context.Context, input *storage.L
}, nil }, nil
} }
func (m *minioClient) ListObjects(ctx context.Context, prefix string) ([]*storage.FileInfo, error) { func (m *minioClient) ListAllObjects(ctx context.Context, prefix string, withTagging bool) ([]*storage.FileInfo, error) {
opts := minio.ListObjectsOptions{ opts := minio.ListObjectsOptions{
Prefix: prefix, Prefix: prefix,
Recursive: true, Recursive: true,
WithMetadata: withTagging,
} }
objectCh := m.client.ListObjects(ctx, m.bucketName, opts) objectCh := m.client.ListObjects(ctx, m.bucketName, opts)
@ -248,14 +263,17 @@ func (m *minioClient) ListObjects(ctx context.Context, prefix string) ([]*storag
if object.Err != nil { if object.Err != nil {
return nil, object.Err return nil, object.Err
} }
files = append(files, &storage.FileInfo{ files = append(files, &storage.FileInfo{
Key: object.Key, Key: object.Key,
LastModified: object.LastModified, LastModified: object.LastModified,
ETag: object.ETag, ETag: object.ETag,
Size: object.Size, Size: object.Size,
Tagging: object.UserTags,
}) })
logs.CtxDebugf(ctx, "key = %s, lastModified = %s, eTag = %s, size = %d", object.Key, object.LastModified, object.ETag, object.Size) logs.CtxDebugf(ctx, "key = %s, lastModified = %s, eTag = %s, size = %d, tagging = %v",
object.Key, object.LastModified, object.ETag, object.Size, object.UserTags)
} }
return files, nil return files, nil

@ -21,16 +21,19 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"net/url"
"time" "time"
"github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage" "github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/infra/impl/storage/proxy" "github.com/coze-dev/coze-studio/backend/infra/impl/storage/proxy"
"github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/taskgroup"
) )
type s3Client struct { type s3Client struct {
@ -178,6 +181,11 @@ func (t *s3Client) PutObjectWithReader(ctx context.Context, objectKey string, co
input.ContentLength = aws.Int64(option.ObjectSize) input.ContentLength = aws.Int64(option.ObjectSize)
} }
if option.Tagging != nil {
tagging := mapToQueryParams(option.Tagging)
input.Tagging = aws.String(tagging)
}
// upload object // upload object
_, err := client.PutObject(ctx, input) _, err := client.PutObject(ctx, input)
return err return err
@ -239,49 +247,36 @@ func (t *s3Client) GetObjectUrl(ctx context.Context, objectKey string, opts ...s
return req.URL, nil return req.URL, nil
} }
func (t *s3Client) ListObjects(ctx context.Context, prefix string) ([]*storage.FileInfo, error) { func (t *s3Client) ListAllObjects(ctx context.Context, prefix string, withTagging bool) ([]*storage.FileInfo, error) {
client := t.client
bucket := t.bucketName
const ( const (
DefaultPageSize = 100 DefaultPageSize = 100
MaxListObjects = 10000 MaxListObjects = 10000
) )
input := &s3.ListObjectsV2Input{
Bucket: aws.String(bucket),
Prefix: aws.String(prefix),
MaxKeys: aws.Int32(DefaultPageSize),
}
paginator := s3.NewListObjectsV2Paginator(client, input)
var files []*storage.FileInfo var files []*storage.FileInfo
for paginator.HasMorePages() { var cursor string
page, err := paginator.NextPage(ctx) for {
output, err := t.ListObjectsPaginated(ctx, &storage.ListObjectsPaginatedInput{
Prefix: prefix,
PageSize: DefaultPageSize,
WithTagging: withTagging,
Cursor: cursor,
})
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get page, %v", err) return nil, err
} }
for _, obj := range page.Contents {
f := &storage.FileInfo{}
if obj.Key != nil {
f.Key = *obj.Key
}
if obj.LastModified != nil {
f.LastModified = *obj.LastModified
}
if obj.ETag != nil {
f.ETag = *obj.ETag
}
if obj.Size != nil {
f.Size = *obj.Size
}
files = append(files, f)
} cursor = output.Cursor
files = append(files, output.Files...)
if len(files) >= MaxListObjects { if len(files) >= MaxListObjects {
logs.CtxErrorf(ctx, "[ListObjects] max list objects reached, total: %d", len(files)) logs.CtxErrorf(ctx, "list objects failed, max list objects: %d", MaxListObjects)
break
}
if !output.IsTruncated {
break break
} }
} }
@ -340,5 +335,52 @@ func (t *s3Client) ListObjectsPaginated(ctx context.Context, input *storage.List
output.Cursor = *p.NextContinuationToken output.Cursor = *p.NextContinuationToken
} }
if input.WithTagging {
taskGroup := taskgroup.NewTaskGroup(ctx, 5)
for idx := range files {
f := files[idx]
taskGroup.Go(func() error {
tagging, err := client.GetObjectTagging(ctx, &s3.GetObjectTaggingInput{
Bucket: aws.String(bucket),
Key: aws.String(f.Key),
})
if err != nil {
return err
}
f.Tagging = tagsToMap(tagging.TagSet)
return nil
})
}
if err := taskGroup.Wait(); err != nil {
return nil, err
}
}
return output, nil return output, nil
} }
func mapToQueryParams(tagging map[string]string) string {
if len(tagging) == 0 {
return ""
}
params := url.Values{}
for k, v := range tagging {
params.Set(k, v)
}
return params.Encode()
}
func tagsToMap(tags []types.Tag) map[string]string {
if len(tags) == 0 {
return nil
}
m := make(map[string]string, len(tags))
for _, tag := range tags {
if tag.Key != nil && tag.Value != nil {
m[*tag.Key] = *tag.Value
}
}
return m
}

@ -73,15 +73,20 @@ func getTosClient(ctx context.Context, ak, sk, bucketName, endpoint, region stri
func (t *tosClient) test() { func (t *tosClient) test() {
// test list objects // test list objects
ctx := context.Background() ctx := context.Background()
t.ListObjects(ctx, "")
// test upload // test upload
objectKey := fmt.Sprintf("test-%s.txt", time.Now().Format("20060102150405")) objectKey := fmt.Sprintf("test-%s.txt", time.Now().Format("20060102150405"))
err := t.PutObject(context.Background(), objectKey, []byte("hello world")) err := t.PutObject(context.Background(), objectKey, []byte("hello world"), storage.WithTagging(map[string]string{
"uid": "7543149965070155780",
"conversation_id": "7543149965070155781",
"type": "user",
}))
if err != nil { if err != nil {
logs.CtxErrorf(context.Background(), "PutObject failed, objectKey: %s, err: %v", objectKey, err) logs.CtxErrorf(context.Background(), "PutObject failed, objectKey: %s, err: %v", objectKey, err)
} }
t.ListAllObjects(ctx, "", true)
// test download // test download
content, err := t.GetObject(context.Background(), objectKey) content, err := t.GetObject(context.Background(), objectKey)
if err != nil { if err != nil {
@ -175,6 +180,10 @@ func (t *tosClient) PutObjectWithReader(ctx context.Context, objectKey string, c
input.ContentLength = option.ObjectSize input.ContentLength = option.ObjectSize
} }
if len(option.Tagging) > 0 {
input.Meta = option.Tagging
}
_, err := client.PutObjectV2(ctx, input) _, err := client.PutObjectV2(ctx, input)
return err return err
@ -251,9 +260,10 @@ func (t *tosClient) ListObjectsPaginated(ctx context.Context, input *storage.Lis
output, err := t.client.ListObjectsV2(ctx, &tos.ListObjectsV2Input{ output, err := t.client.ListObjectsV2(ctx, &tos.ListObjectsV2Input{
Bucket: t.bucketName, Bucket: t.bucketName,
ListObjectsInput: tos.ListObjectsInput{ ListObjectsInput: tos.ListObjectsInput{
MaxKeys: int(input.PageSize), MaxKeys: int(input.PageSize),
Marker: input.Cursor, Marker: input.Cursor,
Prefix: input.Prefix, Prefix: input.Prefix,
FetchMeta: input.WithTagging,
}, },
}) })
if err != nil { if err != nil {
@ -267,11 +277,23 @@ func (t *tosClient) ListObjectsPaginated(ctx context.Context, input *storage.Lis
continue continue
} }
var tagging map[string]string
if obj.Meta != nil {
obj.Meta.Range(func(key, value string) bool {
if tagging == nil {
tagging = make(map[string]string)
}
tagging[key] = value
return true
})
}
files = append(files, &storage.FileInfo{ files = append(files, &storage.FileInfo{
Key: obj.Key, Key: obj.Key,
LastModified: obj.LastModified, LastModified: obj.LastModified,
ETag: obj.ETag, ETag: obj.ETag,
Size: obj.Size, Size: obj.Size,
Tagging: tagging,
}) })
} }
@ -282,7 +304,7 @@ func (t *tosClient) ListObjectsPaginated(ctx context.Context, input *storage.Lis
}, nil }, nil
} }
func (t *tosClient) ListObjects(ctx context.Context, prefix string) ([]*storage.FileInfo, error) { func (t *tosClient) ListAllObjects(ctx context.Context, prefix string, withTagging bool) ([]*storage.FileInfo, error) {
const ( const (
DefaultPageSize = 100 DefaultPageSize = 100
MaxListObjects = 10000 MaxListObjects = 10000
@ -293,16 +315,18 @@ func (t *tosClient) ListObjects(ctx context.Context, prefix string) ([]*storage.
for { for {
output, err := t.ListObjectsPaginated(ctx, &storage.ListObjectsPaginatedInput{ output, err := t.ListObjectsPaginated(ctx, &storage.ListObjectsPaginatedInput{
Prefix: prefix, Prefix: prefix,
PageSize: DefaultPageSize, PageSize: DefaultPageSize,
Cursor: cursor, Cursor: cursor,
WithTagging: withTagging,
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("list objects failed, prefix = %v, err: %v", prefix, err) return nil, fmt.Errorf("list objects failed, prefix = %v, err: %v", prefix, err)
} }
for _, object := range output.Files { for _, object := range output.Files {
logs.CtxDebugf(ctx, "key = %s, lastModified = %s, eTag = %s, size = %d", object.Key, object.LastModified, object.ETag, object.Size) logs.CtxDebugf(ctx, "key = %s, lastModified = %s, eTag = %s, size = %d, tagging = %v",
object.Key, object.LastModified, object.ETag, object.Size, object.Tagging)
files = append(files, object) files = append(files, object)
} }

@ -91,19 +91,19 @@ func (mr *MockStorageMockRecorder) GetObjectUrl(ctx, objectKey any, opts ...any)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetObjectUrl", reflect.TypeOf((*MockStorage)(nil).GetObjectUrl), varargs...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetObjectUrl", reflect.TypeOf((*MockStorage)(nil).GetObjectUrl), varargs...)
} }
// ListObjects mocks base method. // ListAllObjects mocks base method.
func (m *MockStorage) ListObjects(ctx context.Context, prefix string) ([]*storage.FileInfo, error) { func (m *MockStorage) ListAllObjects(ctx context.Context, prefix string, withTagging bool) ([]*storage.FileInfo, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListObjects", ctx, prefix) ret := m.ctrl.Call(m, "ListAllObjects", ctx, prefix, withTagging)
ret0, _ := ret[0].([]*storage.FileInfo) ret0, _ := ret[0].([]*storage.FileInfo)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// ListObjects indicates an expected call of ListObjects. // ListAllObjects indicates an expected call of ListAllObjects.
func (mr *MockStorageMockRecorder) ListObjects(ctx, prefix any) *gomock.Call { func (mr *MockStorageMockRecorder) ListAllObjects(ctx, prefix, withTagging any) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListObjects", reflect.TypeOf((*MockStorage)(nil).ListObjects), ctx, prefix) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAllObjects", reflect.TypeOf((*MockStorage)(nil).ListAllObjects), ctx, prefix, withTagging)
} }
// ListObjectsPaginated mocks base method. // ListObjectsPaginated mocks base method.

Loading…
Cancel
Save