refactor(knowledge): Move the searchstore manager to app infra (#764)

main
Ryo 2 months ago committed by GitHub
parent 710bbbff2b
commit ff00dcb31b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 20
      backend/application/application.go
  2. 288
      backend/application/base/appinfra/app_infra.go
  3. 285
      backend/application/knowledge/init.go
  4. 4
      backend/infra/impl/document/searchstore/vikingdb/vikingdb_manager.go
  5. 9
      backend/infra/impl/embedding/ark/ark.go
  6. 1
      docker/.env.example

@ -128,6 +128,7 @@ func Init(ctx context.Context) (err error) {
if err != nil { if err != nil {
return fmt.Errorf("Init - initVitalServices failed, err: %v", err) return fmt.Errorf("Init - initVitalServices failed, err: %v", err)
} }
crossconnector.SetDefaultSVC(connectorImpl.InitDomainService(basicServices.connectorSVC.DomainSVC)) crossconnector.SetDefaultSVC(connectorImpl.InitDomainService(basicServices.connectorSVC.DomainSVC))
crossdatabase.SetDefaultSVC(databaseImpl.InitDomainService(primaryServices.memorySVC.DatabaseDomainSVC)) crossdatabase.SetDefaultSVC(databaseImpl.InitDomainService(primaryServices.memorySVC.DatabaseDomainSVC))
crossknowledge.SetDefaultSVC(knowledgeImpl.InitDomainService(primaryServices.knowledgeSVC.DomainSVC)) crossknowledge.SetDefaultSVC(knowledgeImpl.InitDomainService(primaryServices.knowledgeSVC.DomainSVC))
@ -254,16 +255,15 @@ func (b *basicServices) toPluginServiceComponents() *plugin.ServiceComponents {
func (b *basicServices) toKnowledgeServiceComponents(memoryService *memory.MemoryApplicationServices) *knowledge.ServiceComponents { func (b *basicServices) toKnowledgeServiceComponents(memoryService *memory.MemoryApplicationServices) *knowledge.ServiceComponents {
return &knowledge.ServiceComponents{ return &knowledge.ServiceComponents{
DB: b.infra.DB, DB: b.infra.DB,
IDGenSVC: b.infra.IDGenSVC, IDGenSVC: b.infra.IDGenSVC,
Storage: b.infra.TOSClient, Storage: b.infra.TOSClient,
RDB: memoryService.RDBDomainSVC, RDB: memoryService.RDBDomainSVC,
ImageX: b.infra.ImageXClient, SearchStoreManagers: b.infra.SearchStoreManagers,
ES: b.infra.ESClient, EventBus: b.eventbus.resourceEventBus,
EventBus: b.eventbus.resourceEventBus, CacheCli: b.infra.CacheCli,
CacheCli: b.infra.CacheCli, OCR: b.infra.OCR,
OCR: b.infra.OCR, ParserManager: b.infra.ParserManager,
ParserManager: b.infra.ParserManager,
} }
} }

@ -23,9 +23,13 @@ import (
"os" "os"
"strconv" "strconv"
"strings" "strings"
"time"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/cloudwego/eino-ext/components/embedding/ollama"
"github.com/cloudwego/eino-ext/components/embedding/openai"
"github.com/milvus-io/milvus/client/v2/milvusclient"
"github.com/volcengine/volc-sdk-golang/service/visual" "github.com/volcengine/volc-sdk-golang/service/visual"
"github.com/coze-dev/coze-studio/backend/application/internal" "github.com/coze-dev/coze-studio/backend/application/internal"
@ -34,6 +38,8 @@ import (
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner" "github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr" "github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser" "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex" "github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr" "github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis" "github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis"
@ -41,14 +47,22 @@ import (
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/sandbox" "github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/sandbox"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/ppocr" "github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/ppocr"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/veocr" "github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/veocr"
builtinParser "github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin" "github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/ppstructure" "github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/ppstructure"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/elasticsearch"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/milvus"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/vikingdb"
"github.com/coze-dev/coze-studio/backend/infra/impl/embedding/ark"
embeddingHttp "github.com/coze-dev/coze-studio/backend/infra/impl/embedding/http"
"github.com/coze-dev/coze-studio/backend/infra/impl/embedding/wrap"
"github.com/coze-dev/coze-studio/backend/infra/impl/es" "github.com/coze-dev/coze-studio/backend/infra/impl/es"
"github.com/coze-dev/coze-studio/backend/infra/impl/eventbus" "github.com/coze-dev/coze-studio/backend/infra/impl/eventbus"
"github.com/coze-dev/coze-studio/backend/infra/impl/idgen" "github.com/coze-dev/coze-studio/backend/infra/impl/idgen"
"github.com/coze-dev/coze-studio/backend/infra/impl/imagex/veimagex" "github.com/coze-dev/coze-studio/backend/infra/impl/imagex/veimagex"
"github.com/coze-dev/coze-studio/backend/infra/impl/mysql" "github.com/coze-dev/coze-studio/backend/infra/impl/mysql"
"github.com/coze-dev/coze-studio/backend/infra/impl/storage" "github.com/coze-dev/coze-studio/backend/infra/impl/storage"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/types/consts" "github.com/coze-dev/coze-studio/backend/types/consts"
) )
@ -66,6 +80,7 @@ type AppDependencies struct {
CodeRunner coderunner.Runner CodeRunner coderunner.Runner
OCR ocr.OCR OCR ocr.OCR
ParserManager parser.Manager ParserManager parser.Manager
SearchStoreManagers []searchstore.Manager
} }
func Init(ctx context.Context) (*AppDependencies, error) { func Init(ctx context.Context) (*AppDependencies, error) {
@ -122,15 +137,37 @@ func Init(ctx context.Context) (*AppDependencies, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
deps.ParserManager, err = initParserManager(deps.TOSClient, deps.OCR, imageAnnotationModel)
deps.ParserManager = initParserManager(deps.TOSClient, deps.OCR, imageAnnotationModel)
deps.SearchStoreManagers, err = initSearchStoreManagers(ctx, deps.ESClient)
if err != nil {
return nil, err
}
deps.SearchStoreManagers, err = initSearchStoreManagers(ctx, deps.ESClient)
if err != nil {
return nil, err
}
return deps, nil return deps, nil
} }
func initImageX(ctx context.Context) (imagex.ImageX, error) { func initSearchStoreManagers(ctx context.Context, es es.Client) ([]searchstore.Manager, error) {
// es full text search
esSearchstoreManager := elasticsearch.NewManager(&elasticsearch.ManagerConfig{Client: es})
uploadComponentType := os.Getenv(consts.FileUploadComponentType) // vector search
mgr, err := getVectorStore(ctx)
if err != nil {
return nil, fmt.Errorf("init vector store failed, err=%w", err)
}
return []searchstore.Manager{esSearchstoreManager, mgr}, nil
}
func initImageX(ctx context.Context) (imagex.ImageX, error) {
uploadComponentType := os.Getenv(consts.FileUploadComponentType)
if uploadComponentType != consts.FileUploadComponentTypeImagex { if uploadComponentType != consts.FileUploadComponentTypeImagex {
return storage.NewImagex(ctx) return storage.NewImagex(ctx)
} }
@ -230,12 +267,10 @@ func initOCR() ocr.OCR {
return ocr return ocr
} }
func initParserManager(storage storage.Storage, ocr ocr.OCR, imageAnnotationModel chatmodel.BaseChatModel) (parser.Manager, error) { func initParserManager(storage storage.Storage, ocr ocr.OCR, imageAnnotationModel chatmodel.BaseChatModel) parser.Manager {
var parserManager parser.Manager var parserManager parser.Manager
parserType := os.Getenv(consts.ParserType) parserType := os.Getenv(consts.ParserType)
switch parserType { switch parserType {
case "builtin":
parserManager = builtinParser.NewManager(storage, ocr, imageAnnotationModel)
case "paddleocr": case "paddleocr":
url := os.Getenv(consts.PPStructureAPIURL) url := os.Getenv(consts.PPStructureAPIURL)
client := &http.Client{} client := &http.Client{}
@ -245,8 +280,243 @@ func initParserManager(storage storage.Storage, ocr ocr.OCR, imageAnnotationMode
} }
parserManager = ppstructure.NewManager(apiConfig, ocr, storage, imageAnnotationModel) parserManager = ppstructure.NewManager(apiConfig, ocr, storage, imageAnnotationModel)
default: default:
return nil, fmt.Errorf("unexpected document parser type, type=%s", parserType) parserManager = builtin.NewManager(storage, ocr, imageAnnotationModel)
}
return parserManager
}
func getVectorStore(ctx context.Context) (searchstore.Manager, error) {
vsType := os.Getenv("VECTOR_STORE_TYPE")
switch vsType {
case "milvus":
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
milvusAddr := os.Getenv("MILVUS_ADDR")
user := os.Getenv("MILVUS_USER")
password := os.Getenv("MILVUS_PASSWORD")
mc, err := milvusclient.New(ctx, &milvusclient.ClientConfig{
Address: milvusAddr,
Username: user,
Password: password,
})
if err != nil {
return nil, fmt.Errorf("init milvus client failed, err=%w", err)
}
emb, err := getEmbedding(ctx)
if err != nil {
return nil, fmt.Errorf("init milvus embedding failed, err=%w", err)
}
mgr, err := milvus.NewManager(&milvus.ManagerConfig{
Client: mc,
Embedding: emb,
EnableHybrid: ptr.Of(true),
})
if err != nil {
return nil, fmt.Errorf("init milvus vector store failed, err=%w", err)
}
return mgr, nil
case "vikingdb":
var (
host = os.Getenv("VIKING_DB_HOST")
region = os.Getenv("VIKING_DB_REGION")
ak = os.Getenv("VIKING_DB_AK")
sk = os.Getenv("VIKING_DB_SK")
scheme = os.Getenv("VIKING_DB_SCHEME")
modelName = os.Getenv("VIKING_DB_MODEL_NAME")
)
if ak == "" || sk == "" {
return nil, fmt.Errorf("invalid vikingdb ak / sk")
}
if host == "" {
host = "api-vikingdb.volces.com"
}
if region == "" {
region = "cn-beijing"
}
if scheme == "" {
scheme = "https"
}
var embConfig *vikingdb.VikingEmbeddingConfig
if modelName != "" {
embName := vikingdb.VikingEmbeddingModelName(modelName)
if embName.Dimensions() == 0 {
return nil, fmt.Errorf("embedding model not support, model_name=%s", modelName)
}
embConfig = &vikingdb.VikingEmbeddingConfig{
UseVikingEmbedding: true,
EnableHybrid: embName.SupportStatus() == embedding.SupportDenseAndSparse,
ModelName: embName,
ModelVersion: embName.ModelVersion(),
DenseWeight: ptr.Of(0.2),
BuiltinEmbedding: nil,
}
} else {
builtinEmbedding, err := getEmbedding(ctx)
if err != nil {
return nil, fmt.Errorf("builtint embedding init failed, err=%w", err)
}
embConfig = &vikingdb.VikingEmbeddingConfig{
UseVikingEmbedding: false,
EnableHybrid: false,
BuiltinEmbedding: builtinEmbedding,
}
}
svc := vikingdb.NewVikingDBService(host, region, ak, sk, scheme)
mgr, err := vikingdb.NewManager(&vikingdb.ManagerConfig{
Service: svc,
IndexingConfig: nil, // use default config
EmbeddingConfig: embConfig,
})
if err != nil {
return nil, fmt.Errorf("init vikingdb manager failed, err=%w", err)
}
return mgr, nil
default:
return nil, fmt.Errorf("unexpected vector store type, type=%s", vsType)
}
}
func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
var batchSize int
if bs, err := strconv.ParseInt(os.Getenv("EMBEDDING_MAX_BATCH_SIZE"), 10, 64); err != nil {
logs.CtxWarnf(ctx, "EMBEDDING_MAX_BATCH_SIZE not set / invalid, using default batchSize=100")
batchSize = 100
} else {
batchSize = int(bs)
}
var emb embedding.Embedder
switch os.Getenv("EMBEDDING_TYPE") {
case "openai":
var (
openAIEmbeddingBaseURL = os.Getenv("OPENAI_EMBEDDING_BASE_URL")
openAIEmbeddingModel = os.Getenv("OPENAI_EMBEDDING_MODEL")
openAIEmbeddingApiKey = os.Getenv("OPENAI_EMBEDDING_API_KEY")
openAIEmbeddingByAzure = os.Getenv("OPENAI_EMBEDDING_BY_AZURE")
openAIEmbeddingApiVersion = os.Getenv("OPENAI_EMBEDDING_API_VERSION")
openAIEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_DIMS")
openAIRequestEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_REQUEST_DIMS")
)
byAzure, err := strconv.ParseBool(openAIEmbeddingByAzure)
if err != nil {
return nil, fmt.Errorf("init openai embedding by_azure failed, err=%w", err)
}
dims, err := strconv.ParseInt(openAIEmbeddingDims, 10, 64)
if err != nil {
return nil, fmt.Errorf("init openai embedding dims failed, err=%w", err)
}
openAICfg := &openai.EmbeddingConfig{
APIKey: openAIEmbeddingApiKey,
ByAzure: byAzure,
BaseURL: openAIEmbeddingBaseURL,
APIVersion: openAIEmbeddingApiVersion,
Model: openAIEmbeddingModel,
// Dimensions: ptr.Of(int(dims)),
}
reqDims := conv.StrToInt64D(openAIRequestEmbeddingDims, 0)
if reqDims > 0 {
// some openai model not support request dims
openAICfg.Dimensions = ptr.Of(int(reqDims))
}
emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, dims, batchSize)
if err != nil {
return nil, fmt.Errorf("init openai embedding failed, err=%w", err)
}
case "ark":
var (
arkEmbeddingBaseURL = os.Getenv("ARK_EMBEDDING_BASE_URL")
arkEmbeddingModel = os.Getenv("ARK_EMBEDDING_MODEL")
arkEmbeddingApiKey = os.Getenv("ARK_EMBEDDING_API_KEY")
// deprecated: use ARK_EMBEDDING_API_KEY instead
// ARK_EMBEDDING_AK will be removed in the future
arkEmbeddingAK = os.Getenv("ARK_EMBEDDING_AK")
arkEmbeddingDims = os.Getenv("ARK_EMBEDDING_DIMS")
arkEmbeddingAPIType = os.Getenv("ARK_EMBEDDING_API_TYPE")
)
dims, err := strconv.ParseInt(arkEmbeddingDims, 10, 64)
if err != nil {
return nil, fmt.Errorf("init ark embedding dims failed, err=%w", err)
}
apiType := ark.APITypeText
if arkEmbeddingAPIType != "" {
if t := ark.APIType(arkEmbeddingAPIType); t != ark.APITypeText && t != ark.APITypeMultiModal {
return nil, fmt.Errorf("init ark embedding api_type failed, invalid api_type=%s", t)
} else {
apiType = t
}
}
emb, err = ark.NewArkEmbedder(ctx, &ark.EmbeddingConfig{
APIKey: func() string {
if arkEmbeddingApiKey != "" {
return arkEmbeddingApiKey
}
return arkEmbeddingAK
}(),
Model: arkEmbeddingModel,
BaseURL: arkEmbeddingBaseURL,
APIType: &apiType,
}, dims, batchSize)
if err != nil {
return nil, fmt.Errorf("init ark embedding client failed, err=%w", err)
}
case "ollama":
var (
ollamaEmbeddingBaseURL = os.Getenv("OLLAMA_EMBEDDING_BASE_URL")
ollamaEmbeddingModel = os.Getenv("OLLAMA_EMBEDDING_MODEL")
ollamaEmbeddingDims = os.Getenv("OLLAMA_EMBEDDING_DIMS")
)
dims, err := strconv.ParseInt(ollamaEmbeddingDims, 10, 64)
if err != nil {
return nil, fmt.Errorf("init ollama embedding dims failed, err=%w", err)
}
emb, err = wrap.NewOllamaEmbedder(ctx, &ollama.EmbeddingConfig{
BaseURL: ollamaEmbeddingBaseURL,
Model: ollamaEmbeddingModel,
}, dims, batchSize)
if err != nil {
return nil, fmt.Errorf("init ollama embedding failed, err=%w", err)
}
case "http":
var (
httpEmbeddingBaseURL = os.Getenv("HTTP_EMBEDDING_ADDR")
httpEmbeddingDims = os.Getenv("HTTP_EMBEDDING_DIMS")
)
dims, err := strconv.ParseInt(httpEmbeddingDims, 10, 64)
if err != nil {
return nil, fmt.Errorf("init http embedding dims failed, err=%w", err)
}
emb, err = embeddingHttp.NewEmbedding(httpEmbeddingBaseURL, dims, batchSize)
if err != nil {
return nil, fmt.Errorf("init http embedding failed, err=%w", err)
}
default:
return nil, fmt.Errorf("init knowledge embedding failed, type not configured")
} }
return parserManager, nil return emb, nil
} }

@ -22,16 +22,9 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"time"
"github.com/cloudwego/eino-ext/components/embedding/ark"
ollamaEmb "github.com/cloudwego/eino-ext/components/embedding/ollama"
"github.com/cloudwego/eino-ext/components/embedding/openai"
"github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/schema" "github.com/cloudwego/eino/schema"
"github.com/milvus-io/milvus/client/v2/milvusclient"
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/application/internal" "github.com/coze-dev/coze-studio/backend/application/internal"
@ -42,41 +35,29 @@ import (
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr" "github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser" "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore" "github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
"github.com/coze-dev/coze-studio/backend/infra/contract/es"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen" "github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
"github.com/coze-dev/coze-studio/backend/infra/contract/messages2query" "github.com/coze-dev/coze-studio/backend/infra/contract/messages2query"
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb" "github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage" "github.com/coze-dev/coze-studio/backend/infra/contract/storage"
chatmodelImpl "github.com/coze-dev/coze-studio/backend/infra/impl/chatmodel" chatmodelImpl "github.com/coze-dev/coze-studio/backend/infra/impl/chatmodel"
builtinNL2SQL "github.com/coze-dev/coze-studio/backend/infra/impl/document/nl2sql/builtin" builtinNL2SQL "github.com/coze-dev/coze-studio/backend/infra/impl/document/nl2sql/builtin"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf" "github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf"
sses "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/elasticsearch"
ssmilvus "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/milvus"
ssvikingdb "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/vikingdb"
arkemb "github.com/coze-dev/coze-studio/backend/infra/impl/embedding/ark"
"github.com/coze-dev/coze-studio/backend/infra/impl/embedding/http"
"github.com/coze-dev/coze-studio/backend/infra/impl/embedding/wrap"
"github.com/coze-dev/coze-studio/backend/infra/impl/eventbus" "github.com/coze-dev/coze-studio/backend/infra/impl/eventbus"
builtinM2Q "github.com/coze-dev/coze-studio/backend/infra/impl/messages2query/builtin" builtinM2Q "github.com/coze-dev/coze-studio/backend/infra/impl/messages2query/builtin"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/types/consts" "github.com/coze-dev/coze-studio/backend/types/consts"
) )
type ServiceComponents struct { type ServiceComponents struct {
DB *gorm.DB DB *gorm.DB
IDGenSVC idgen.IDGenerator IDGenSVC idgen.IDGenerator
Storage storage.Storage Storage storage.Storage
RDB rdb.RDB RDB rdb.RDB
ImageX imagex.ImageX EventBus search.ResourceEventBus
ES es.Client CacheCli cache.Cmdable
EventBus search.ResourceEventBus OCR ocr.OCR
CacheCli cache.Cmdable ParserManager parser.Manager
OCR ocr.OCR SearchStoreManagers []searchstore.Manager
ParserManager parser.Manager
} }
func InitService(c *ServiceComponents) (*KnowledgeApplicationService, error) { func InitService(c *ServiceComponents) (*KnowledgeApplicationService, error) {
@ -89,18 +70,6 @@ func InitService(c *ServiceComponents) (*KnowledgeApplicationService, error) {
return nil, fmt.Errorf("init knowledge producer failed, err=%w", err) return nil, fmt.Errorf("init knowledge producer failed, err=%w", err)
} }
var sManagers []searchstore.Manager
// es full text search
sManagers = append(sManagers, sses.NewManager(&sses.ManagerConfig{Client: c.ES}))
// vector search
mgr, err := getVectorStore(ctx)
if err != nil {
return nil, fmt.Errorf("init vector store failed, err=%w", err)
}
sManagers = append(sManagers, mgr)
root, err := os.Getwd() root, err := os.Getwd()
if err != nil { if err != nil {
logs.Warnf("[InitConfig] Failed to get current working directory: %v", err) logs.Warnf("[InitConfig] Failed to get current working directory: %v", err)
@ -142,7 +111,7 @@ func InitService(c *ServiceComponents) (*KnowledgeApplicationService, error) {
IDGen: c.IDGenSVC, IDGen: c.IDGenSVC,
RDB: c.RDB, RDB: c.RDB,
Producer: knowledgeProducer, Producer: knowledgeProducer,
SearchStoreManagers: sManagers, SearchStoreManagers: c.SearchStoreManagers,
ParseManager: c.ParserManager, ParseManager: c.ParserManager,
Storage: c.Storage, Storage: c.Storage,
Rewriter: rewriter, Rewriter: rewriter,
@ -163,240 +132,6 @@ func InitService(c *ServiceComponents) (*KnowledgeApplicationService, error) {
return KnowledgeSVC, nil return KnowledgeSVC, nil
} }
func getVectorStore(ctx context.Context) (searchstore.Manager, error) {
vsType := os.Getenv("VECTOR_STORE_TYPE")
switch vsType {
case "milvus":
cctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
milvusAddr := os.Getenv("MILVUS_ADDR")
user := os.Getenv("MILVUS_USER")
password := os.Getenv("MILVUS_PASSWORD")
mc, err := milvusclient.New(cctx, &milvusclient.ClientConfig{
Address: milvusAddr,
Username: user,
Password: password,
})
if err != nil {
return nil, fmt.Errorf("init milvus client failed, err=%w", err)
}
emb, err := getEmbedding(ctx)
if err != nil {
return nil, fmt.Errorf("init milvus embedding failed, err=%w", err)
}
mgr, err := ssmilvus.NewManager(&ssmilvus.ManagerConfig{
Client: mc,
Embedding: emb,
EnableHybrid: ptr.Of(true),
})
if err != nil {
return nil, fmt.Errorf("init milvus vector store failed, err=%w", err)
}
return mgr, nil
case "vikingdb":
var (
host = os.Getenv("VIKING_DB_HOST")
region = os.Getenv("VIKING_DB_REGION")
ak = os.Getenv("VIKING_DB_AK")
sk = os.Getenv("VIKING_DB_SK")
scheme = os.Getenv("VIKING_DB_SCHEME")
modelName = os.Getenv("VIKING_DB_MODEL_NAME")
)
if ak == "" || sk == "" {
return nil, fmt.Errorf("invalid vikingdb ak / sk")
}
if host == "" {
host = "api-vikingdb.volces.com"
}
if region == "" {
region = "cn-beijing"
}
if scheme == "" {
scheme = "https"
}
var embConfig *ssvikingdb.VikingEmbeddingConfig
if modelName != "" {
embName := ssvikingdb.VikingEmbeddingModelName(modelName)
if embName.Dimensions() == 0 {
return nil, fmt.Errorf("embedding model not support, model_name=%s", modelName)
}
embConfig = &ssvikingdb.VikingEmbeddingConfig{
UseVikingEmbedding: true,
EnableHybrid: embName.SupportStatus() == embedding.SupportDenseAndSparse,
ModelName: embName,
ModelVersion: embName.ModelVersion(),
DenseWeight: ptr.Of(0.2),
BuiltinEmbedding: nil,
}
} else {
builtinEmbedding, err := getEmbedding(ctx)
if err != nil {
return nil, fmt.Errorf("builtint embedding init failed, err=%w", err)
}
embConfig = &ssvikingdb.VikingEmbeddingConfig{
UseVikingEmbedding: false,
EnableHybrid: false,
BuiltinEmbedding: builtinEmbedding,
}
}
svc := vikingdb.NewVikingDBService(host, region, ak, sk, scheme)
mgr, err := ssvikingdb.NewManager(&ssvikingdb.ManagerConfig{
Service: svc,
IndexingConfig: nil, // use default config
EmbeddingConfig: embConfig,
})
if err != nil {
return nil, fmt.Errorf("init vikingdb manager failed, err=%w", err)
}
return mgr, nil
default:
return nil, fmt.Errorf("unexpected vector store type, type=%s", vsType)
}
}
func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
var batchSize int
if bs, err := strconv.ParseInt(os.Getenv("EMBEDDING_MAX_BATCH_SIZE"), 10, 64); err != nil {
logs.CtxWarnf(ctx, "EMBEDDING_MAX_BATCH_SIZE not set / invalid, using default batchSize=100")
batchSize = 100
} else {
batchSize = int(bs)
}
var emb embedding.Embedder
switch os.Getenv("EMBEDDING_TYPE") {
case "openai":
var (
openAIEmbeddingBaseURL = os.Getenv("OPENAI_EMBEDDING_BASE_URL")
openAIEmbeddingModel = os.Getenv("OPENAI_EMBEDDING_MODEL")
openAIEmbeddingApiKey = os.Getenv("OPENAI_EMBEDDING_API_KEY")
openAIEmbeddingByAzure = os.Getenv("OPENAI_EMBEDDING_BY_AZURE")
openAIEmbeddingApiVersion = os.Getenv("OPENAI_EMBEDDING_API_VERSION")
openAIEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_DIMS")
openAIRequestEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_REQUEST_DIMS")
)
byAzure, err := strconv.ParseBool(openAIEmbeddingByAzure)
if err != nil {
return nil, fmt.Errorf("init openai embedding by_azure failed, err=%w", err)
}
dims, err := strconv.ParseInt(openAIEmbeddingDims, 10, 64)
if err != nil {
return nil, fmt.Errorf("init openai embedding dims failed, err=%w", err)
}
openAICfg := &openai.EmbeddingConfig{
APIKey: openAIEmbeddingApiKey,
ByAzure: byAzure,
BaseURL: openAIEmbeddingBaseURL,
APIVersion: openAIEmbeddingApiVersion,
Model: openAIEmbeddingModel,
// Dimensions: ptr.Of(int(dims)),
}
reqDims := conv.StrToInt64D(openAIRequestEmbeddingDims, 0)
if reqDims > 0 {
// some openai model not support request dims
openAICfg.Dimensions = ptr.Of(int(reqDims))
}
emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, dims, batchSize)
if err != nil {
return nil, fmt.Errorf("init openai embedding failed, err=%w", err)
}
case "ark":
var (
arkEmbeddingBaseURL = os.Getenv("ARK_EMBEDDING_BASE_URL")
arkEmbeddingModel = os.Getenv("ARK_EMBEDDING_MODEL")
arkEmbeddingApiKey = os.Getenv("ARK_EMBEDDING_API_KEY")
// deprecated: use ARK_EMBEDDING_API_KEY instead
// ARK_EMBEDDING_AK will be removed in the future
arkEmbeddingAK = os.Getenv("ARK_EMBEDDING_AK")
arkEmbeddingDims = os.Getenv("ARK_EMBEDDING_DIMS")
arkEmbeddingAPIType = os.Getenv("ARK_EMBEDDING_API_TYPE")
)
dims, err := strconv.ParseInt(arkEmbeddingDims, 10, 64)
if err != nil {
return nil, fmt.Errorf("init ark embedding dims failed, err=%w", err)
}
apiType := ark.APITypeText
if arkEmbeddingAPIType != "" {
if t := ark.APIType(arkEmbeddingAPIType); t != ark.APITypeText && t != ark.APITypeMultiModal {
return nil, fmt.Errorf("init ark embedding api_type failed, invalid api_type=%s", t)
} else {
apiType = t
}
}
emb, err = arkemb.NewArkEmbedder(ctx, &ark.EmbeddingConfig{
APIKey: func() string {
if arkEmbeddingApiKey != "" {
return arkEmbeddingApiKey
}
return arkEmbeddingAK
}(),
Model: arkEmbeddingModel,
BaseURL: arkEmbeddingBaseURL,
APIType: &apiType,
}, dims, batchSize)
if err != nil {
return nil, fmt.Errorf("init ark embedding client failed, err=%w", err)
}
case "ollama":
var (
ollamaEmbeddingBaseURL = os.Getenv("OLLAMA_EMBEDDING_BASE_URL")
ollamaEmbeddingModel = os.Getenv("OLLAMA_EMBEDDING_MODEL")
ollamaEmbeddingDims = os.Getenv("OLLAMA_EMBEDDING_DIMS")
)
dims, err := strconv.ParseInt(ollamaEmbeddingDims, 10, 64)
if err != nil {
return nil, fmt.Errorf("init ollama embedding dims failed, err=%w", err)
}
emb, err = wrap.NewOllamaEmbedder(ctx, &ollamaEmb.EmbeddingConfig{
BaseURL: ollamaEmbeddingBaseURL,
Model: ollamaEmbeddingModel,
}, dims, batchSize)
if err != nil {
return nil, fmt.Errorf("init ollama embedding failed, err=%w", err)
}
case "http":
var (
httpEmbeddingBaseURL = os.Getenv("HTTP_EMBEDDING_ADDR")
httpEmbeddingDims = os.Getenv("HTTP_EMBEDDING_DIMS")
)
dims, err := strconv.ParseInt(httpEmbeddingDims, 10, 64)
if err != nil {
return nil, fmt.Errorf("init http embedding dims failed, err=%w", err)
}
emb, err = http.NewEmbedding(httpEmbeddingBaseURL, dims, batchSize)
if err != nil {
return nil, fmt.Errorf("init http embedding failed, err=%w", err)
}
default:
return nil, fmt.Errorf("init knowledge embedding failed, type not configured")
}
return emb, nil
}
func readJinja2PromptTemplate(jsonFilePath string) (prompt.ChatTemplate, error) { func readJinja2PromptTemplate(jsonFilePath string) (prompt.ChatTemplate, error) {
b, err := os.ReadFile(jsonFilePath) b, err := os.ReadFile(jsonFilePath)
if err != nil { if err != nil {

@ -65,6 +65,10 @@ type VikingEmbeddingConfig struct {
BuiltinEmbedding embedding.Embedder BuiltinEmbedding embedding.Embedder
} }
func NewVikingDBService(host string, region string, ak string, sk string, scheme string) *vikingdb.VikingDBService {
return vikingdb.NewVikingDBService(host, region, ak, sk, scheme)
}
func NewManager(config *ManagerConfig) (searchstore.Manager, error) { func NewManager(config *ManagerConfig) (searchstore.Manager, error) {
if config.Service == nil { if config.Service == nil {
return nil, fmt.Errorf("[NewManager] vikingdb service is nil") return nil, fmt.Errorf("[NewManager] vikingdb service is nil")

@ -33,6 +33,15 @@ import (
"github.com/coze-dev/coze-studio/backend/types/errno" "github.com/coze-dev/coze-studio/backend/types/errno"
) )
type EmbeddingConfig = ark.EmbeddingConfig
type APIType = ark.APIType
const (
APITypeText = ark.APITypeText
APITypeMultiModal APIType = ark.APITypeMultiModal
)
func NewArkEmbedder(ctx context.Context, config *ark.EmbeddingConfig, dimensions int64, batchSize int) (contract.Embedder, error) { func NewArkEmbedder(ctx context.Context, config *ark.EmbeddingConfig, dimensions int64, batchSize int) (contract.Embedder, error) {
emb, err := ark.NewEmbedder(ctx, config) emb, err := ark.NewEmbedder(ctx, config)
if err != nil { if err != nil {

@ -3,7 +3,6 @@ export LISTEN_ADDR=":8888"
export LOG_LEVEL="debug" export LOG_LEVEL="debug"
export MAX_REQUEST_BODY_SIZE=1073741824 export MAX_REQUEST_BODY_SIZE=1073741824
export SERVER_HOST="http://localhost${LISTEN_ADDR}" export SERVER_HOST="http://localhost${LISTEN_ADDR}"
export MINIO_PROXY_ENDPOINT=""
export USE_SSL="0" export USE_SSL="0"
export SSL_CERT_FILE="" export SSL_CERT_FILE=""
export SSL_KEY_FILE="" export SSL_KEY_FILE=""

Loading…
Cancel
Save