fix(model): Fix the bug that the webserver cannot return model instances

This commit is contained in:
FangYin Cheng
2023-09-14 12:36:18 +08:00
parent 7b64c03d58
commit f304f9709d
5 changed files with 29 additions and 18 deletions

View File

@@ -4,6 +4,7 @@ import logging
from typing import List
from fastapi import APIRouter, FastAPI
from pilot.componet import BaseComponet, ComponetType, SystemApp
from pilot.model.base import ModelInstance
from pilot.model.parameter import ModelControllerParameters
from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
@@ -14,7 +15,12 @@ from pilot.utils.api_utils import (
)
class BaseModelController(ABC):
class BaseModelController(BaseComponet, ABC):
name = ComponetType.MODEL_CONTROLLER
def init_app(self, system_app: SystemApp):
pass
@abstractmethod
async def register_instance(self, instance: ModelInstance) -> bool:
"""Register a given model instance"""
@@ -25,7 +31,7 @@ class BaseModelController(ABC):
@abstractmethod
async def get_all_instances(
self, model_name: str, healthy_only: bool = False
self, model_name: str = None, healthy_only: bool = False
) -> List[ModelInstance]:
"""Fetch all instances of a given model. Optionally, fetch only the healthy instances."""
@@ -51,7 +57,7 @@ class LocalModelController(BaseModelController):
return await self.registry.deregister_instance(instance)
async def get_all_instances(
self, model_name: str, healthy_only: bool = False
self, model_name: str = None, healthy_only: bool = False
) -> List[ModelInstance]:
logging.info(
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
@@ -94,7 +100,7 @@ class ModelRegistryClient(_RemoteModelController, ModelRegistry):
@sync_api_remote(path="/api/controller/models")
def sync_get_all_instances(
self, model_name: str, healthy_only: bool = False
self, model_name: str = None, healthy_only: bool = False
) -> List[ModelInstance]:
pass
@@ -110,7 +116,7 @@ class ModelControllerAdapter(BaseModelController):
return await self.backend.deregister_instance(instance)
async def get_all_instances(
self, model_name: str, healthy_only: bool = False
self, model_name: str = None, healthy_only: bool = False
) -> List[ModelInstance]:
return await self.backend.get_all_instances(model_name, healthy_only)