fix: SQL parsing errors occurred in custom database node (#1867)

main
Zhj 2 months ago committed by GitHub
parent b257802c7e
commit 97a023d79e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 81
      backend/domain/workflow/internal/nodes/database/customsql.go
  2. 7
      backend/domain/workflow/internal/nodes/database/customsql_test.go

@ -21,8 +21,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"regexp"
"strconv" "strconv"
"strings"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database" "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database" crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
@ -34,6 +34,8 @@ import (
"github.com/coze-dev/coze-studio/backend/pkg/sonic" "github.com/coze-dev/coze-studio/backend/pkg/sonic"
) )
var singleQuotesStringRegexp = regexp.MustCompile("[`']\\{\\{([a-zA-Z_][a-zA-Z0-9_]*(?:\\.\\w+|\\[\\d+\\])*)+\\}\\}[`']")
type CustomSQLConfig struct { type CustomSQLConfig struct {
DatabaseInfoID int64 DatabaseInfoID int64
SQLTemplate string SQLTemplate string
@ -111,47 +113,60 @@ func (c *CustomSQL) Invoke(ctx context.Context, input map[string]any) (map[strin
return nil, err return nil, err
} }
templateParts := nodes.ParseTemplate(singleQuotesStringRegexp.ReplaceAllString(c.sqlTemplate, "?"))
templateSQL := "" templateSQL := ""
templateParts := nodes.ParseTemplate(c.sqlTemplate) if len(templateParts) > 0 {
sqlParams := make([]database.SQLParam, 0, len(templateParts)) if len(templateParts) == 0 {
var nilError = errors.New("field is nil") templateSQL = templateParts[0].Value
for _, templatePart := range templateParts { } else {
if !templatePart.IsVariable { for _, templatePart := range templateParts {
templateSQL += templatePart.Value if !templatePart.IsVariable {
continue templateSQL += templatePart.Value
} continue
}
templateSQL += "?" val, err := templatePart.Render(inputBytes, nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) {
val, err := templatePart.Render(inputBytes, nodes.WithNilRender(func() (string, error) { b := val.(bool)
return "", nilError if b {
}), return "1", nil
nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) { }
b := val.(bool) return "0", nil
if b { }))
return "1", nil if err != nil {
return nil, err
} }
return "0", nil templateSQL += val
}))
if err != nil {
if !errors.Is(err, nilError) {
return nil, err
} }
sqlParams = append(sqlParams, database.SQLParam{
IsNull: true,
})
} else {
sqlParams = append(sqlParams, database.SQLParam{
Value: val,
IsNull: false,
})
} }
} else {
return nil, fmt.Errorf("parse template invalid")
}
sqlParamStrings := singleQuotesStringRegexp.FindAllString(c.sqlTemplate, -1)
sqlParams := make([]database.SQLParam, 0, len(sqlParamStrings))
for _, s := range sqlParamStrings {
parts := nodes.ParseTemplate(s)
for _, part := range parts {
if part.IsVariable {
val, err := part.Render(inputBytes, nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) {
b := val.(bool)
if b {
return "1", nil
}
return "0", nil
}))
if err != nil {
return nil, err
}
sqlParams = append(sqlParams, database.SQLParam{
Value: val,
})
}
}
} }
// replace sql template '?' to ?
templateSQL = strings.Replace(templateSQL, "'?'", "?", -1)
templateSQL = strings.Replace(templateSQL, "`?`", "?", -1)
req.SQL = templateSQL req.SQL = templateSQL
req.Params = sqlParams req.Params = sqlParams
response, err := crossdatabase.DefaultSVC().Execute(ctx, req) response, err := crossdatabase.DefaultSVC().Execute(ctx, req)

@ -61,12 +61,12 @@ func TestCustomSQL_Execute(t *testing.T) {
validate: func(req *database.CustomSQLRequest) { validate: func(req *database.CustomSQLRequest) {
assert.Equal(t, int64(111), req.DatabaseInfoID) assert.Equal(t, int64(111), req.DatabaseInfoID)
ps := []database.SQLParam{ ps := []database.SQLParam{
{Value: "v1_value"},
{Value: "v2_value"}, {Value: "v2_value"},
{Value: "v3_value"}, {Value: "v3_value"},
{Value: "1"},
} }
assert.Equal(t, ps, req.Params) assert.Equal(t, ps, req.Params)
assert.Equal(t, "select * from v1 where v1 = ? and v2 = ? and v3 = ?", req.SQL) assert.Equal(t, "select * from v1 where v1 = v1_value and v2 = ? and v3 = ? and v4 = ?", req.SQL)
}, },
} }
@ -86,7 +86,7 @@ func TestCustomSQL_Execute(t *testing.T) {
cfg := &CustomSQLConfig{ cfg := &CustomSQLConfig{
DatabaseInfoID: 111, DatabaseInfoID: 111,
SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}`", SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}` and v4 = '{{v4}}'",
} }
c1, err := cfg.Build(context.Background(), &schema.NodeSchema{ c1, err := cfg.Build(context.Background(), &schema.NodeSchema{
@ -104,6 +104,7 @@ func TestCustomSQL_Execute(t *testing.T) {
"v1": "v1_value", "v1": "v1_value",
"v2": "v2_value", "v2": "v2_value",
"v3": "v3_value", "v3": "v3_value",
"v4": true,
}) })
assert.Nil(t, err) assert.Nil(t, err)

Loading…
Cancel
Save