fix(backend_plugin): plugin common header and parameter default value (#181)

main
mrh997 3 months ago committed by GitHub
parent 8137b0aee5
commit 53345f58c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 1
      backend/api/model/crossdomain/plugin/toolinfo.go
  2. 21
      backend/application/base/pluginutil/api.go
  3. 1
      backend/domain/plugin/entity/plugin.go
  4. 187
      backend/domain/plugin/service/exec_tool.go
  5. 49
      backend/domain/plugin/service/exec_tool_test.go
  6. 11
      backend/domain/plugin/service/plugin_draft.go
  7. 2
      backend/domain/plugin/service/plugin_oauth.go

@ -263,6 +263,7 @@ func toAPIParameter(paramMeta paramMetaInfo, sc *openapi3.Schema) (*common.APIPa
}
if sc.Default != nil {
apiParam.GlobalDefault = ptr.Of(fmt.Sprintf("%v", sc.Default))
apiParam.LocalDefault = ptr.Of(fmt.Sprintf("%v", sc.Default))
}

@ -20,9 +20,8 @@ import (
"net/http"
"strconv"
"github.com/getkin/kin-openapi/openapi3"
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
"github.com/getkin/kin-openapi/openapi3"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
@ -33,14 +32,6 @@ import (
func APIParamsToOpenapiOperation(reqParams, respParams []*common.APIParameter) (*openapi3.Operation, error) {
op := &openapi3.Operation{}
if reqParams != nil && len(reqParams) == 0 {
op.Parameters = []*openapi3.ParameterRef{}
op.RequestBody = entity.DefaultOpenapi3RequestBody()
}
if respParams != nil && len(respParams) == 0 {
op.Responses = entity.DefaultOpenapi3Responses()
}
hasSetReqBody := false
hasSetParams := false
@ -136,6 +127,16 @@ func APIParamsToOpenapiOperation(reqParams, respParams []*common.APIParameter) (
}
}
if op.Parameters == nil {
op.Parameters = []*openapi3.ParameterRef{}
}
if op.RequestBody == nil {
op.RequestBody = entity.DefaultOpenapi3RequestBody()
}
if op.Responses == nil {
op.Responses = entity.DefaultOpenapi3Responses()
}
return op, nil
}

@ -157,7 +157,6 @@ func NewDefaultPluginManifest() *PluginManifest {
Value: "Coze/1.0",
},
},
model.ParamInPath: {},
model.ParamInQuery: {},
},
}

@ -29,6 +29,7 @@ import (
"github.com/bytedance/sonic"
"github.com/getkin/kin-openapi/openapi3"
"github.com/tidwall/sjson"
einoCompose "github.com/cloudwego/eino/compose"
@ -479,11 +480,6 @@ func (t *toolExecutor) execute(ctx context.Context, argumentsInJson string) (res
return nil, err
}
requestStr, err := sonic.MarshalString(args)
if err != nil {
return nil, err
}
httpReq, err := t.buildHTTPRequest(ctx, args)
if err != nil {
return nil, err
@ -504,18 +500,29 @@ func (t *toolExecutor) execute(ctx context.Context, argumentsInJson string) (res
}
var reqBodyBytes []byte
if httpReq.Body != nil {
reqBodyBytes, err = io.ReadAll(httpReq.Body)
if httpReq.GetBody != nil {
reqBody, err := httpReq.GetBody()
if err != nil {
return nil, err
}
defer reqBody.Close()
reqBodyBytes, err = io.ReadAll(reqBody)
if err != nil {
return nil, err
}
}
requestStr, err := genRequestString(httpReq, reqBodyBytes)
if err != nil {
return nil, err
}
restyReq := t.svc.httpCli.NewRequest()
restyReq.Header = httpReq.Header
restyReq.Method = httpReq.Method
restyReq.URL = httpReq.URL.String()
if len(reqBodyBytes) > 0 {
if reqBodyBytes != nil {
restyReq.SetBody(reqBodyBytes)
}
restyReq.SetContext(ctx)
@ -559,6 +566,46 @@ func (t *toolExecutor) execute(ctx context.Context, argumentsInJson string) (res
}, nil
}
func genRequestString(req *http.Request, body []byte) (string, error) {
type Request struct {
Path string `json:"path"`
Header map[string]string `json:"header"`
Query map[string]string `json:"query"`
Body *[]byte `json:"body"`
}
req_ := &Request{
Path: req.URL.Path,
Header: map[string]string{},
Query: map[string]string{},
}
if len(req.Header) > 0 {
for k, v := range req.Header {
req_.Header[k] = v[0]
}
}
if len(req.URL.Query()) > 0 {
for k, v := range req.URL.Query() {
req_.Query[k] = v[0]
}
}
requestStr, err := sonic.MarshalString(req_)
if err != nil {
return "", fmt.Errorf("[genRequestString] marshal failed, err=%s", err)
}
if body != nil {
requestStr, err = sjson.SetRaw(requestStr, "body", string(body))
if err != nil {
return "", fmt.Errorf("[genRequestString] set body failed, err=%s", err)
}
}
return requestStr, nil
}
func (t *toolExecutor) preprocessArgumentsInJson(ctx context.Context, argumentsInJson string) (args map[string]any, err error) {
args, err = t.prepareArguments(ctx, argumentsInJson)
if err != nil {
@ -653,23 +700,13 @@ func (t *toolExecutor) buildHTTPRequest(ctx context.Context, argMaps map[string]
return nil, err
}
reqURL, err := locArgs.buildHTTPRequestURL(ctx, rawURL)
if err != nil {
return nil, err
}
httpReq, err = http.NewRequestWithContext(ctx, tool.GetMethod(), reqURL.String(), nil)
if err != nil {
return nil, err
}
commonParams := t.plugin.Manifest.CommonParams
header, err := locArgs.buildHTTPRequestHeader(ctx)
reqURL, err := locArgs.buildHTTPRequestURL(ctx, rawURL, commonParams)
if err != nil {
return nil, err
}
httpReq.Header = header
bodyArgs := map[string]any{}
for k, v := range argMaps {
if _, ok := locArgs.header[k]; ok {
@ -684,13 +721,27 @@ func (t *toolExecutor) buildHTTPRequest(ctx context.Context, argMaps map[string]
bodyArgs[k] = v
}
bodyBytes, contentType, err := t.buildRequestBody(ctx, tool.Operation, bodyArgs)
commonBody := commonParams[model.ParamInBody]
bodyBytes, contentType, err := t.buildRequestBody(ctx, tool.Operation, bodyArgs, commonBody)
if err != nil {
return nil, err
}
httpReq, err = http.NewRequestWithContext(ctx, tool.GetMethod(), reqURL.String(), bytes.NewBuffer(bodyBytes))
if err != nil {
return nil, err
}
commonHeader := commonParams[model.ParamInHeader]
header, err := locArgs.buildHTTPRequestHeader(ctx, commonHeader)
if err != nil {
return nil, err
}
httpReq.Header = header
if len(bodyBytes) > 0 {
httpReq.Header.Set("Content-Type", contentType)
httpReq.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
return httpReq, nil
@ -698,13 +749,6 @@ func (t *toolExecutor) buildHTTPRequest(ctx context.Context, argMaps map[string]
func (t *toolExecutor) prepareArguments(_ context.Context, argumentsInJson string) (map[string]any, error) {
args := map[string]any{}
for loc, params := range t.plugin.Manifest.CommonParams {
for _, p := range params {
if loc != model.ParamInBody {
args[p.Name] = p.Value
}
}
}
decoder := sonic.ConfigDefault.NewDecoder(bytes.NewBufferString(argumentsInJson))
decoder.UseNumber()
@ -1175,7 +1219,9 @@ type valueWithSchema struct {
paramSchema *openapi3.Parameter
}
func (l *locationArguments) buildHTTPRequestURL(_ context.Context, rawURL string) (reqURL *url.URL, err error) {
func (l *locationArguments) buildHTTPRequestURL(_ context.Context, rawURL string,
commonParams map[model.HTTPParamLocation][]*common.CommonParamSchema) (reqURL *url.URL, err error) {
if len(l.path) > 0 {
for k, v := range l.path {
vStr, err := encoder.EncodeParameter(v.paramSchema, v.argValue)
@ -1186,9 +1232,8 @@ func (l *locationArguments) buildHTTPRequestURL(_ context.Context, rawURL string
}
}
encodeQuery := ""
query := url.Values{}
if len(l.query) > 0 {
query := url.Values{}
for k, val := range l.query {
switch v := val.argValue.(type) {
case []any:
@ -1199,10 +1244,18 @@ func (l *locationArguments) buildHTTPRequestURL(_ context.Context, rawURL string
query.Add(k, encoder.MustString(v))
}
}
}
encodeQuery = query.Encode()
commonQuery := commonParams[model.ParamInQuery]
for _, v := range commonQuery {
if _, ok := l.query[v.Name]; ok {
continue
}
query.Add(v.Name, v.Value)
}
encodeQuery := query.Encode()
reqURL, err = url.Parse(rawURL)
if err != nil {
return nil, err
@ -1217,7 +1270,7 @@ func (l *locationArguments) buildHTTPRequestURL(_ context.Context, rawURL string
return reqURL, nil
}
func (l *locationArguments) buildHTTPRequestHeader(_ context.Context) (http.Header, error) {
func (l *locationArguments) buildHTTPRequestHeader(_ context.Context, commonHeaders []*common.CommonParamSchema) (http.Header, error) {
header := http.Header{}
if len(l.header) > 0 {
for k, v := range l.header {
@ -1232,44 +1285,64 @@ func (l *locationArguments) buildHTTPRequestHeader(_ context.Context) (http.Head
}
}
for _, h := range commonHeaders {
if header.Get(h.Name) != "" {
continue
}
header.Add(h.Name, h.Value)
}
return header, nil
}
func (t *toolExecutor) buildRequestBody(ctx context.Context, op *model.Openapi3Operation, bodyArgs map[string]any) (body []byte, contentType string, err error) {
func (t *toolExecutor) buildRequestBody(ctx context.Context, op *model.Openapi3Operation, bodyArgs map[string]any,
commonBody []*common.CommonParamSchema) (body []byte, contentType string, err error) {
var bodyMap map[string]any
contentType, bodySchema := t.getReqBodySchema(op)
if bodySchema == nil || bodySchema.Value == nil {
return nil, "", nil
}
if bodySchema != nil && len(bodySchema.Value.Properties) > 0 {
bodyMap, err = t.injectRequestBodyDefaultValue(ctx, bodySchema.Value, bodyArgs)
if err != nil {
return nil, "", err
}
if len(bodySchema.Value.Properties) == 0 {
return nil, "", nil
}
for paramName, prop := range bodySchema.Value.Properties {
value, ok := bodyMap[paramName]
if !ok {
continue
}
bodyMap, err := t.injectRequestBodyDefaultValue(ctx, bodySchema.Value, bodyArgs)
if err != nil {
return nil, "", err
}
_value, err := encoder.TryFixValueType(paramName, prop, value)
if err != nil {
return nil, "", err
}
for paramName, prop := range bodySchema.Value.Properties {
value, ok := bodyMap[paramName]
if !ok {
continue
bodyMap[paramName] = _value
}
_value, err := encoder.TryFixValueType(paramName, prop, value)
body, err = encoder.EncodeBodyWithContentType(contentType, bodyMap)
if err != nil {
return nil, "", err
return nil, "", fmt.Errorf("[buildRequestBody] EncodeBodyWithContentType failed, err=%v", err)
}
}
bodyMap[paramName] = _value
commonBody_ := make([]*common.CommonParamSchema, 0, len(commonBody))
for _, v := range commonBody {
if _, ok := bodyMap[v.Name]; ok {
continue
}
commonBody_ = append(commonBody_, v)
}
reqBodyStr, err := encoder.EncodeBodyWithContentType(contentType, bodyMap)
if err != nil {
return nil, "", fmt.Errorf("[buildRequestBody] EncodeBodyWithContentType failed, err=%v", err)
for _, v := range commonBody_ {
body, err = sjson.SetRawBytes(body, v.Name, []byte(v.Value))
if err != nil {
return nil, "", fmt.Errorf("[buildRequestBody] SetRawBytes failed, err=%v", err)
}
}
return reqBodyStr, contentType, nil
return body, contentType, nil
}
func (t *toolExecutor) injectRequestBodyDefaultValue(ctx context.Context, sc *openapi3.Schema, vals map[string]any) (newVals map[string]any, err error) {
@ -1327,7 +1400,7 @@ func (t *toolExecutor) injectRequestBodyDefaultValue(ctx context.Context, sc *op
}
func (t *toolExecutor) getReqBodySchema(op *model.Openapi3Operation) (string, *openapi3.SchemaRef) {
if op.RequestBody == nil || op.RequestBody.Value == nil || len(op.RequestBody.Value.Content) == 0 {
if op.RequestBody == nil || len(op.RequestBody.Value.Content) == 0 {
return "", nil
}

@ -0,0 +1,49 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package service
import (
"net/http"
"net/url"
"testing"
. "github.com/bytedance/mockey"
"github.com/stretchr/testify/assert"
)
func TestGenRequestString(t *testing.T) {
PatchConvey("", t, func() {
requestStr, err := genRequestString(&http.Request{
Header: http.Header{
"Content-Type": []string{"application/json"},
},
Method: http.MethodPost,
URL: &url.URL{Path: "/test"},
}, []byte(`{"a": 1}`))
assert.NoError(t, err)
assert.Equal(t, `{"header":{"Content-Type":["application/json"]},"query":null,"path":"/test","body":{"a": 1}}`, requestStr)
})
PatchConvey("", t, func() {
var body []byte
requestStr, err := genRequestString(&http.Request{
URL: &url.URL{Path: "/test"},
}, body)
assert.NoError(t, err)
assert.Equal(t, `{"header":null,"query":null,"path":"/test","body":null}`, requestStr)
})
}

@ -46,6 +46,7 @@ import (
func (p *pluginServiceImpl) CreateDraftPlugin(ctx context.Context, req *CreateDraftPluginRequest) (pluginID int64, err error) {
mf := entity.NewDefaultPluginManifest()
mf.CommonParams = map[model.HTTPParamLocation][]*plugin_develop_common.CommonParamSchema{}
mf.NameForHuman = req.Name
mf.NameForModel = req.Name
mf.DescriptionForHuman = req.Desc
@ -65,11 +66,11 @@ func (p *pluginServiceImpl) CreateDraftPlugin(ctx context.Context, req *CreateDr
return 0, fmt.Errorf("invalid location '%s'", loc.String())
}
for _, param := range params {
mParams := mf.CommonParams[location]
mParams = append(mParams, &plugin_develop_common.CommonParamSchema{
Name: param.Name,
Value: param.Value,
})
mf.CommonParams[location] = append(mf.CommonParams[location],
&plugin_develop_common.CommonParamSchema{
Name: param.Name,
Value: param.Value,
})
}
}

@ -194,7 +194,7 @@ func (p *pluginServiceImpl) getAccessTokenByAuthorizationCode(ctx context.Contex
meta := ci.Meta
info, exist, err := p.oauthRepo.GetAuthorizationCode(ctx, ci.Meta)
if err != nil {
return "", errorx.Wrapf(err, "GetAuthorizationCode failed, userID=%s, pluginID=%d, isDraft=%p",
return "", errorx.Wrapf(err, "GetAuthorizationCode failed, userID=%s, pluginID=%d, isDraft=%t",
meta.UserID, meta.PluginID, meta.IsDraft)
}
if !exist {

Loading…
Cancel
Save