Merge remote-tracking branch 'origin/master'

master
masong 5 months ago
commit de058b5ecb
  1. 32
      packages/dbgpt-app/src/dbgpt_app/knowledge/api.py
  2. 34
      packages/dbgpt-serve/src/dbgpt_serve/agent/db/gpts_app.py
  3. 101
      packages/dbgpt-system/src/dbgpt_system/upload_image/controller.py

@ -1,9 +1,12 @@
import logging
import os
import shutil
from http.client import HTTPException
from typing import List
from fastapi import APIRouter, Depends, File, Form, UploadFile
from starlette.responses import StreamingResponse
from dbgpt_serve.utils.auth import UserRequest, get_user_from_headers
from dbgpt._private.config import Config
@ -581,3 +584,32 @@ async def document_summary(request: DocumentSummaryRequest):
)
except Exception as e:
return Result.failed(code="E000X", msg=f"document summary error {e}")
@router.get("/file/preview/{file_uri:path}") # 使用path接收可能含/的URI
async def preview_file(file_uri: str):
try:
fs = get_fs()
# 1. 直接使用完整URI获取文件
file_data, metadata = fs.get_file(file_uri)
# 2. 确定媒体类型
file_ext = file_uri.split('.')[-1].lower()
media_type = f"image/{file_ext}" if file_ext in ["jpg", "jpeg", "png", "gif"] else "application/octet-stream"
# 3. 返回文件流
return StreamingResponse(
file_data,
media_type=media_type,
headers={"Content-Disposition": f"inline; filename={metadata.file_name}"}
)
except FileNotFoundError:
raise HTTPException(404, "文件不存在")
except Exception as e:
raise HTTPException(500, f"文件读取失败: {str(e)}")

@ -5,7 +5,7 @@ from datetime import datetime
from enum import Enum
from itertools import groupby
from typing import Any, Dict, List, Optional, Union
# from dbgpt_system.sys_type.type_db import TypeDao
from dbgpt_system.sys_type.type_db import TypeDao
from sqlalchemy import (
@ -85,7 +85,7 @@ class GptsAppDetail(BaseModel):
app_name=d["app_name"],
agent_name=d["agent_name"],
node_id=d["node_id"],
# icon=d.get("icon"),
icon=d.get("icon"),
resources=AgentResource.from_json_list_str(d.get("resources", None)),
prompt_template=d.get("prompt_template", None),
llm_strategy=d.get("llm_strategy", None),
@ -136,7 +136,7 @@ class GptsApp(BaseModel):
owner_avatar_url: Optional[str] = None
recommend_questions: Optional[List[RecommendQuestion]] = []
admins: List[str] = Field(default_factory=list)
# app_type: Optional[str] = None
app_type: Optional[str] = None
# By default, keep the last two rounds of conversation records as the context
keep_start_rounds: int = 0
@ -202,7 +202,7 @@ class GptsAppQuery(GptsApp):
app_codes: Optional[List[str]] = []
hot_map: Optional[Dict[str, int]] = {}
need_owner_info: Optional[str] = "true"
# app_type: Optional[str] = None
app_type: Optional[str] = None
class GptsAppResponse(BaseModel):
@ -336,7 +336,7 @@ class GptsAppEntity(Model):
)
admins = Column(Text, nullable=True, comment="administrators")
# app_type = Column(String(10), nullable=True, comment="分类")
app_type = Column(String(10), nullable=True, comment="分类")
__table_args__ = (UniqueConstraint("app_name", name="uk_gpts_app"),)
@ -631,6 +631,8 @@ class GptsAppDao(BaseDao):
app_qry = app_qry.filter(
GptsAppEntity.published == query.published.lower()
)
if query.app_type:
app_qry = app_qry.filter(GptsAppEntity.app_type == query.app_type)
if query.app_codes:
app_qry = app_qry.filter(GptsAppEntity.app_code.in_(query.app_codes))
total_count = app_qry.count()
@ -694,12 +696,12 @@ class GptsAppDao(BaseDao):
recommend_questions: List[RecommendQuestionEntity] = None,
):
# type_label = None
# if app_info.app_type:
# # 这里假设你有一个字典表DAO类可以查询
# dict_dao = TypeDao() # 需要实现这个类
# dict_item = dict_dao.select_type_by_value(app_info.app_type)
# type_label = dict_item.label if dict_item else None
type_label = None
if app_info.app_type:
# 这里假设你有一个字典表DAO类可以查询
dict_dao = TypeDao() # 需要实现这个类
dict_item = dict_dao.select_type_by_value(app_info.app_type)
type_label = dict_item.label if dict_item else None
return {
"app_code": app_info.app_code,
@ -710,7 +712,7 @@ class GptsAppDao(BaseDao):
"team_context": _load_team_context(
app_info.team_mode, app_info.team_context
),
# "icon": app_info.icon,
"icon": app_info.icon,
"user_code": app_info.user_code,
"sys_code": app_info.sys_code,
"is_collected": "true" if app_info.app_code in app_collects else "false",
@ -735,8 +737,8 @@ class GptsAppDao(BaseDao):
else []
),
"admins": [],
# "app_type": app_info.app_type,
# "type_label": type_label, #用于前端显示,type对应的汉字标签
"app_type": app_info.app_type,
"type_label": type_label, #用于前端显示,type对应的汉字标签
}
def _group_app_details(self, app_codes, session):
@ -879,7 +881,7 @@ class GptsAppDao(BaseDao):
param_need=(
json.dumps(gpts_app.param_need) if gpts_app.param_need else None
),
# app_type=gpts_app.app_type
app_type=gpts_app.app_type
)
session.add(app_entity)
@ -951,7 +953,7 @@ class GptsAppDao(BaseDao):
app_entity.param_need = json.dumps(gpts_app.param_need)
app_entity.keep_start_rounds = gpts_app.keep_start_rounds
app_entity.keep_end_rounds = gpts_app.keep_end_rounds
# app_entity.app_type = gpts_app.app_type
app_entity.app_type = gpts_app.app_type
session.merge(app_entity)
old_details = session.query(GptsAppDetailEntity).filter(

@ -1,56 +1,69 @@
from fastapi import FastAPI, UploadFile, File, HTTPException,APIRouter
from fastapi import APIRouter, UploadFile, File, HTTPException, Depends
from fastapi.responses import JSONResponse
from pathlib import Path
from dbgpt_serve.utils.auth import UserRequest, get_user_from_headers
from dbgpt.core.interface.file import FileStorageClient
from dbgpt._private.config import Config
import uuid
from fastapi.staticfiles import StaticFiles
import io
router = APIRouter()
CFG = Config()
# 配置
UPLOAD_DIR = "static/web/_next/static/uploads" # 上传文件保存目录
FILE_VISIT_PATH = "_next/static/uploads" # 上传文件保存目录
ALLOWED_EXTENSIONS = {"png", "jpg", "jpeg", "gif"} # 允许的文件类型
MAX_FILE_SIZE = 16 * 1024 * 1024 # 16MB 文件大小限制
BASE_URL = "http://192.168.130.191:5670" # 你的服务域名
ALLOWED_EXTENSIONS = {"png", "jpg", "jpeg", "gif"}
MAX_FILE_SIZE = 16 * 1024 * 1024 # 16MB
# 确保上传目录存在
Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
app = FastAPI()
# 挂载静态文件目录(添加到FastAPI应用中)
app.mount("/_next/static", StaticFiles(directory="static"), name="static")
def get_fs() -> FileStorageClient:
return FileStorageClient.get_instance(CFG.SYSTEM_APP)
def allowed_file(filename: str) -> bool:
return "." in filename and \
filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS
def is_allowed_file(filename: str) -> bool:
return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS
@router.post("/images")
async def upload_image(file: UploadFile = File(...)):
# 1. 验证文件类型
if not allowed_file(file.filename):
raise HTTPException(status_code=400, detail="只允许上传图片文件 (PNG/JPG/JPEG/GIF)")
# 2. 验证文件大小
file_size = 0
for chunk in file.file:
file_size += len(chunk)
if file_size > MAX_FILE_SIZE:
raise HTTPException(status_code=413, detail="文件大小超过 16MB 限制")
# 3. 生成唯一文件名
file_ext = file.filename.rsplit(".", 1)[1].lower()
unique_filename = f"{uuid.uuid4().hex}.{file_ext}"
save_path = Path(UPLOAD_DIR) / unique_filename
# 4. 保存文件
with open(save_path, "wb") as buffer:
file.file.seek(0) # 回到文件开头
buffer.write(file.file.read())
# 5. 返回访问 URL
file_url = f"{BASE_URL}/{FILE_VISIT_PATH}/{unique_filename}"
return JSONResponse(
status_code=200,
content={"message": "文件上传成功", "url": file_url}
)
async def upload_image(
file: UploadFile = File(...),
user_info: UserRequest = Depends(get_user_from_headers),
):
"""完全匹配FileStorageClient.save_file参数要求的图片上传接口"""
try:
# 1. 验证文件类型和大小
if not is_allowed_file(file.filename):
raise HTTPException(status_code=400, detail="仅支持PNG/JPG/JPEG/GIF图片")
file_content = await file.read()
if len(file_content) > MAX_FILE_SIZE:
raise HTTPException(status_code=413, detail="图片大小超过16MB限制")
# 2. 准备参数
fs = get_fs()
file_ext = file.filename.split(".")[-1]
file_name = f"img_{uuid.uuid4().hex}.{file_ext}"
bucket = "dbgpt_logo_file" # 设置存储桶名称
storage_type = "distributed" # 设置存储类型,如local、distributed等
# 3. 调用save_file方法,传递所有必需参数
file_uri = fs.save_file(
bucket=bucket,
file_name=file_name,
file_data=io.BytesIO(file_content),
storage_type=storage_type
)
# 4. 返回结果
return JSONResponse(
status_code=200,
content={
"success": True,
"url": f"/file/preview/{file_uri}",
"filename": file_name,
"message": "上传成功"
}
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
Loading…
Cancel
Save