feat: init improvements (#174)

main
N3ko 3 months ago committed by GitHub
parent b48c4c2792
commit 8137b0aee5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 25
      backend/application/knowledge/init.go
  2. 2
      backend/conf/model/template/model_template_openai.yaml
  3. 7
      backend/go.mod
  4. 10
      backend/go.sum
  5. 38
      backend/infra/impl/document/ocr/veocr/ve_ocr.go
  6. 21
      backend/infra/impl/embedding/ark/ark.go
  7. 32
      backend/infra/impl/embedding/wrap/ollama.go
  8. 5
      backend/infra/impl/modelmgr/static/modelmgr.go
  9. 8
      docker/.env.example

@ -26,6 +26,7 @@ import (
"time"
"github.com/cloudwego/eino-ext/components/embedding/ark"
ollamaEmb "github.com/cloudwego/eino-ext/components/embedding/ollama"
"github.com/cloudwego/eino-ext/components/embedding/openai"
ao "github.com/cloudwego/eino-ext/components/model/ark"
"github.com/cloudwego/eino-ext/components/model/deepseek"
@ -111,6 +112,9 @@ func InitService(c *ServiceComponents) (*KnowledgeApplicationService, error) {
case "ve":
ocrAK := os.Getenv("VE_OCR_AK")
ocrSK := os.Getenv("VE_OCR_SK")
if ocrAK == "" || ocrSK == "" {
logs.Warnf("[ve_ocr] ak / sk not configured, ocr might not work well")
}
inst := visual.NewInstance()
inst.Client.SetAccessKey(ocrAK)
inst.Client.SetSecretKey(ocrSK)
@ -346,6 +350,27 @@ func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
if err != nil {
return nil, fmt.Errorf("init ark embedding client failed, err=%w", err)
}
case "ollama":
var (
ollamaEmbeddingBaseURL = os.Getenv("OLLAMA_EMBEDDING_BASE_URL")
ollamaEmbeddingModel = os.Getenv("OLLAMA_EMBEDDING_MODEL")
ollamaEmbeddingDims = os.Getenv("OLLAMA_EMBEDDING_DIMS")
)
dims, err := strconv.ParseInt(ollamaEmbeddingDims, 10, 64)
if err != nil {
return nil, fmt.Errorf("init ollama embedding dims failed, err=%w", err)
}
emb, err = wrap.NewOllamaEmbedder(ctx, &ollamaEmb.EmbeddingConfig{
BaseURL: ollamaEmbeddingBaseURL,
Model: ollamaEmbeddingModel,
}, dims)
if err != nil {
return nil, fmt.Errorf("init ollama embedding failed, err=%w", err)
}
default:
return nil, fmt.Errorf("init knowledge embedding failed, type not configured")
}

@ -157,7 +157,7 @@ meta:
top_k: 0
stop: []
openai:
by_azure: true
by_azure: false
api_version: ""
response_format:
type: text

@ -12,7 +12,7 @@ require (
github.com/apache/thrift v0.21.0
github.com/bytedance/mockey v1.2.14
github.com/bytedance/sonic v1.13.2
github.com/cloudwego/eino v0.3.51
github.com/cloudwego/eino v0.3.55
github.com/cloudwego/eino-ext/components/model/ark v0.1.15
github.com/cloudwego/eino-ext/components/model/claude v0.1.1
github.com/cloudwego/eino-ext/components/model/deepseek v0.0.0-20250715055739-0d0e28441a2f
@ -55,6 +55,7 @@ require github.com/alicebob/miniredis/v2 v2.34.0
require (
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/cloudwego/eino-ext/components/embedding/ark v0.0.0-20250522060253-ddb617598b09
github.com/cloudwego/eino-ext/components/embedding/ollama v0.0.0-20250728060543-79ec300857b8
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20250522060253-ddb617598b09
github.com/cloudwego/eino-ext/components/model/gemini v0.1.2
github.com/cloudwego/eino-ext/components/model/ollama v0.0.0-20250610035057-2c4e7c8488a5
@ -66,7 +67,7 @@ require (
github.com/jinzhu/copier v0.4.0
github.com/mattn/go-shellwords v1.0.12
github.com/nsqio/go-nsq v1.1.0
github.com/ollama/ollama v0.6.5
github.com/ollama/ollama v0.9.6
github.com/rbretecher/go-postman-collection v0.9.0
github.com/volcengine/ve-tos-golang-sdk/v2 v2.7.17
github.com/yuin/goldmark v1.4.13
@ -246,7 +247,7 @@ require (
github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/uber/jaeger-client-go v2.30.0+incompatible // indirect
github.com/volcengine/volcengine-go-sdk v1.1.20 // indirect
github.com/volcengine/volcengine-go-sdk v1.1.20
github.com/x448/float16 v0.8.4 // indirect
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 // indirect
github.com/xuri/efp v0.0.0-20240408161823-9ad904a10d6d // indirect

@ -938,10 +938,12 @@ github.com/circonus-labs/circonusllhist v0.1.3/go.mod h1:kMXHVDlOchFAehlya5ePtbp
github.com/clbanning/mxj v1.8.4/go.mod h1:BVjHeAH+rl9rs6f+QIpeRl0tfu10SXn1pUSa5PVGJng=
github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4=
github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/eino v0.3.51 h1:emSaDu49v9EEJYOusL42Li/VL5QBSyBvhxO9ZcKPZvs=
github.com/cloudwego/eino v0.3.51/go.mod h1:wUjz990apdsaOraOXdh6CdhVXq8DJsOvLsVlxNTcNfY=
github.com/cloudwego/eino v0.3.55 h1:lMZrGtEh0k3qykQTLNXSXuAa98OtF2tS43GMHyvN7nA=
github.com/cloudwego/eino v0.3.55/go.mod h1:wUjz990apdsaOraOXdh6CdhVXq8DJsOvLsVlxNTcNfY=
github.com/cloudwego/eino-ext/components/embedding/ark v0.0.0-20250522060253-ddb617598b09 h1:hZScBE/Etiji2RqjlABcAkq6n1uzYPu+jo4GV5TF8Hc=
github.com/cloudwego/eino-ext/components/embedding/ark v0.0.0-20250522060253-ddb617598b09/go.mod h1:pLtH5BZKgb7/bB8+P3W5/f1d46gTl9K77+08j88Gb4k=
github.com/cloudwego/eino-ext/components/embedding/ollama v0.0.0-20250728060543-79ec300857b8 h1:uJrs6SmfYnca8A+k9+3qJ4MYwYHMncUlGac1mYQT+Ak=
github.com/cloudwego/eino-ext/components/embedding/ollama v0.0.0-20250728060543-79ec300857b8/go.mod h1:nav79aUcd+UR24dLA+7l7RcHCMlg26zbDAKvjONdrw0=
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20250522060253-ddb617598b09 h1:C8RjF193iguUuevkuv0q4SC+XGlM/DlJEgic7l8OUAI=
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20250522060253-ddb617598b09/go.mod h1:S09z/CAQNyx+AbgfJRQXLUAYlPpxQWWLVuQxO34F90A=
github.com/cloudwego/eino-ext/components/model/ark v0.1.15 h1:ydOvtEK67VI5DvNgg64eTxbjxMYhGBMOVP2okaZKk18=
@ -1616,8 +1618,8 @@ github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+
github.com/nyaruka/phonenumbers v1.0.55 h1:bj0nTO88Y68KeUQ/n3Lo2KgK7lM1hF7L9NFuwcCl3yg=
github.com/nyaruka/phonenumbers v1.0.55/go.mod h1:sDaTZ/KPX5f8qyV9qN+hIm+4ZBARJrupC6LuhshJq1U=
github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U=
github.com/ollama/ollama v0.6.5 h1:vXKkVX57ql/1ZzMw4SVK866Qfd6pjwEcITVyEpF0QXQ=
github.com/ollama/ollama v0.6.5/go.mod h1:pGgtoNyc9DdM6oZI6yMfI6jTk2Eh4c36c2GpfQCH7PY=
github.com/ollama/ollama v0.9.6 h1:HZNJmB52pMt6zLkGkkheBuXBXM5478eiSAj7GR75AMc=
github.com/ollama/ollama v0.9.6/go.mod h1:zLwx3iZ3AI4Rc/egsrx3u1w4RU2MHQ/Ylxse48jvyt4=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=

@ -18,12 +18,16 @@ package veocr
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/types/errno"
"github.com/volcengine/volc-sdk-golang/service/visual"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
)
@ -52,10 +56,14 @@ func (o *ocrImpl) FromBase64(ctx context.Context, b64 string) ([]string, error)
resp, statusCode, err := o.config.Client.OCRNormal(form)
if err != nil {
return nil, err
return nil, o.handleError(fmt.Errorf("[ve_ocr][FromBase64] OCRNormal err: %w", err))
}
if statusCode != http.StatusOK {
return nil, fmt.Errorf("[FromBase64] failed, status code=%d", statusCode)
err = fmt.Errorf("[ve_ocr][FromBase64] OCRNormal failed, status code=%d", statusCode)
if statusCode == http.StatusBadRequest {
return nil, errorx.WrapByCode(err, errno.ErrKnowledgeNonRetryableCode)
}
return nil, err
}
return resp.Data.LineTexts, nil
@ -67,10 +75,14 @@ func (o *ocrImpl) FromURL(ctx context.Context, url string) ([]string, error) {
resp, statusCode, err := o.config.Client.OCRNormal(form)
if err != nil {
return nil, err
return nil, o.handleError(fmt.Errorf("[ve_ocr][FromURL] OCRNormal error: %w", err))
}
if statusCode != http.StatusOK {
return nil, fmt.Errorf("[FromBase64] failed, status code=%d", statusCode)
err = fmt.Errorf("[ve_ocr][FromURL] OCRNormal failed, status code=%d", statusCode)
if statusCode == http.StatusBadRequest {
return nil, errorx.WrapByCode(err, errno.ErrKnowledgeNonRetryableCode)
}
return nil, err
}
return resp.Data.LineTexts, nil
@ -94,3 +106,21 @@ func (o *ocrImpl) newForm() url.Values {
}
return form
}
func (o *ocrImpl) handleError(err error) error {
var (
apiErr = &model.APIError{}
reqErr = &model.RequestError{}
)
if errors.As(err, &apiErr) {
if apiErr.HTTPStatusCode >= http.StatusInternalServerError ||
apiErr.HTTPStatusCode == http.StatusTooManyRequests {
return err
}
} else if errors.As(err, &reqErr) {
if reqErr.HTTPStatusCode >= http.StatusInternalServerError {
return err
}
}
return errorx.WrapByCode(err, errno.ErrKnowledgeNonRetryableCode)
}

@ -18,11 +18,16 @@ package ark
import (
"context"
"errors"
"fmt"
"math"
"net/http"
"github.com/cloudwego/eino-ext/components/embedding/ark"
"github.com/cloudwego/eino/components/embedding"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/types/errno"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
@ -51,7 +56,21 @@ func (d embWrap) EmbedStrings(ctx context.Context, texts []string, opts ...embed
}
normed, err := d.slicedNormL2(partResult)
if err != nil {
return nil, err
var (
apiErr = &model.APIError{}
reqErr = &model.RequestError{}
)
if errors.As(err, &apiErr) {
if apiErr.HTTPStatusCode >= http.StatusInternalServerError ||
apiErr.HTTPStatusCode == http.StatusTooManyRequests {
return nil, err
}
} else if errors.As(err, &reqErr) {
if reqErr.HTTPStatusCode >= http.StatusInternalServerError {
return nil, err
}
}
return nil, errorx.WrapByCode(err, errno.ErrKnowledgeNonRetryableCode)
}
resp = append(resp, normed...)
}

@ -0,0 +1,32 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package wrap
import (
"context"
"github.com/cloudwego/eino-ext/components/embedding/ollama"
contract "github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
)
func NewOllamaEmbedder(ctx context.Context, config *ollama.EmbeddingConfig, dimensions int64) (contract.Embedder, error) {
emb, err := ollama.NewEmbedder(ctx, config)
if err != nil {
return nil, err
}
return &denseOnlyWrap{dims: dimensions, Embedder: emb}, nil
}

@ -24,9 +24,14 @@ import (
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
func NewModelMgr(staticModels []*modelmgr.Model) (modelmgr.Manager, error) {
if len(staticModels) == 0 {
logs.Warnf("[NewModelMgr] no static models found, please check if the config has been loaded correctly")
}
mapping := make(map[int64]*modelmgr.Model, len(staticModels))
for i := range staticModels {
mapping[staticModels[i].ID] = staticModels[i]

@ -90,8 +90,8 @@ export VIKING_DB_MODEL_NAME="" # if vikingdb model name is not set, you need to
# Settings for Embedding
# The Embedding model relied on by knowledge base vectorization does not need to be configured
# if the vector database comes with built-in Embedding functionality (such as VikingDB). Currently,
# Coze Studio supports three access methods: openai, ark, and custom http. Users can simply choose one of them when using
# embedding type: openai / ark / http
# Coze Studio supports three access methods: openai, ark, ollama, and custom http. Users can simply choose one of them when using
# embedding type: openai / ark / ollama / http
export EMBEDDING_TYPE="ark"
# openai embedding
export OPENAI_EMBEDDING_BASE_URL="" # (string) OpenAI base_url
@ -108,6 +108,10 @@ export ARK_EMBEDDING_AK=""
export ARK_EMBEDDING_DIMS="2048"
export ARK_EMBEDDING_BASE_URL=""
# ollama embedding
export OLLAMA_EMBEDDING_BASE_URL=""
export OLLAMA_EMBEDDING_MODEL=""
export OLLAMA_EMBEDDING_DIMS=""
# http embedding
export HTTP_EMBEDDING_ADDR="http://127.0.0.1:6543"

Loading…
Cancel
Save