diff --git a/backend/domain/workflow/internal/nodes/database/customsql.go b/backend/domain/workflow/internal/nodes/database/customsql.go index 38b7af39..f1c6cad6 100644 --- a/backend/domain/workflow/internal/nodes/database/customsql.go +++ b/backend/domain/workflow/internal/nodes/database/customsql.go @@ -21,8 +21,8 @@ import ( "errors" "fmt" "reflect" + "regexp" "strconv" - "strings" "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database" crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database" @@ -34,6 +34,8 @@ import ( "github.com/coze-dev/coze-studio/backend/pkg/sonic" ) +var singleQuotesStringRegexp = regexp.MustCompile("[`']\\{\\{([a-zA-Z_][a-zA-Z0-9_]*(?:\\.\\w+|\\[\\d+\\])*)+\\}\\}[`']") + type CustomSQLConfig struct { DatabaseInfoID int64 SQLTemplate string @@ -111,47 +113,60 @@ func (c *CustomSQL) Invoke(ctx context.Context, input map[string]any) (map[strin return nil, err } + templateParts := nodes.ParseTemplate(singleQuotesStringRegexp.ReplaceAllString(c.sqlTemplate, "?")) templateSQL := "" - templateParts := nodes.ParseTemplate(c.sqlTemplate) - sqlParams := make([]database.SQLParam, 0, len(templateParts)) - var nilError = errors.New("field is nil") - for _, templatePart := range templateParts { - if !templatePart.IsVariable { - templateSQL += templatePart.Value - continue - } + if len(templateParts) > 0 { + if len(templateParts) == 0 { + templateSQL = templateParts[0].Value + } else { + for _, templatePart := range templateParts { + if !templatePart.IsVariable { + templateSQL += templatePart.Value + continue + } - templateSQL += "?" - val, err := templatePart.Render(inputBytes, nodes.WithNilRender(func() (string, error) { - return "", nilError - }), - nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) { - b := val.(bool) - if b { - return "1", nil + val, err := templatePart.Render(inputBytes, nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) { + b := val.(bool) + if b { + return "1", nil + } + return "0", nil + })) + if err != nil { + return nil, err } - return "0", nil - })) + templateSQL += val - if err != nil { - if !errors.Is(err, nilError) { - return nil, err } - sqlParams = append(sqlParams, database.SQLParam{ - IsNull: true, - }) - } else { - sqlParams = append(sqlParams, database.SQLParam{ - Value: val, - IsNull: false, - }) } + } else { + return nil, fmt.Errorf("parse template invalid") + } + + sqlParamStrings := singleQuotesStringRegexp.FindAllString(c.sqlTemplate, -1) + sqlParams := make([]database.SQLParam, 0, len(sqlParamStrings)) + for _, s := range sqlParamStrings { + parts := nodes.ParseTemplate(s) + for _, part := range parts { + if part.IsVariable { + val, err := part.Render(inputBytes, nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) { + b := val.(bool) + if b { + return "1", nil + } + return "0", nil + })) + if err != nil { + return nil, err + } + sqlParams = append(sqlParams, database.SQLParam{ + Value: val, + }) + } + } } - // replace sql template '?' to ? - templateSQL = strings.Replace(templateSQL, "'?'", "?", -1) - templateSQL = strings.Replace(templateSQL, "`?`", "?", -1) req.SQL = templateSQL req.Params = sqlParams response, err := crossdatabase.DefaultSVC().Execute(ctx, req) diff --git a/backend/domain/workflow/internal/nodes/database/customsql_test.go b/backend/domain/workflow/internal/nodes/database/customsql_test.go index 1d7b8255..69ed10b4 100644 --- a/backend/domain/workflow/internal/nodes/database/customsql_test.go +++ b/backend/domain/workflow/internal/nodes/database/customsql_test.go @@ -61,12 +61,12 @@ func TestCustomSQL_Execute(t *testing.T) { validate: func(req *database.CustomSQLRequest) { assert.Equal(t, int64(111), req.DatabaseInfoID) ps := []database.SQLParam{ - {Value: "v1_value"}, {Value: "v2_value"}, {Value: "v3_value"}, + {Value: "1"}, } assert.Equal(t, ps, req.Params) - assert.Equal(t, "select * from v1 where v1 = ? and v2 = ? and v3 = ?", req.SQL) + assert.Equal(t, "select * from v1 where v1 = v1_value and v2 = ? and v3 = ? and v4 = ?", req.SQL) }, } @@ -86,7 +86,7 @@ func TestCustomSQL_Execute(t *testing.T) { cfg := &CustomSQLConfig{ DatabaseInfoID: 111, - SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}`", + SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}` and v4 = '{{v4}}'", } c1, err := cfg.Build(context.Background(), &schema.NodeSchema{ @@ -104,6 +104,7 @@ func TestCustomSQL_Execute(t *testing.T) { "v1": "v1_value", "v2": "v2_value", "v3": "v3_value", + "v4": true, }) assert.Nil(t, err)