fix: SQL parsing errors occurred in custom database node (#1867)

main
Zhj 2 months ago committed by GitHub
parent b257802c7e
commit 97a023d79e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 81
      backend/domain/workflow/internal/nodes/database/customsql.go
  2. 7
      backend/domain/workflow/internal/nodes/database/customsql_test.go

@ -21,8 +21,8 @@ import (
"errors"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
@ -34,6 +34,8 @@ import (
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
var singleQuotesStringRegexp = regexp.MustCompile("[`']\\{\\{([a-zA-Z_][a-zA-Z0-9_]*(?:\\.\\w+|\\[\\d+\\])*)+\\}\\}[`']")
type CustomSQLConfig struct {
DatabaseInfoID int64
SQLTemplate string
@ -111,47 +113,60 @@ func (c *CustomSQL) Invoke(ctx context.Context, input map[string]any) (map[strin
return nil, err
}
templateParts := nodes.ParseTemplate(singleQuotesStringRegexp.ReplaceAllString(c.sqlTemplate, "?"))
templateSQL := ""
templateParts := nodes.ParseTemplate(c.sqlTemplate)
sqlParams := make([]database.SQLParam, 0, len(templateParts))
var nilError = errors.New("field is nil")
for _, templatePart := range templateParts {
if !templatePart.IsVariable {
templateSQL += templatePart.Value
continue
}
if len(templateParts) > 0 {
if len(templateParts) == 0 {
templateSQL = templateParts[0].Value
} else {
for _, templatePart := range templateParts {
if !templatePart.IsVariable {
templateSQL += templatePart.Value
continue
}
templateSQL += "?"
val, err := templatePart.Render(inputBytes, nodes.WithNilRender(func() (string, error) {
return "", nilError
}),
nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) {
b := val.(bool)
if b {
return "1", nil
val, err := templatePart.Render(inputBytes, nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) {
b := val.(bool)
if b {
return "1", nil
}
return "0", nil
}))
if err != nil {
return nil, err
}
return "0", nil
}))
templateSQL += val
if err != nil {
if !errors.Is(err, nilError) {
return nil, err
}
sqlParams = append(sqlParams, database.SQLParam{
IsNull: true,
})
} else {
sqlParams = append(sqlParams, database.SQLParam{
Value: val,
IsNull: false,
})
}
} else {
return nil, fmt.Errorf("parse template invalid")
}
sqlParamStrings := singleQuotesStringRegexp.FindAllString(c.sqlTemplate, -1)
sqlParams := make([]database.SQLParam, 0, len(sqlParamStrings))
for _, s := range sqlParamStrings {
parts := nodes.ParseTemplate(s)
for _, part := range parts {
if part.IsVariable {
val, err := part.Render(inputBytes, nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) {
b := val.(bool)
if b {
return "1", nil
}
return "0", nil
}))
if err != nil {
return nil, err
}
sqlParams = append(sqlParams, database.SQLParam{
Value: val,
})
}
}
}
// replace sql template '?' to ?
templateSQL = strings.Replace(templateSQL, "'?'", "?", -1)
templateSQL = strings.Replace(templateSQL, "`?`", "?", -1)
req.SQL = templateSQL
req.Params = sqlParams
response, err := crossdatabase.DefaultSVC().Execute(ctx, req)

@ -61,12 +61,12 @@ func TestCustomSQL_Execute(t *testing.T) {
validate: func(req *database.CustomSQLRequest) {
assert.Equal(t, int64(111), req.DatabaseInfoID)
ps := []database.SQLParam{
{Value: "v1_value"},
{Value: "v2_value"},
{Value: "v3_value"},
{Value: "1"},
}
assert.Equal(t, ps, req.Params)
assert.Equal(t, "select * from v1 where v1 = ? and v2 = ? and v3 = ?", req.SQL)
assert.Equal(t, "select * from v1 where v1 = v1_value and v2 = ? and v3 = ? and v4 = ?", req.SQL)
},
}
@ -86,7 +86,7 @@ func TestCustomSQL_Execute(t *testing.T) {
cfg := &CustomSQLConfig{
DatabaseInfoID: 111,
SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}`",
SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}` and v4 = '{{v4}}'",
}
c1, err := cfg.Build(context.Background(), &schema.NodeSchema{
@ -104,6 +104,7 @@ func TestCustomSQL_Execute(t *testing.T) {
"v1": "v1_value",
"v2": "v2_value",
"v3": "v3_value",
"v4": true,
})
assert.Nil(t, err)

Loading…
Cancel
Save