Merge remote-tracking branch 'origin/master'

master
masong 4 months ago
commit ce22155228
  1. 5
      assets/schema/updateSQL.sql
  2. 24
      packages/dbgpt-core/src/dbgpt/model/base.py
  3. 4
      packages/dbgpt-core/src/dbgpt/model/cli.py
  4. 14
      packages/dbgpt-core/src/dbgpt/model/cluster/base.py
  5. 20
      packages/dbgpt-core/src/dbgpt/model/cluster/controller/controller.py
  6. 20
      packages/dbgpt-core/src/dbgpt/model/cluster/registry.py
  7. 2
      packages/dbgpt-core/src/dbgpt/model/cluster/registry_impl/db_storage.py
  8. 8
      packages/dbgpt-core/src/dbgpt/model/cluster/registry_impl/storage.py
  9. 5
      packages/dbgpt-core/src/dbgpt/model/cluster/storage.py
  10. 4
      packages/dbgpt-core/src/dbgpt/model/cluster/worker/manager.py
  11. 2
      packages/dbgpt-serve/src/dbgpt_serve/agent/db/gpts_app.py
  12. 5
      packages/dbgpt-serve/src/dbgpt_serve/model/api/endpoints.py
  13. 1
      packages/dbgpt-serve/src/dbgpt_serve/model/api/schemas.py
  14. 2
      packages/dbgpt-serve/src/dbgpt_serve/model/models/model_adapter.py
  15. 3
      packages/dbgpt-serve/src/dbgpt_serve/model/models/models.py

@ -1,4 +1,9 @@
--2025/6/12 type_name
ALTER TABLE `dbgpt_serve_model`
ADD COLUMN `type_name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NULL DEFAULT NULL COMMENT '分类名称';
---- 2025/12 ms 智能体插件
create table sys_agent_plugin(
`id` bigint(19) AUTO_INCREMENT comment '主键',

@ -1,12 +1,13 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from dataclasses import asdict, dataclass, field
from dataclasses import asdict, dataclass, field, Field
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional
from dbgpt.util.parameter_utils import ParameterDescription
from pydantic import validator
class ModelType:
@ -27,6 +28,7 @@ class ModelInstance:
model_name: str
host: str
port: int
type_name: Optional[str] = None
weight: Optional[float] = 1.0
check_healthy: Optional[bool] = True
healthy: Optional[bool] = False
@ -49,6 +51,26 @@ class ModelInstance:
return self.last_heartbeat
return self.last_heartbeat.strftime("%Y-%m-%d %H:%M:%S")
@validator("type_name", always=True)
def auto_fill_type_name(cls, v, values):
"""根据model_name自动填充type_name"""
if v is not None:
return v
type_mapping = {
"llm": "大语言模型",
"text2vec": "文本向量模型",
"reranker": "重排序模型"
}
if "model_name" in values and values["model_name"]:
try:
model_type = values["model_name"].split("@")[1]
return type_mapping.get(model_type, "未知")
except IndexError:
return "未知"
return v
class WorkerApplyType(str, Enum):
START = "start"

@ -266,6 +266,10 @@ def _remote_model_dynamic_factory() -> Callable[[None], List[Type]]:
},
)
type_name: Optional[str] = field(
default=None
)
return [RemoteModelWorkerParameters, real_params_cls]

@ -1,4 +1,5 @@
from typing import Any, Dict, List, Optional, Union
from pydantic import validator
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core.interface.message import ModelMessage
@ -79,6 +80,7 @@ class WorkerStartupRequest(BaseModel):
port: int
model: str
worker_type: WorkerType
type_name: Optional[str] = Field(None, description="自动根据worker_type填充")
params: Dict
delete_after: Optional[bool] = Field(
False,
@ -92,3 +94,15 @@ class WorkerStartupRequest(BaseModel):
user_name: Optional[str] = Field(
None, description="The user name for the worker, used for authentication"
)
@validator("type_name", always=True)
def auto_fill_type_name(cls, v, values):
"""根据worker_type自动填充type_name"""
type_mapping = {
WorkerType.LLM: "大语言模型",
WorkerType.TEXT2VEC: "文本向量模型",
WorkerType.RERANKER: "重排序模型"
}
if "worker_type" in values:
return type_mapping.get(values["worker_type"], "未知类型")
return v

@ -44,7 +44,7 @@ class BaseModelController(BaseComponent, ABC):
@abstractmethod
async def get_all_instances(
self, model_name: str = None, healthy_only: bool = False
self, model_name: str = None, healthy_only: bool = False, type_name: Optional[str] = None,name: Optional[str] = None
) -> List[ModelInstance]:
"""Fetch all instances of a given model.
@ -72,14 +72,17 @@ class LocalModelController(BaseModelController):
return await self.registry.deregister_instance(instance)
async def get_all_instances(
self, model_name: str = None, healthy_only: bool = False
self, model_name: str = None, healthy_only: bool = False,
type_name: Optional[str] = None,name: Optional[str] = None
) -> List[ModelInstance]:
logger.info(
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
)
if not model_name:
return await self.registry.get_all_model_instances(
healthy_only=healthy_only
healthy_only=healthy_only,
type_name=type_name,
name=name
)
else:
return await self.registry.get_all_instances(model_name, healthy_only)
@ -118,7 +121,7 @@ class _RemoteModelController(APIMixin, BaseModelController):
@api_remote(path="/api/controller/models")
async def get_all_instances(
self, model_name: str = None, healthy_only: bool = False
self, model_name: str = None, healthy_only: bool = False,type_name: Optional[str] = None,name: Optional[str] = None
) -> List[ModelInstance]:
pass
@ -129,9 +132,10 @@ class _RemoteModelController(APIMixin, BaseModelController):
class ModelRegistryClient(_RemoteModelController, ModelRegistry):
async def get_all_model_instances(
self, healthy_only: bool = False
self, healthy_only: bool = False,
type_name: Optional[str] = None,name: Optional[str] = None
) -> List[ModelInstance]:
return await self.get_all_instances(healthy_only=healthy_only)
return await self.get_all_instances(healthy_only=healthy_only,type_name=type_name)
@sync_api_remote(path="/api/controller/models")
def sync_get_all_instances(
@ -151,9 +155,9 @@ class ModelControllerAdapter(BaseModelController):
return await self.backend.deregister_instance(instance)
async def get_all_instances(
self, model_name: str = None, healthy_only: bool = False
self, model_name: str = None, healthy_only: bool = False,type_name: Optional[str] = None,name: Optional[str] = None
) -> List[ModelInstance]:
return await self.backend.get_all_instances(model_name, healthy_only)
return await self.backend.get_all_instances(model_name, healthy_only,type_name,name)
async def send_heartbeat(self, instance: ModelInstance) -> bool:
return await self.backend.send_heartbeat(instance)

@ -6,7 +6,7 @@ import time
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Optional
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.model.base import ModelInstance
@ -84,7 +84,8 @@ class ModelRegistry(BaseComponent, ABC):
@abstractmethod
async def get_all_model_instances(
self, healthy_only: bool = False
self, healthy_only: bool = False,
type_name: Optional[str] = None,name: Optional[str] = None
) -> List[ModelInstance]:
"""
Fetch all instances of all models, Optionally, fetch only the healthy instances.
@ -174,7 +175,6 @@ class EmbeddedModelRegistry(ModelRegistry):
model_name = instance.model_name.strip()
host = instance.host.strip()
port = instance.port
instances, exist_ins = self._get_instances(
model_name, host, port, healthy_only=False
)
@ -218,12 +218,24 @@ class EmbeddedModelRegistry(ModelRegistry):
return instances
async def get_all_model_instances(
self, healthy_only: bool = False
self, healthy_only: bool = False,
type_name: Optional[str] = None,name: Optional[str] = None
) -> List[ModelInstance]:
logger.debug("Current registry metadata:\n{self.registry}")
instances = list(itertools.chain(*self.registry.values()))
if healthy_only:
instances = [ins for ins in instances if ins.healthy is True]
if type_name:
instances = [
ins for ins in instances
if hasattr(ins, 'type_name') and ins.type_name == type_name
]
# 名称模糊匹配
if name:
instances = [
ins for ins in instances
if name.lower() in ins.model_name.lower()
]
return instances
async def send_heartbeat(self, instance: ModelInstance) -> bool:

@ -64,6 +64,7 @@ class ModelInstanceEntity(Model):
sys_code = Column(String(128), nullable=True, comment="System code")
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
type_name = Column(String(255), nullable=True, comment="模型分类名称")
class ModelInstanceItemAdapter(
@ -80,6 +81,7 @@ class ModelInstanceItemAdapter(
enabled=item.enabled,
prompt_template=item.prompt_template,
last_heartbeat=item.last_heartbeat,
type_name=item.type_name,
# user_name=item.user_name,
# sys_code=item.sys_code,
)

@ -78,6 +78,7 @@ class ModelInstanceStorageItem(StorageItem):
model_name: str
host: str
port: int
type_name: Optional[str] = None
weight: Optional[float] = 1.0
check_healthy: Optional[bool] = True
healthy: Optional[bool] = False
@ -119,6 +120,7 @@ class ModelInstanceStorageItem(StorageItem):
"enabled": self.enabled,
"prompt_template": self.prompt_template,
"last_heartbeat": last_heartbeat,
"type_name": self.type_name,
}
def from_object(self, item: "ModelInstanceStorageItem") -> None:
@ -132,6 +134,7 @@ class ModelInstanceStorageItem(StorageItem):
self.enabled = item.enabled
self.prompt_template = item.prompt_template
self.last_heartbeat = item.last_heartbeat
self.type_name = item.type_name
@classmethod
def from_model_instance(cls, instance: ModelInstance) -> "ModelInstanceStorageItem":
@ -145,6 +148,7 @@ class ModelInstanceStorageItem(StorageItem):
enabled=instance.enabled,
prompt_template=instance.prompt_template,
last_heartbeat=instance.last_heartbeat,
type_name=instance.type_name,
)
@classmethod
@ -159,6 +163,7 @@ class ModelInstanceStorageItem(StorageItem):
enabled=item.enabled,
prompt_template=item.prompt_template,
last_heartbeat=item.last_heartbeat,
type_name=item.type_name
)
@ -333,11 +338,12 @@ class StorageModelRegistry(ModelRegistry):
return [ModelInstanceStorageItem.to_model_instance(ins) for ins in instances]
async def get_all_model_instances(
self, healthy_only: bool = False
self, healthy_only: bool = False,type_name: Optional[str] = None,name: Optional[str] = None
) -> List[ModelInstance]:
"""Get all model instances.
Args:
type_name:
healthy_only (bool): Whether only get healthy instances. Defaults to False.
Returns:

@ -75,6 +75,9 @@ class ModelStorageItem(StorageItem):
port: int = field(metadata={"help": "The port of the worker"})
model: str = field(metadata={"help": "The model name"})
provider: str = field(metadata={"help": "The provider of the model"})
type_name: str = field(
metadata={"help": "The type name of the model, e.g. openai, llama, etc."}
)
worker_type: str = field(
metadata={"help": "The worker type of the model, e.g. llm, tex2vec, reranker"}
)
@ -159,6 +162,7 @@ class ModelStorageItem(StorageItem):
host=request.host,
port=request.port,
model=request.model,
type_name=request.type_name,
provider=request.params.get("provider"),
worker_type=request.worker_type.value,
enabled=True,
@ -242,6 +246,7 @@ class ModelStorage:
"""
model = ModelStorageItem.from_startup_req(request)
model.enabled = enabled
model.type_name = request.type_name or request.params['type_name']
self._storage.save_or_update(model)
def delete(self, identifier: ModelStorageIdentifier) -> None:

@ -284,6 +284,9 @@ class LocalWorkerManager(WorkerManager):
worker_type = startup_req.worker_type
params = startup_req.params
if "type_name" not in params and hasattr(startup_req, "type_name"):
params["type_name"] = startup_req.type_name
cfg = ConfigurationManager(params)
if worker_type == WorkerType.TEXT2VEC:
deploy_params = cfg.parse_config(EmbeddingDeployModelParameters)
@ -299,6 +302,7 @@ class LocalWorkerManager(WorkerManager):
worker_params: ModelWorkerParameters = ModelWorkerParameters.from_dict(
{
"worker_type": worker_type.value,
"type_name": params.get("type_name"),
},
ignore_extra_fields=True,
)

@ -701,7 +701,7 @@ class GptsAppDao(BaseDao):
# 这里假设你有一个字典表DAO类可以查询
dict_dao = TypeDao() # 需要实现这个类
dict_item = dict_dao.select_type_by_value(app_info.app_type)
type_label = dict_item.label if dict_item else None
type_label = dict_item.type_label if dict_item else None
return {
"app_code": app_info.app_code,

@ -140,14 +140,14 @@ async def model_params(worker_manager: WorkerManager = Depends(get_worker_manage
@router.get("/models")
async def model_list(controller: BaseModelController = Depends(get_model_controller)):
async def model_list(controller: BaseModelController = Depends(get_model_controller),type_name: Optional[str] = None,name: Optional[str] = None):
try:
responses = []
managers = await controller.get_all_instances(
model_name="WorkerManager@service", healthy_only=True
)
manager_map = dict(map(lambda manager: (manager.host, manager), managers))
models = await controller.get_all_instances()
models = await controller.get_all_instances(type_name=type_name,name=name)
for model in models:
worker_name, worker_type = model.model_name.split("@")
if worker_type in WorkerType.values():
@ -166,6 +166,7 @@ async def model_list(controller: BaseModelController = Depends(get_model_control
check_healthy=model.check_healthy,
last_heartbeat=model.str_last_heartbeat,
prompt_template=model.prompt_template,
type_name=model.type_name if model.type_name else "未知",
)
responses.append(response)
return Result.succ(responses)

@ -53,3 +53,4 @@ class ModelResponse(BaseModel):
check_healthy: bool = Field(True, description="Check model health status")
prompt_template: Optional[str] = Field(None, description="Model prompt template")
last_heartbeat: Optional[str] = Field(None, description="Model last heartbeat")
type_name: str = Field(description="模型分类名称")

@ -23,6 +23,7 @@ class ModelStorageAdapter(StorageItemAdapter[ModelStorageItem, ServeEntity]):
host=item.host,
port=item.port,
model=item.model,
type_name=item.type_name,
provider=item.provider,
worker_type=item.worker_type,
enabled=enabled,
@ -41,6 +42,7 @@ class ModelStorageAdapter(StorageItemAdapter[ModelStorageItem, ServeEntity]):
host=model.host,
port=model.port,
model=model.model,
type_name=model.type_name,
provider=model.provider,
worker_type=model.worker_type,
enabled=enabled,

@ -27,6 +27,9 @@ class ServeEntity(Model):
model = Column(String(255), nullable=False, comment="The model name")
provider = Column(String(255), nullable=False, comment="The model provider")
worker_type = Column(String(255), nullable=False, comment="The worker type")
type_name = Column(
String(255), nullable=False, comment="模型分类名称"
)
params = Column(Text, nullable=False, comment="The model parameters, JSON format")
enabled = Column(
Integer,

Loading…
Cancel
Save