|
|
|
@ -44,7 +44,7 @@ class BaseModelController(BaseComponent, ABC): |
|
|
|
|
|
|
|
|
|
@abstractmethod |
|
|
|
|
async def get_all_instances( |
|
|
|
|
self, model_name: str = None, healthy_only: bool = False, type_name: Optional[str] = None |
|
|
|
|
self, model_name: str = None, healthy_only: bool = False, type_name: Optional[str] = None,name: Optional[str] = None |
|
|
|
|
) -> List[ModelInstance]: |
|
|
|
|
"""Fetch all instances of a given model. |
|
|
|
|
|
|
|
|
@ -73,7 +73,7 @@ class LocalModelController(BaseModelController): |
|
|
|
|
|
|
|
|
|
async def get_all_instances( |
|
|
|
|
self, model_name: str = None, healthy_only: bool = False, |
|
|
|
|
type_name: Optional[str] = None |
|
|
|
|
type_name: Optional[str] = None,name: Optional[str] = None |
|
|
|
|
) -> List[ModelInstance]: |
|
|
|
|
logger.info( |
|
|
|
|
f"Get all instances with {model_name}, healthy_only: {healthy_only}" |
|
|
|
@ -81,7 +81,8 @@ class LocalModelController(BaseModelController): |
|
|
|
|
if not model_name: |
|
|
|
|
return await self.registry.get_all_model_instances( |
|
|
|
|
healthy_only=healthy_only, |
|
|
|
|
type_name=type_name |
|
|
|
|
type_name=type_name, |
|
|
|
|
name=name |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
return await self.registry.get_all_instances(model_name, healthy_only) |
|
|
|
@ -120,7 +121,7 @@ class _RemoteModelController(APIMixin, BaseModelController): |
|
|
|
|
|
|
|
|
|
@api_remote(path="/api/controller/models") |
|
|
|
|
async def get_all_instances( |
|
|
|
|
self, model_name: str = None, healthy_only: bool = False,type_name: Optional[str] = None |
|
|
|
|
self, model_name: str = None, healthy_only: bool = False,type_name: Optional[str] = None,name: Optional[str] = None |
|
|
|
|
) -> List[ModelInstance]: |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
@ -132,7 +133,7 @@ class _RemoteModelController(APIMixin, BaseModelController): |
|
|
|
|
class ModelRegistryClient(_RemoteModelController, ModelRegistry): |
|
|
|
|
async def get_all_model_instances( |
|
|
|
|
self, healthy_only: bool = False, |
|
|
|
|
type_name: Optional[str] = None |
|
|
|
|
type_name: Optional[str] = None,name: Optional[str] = None |
|
|
|
|
) -> List[ModelInstance]: |
|
|
|
|
return await self.get_all_instances(healthy_only=healthy_only,type_name=type_name) |
|
|
|
|
|
|
|
|
@ -154,9 +155,9 @@ class ModelControllerAdapter(BaseModelController): |
|
|
|
|
return await self.backend.deregister_instance(instance) |
|
|
|
|
|
|
|
|
|
async def get_all_instances( |
|
|
|
|
self, model_name: str = None, healthy_only: bool = False,type_name: Optional[str] = None |
|
|
|
|
self, model_name: str = None, healthy_only: bool = False,type_name: Optional[str] = None,name: Optional[str] = None |
|
|
|
|
) -> List[ModelInstance]: |
|
|
|
|
return await self.backend.get_all_instances(model_name, healthy_only,type_name) |
|
|
|
|
return await self.backend.get_all_instances(model_name, healthy_only,type_name,name) |
|
|
|
|
|
|
|
|
|
async def send_heartbeat(self, instance: ModelInstance) -> bool: |
|
|
|
|
return await self.backend.send_heartbeat(instance) |
|
|
|
|