fix: When agents use data tables, different users influence each other (#565)

main
liuyunchao-1998 3 months ago committed by GitHub
parent 8b91a640b9
commit 60285ca014
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 11
      backend/domain/memory/database/service/database_impl.go
  2. 10
      backend/infra/contract/sqlparser/sql_parser.go
  3. 79
      backend/infra/impl/sqlparser/sql_parser.go
  4. 203
      backend/infra/impl/sqlparser/sql_parser_test.go

@ -1075,7 +1075,16 @@ func (d databaseService) executeCustomSQL(ctx context.Context, req *ExecuteSQLRe
if err != nil {
return nil, fmt.Errorf("parse sql failed: %v", err)
}
// add rw mode
if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite && len(req.UserID) != 0 {
switch operation {
case sqlparsercontract.OperationTypeSelect, sqlparsercontract.OperationTypeUpdate, sqlparsercontract.OperationTypeDelete:
parsedSQL, err = sqlparser.NewSQLParser().AppendSQLFilter(parsedSQL, sqlparsercontract.SQLFilterOpAnd, fmt.Sprintf("%s = '%s'", database.DefaultUidColName, req.UserID))
if err != nil {
return nil, fmt.Errorf("append sql filter failed: %v", err)
}
}
}
insertResult := make([]map[string]interface{}, 0)
if operation == sqlparsercontract.OperationTypeInsert {
cid := consts.CozeConnectorID

@ -48,6 +48,13 @@ const (
OperationTypeUnknown OperationType = "UNKNOWN"
)
type SQLFilterOp string
const (
SQLFilterOpAnd SQLFilterOp = "AND"
SQLFilterOpOr SQLFilterOp = "OR"
)
// SQLParser defines the interface for parsing and modifying SQL statements
type SQLParser interface {
// ParseAndModifySQL parses SQL and replaces table/column names according to the provided message
@ -64,4 +71,7 @@ type SQLParser interface {
// GetInsertDataNums extracts the number of rows to be inserted from a SQL statement. Only supports single-table insert.
GetInsertDataNums(sql string) (int, error)
// AppendSQLFilter appends a filter condition to the SQL statement.
AppendSQLFilter(sql string, op SQLFilterOp, filter string) (string, error)
}

@ -24,6 +24,7 @@ import (
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/format"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/parser/opcode"
_ "github.com/pingcap/tidb/pkg/parser/test_driver"
"github.com/coze-dev/coze-studio/backend/infra/contract/sqlparser"
@ -411,3 +412,81 @@ func (p *Impl) GetInsertDataNums(sql string) (int, error) {
return len(insert.Lists), nil
}
func (p *Impl) AppendSQLFilter(sql string, op sqlparser.SQLFilterOp, filter string) (string, error) {
if sql == "" {
return "", fmt.Errorf("empty SQL statement")
}
if op == "" || (op != sqlparser.SQLFilterOpAnd && op != sqlparser.SQLFilterOpOr) {
return "", fmt.Errorf("invalid filter operator: %s", op)
}
if filter == "" {
return "", fmt.Errorf("empty filter condition")
}
stmtNode, err := p.parser.ParseOneStmt(sql, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
if err != nil {
return "", fmt.Errorf("failed to parse SQL: %v", err)
}
// extract WHERE clause
var originalWhere ast.ExprNode
switch stmt := stmtNode.(type) {
case *ast.SelectStmt:
originalWhere = stmt.Where
case *ast.UpdateStmt:
originalWhere = stmt.Where
case *ast.DeleteStmt:
originalWhere = stmt.Where
default:
return "", fmt.Errorf("append filter condition failed: only support SELECT/UPDATE/DELETE")
}
tmpSQL := fmt.Sprintf("SELECT * FROM tmp WHERE %s", filter)
tmpNode, err := p.parser.ParseOneStmt(tmpSQL, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
if err != nil {
return "", fmt.Errorf("parse filter condition failed: %v", err)
}
newExpr := tmpNode.(*ast.SelectStmt).Where
mergedExpr := mergeExpr(originalWhere, newExpr, op)
// update AST
switch stmt := stmtNode.(type) {
case *ast.SelectStmt:
stmt.Where = mergedExpr
case *ast.UpdateStmt:
stmt.Where = mergedExpr
case *ast.DeleteStmt:
stmt.Where = mergedExpr
}
// regenerate SQL
var sb strings.Builder
flags := format.RestoreStringSingleQuotes | format.RestoreStringWithoutCharset | format.RestoreNameBackQuotes
restoreCtx := format.NewRestoreCtx(flags, &sb)
if err := stmtNode.Restore(restoreCtx); err != nil {
return "", fmt.Errorf("gen SQL failed: %v", err)
}
return sb.String(), nil
}
func mergeExpr(left, right ast.ExprNode, op sqlparser.SQLFilterOp) ast.ExprNode {
if left == nil {
return right
}
if right == nil {
return left
}
switch op {
case sqlparser.SQLFilterOpAnd:
return &ast.BinaryOperationExpr{
Op: opcode.LogicAnd,
L: left,
R: right,
}
case sqlparser.SQLFilterOpOr:
return &ast.BinaryOperationExpr{
Op: opcode.LogicOr,
L: left,
R: right,
}
default:
return nil
}
}

@ -548,3 +548,206 @@ func TestGetInsertDataNums(t *testing.T) {
})
}
}
func TestAppendSQLFilter(t *testing.T) {
parser := NewSQLParser().(*Impl)
tests := []struct {
name string
sql string
condition string
connector string
want string
wantErr bool
errContains string
}{
// tset - SELECT
{
name: "SELECT - add AND to existing WHERE",
sql: "SELECT * FROM users WHERE age > 18",
condition: "status = 'active'",
connector: "AND",
want: " select * from `users` where `age`>18 and `status`='active'",
},
{
name: "SELECT - add OR to existing WHERE",
sql: "SELECT * FROM products WHERE price < 100",
condition: "category = 'electronics'",
connector: "OR",
want: "select * from `products` where `price`<100 or `category`='electronics'",
},
{
name: "SELECT - add AND to multiple conditions",
sql: "SELECT * FROM orders WHERE total > 50 AND status = 'completed'",
condition: "customer_id = 123",
connector: "AND",
want: "select * from `orders` where `total`>50 and `status`='completed' and `customer_id`=123",
},
{
name: "SELECT - add condition without WHERE",
sql: "SELECT id, name FROM customers",
condition: "is_verified = 1",
connector: "AND",
want: "select `id`,`name` from `customers` where `is_verified`=1",
},
// tset - UPDATE
{
name: "UPDATE - add AND condition",
sql: "UPDATE users SET last_login = NOW() WHERE id = 42",
condition: "is_active = true",
connector: "AND",
want: "update `users` set `last_login`=now() where `id`=42 and `is_active`=true",
},
{
name: "UPDATE - add OR condition without WHERE",
sql: "UPDATE products SET discount = 0.1",
condition: "inventory > 100",
connector: "OR",
want: "update `products` set `discount`=0.1 where `inventory`>100",
},
// tset - DELETE
{
name: "DELETE - add AND condition",
sql: "DELETE FROM logs WHERE created_at < '2023-01-01'",
condition: "severity = 'DEBUG'",
connector: "AND",
want: "delete from `logs` where `created_at`<'2023-01-01' and `severity`='debug'",
},
{
name: "DELETE - add OR condition",
sql: "DELETE FROM sessions WHERE expires_at < NOW()",
condition: "invalid = true",
connector: "OR",
want: "delete from `sessions` where `expires_at`<now() or `invalid`=true",
},
// tset - complex expr
{
name: "Complex condition with parentheses",
sql: "SELECT * FROM orders WHERE `status` = 'shipped'",
condition: "(total > 100 OR priority = 1)",
connector: "AND",
want: "select * from `orders` where `status`='shipped' and (`total`>100 or `priority`=1)",
},
{
name: "Add condition to existing parentheses",
sql: "SELECT * FROM users WHERE (age > 18 OR parent_consent = true) AND country = 'US'",
condition: "is_verified = 1",
connector: "AND",
want: "select * from `users` where (`age`>18 or `parent_consent`=true) and `country`='us' and `is_verified`=1",
},
// tset - JOIN
{
name: "SELECT with JOIN",
sql: "SELECT u.name, o.total FROM users u JOIN orders o ON u.id = o.user_id WHERE u.country = 'US'",
condition: "o.status = 'completed'",
connector: "AND",
want: "select `u`.`name`,`o`.`total` from `users` as `u` join `orders` as `o` on `u`.`id`=`o`.`user_id` where `u`.`country`='us' and `o`.`status`='completed'",
},
{
name: "SELECT with multiple joins",
sql: "SELECT p.name, c.category_name FROM products p JOIN categories c ON p.category_id = c.id WHERE p.price < 50",
condition: "c.parent_id = 1",
connector: "AND",
want: "select `p`.`name`,`c`.`category_name` from `products` as `p` join `categories` as `c` on `p`.`category_id`=`c`.`id` where `p`.`price`<50 and `c`.`parent_id`=1",
},
// test - case sensitive
{
name: "Mixed case connector",
sql: "SELECT * FROM users WHERE age > 18",
condition: "status = 'active'",
connector: "aNd",
want: "",
wantErr: true,
errContains: "invalid filter operator",
},
{
name: "Mixed case condition",
sql: "SELECT * FROM products",
condition: "CaTegorY = 'ELECTRONICS'",
connector: "AND",
want: "select * from `products` where `category`='electronics'",
},
// test - error case
{
name: "Empty SQL",
sql: "",
condition: "id = 1",
connector: "AND",
wantErr: true,
errContains: "empty SQL statement",
},
{
name: "Empty condition",
sql: "SELECT * FROM users",
condition: "",
connector: "AND",
wantErr: true,
errContains: "empty filter condition",
},
{
name: "Invalid connector",
sql: "SELECT * FROM users",
condition: "is_active = true",
connector: "",
wantErr: true,
errContains: "invalid filter operator",
},
{
name: "Unsupported statement type",
sql: "CREATE TABLE users (id INT, name VARCHAR(255))",
condition: "id > 0",
connector: "AND",
wantErr: true,
errContains: "only support SELECT/UPDATE/DELETE",
},
{
name: "Malformed SQL",
sql: "SELECTZ * FRON users",
condition: "id = 1",
connector: "AND",
wantErr: true,
errContains: "failed to parse SQL",
},
{
name: "Malformed condition",
sql: "SELECT * FROM users",
condition: "id ==",
connector: "AND",
wantErr: true,
errContains: "parse filter condition failed",
},
}
// run case
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := parser.AppendSQLFilter(tt.sql, sqlparser.SQLFilterOp(tt.connector), tt.condition)
if tt.wantErr {
if err == nil {
t.Fatal("Expected error, got nil")
}
if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("Expected error to contain %q, got %q", tt.errContains, err.Error())
}
return
}
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
normalizedResult := strings.ToLower(strings.Join(strings.Fields(result), " "))
normalizedWant := strings.ToLower(strings.Join(strings.Fields(tt.want), " "))
if !strings.EqualFold(normalizedResult, normalizedWant) {
t.Errorf("Result mismatch:\nWant: %s\nGot: %s", normalizedWant, normalizedResult)
}
})
}
}

Loading…
Cancel
Save