mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-10 03:03:57 +00:00
186 lines
6.2 KiB
Python
186 lines
6.2 KiB
Python
from abc import ABC, abstractmethod
|
|
|
|
import logging
|
|
from typing import List
|
|
|
|
from fastapi import APIRouter, FastAPI
|
|
from pilot.componet import BaseComponet, ComponetType, SystemApp
|
|
from pilot.model.base import ModelInstance
|
|
from pilot.model.parameter import ModelControllerParameters
|
|
from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
|
|
from pilot.utils.parameter_utils import EnvArgumentParser
|
|
from pilot.utils.api_utils import (
|
|
_api_remote as api_remote,
|
|
_sync_api_remote as sync_api_remote,
|
|
)
|
|
|
|
|
|
class BaseModelController(BaseComponet, ABC):
|
|
name = ComponetType.MODEL_CONTROLLER
|
|
|
|
def init_app(self, system_app: SystemApp):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def register_instance(self, instance: ModelInstance) -> bool:
|
|
"""Register a given model instance"""
|
|
|
|
@abstractmethod
|
|
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
|
"""Deregister a given model instance."""
|
|
|
|
@abstractmethod
|
|
async def get_all_instances(
|
|
self, model_name: str = None, healthy_only: bool = False
|
|
) -> List[ModelInstance]:
|
|
"""Fetch all instances of a given model. Optionally, fetch only the healthy instances."""
|
|
|
|
@abstractmethod
|
|
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
|
"""Send a heartbeat for a given model instance. This can be used to verify if the instance is still alive and functioning."""
|
|
|
|
async def model_apply(self) -> bool:
|
|
raise NotImplementedError
|
|
|
|
|
|
class LocalModelController(BaseModelController):
|
|
def __init__(self, registry: ModelRegistry = None) -> None:
|
|
if not registry:
|
|
registry = EmbeddedModelRegistry()
|
|
self.registry = registry
|
|
self.deployment = None
|
|
|
|
async def register_instance(self, instance: ModelInstance) -> bool:
|
|
return await self.registry.register_instance(instance)
|
|
|
|
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
|
return await self.registry.deregister_instance(instance)
|
|
|
|
async def get_all_instances(
|
|
self, model_name: str = None, healthy_only: bool = False
|
|
) -> List[ModelInstance]:
|
|
logging.info(
|
|
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
|
|
)
|
|
if not model_name:
|
|
return await self.registry.get_all_model_instances()
|
|
else:
|
|
return await self.registry.get_all_instances(model_name, healthy_only)
|
|
|
|
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
|
return await self.registry.send_heartbeat(instance)
|
|
|
|
|
|
class _RemoteModelController(BaseModelController):
|
|
def __init__(self, base_url: str) -> None:
|
|
self.base_url = base_url
|
|
|
|
@api_remote(path="/api/controller/models", method="POST")
|
|
async def register_instance(self, instance: ModelInstance) -> bool:
|
|
pass
|
|
|
|
@api_remote(path="/api/controller/models", method="DELETE")
|
|
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
|
pass
|
|
|
|
@api_remote(path="/api/controller/models")
|
|
async def get_all_instances(
|
|
self, model_name: str = None, healthy_only: bool = False
|
|
) -> List[ModelInstance]:
|
|
pass
|
|
|
|
@api_remote(path="/api/controller/heartbeat", method="POST")
|
|
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
|
pass
|
|
|
|
|
|
class ModelRegistryClient(_RemoteModelController, ModelRegistry):
|
|
async def get_all_model_instances(self) -> List[ModelInstance]:
|
|
return await self.get_all_instances()
|
|
|
|
@sync_api_remote(path="/api/controller/models")
|
|
def sync_get_all_instances(
|
|
self, model_name: str = None, healthy_only: bool = False
|
|
) -> List[ModelInstance]:
|
|
pass
|
|
|
|
|
|
class ModelControllerAdapter(BaseModelController):
|
|
def __init__(self, backend: BaseModelController = None) -> None:
|
|
self.backend = backend
|
|
|
|
async def register_instance(self, instance: ModelInstance) -> bool:
|
|
return await self.backend.register_instance(instance)
|
|
|
|
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
|
return await self.backend.deregister_instance(instance)
|
|
|
|
async def get_all_instances(
|
|
self, model_name: str = None, healthy_only: bool = False
|
|
) -> List[ModelInstance]:
|
|
return await self.backend.get_all_instances(model_name, healthy_only)
|
|
|
|
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
|
return await self.backend.send_heartbeat(instance)
|
|
|
|
async def model_apply(self) -> bool:
|
|
return await self.backend.model_apply()
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
controller = ModelControllerAdapter()
|
|
|
|
|
|
def initialize_controller(
|
|
app=None, remote_controller_addr: str = None, host: str = None, port: int = None
|
|
):
|
|
global controller
|
|
if remote_controller_addr:
|
|
controller.backend = _RemoteModelController(remote_controller_addr)
|
|
else:
|
|
controller.backend = LocalModelController()
|
|
|
|
if app:
|
|
app.include_router(router, prefix="/api")
|
|
else:
|
|
import uvicorn
|
|
|
|
app = FastAPI()
|
|
app.include_router(router, prefix="/api")
|
|
uvicorn.run(app, host=host, port=port, log_level="info")
|
|
|
|
|
|
@router.post("/controller/models")
|
|
async def api_register_instance(request: ModelInstance):
|
|
return await controller.register_instance(request)
|
|
|
|
|
|
@router.delete("/controller/models")
|
|
async def api_deregister_instance(model_name: str, host: str, port: int):
|
|
instance = ModelInstance(model_name=model_name, host=host, port=port)
|
|
return await controller.deregister_instance(instance)
|
|
|
|
|
|
@router.get("/controller/models")
|
|
async def api_get_all_instances(model_name: str = None, healthy_only: bool = False):
|
|
return await controller.get_all_instances(model_name, healthy_only=healthy_only)
|
|
|
|
|
|
@router.post("/controller/heartbeat")
|
|
async def api_model_heartbeat(request: ModelInstance):
|
|
return await controller.send_heartbeat(request)
|
|
|
|
|
|
def run_model_controller():
|
|
parser = EnvArgumentParser()
|
|
env_prefix = "controller_"
|
|
controller_params: ModelControllerParameters = parser.parse_args_into_dataclass(
|
|
ModelControllerParameters, env_prefix=env_prefix
|
|
)
|
|
initialize_controller(host=controller_params.host, port=controller_params.port)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_model_controller()
|