增加模型的名称查询

master
zc 4 months ago
parent f37441704e
commit b2a250fa80
  1. 15
      packages/dbgpt-core/src/dbgpt/model/cluster/controller/controller.py
  2. 10
      packages/dbgpt-core/src/dbgpt/model/cluster/registry.py
  3. 2
      packages/dbgpt-core/src/dbgpt/model/cluster/registry_impl/storage.py
  4. 2
      packages/dbgpt-serve/src/dbgpt_serve/agent/db/gpts_app.py
  5. 6
      packages/dbgpt-serve/src/dbgpt_serve/model/api/endpoints.py

@ -44,7 +44,7 @@ class BaseModelController(BaseComponent, ABC):
@abstractmethod
async def get_all_instances(
self, model_name: str = None, healthy_only: bool = False, type_name: Optional[str] = None
self, model_name: str = None, healthy_only: bool = False, type_name: Optional[str] = None,name: Optional[str] = None
) -> List[ModelInstance]:
"""Fetch all instances of a given model.
@ -73,7 +73,7 @@ class LocalModelController(BaseModelController):
async def get_all_instances(
self, model_name: str = None, healthy_only: bool = False,
type_name: Optional[str] = None
type_name: Optional[str] = None,name: Optional[str] = None
) -> List[ModelInstance]:
logger.info(
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
@ -81,7 +81,8 @@ class LocalModelController(BaseModelController):
if not model_name:
return await self.registry.get_all_model_instances(
healthy_only=healthy_only,
type_name=type_name
type_name=type_name,
name=name
)
else:
return await self.registry.get_all_instances(model_name, healthy_only)
@ -120,7 +121,7 @@ class _RemoteModelController(APIMixin, BaseModelController):
@api_remote(path="/api/controller/models")
async def get_all_instances(
self, model_name: str = None, healthy_only: bool = False,type_name: Optional[str] = None
self, model_name: str = None, healthy_only: bool = False,type_name: Optional[str] = None,name: Optional[str] = None
) -> List[ModelInstance]:
pass
@ -132,7 +133,7 @@ class _RemoteModelController(APIMixin, BaseModelController):
class ModelRegistryClient(_RemoteModelController, ModelRegistry):
async def get_all_model_instances(
self, healthy_only: bool = False,
type_name: Optional[str] = None
type_name: Optional[str] = None,name: Optional[str] = None
) -> List[ModelInstance]:
return await self.get_all_instances(healthy_only=healthy_only,type_name=type_name)
@ -154,9 +155,9 @@ class ModelControllerAdapter(BaseModelController):
return await self.backend.deregister_instance(instance)
async def get_all_instances(
self, model_name: str = None, healthy_only: bool = False,type_name: Optional[str] = None
self, model_name: str = None, healthy_only: bool = False,type_name: Optional[str] = None,name: Optional[str] = None
) -> List[ModelInstance]:
return await self.backend.get_all_instances(model_name, healthy_only,type_name)
return await self.backend.get_all_instances(model_name, healthy_only,type_name,name)
async def send_heartbeat(self, instance: ModelInstance) -> bool:
return await self.backend.send_heartbeat(instance)

@ -85,7 +85,7 @@ class ModelRegistry(BaseComponent, ABC):
@abstractmethod
async def get_all_model_instances(
self, healthy_only: bool = False,
type_name: Optional[str] = None
type_name: Optional[str] = None,name: Optional[str] = None
) -> List[ModelInstance]:
"""
Fetch all instances of all models, Optionally, fetch only the healthy instances.
@ -219,7 +219,7 @@ class EmbeddedModelRegistry(ModelRegistry):
async def get_all_model_instances(
self, healthy_only: bool = False,
type_name: Optional[str] = None
type_name: Optional[str] = None,name: Optional[str] = None
) -> List[ModelInstance]:
logger.debug("Current registry metadata:\n{self.registry}")
instances = list(itertools.chain(*self.registry.values()))
@ -230,6 +230,12 @@ class EmbeddedModelRegistry(ModelRegistry):
ins for ins in instances
if hasattr(ins, 'type_name') and ins.type_name == type_name
]
# 名称模糊匹配
if name:
instances = [
ins for ins in instances
if name.lower() in ins.model_name.lower()
]
return instances
async def send_heartbeat(self, instance: ModelInstance) -> bool:

@ -338,7 +338,7 @@ class StorageModelRegistry(ModelRegistry):
return [ModelInstanceStorageItem.to_model_instance(ins) for ins in instances]
async def get_all_model_instances(
self, healthy_only: bool = False,type_name: Optional[str] = None
self, healthy_only: bool = False,type_name: Optional[str] = None,name: Optional[str] = None
) -> List[ModelInstance]:
"""Get all model instances.

@ -701,7 +701,7 @@ class GptsAppDao(BaseDao):
# 这里假设你有一个字典表DAO类可以查询
dict_dao = TypeDao() # 需要实现这个类
dict_item = dict_dao.select_type_by_value(app_info.app_type)
type_label = dict_item.label if dict_item else None
type_label = dict_item.type_label if dict_item else None
return {
"app_code": app_info.app_code,

@ -140,14 +140,14 @@ async def model_params(worker_manager: WorkerManager = Depends(get_worker_manage
@router.get("/models")
async def model_list(controller: BaseModelController = Depends(get_model_controller),type_name: Optional[str] = None):
async def model_list(controller: BaseModelController = Depends(get_model_controller),type_name: Optional[str] = None,name: Optional[str] = None):
try:
responses = []
managers = await controller.get_all_instances(
model_name="WorkerManager@service", healthy_only=True, type_name=type_name
model_name="WorkerManager@service", healthy_only=True
)
manager_map = dict(map(lambda manager: (manager.host, manager), managers))
models = await controller.get_all_instances(type_name=type_name)
models = await controller.get_all_instances(type_name=type_name,name=name)
for model in models:
worker_name, worker_type = model.model_name.split("@")
if worker_type in WorkerType.values():

Loading…
Cancel
Save