You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
373 lines
10 KiB
373 lines
10 KiB
/*
|
|
* Copyright 2025 coze-dev Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package oceanbase
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"gorm.io/driver/mysql"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
)
|
|
|
|
type OceanBaseOfficialClient struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
type VectorResult struct {
|
|
VectorID string `json:"vector_id"`
|
|
Content string `json:"content"`
|
|
Metadata string `json:"metadata"`
|
|
Embedding []float64 `json:"embedding"`
|
|
SimilarityScore float64 `json:"similarity_score"`
|
|
Distance float64 `json:"distance"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
type CollectionInfo struct {
|
|
Name string `json:"name"`
|
|
Dimension int `json:"dimension"`
|
|
IndexType string `json:"index_type"`
|
|
}
|
|
|
|
func NewOceanBaseOfficialClient(dsn string) (*OceanBaseOfficialClient, error) {
|
|
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Info),
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to OceanBase: %v", err)
|
|
}
|
|
|
|
client := &OceanBaseOfficialClient{db: db}
|
|
|
|
if err := client.setVectorParameters(); err != nil {
|
|
log.Printf("Warning: Failed to set vector parameters: %v", err)
|
|
}
|
|
|
|
return client, nil
|
|
}
|
|
|
|
func (c *OceanBaseOfficialClient) setVectorParameters() error {
|
|
params := map[string]string{
|
|
"ob_vector_memory_limit_percentage": "30",
|
|
"ob_query_timeout": "86400000000",
|
|
"max_allowed_packet": "1073741824",
|
|
}
|
|
|
|
for param, value := range params {
|
|
if err := c.db.Exec(fmt.Sprintf("SET GLOBAL %s = %s", param, value)).Error; err != nil {
|
|
log.Printf("Warning: Failed to set %s: %v", param, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *OceanBaseOfficialClient) CreateCollection(ctx context.Context, collectionName string, dimension int) error {
|
|
createTableSQL := fmt.Sprintf(`
|
|
CREATE TABLE IF NOT EXISTS %s (
|
|
vector_id VARCHAR(255) PRIMARY KEY,
|
|
content TEXT NOT NULL,
|
|
metadata JSON,
|
|
embedding VECTOR(%d) NOT NULL,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
|
INDEX idx_created_at (created_at),
|
|
INDEX idx_content (content(100))
|
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
|
|
`, collectionName, dimension)
|
|
|
|
if err := c.db.WithContext(ctx).Exec(createTableSQL).Error; err != nil {
|
|
return fmt.Errorf("failed to create table: %v", err)
|
|
}
|
|
|
|
createIndexSQL := fmt.Sprintf(`
|
|
CREATE VECTOR INDEX idx_%s_embedding ON %s(embedding)
|
|
WITH (distance=cosine, type=hnsw, lib=vsag, m=16, ef_construction=200, ef_search=64)
|
|
`, collectionName, collectionName)
|
|
|
|
if err := c.db.WithContext(ctx).Exec(createIndexSQL).Error; err != nil {
|
|
log.Printf("Warning: Failed to create HNSW vector index, will use exact search: %v", err)
|
|
}
|
|
|
|
log.Printf("Successfully created collection '%s' with dimension %d", collectionName, dimension)
|
|
return nil
|
|
}
|
|
|
|
func (c *OceanBaseOfficialClient) InsertVectors(ctx context.Context, collectionName string, vectors []VectorResult) error {
|
|
if len(vectors) == 0 {
|
|
return nil
|
|
}
|
|
|
|
const batchSize = 100
|
|
for i := 0; i < len(vectors); i += batchSize {
|
|
end := i + batchSize
|
|
if end > len(vectors) {
|
|
end = len(vectors)
|
|
}
|
|
batch := vectors[i:end]
|
|
|
|
if err := c.insertBatch(ctx, collectionName, batch); err != nil {
|
|
return fmt.Errorf("failed to insert vectors batch %d-%d: %v", i, end-1, err)
|
|
}
|
|
}
|
|
|
|
log.Printf("Successfully inserted %d vectors into collection '%s'", len(vectors), collectionName)
|
|
return nil
|
|
}
|
|
|
|
func (c *OceanBaseOfficialClient) insertBatch(ctx context.Context, collectionName string, batch []VectorResult) error {
|
|
placeholders := make([]string, len(batch))
|
|
values := make([]interface{}, 0, len(batch)*5)
|
|
|
|
for j, vector := range batch {
|
|
placeholders[j] = "(?, ?, ?, ?, NOW())"
|
|
values = append(values,
|
|
vector.VectorID,
|
|
vector.Content,
|
|
vector.Metadata,
|
|
c.vectorToString(vector.Embedding),
|
|
)
|
|
}
|
|
|
|
sql := fmt.Sprintf(`
|
|
INSERT INTO %s (vector_id, content, metadata, embedding, created_at)
|
|
VALUES %s
|
|
ON DUPLICATE KEY UPDATE
|
|
content = VALUES(content),
|
|
metadata = VALUES(metadata),
|
|
embedding = VALUES(embedding),
|
|
updated_at = NOW()
|
|
`, collectionName, strings.Join(placeholders, ","))
|
|
|
|
return c.db.WithContext(ctx).Exec(sql, values...).Error
|
|
}
|
|
|
|
func (c *OceanBaseOfficialClient) SearchVectors(
|
|
ctx context.Context,
|
|
collectionName string,
|
|
queryVector []float64,
|
|
topK int,
|
|
threshold float64,
|
|
) ([]VectorResult, error) {
|
|
|
|
var count int64
|
|
if err := c.db.WithContext(ctx).Table(collectionName).Count(&count).Error; err != nil {
|
|
return nil, fmt.Errorf("collection '%s' does not exist: %v", collectionName, err)
|
|
}
|
|
|
|
if count == 0 {
|
|
log.Printf("Collection '%s' is empty", collectionName)
|
|
return []VectorResult{}, nil
|
|
}
|
|
|
|
collectionInfo, err := c.getCollectionInfo(ctx, collectionName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get collection info: %v", err)
|
|
}
|
|
|
|
log.Printf("[Debug] Collection info: name=%s, dimension=%d, index_type=%s",
|
|
collectionName, collectionInfo.Dimension, collectionInfo.IndexType)
|
|
|
|
query, params, err := c.buildOptimizedSearchQuery(collectionName, queryVector, topK)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to build search query: %v", err)
|
|
}
|
|
|
|
log.Printf("[Debug] Built optimized query: %s", query)
|
|
log.Printf("[Debug] Query params count: %d", len(params))
|
|
|
|
var results []VectorResult
|
|
rows, err := c.db.WithContext(ctx).Raw(query, params...).Rows()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to execute search query: %v", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var result VectorResult
|
|
var embeddingStr string
|
|
if err := rows.Scan(
|
|
&result.VectorID,
|
|
&result.Content,
|
|
&result.Metadata,
|
|
&embeddingStr,
|
|
&result.SimilarityScore,
|
|
&result.CreatedAt,
|
|
); err != nil {
|
|
return nil, fmt.Errorf("failed to scan result row: %v", err)
|
|
}
|
|
results = append(results, result)
|
|
}
|
|
|
|
log.Printf("[Debug] Raw search results count: %d", len(results))
|
|
|
|
finalResults := c.postProcessResults(results, topK, threshold)
|
|
|
|
log.Printf("[Debug] Final results count: %d", len(finalResults))
|
|
return finalResults, nil
|
|
}
|
|
|
|
func (c *OceanBaseOfficialClient) buildOptimizedSearchQuery(
|
|
collectionName string,
|
|
queryVector []float64,
|
|
topK int,
|
|
) (string, []interface{}, error) {
|
|
|
|
queryVectorStr := c.vectorToString(queryVector)
|
|
|
|
similarityExpr := "GREATEST(0, LEAST(1, 1 - COSINE_DISTANCE(embedding, ?)))"
|
|
orderBy := "COSINE_DISTANCE(embedding, ?) ASC"
|
|
|
|
query := fmt.Sprintf(`
|
|
SELECT
|
|
vector_id,
|
|
content,
|
|
metadata,
|
|
embedding,
|
|
%s as similarity_score,
|
|
created_at
|
|
FROM %s
|
|
ORDER BY %s
|
|
APPROXIMATE
|
|
LIMIT %d
|
|
`, similarityExpr, collectionName, orderBy, topK*2)
|
|
|
|
params := []interface{}{
|
|
queryVectorStr,
|
|
queryVectorStr,
|
|
}
|
|
|
|
return query, params, nil
|
|
}
|
|
|
|
func (c *OceanBaseOfficialClient) getCollectionInfo(ctx context.Context, collectionName string) (*CollectionInfo, error) {
|
|
var dimension int
|
|
|
|
dimQuery := `
|
|
SELECT
|
|
SUBSTRING_INDEX(SUBSTRING_INDEX(COLUMN_TYPE, '(', -1), ')', 1) as dimension
|
|
FROM INFORMATION_SCHEMA.COLUMNS
|
|
WHERE TABLE_NAME = ? AND COLUMN_NAME = 'embedding'
|
|
`
|
|
|
|
if err := c.db.WithContext(ctx).Raw(dimQuery, collectionName).Scan(&dimension).Error; err != nil {
|
|
return nil, fmt.Errorf("failed to get vector dimension: %v", err)
|
|
}
|
|
|
|
var indexType string
|
|
indexQuery := `
|
|
SELECT INDEX_TYPE
|
|
FROM INFORMATION_SCHEMA.STATISTICS
|
|
WHERE TABLE_NAME = ? AND INDEX_NAME LIKE 'idx_%_embedding'
|
|
`
|
|
|
|
if err := c.db.WithContext(ctx).Raw(indexQuery, collectionName).Scan(&indexType).Error; err != nil {
|
|
indexType = "none"
|
|
}
|
|
|
|
return &CollectionInfo{
|
|
Name: collectionName,
|
|
Dimension: dimension,
|
|
IndexType: indexType,
|
|
}, nil
|
|
}
|
|
|
|
func (c *OceanBaseOfficialClient) vectorToString(vector []float64) string {
|
|
if len(vector) == 0 {
|
|
return "[]"
|
|
}
|
|
|
|
parts := make([]string, len(vector))
|
|
for i, v := range vector {
|
|
parts[i] = fmt.Sprintf("%.6f", v)
|
|
}
|
|
return "[" + strings.Join(parts, ",") + "]"
|
|
}
|
|
|
|
func (c *OceanBaseOfficialClient) postProcessResults(results []VectorResult, topK int, threshold float64) []VectorResult {
|
|
if len(results) == 0 {
|
|
return results
|
|
}
|
|
|
|
filtered := make([]VectorResult, 0, len(results))
|
|
for _, result := range results {
|
|
if result.SimilarityScore >= threshold {
|
|
filtered = append(filtered, result)
|
|
}
|
|
}
|
|
|
|
sort.Slice(filtered, func(i, j int) bool {
|
|
return filtered[i].SimilarityScore > filtered[j].SimilarityScore
|
|
})
|
|
|
|
if len(filtered) > topK {
|
|
filtered = filtered[:topK]
|
|
}
|
|
|
|
log.Printf("[Debug] Post-processed results: %d results with threshold %.3f", len(filtered), threshold)
|
|
return filtered
|
|
}
|
|
|
|
func (c *OceanBaseOfficialClient) GetDB() *gorm.DB {
|
|
return c.db
|
|
}
|
|
|
|
func (c *OceanBaseOfficialClient) DebugCollectionData(ctx context.Context, collectionName string) error {
|
|
var count int64
|
|
if err := c.db.WithContext(ctx).Table(collectionName).Count(&count).Error; err != nil {
|
|
log.Printf("[Debug] Collection '%s' does not exist: %v", collectionName, err)
|
|
return err
|
|
}
|
|
log.Printf("[Debug] Collection '%s' exists with %d vectors", collectionName, count)
|
|
|
|
log.Printf("[Debug] Sample data from collection '%s':", collectionName)
|
|
rows, err := c.db.WithContext(ctx).Raw(`
|
|
SELECT vector_id, content, created_at
|
|
FROM ` + collectionName + `
|
|
ORDER BY created_at DESC
|
|
LIMIT 5
|
|
`).Rows()
|
|
if err != nil {
|
|
log.Printf("[Debug] Failed to get sample data: %v", err)
|
|
} else {
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
var vectorID, content string
|
|
var createdAt time.Time
|
|
if err := rows.Scan(&vectorID, &content, &createdAt); err != nil {
|
|
log.Printf("[Debug] Failed to scan sample row: %v", err)
|
|
continue
|
|
}
|
|
log.Printf("[Debug] Sample: ID=%s, Content=%s, Created=%s", vectorID, content[:min(50, len(content))], createdAt)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func min(a, b int) int {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|
|
|