|
|
@ -23,9 +23,13 @@ import ( |
|
|
|
"os" |
|
|
|
"os" |
|
|
|
"strconv" |
|
|
|
"strconv" |
|
|
|
"strings" |
|
|
|
"strings" |
|
|
|
|
|
|
|
"time" |
|
|
|
|
|
|
|
|
|
|
|
"gorm.io/gorm" |
|
|
|
"gorm.io/gorm" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"github.com/cloudwego/eino-ext/components/embedding/ollama" |
|
|
|
|
|
|
|
"github.com/cloudwego/eino-ext/components/embedding/openai" |
|
|
|
|
|
|
|
"github.com/milvus-io/milvus/client/v2/milvusclient" |
|
|
|
"github.com/volcengine/volc-sdk-golang/service/visual" |
|
|
|
"github.com/volcengine/volc-sdk-golang/service/visual" |
|
|
|
|
|
|
|
|
|
|
|
"github.com/coze-dev/coze-studio/backend/application/internal" |
|
|
|
"github.com/coze-dev/coze-studio/backend/application/internal" |
|
|
@ -34,6 +38,8 @@ import ( |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser" |
|
|
|
|
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore" |
|
|
|
|
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis" |
|
|
@ -41,14 +47,22 @@ import ( |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/sandbox" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/sandbox" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/ppocr" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/ppocr" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/veocr" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/veocr" |
|
|
|
builtinParser "github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/ppstructure" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/ppstructure" |
|
|
|
|
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/elasticsearch" |
|
|
|
|
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/milvus" |
|
|
|
|
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/vikingdb" |
|
|
|
|
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/embedding/ark" |
|
|
|
|
|
|
|
embeddingHttp "github.com/coze-dev/coze-studio/backend/infra/impl/embedding/http" |
|
|
|
|
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/embedding/wrap" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/es" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/es" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/eventbus" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/eventbus" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/idgen" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/idgen" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/imagex/veimagex" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/imagex/veimagex" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/mysql" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/mysql" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/storage" |
|
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/storage" |
|
|
|
|
|
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv" |
|
|
|
|
|
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" |
|
|
|
"github.com/coze-dev/coze-studio/backend/pkg/logs" |
|
|
|
"github.com/coze-dev/coze-studio/backend/pkg/logs" |
|
|
|
"github.com/coze-dev/coze-studio/backend/types/consts" |
|
|
|
"github.com/coze-dev/coze-studio/backend/types/consts" |
|
|
|
) |
|
|
|
) |
|
|
@ -66,6 +80,7 @@ type AppDependencies struct { |
|
|
|
CodeRunner coderunner.Runner |
|
|
|
CodeRunner coderunner.Runner |
|
|
|
OCR ocr.OCR |
|
|
|
OCR ocr.OCR |
|
|
|
ParserManager parser.Manager |
|
|
|
ParserManager parser.Manager |
|
|
|
|
|
|
|
SearchStoreManagers []searchstore.Manager |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func Init(ctx context.Context) (*AppDependencies, error) { |
|
|
|
func Init(ctx context.Context) (*AppDependencies, error) { |
|
|
@ -122,15 +137,37 @@ func Init(ctx context.Context) (*AppDependencies, error) { |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
} |
|
|
|
deps.ParserManager, err = initParserManager(deps.TOSClient, deps.OCR, imageAnnotationModel) |
|
|
|
|
|
|
|
|
|
|
|
deps.ParserManager = initParserManager(deps.TOSClient, deps.OCR, imageAnnotationModel) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deps.SearchStoreManagers, err = initSearchStoreManagers(ctx, deps.ESClient) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
|
|
|
|
return nil, err |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deps.SearchStoreManagers, err = initSearchStoreManagers(ctx, deps.ESClient) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
|
|
|
|
return nil, err |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
return deps, nil |
|
|
|
return deps, nil |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func initImageX(ctx context.Context) (imagex.ImageX, error) { |
|
|
|
func initSearchStoreManagers(ctx context.Context, es es.Client) ([]searchstore.Manager, error) { |
|
|
|
|
|
|
|
// es full text search
|
|
|
|
|
|
|
|
esSearchstoreManager := elasticsearch.NewManager(&elasticsearch.ManagerConfig{Client: es}) |
|
|
|
|
|
|
|
|
|
|
|
uploadComponentType := os.Getenv(consts.FileUploadComponentType) |
|
|
|
// vector search
|
|
|
|
|
|
|
|
mgr, err := getVectorStore(ctx) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
|
|
|
|
return nil, fmt.Errorf("init vector store failed, err=%w", err) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return []searchstore.Manager{esSearchstoreManager, mgr}, nil |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func initImageX(ctx context.Context) (imagex.ImageX, error) { |
|
|
|
|
|
|
|
uploadComponentType := os.Getenv(consts.FileUploadComponentType) |
|
|
|
if uploadComponentType != consts.FileUploadComponentTypeImagex { |
|
|
|
if uploadComponentType != consts.FileUploadComponentTypeImagex { |
|
|
|
return storage.NewImagex(ctx) |
|
|
|
return storage.NewImagex(ctx) |
|
|
|
} |
|
|
|
} |
|
|
@ -230,12 +267,10 @@ func initOCR() ocr.OCR { |
|
|
|
return ocr |
|
|
|
return ocr |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func initParserManager(storage storage.Storage, ocr ocr.OCR, imageAnnotationModel chatmodel.BaseChatModel) (parser.Manager, error) { |
|
|
|
func initParserManager(storage storage.Storage, ocr ocr.OCR, imageAnnotationModel chatmodel.BaseChatModel) parser.Manager { |
|
|
|
var parserManager parser.Manager |
|
|
|
var parserManager parser.Manager |
|
|
|
parserType := os.Getenv(consts.ParserType) |
|
|
|
parserType := os.Getenv(consts.ParserType) |
|
|
|
switch parserType { |
|
|
|
switch parserType { |
|
|
|
case "builtin": |
|
|
|
|
|
|
|
parserManager = builtinParser.NewManager(storage, ocr, imageAnnotationModel) |
|
|
|
|
|
|
|
case "paddleocr": |
|
|
|
case "paddleocr": |
|
|
|
url := os.Getenv(consts.PPStructureAPIURL) |
|
|
|
url := os.Getenv(consts.PPStructureAPIURL) |
|
|
|
client := &http.Client{} |
|
|
|
client := &http.Client{} |
|
|
@ -245,8 +280,243 @@ func initParserManager(storage storage.Storage, ocr ocr.OCR, imageAnnotationMode |
|
|
|
} |
|
|
|
} |
|
|
|
parserManager = ppstructure.NewManager(apiConfig, ocr, storage, imageAnnotationModel) |
|
|
|
parserManager = ppstructure.NewManager(apiConfig, ocr, storage, imageAnnotationModel) |
|
|
|
default: |
|
|
|
default: |
|
|
|
return nil, fmt.Errorf("unexpected document parser type, type=%s", parserType) |
|
|
|
parserManager = builtin.NewManager(storage, ocr, imageAnnotationModel) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return parserManager |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func getVectorStore(ctx context.Context) (searchstore.Manager, error) { |
|
|
|
|
|
|
|
vsType := os.Getenv("VECTOR_STORE_TYPE") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
switch vsType { |
|
|
|
|
|
|
|
case "milvus": |
|
|
|
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, time.Second*5) |
|
|
|
|
|
|
|
defer cancel() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
milvusAddr := os.Getenv("MILVUS_ADDR") |
|
|
|
|
|
|
|
user := os.Getenv("MILVUS_USER") |
|
|
|
|
|
|
|
password := os.Getenv("MILVUS_PASSWORD") |
|
|
|
|
|
|
|
mc, err := milvusclient.New(ctx, &milvusclient.ClientConfig{ |
|
|
|
|
|
|
|
Address: milvusAddr, |
|
|
|
|
|
|
|
Username: user, |
|
|
|
|
|
|
|
Password: password, |
|
|
|
|
|
|
|
}) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
|
|
|
|
return nil, fmt.Errorf("init milvus client failed, err=%w", err) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
emb, err := getEmbedding(ctx) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
|
|
|
|
return nil, fmt.Errorf("init milvus embedding failed, err=%w", err) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mgr, err := milvus.NewManager(&milvus.ManagerConfig{ |
|
|
|
|
|
|
|
Client: mc, |
|
|
|
|
|
|
|
Embedding: emb, |
|
|
|
|
|
|
|
EnableHybrid: ptr.Of(true), |
|
|
|
|
|
|
|
}) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
|
|
|
|
return nil, fmt.Errorf("init milvus vector store failed, err=%w", err) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return mgr, nil |
|
|
|
|
|
|
|
case "vikingdb": |
|
|
|
|
|
|
|
var ( |
|
|
|
|
|
|
|
host = os.Getenv("VIKING_DB_HOST") |
|
|
|
|
|
|
|
region = os.Getenv("VIKING_DB_REGION") |
|
|
|
|
|
|
|
ak = os.Getenv("VIKING_DB_AK") |
|
|
|
|
|
|
|
sk = os.Getenv("VIKING_DB_SK") |
|
|
|
|
|
|
|
scheme = os.Getenv("VIKING_DB_SCHEME") |
|
|
|
|
|
|
|
modelName = os.Getenv("VIKING_DB_MODEL_NAME") |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
if ak == "" || sk == "" { |
|
|
|
|
|
|
|
return nil, fmt.Errorf("invalid vikingdb ak / sk") |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
if host == "" { |
|
|
|
|
|
|
|
host = "api-vikingdb.volces.com" |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
if region == "" { |
|
|
|
|
|
|
|
region = "cn-beijing" |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
if scheme == "" { |
|
|
|
|
|
|
|
scheme = "https" |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
var embConfig *vikingdb.VikingEmbeddingConfig |
|
|
|
|
|
|
|
if modelName != "" { |
|
|
|
|
|
|
|
embName := vikingdb.VikingEmbeddingModelName(modelName) |
|
|
|
|
|
|
|
if embName.Dimensions() == 0 { |
|
|
|
|
|
|
|
return nil, fmt.Errorf("embedding model not support, model_name=%s", modelName) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
embConfig = &vikingdb.VikingEmbeddingConfig{ |
|
|
|
|
|
|
|
UseVikingEmbedding: true, |
|
|
|
|
|
|
|
EnableHybrid: embName.SupportStatus() == embedding.SupportDenseAndSparse, |
|
|
|
|
|
|
|
ModelName: embName, |
|
|
|
|
|
|
|
ModelVersion: embName.ModelVersion(), |
|
|
|
|
|
|
|
DenseWeight: ptr.Of(0.2), |
|
|
|
|
|
|
|
BuiltinEmbedding: nil, |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} else { |
|
|
|
|
|
|
|
builtinEmbedding, err := getEmbedding(ctx) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
|
|
|
|
return nil, fmt.Errorf("builtint embedding init failed, err=%w", err) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embConfig = &vikingdb.VikingEmbeddingConfig{ |
|
|
|
|
|
|
|
UseVikingEmbedding: false, |
|
|
|
|
|
|
|
EnableHybrid: false, |
|
|
|
|
|
|
|
BuiltinEmbedding: builtinEmbedding, |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
svc := vikingdb.NewVikingDBService(host, region, ak, sk, scheme) |
|
|
|
|
|
|
|
mgr, err := vikingdb.NewManager(&vikingdb.ManagerConfig{ |
|
|
|
|
|
|
|
Service: svc, |
|
|
|
|
|
|
|
IndexingConfig: nil, // use default config
|
|
|
|
|
|
|
|
EmbeddingConfig: embConfig, |
|
|
|
|
|
|
|
}) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
|
|
|
|
return nil, fmt.Errorf("init vikingdb manager failed, err=%w", err) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return mgr, nil |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default: |
|
|
|
|
|
|
|
return nil, fmt.Errorf("unexpected vector store type, type=%s", vsType) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func getEmbedding(ctx context.Context) (embedding.Embedder, error) { |
|
|
|
|
|
|
|
var batchSize int |
|
|
|
|
|
|
|
if bs, err := strconv.ParseInt(os.Getenv("EMBEDDING_MAX_BATCH_SIZE"), 10, 64); err != nil { |
|
|
|
|
|
|
|
logs.CtxWarnf(ctx, "EMBEDDING_MAX_BATCH_SIZE not set / invalid, using default batchSize=100") |
|
|
|
|
|
|
|
batchSize = 100 |
|
|
|
|
|
|
|
} else { |
|
|
|
|
|
|
|
batchSize = int(bs) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
var emb embedding.Embedder |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
switch os.Getenv("EMBEDDING_TYPE") { |
|
|
|
|
|
|
|
case "openai": |
|
|
|
|
|
|
|
var ( |
|
|
|
|
|
|
|
openAIEmbeddingBaseURL = os.Getenv("OPENAI_EMBEDDING_BASE_URL") |
|
|
|
|
|
|
|
openAIEmbeddingModel = os.Getenv("OPENAI_EMBEDDING_MODEL") |
|
|
|
|
|
|
|
openAIEmbeddingApiKey = os.Getenv("OPENAI_EMBEDDING_API_KEY") |
|
|
|
|
|
|
|
openAIEmbeddingByAzure = os.Getenv("OPENAI_EMBEDDING_BY_AZURE") |
|
|
|
|
|
|
|
openAIEmbeddingApiVersion = os.Getenv("OPENAI_EMBEDDING_API_VERSION") |
|
|
|
|
|
|
|
openAIEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_DIMS") |
|
|
|
|
|
|
|
openAIRequestEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_REQUEST_DIMS") |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
byAzure, err := strconv.ParseBool(openAIEmbeddingByAzure) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
|
|
|
|
return nil, fmt.Errorf("init openai embedding by_azure failed, err=%w", err) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dims, err := strconv.ParseInt(openAIEmbeddingDims, 10, 64) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
|
|
|
|
return nil, fmt.Errorf("init openai embedding dims failed, err=%w", err) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
openAICfg := &openai.EmbeddingConfig{ |
|
|
|
|
|
|
|
APIKey: openAIEmbeddingApiKey, |
|
|
|
|
|
|
|
ByAzure: byAzure, |
|
|
|
|
|
|
|
BaseURL: openAIEmbeddingBaseURL, |
|
|
|
|
|
|
|
APIVersion: openAIEmbeddingApiVersion, |
|
|
|
|
|
|
|
Model: openAIEmbeddingModel, |
|
|
|
|
|
|
|
// Dimensions: ptr.Of(int(dims)),
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
reqDims := conv.StrToInt64D(openAIRequestEmbeddingDims, 0) |
|
|
|
|
|
|
|
if reqDims > 0 { |
|
|
|
|
|
|
|
// some openai model not support request dims
|
|
|
|
|
|
|
|
openAICfg.Dimensions = ptr.Of(int(reqDims)) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, dims, batchSize) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
|
|
|
|
return nil, fmt.Errorf("init openai embedding failed, err=%w", err) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
case "ark": |
|
|
|
|
|
|
|
var ( |
|
|
|
|
|
|
|
arkEmbeddingBaseURL = os.Getenv("ARK_EMBEDDING_BASE_URL") |
|
|
|
|
|
|
|
arkEmbeddingModel = os.Getenv("ARK_EMBEDDING_MODEL") |
|
|
|
|
|
|
|
arkEmbeddingApiKey = os.Getenv("ARK_EMBEDDING_API_KEY") |
|
|
|
|
|
|
|
// deprecated: use ARK_EMBEDDING_API_KEY instead
|
|
|
|
|
|
|
|
// ARK_EMBEDDING_AK will be removed in the future
|
|
|
|
|
|
|
|
arkEmbeddingAK = os.Getenv("ARK_EMBEDDING_AK") |
|
|
|
|
|
|
|
arkEmbeddingDims = os.Getenv("ARK_EMBEDDING_DIMS") |
|
|
|
|
|
|
|
arkEmbeddingAPIType = os.Getenv("ARK_EMBEDDING_API_TYPE") |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dims, err := strconv.ParseInt(arkEmbeddingDims, 10, 64) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
|
|
|
|
return nil, fmt.Errorf("init ark embedding dims failed, err=%w", err) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apiType := ark.APITypeText |
|
|
|
|
|
|
|
if arkEmbeddingAPIType != "" { |
|
|
|
|
|
|
|
if t := ark.APIType(arkEmbeddingAPIType); t != ark.APITypeText && t != ark.APITypeMultiModal { |
|
|
|
|
|
|
|
return nil, fmt.Errorf("init ark embedding api_type failed, invalid api_type=%s", t) |
|
|
|
|
|
|
|
} else { |
|
|
|
|
|
|
|
apiType = t |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
emb, err = ark.NewArkEmbedder(ctx, &ark.EmbeddingConfig{ |
|
|
|
|
|
|
|
APIKey: func() string { |
|
|
|
|
|
|
|
if arkEmbeddingApiKey != "" { |
|
|
|
|
|
|
|
return arkEmbeddingApiKey |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
return arkEmbeddingAK |
|
|
|
|
|
|
|
}(), |
|
|
|
|
|
|
|
Model: arkEmbeddingModel, |
|
|
|
|
|
|
|
BaseURL: arkEmbeddingBaseURL, |
|
|
|
|
|
|
|
APIType: &apiType, |
|
|
|
|
|
|
|
}, dims, batchSize) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
|
|
|
|
return nil, fmt.Errorf("init ark embedding client failed, err=%w", err) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
case "ollama": |
|
|
|
|
|
|
|
var ( |
|
|
|
|
|
|
|
ollamaEmbeddingBaseURL = os.Getenv("OLLAMA_EMBEDDING_BASE_URL") |
|
|
|
|
|
|
|
ollamaEmbeddingModel = os.Getenv("OLLAMA_EMBEDDING_MODEL") |
|
|
|
|
|
|
|
ollamaEmbeddingDims = os.Getenv("OLLAMA_EMBEDDING_DIMS") |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dims, err := strconv.ParseInt(ollamaEmbeddingDims, 10, 64) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
|
|
|
|
return nil, fmt.Errorf("init ollama embedding dims failed, err=%w", err) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
emb, err = wrap.NewOllamaEmbedder(ctx, &ollama.EmbeddingConfig{ |
|
|
|
|
|
|
|
BaseURL: ollamaEmbeddingBaseURL, |
|
|
|
|
|
|
|
Model: ollamaEmbeddingModel, |
|
|
|
|
|
|
|
}, dims, batchSize) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
|
|
|
|
return nil, fmt.Errorf("init ollama embedding failed, err=%w", err) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
case "http": |
|
|
|
|
|
|
|
var ( |
|
|
|
|
|
|
|
httpEmbeddingBaseURL = os.Getenv("HTTP_EMBEDDING_ADDR") |
|
|
|
|
|
|
|
httpEmbeddingDims = os.Getenv("HTTP_EMBEDDING_DIMS") |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
dims, err := strconv.ParseInt(httpEmbeddingDims, 10, 64) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
|
|
|
|
return nil, fmt.Errorf("init http embedding dims failed, err=%w", err) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
emb, err = embeddingHttp.NewEmbedding(httpEmbeddingBaseURL, dims, batchSize) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
|
|
|
|
return nil, fmt.Errorf("init http embedding failed, err=%w", err) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default: |
|
|
|
|
|
|
|
return nil, fmt.Errorf("init knowledge embedding failed, type not configured") |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
return parserManager, nil |
|
|
|
return emb, nil |
|
|
|
} |
|
|
|
} |
|
|
|