|
|
|
@ -21,8 +21,8 @@ import ( |
|
|
|
|
"errors" |
|
|
|
|
"fmt" |
|
|
|
|
"reflect" |
|
|
|
|
"regexp" |
|
|
|
|
"strconv" |
|
|
|
|
"strings" |
|
|
|
|
|
|
|
|
|
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database" |
|
|
|
|
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database" |
|
|
|
@ -34,6 +34,8 @@ import ( |
|
|
|
|
"github.com/coze-dev/coze-studio/backend/pkg/sonic" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
var singleQuotesStringRegexp = regexp.MustCompile("[`']\\{\\{([a-zA-Z_][a-zA-Z0-9_]*(?:\\.\\w+|\\[\\d+\\])*)+\\}\\}[`']") |
|
|
|
|
|
|
|
|
|
type CustomSQLConfig struct { |
|
|
|
|
DatabaseInfoID int64 |
|
|
|
|
SQLTemplate string |
|
|
|
@ -111,47 +113,60 @@ func (c *CustomSQL) Invoke(ctx context.Context, input map[string]any) (map[strin |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
templateParts := nodes.ParseTemplate(singleQuotesStringRegexp.ReplaceAllString(c.sqlTemplate, "?")) |
|
|
|
|
templateSQL := "" |
|
|
|
|
templateParts := nodes.ParseTemplate(c.sqlTemplate) |
|
|
|
|
sqlParams := make([]database.SQLParam, 0, len(templateParts)) |
|
|
|
|
var nilError = errors.New("field is nil") |
|
|
|
|
if len(templateParts) > 0 { |
|
|
|
|
if len(templateParts) == 0 { |
|
|
|
|
templateSQL = templateParts[0].Value |
|
|
|
|
} else { |
|
|
|
|
for _, templatePart := range templateParts { |
|
|
|
|
if !templatePart.IsVariable { |
|
|
|
|
templateSQL += templatePart.Value |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
templateSQL += "?" |
|
|
|
|
val, err := templatePart.Render(inputBytes, nodes.WithNilRender(func() (string, error) { |
|
|
|
|
return "", nilError |
|
|
|
|
}), |
|
|
|
|
nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) { |
|
|
|
|
val, err := templatePart.Render(inputBytes, nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) { |
|
|
|
|
b := val.(bool) |
|
|
|
|
if b { |
|
|
|
|
return "1", nil |
|
|
|
|
} |
|
|
|
|
return "0", nil |
|
|
|
|
})) |
|
|
|
|
|
|
|
|
|
if err != nil { |
|
|
|
|
if !errors.Is(err, nilError) { |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|
sqlParams = append(sqlParams, database.SQLParam{ |
|
|
|
|
IsNull: true, |
|
|
|
|
}) |
|
|
|
|
templateSQL += val |
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
} else { |
|
|
|
|
return nil, fmt.Errorf("parse template invalid") |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
sqlParamStrings := singleQuotesStringRegexp.FindAllString(c.sqlTemplate, -1) |
|
|
|
|
sqlParams := make([]database.SQLParam, 0, len(sqlParamStrings)) |
|
|
|
|
for _, s := range sqlParamStrings { |
|
|
|
|
parts := nodes.ParseTemplate(s) |
|
|
|
|
for _, part := range parts { |
|
|
|
|
if part.IsVariable { |
|
|
|
|
val, err := part.Render(inputBytes, nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) { |
|
|
|
|
b := val.(bool) |
|
|
|
|
if b { |
|
|
|
|
return "1", nil |
|
|
|
|
} |
|
|
|
|
return "0", nil |
|
|
|
|
})) |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|
sqlParams = append(sqlParams, database.SQLParam{ |
|
|
|
|
Value: val, |
|
|
|
|
IsNull: false, |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// replace sql template '?' to ?
|
|
|
|
|
templateSQL = strings.Replace(templateSQL, "'?'", "?", -1) |
|
|
|
|
templateSQL = strings.Replace(templateSQL, "`?`", "?", -1) |
|
|
|
|
req.SQL = templateSQL |
|
|
|
|
req.Params = sqlParams |
|
|
|
|
response, err := crossdatabase.DefaultSVC().Execute(ctx, req) |
|
|
|
|