mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-11 03:33:28 +00:00
fix(model): Fix the bug that the webserver cannot return model instances
This commit is contained in:
@@ -4,6 +4,7 @@ import logging
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from pilot.componet import BaseComponet, ComponetType, SystemApp
|
||||
from pilot.model.base import ModelInstance
|
||||
from pilot.model.parameter import ModelControllerParameters
|
||||
from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
|
||||
@@ -14,7 +15,12 @@ from pilot.utils.api_utils import (
|
||||
)
|
||||
|
||||
|
||||
class BaseModelController(ABC):
|
||||
class BaseModelController(BaseComponet, ABC):
|
||||
name = ComponetType.MODEL_CONTROLLER
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def register_instance(self, instance: ModelInstance) -> bool:
|
||||
"""Register a given model instance"""
|
||||
@@ -25,7 +31,7 @@ class BaseModelController(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def get_all_instances(
|
||||
self, model_name: str, healthy_only: bool = False
|
||||
self, model_name: str = None, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
"""Fetch all instances of a given model. Optionally, fetch only the healthy instances."""
|
||||
|
||||
@@ -51,7 +57,7 @@ class LocalModelController(BaseModelController):
|
||||
return await self.registry.deregister_instance(instance)
|
||||
|
||||
async def get_all_instances(
|
||||
self, model_name: str, healthy_only: bool = False
|
||||
self, model_name: str = None, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
logging.info(
|
||||
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
|
||||
@@ -94,7 +100,7 @@ class ModelRegistryClient(_RemoteModelController, ModelRegistry):
|
||||
|
||||
@sync_api_remote(path="/api/controller/models")
|
||||
def sync_get_all_instances(
|
||||
self, model_name: str, healthy_only: bool = False
|
||||
self, model_name: str = None, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
pass
|
||||
|
||||
@@ -110,7 +116,7 @@ class ModelControllerAdapter(BaseModelController):
|
||||
return await self.backend.deregister_instance(instance)
|
||||
|
||||
async def get_all_instances(
|
||||
self, model_name: str, healthy_only: bool = False
|
||||
self, model_name: str = None, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
return await self.backend.get_all_instances(model_name, healthy_only)
|
||||
|
||||
|
Reference in New Issue
Block a user