mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-01 15:20:03 +00:00
66 lines
2.1 KiB
Python
66 lines
2.1 KiB
Python
import logging
|
|
from typing import List
|
|
|
|
from fastapi import APIRouter
|
|
from pilot.model.base import ModelInstance
|
|
from pilot.model.controller.registry import EmbeddedModelRegistry, ModelRegistry
|
|
|
|
|
|
class ModelController:
|
|
def __init__(self, registry: ModelRegistry = None) -> None:
|
|
if not registry:
|
|
registry = EmbeddedModelRegistry()
|
|
self.registry = registry
|
|
self.deployment = None
|
|
|
|
async def register_instance(self, instance: ModelInstance) -> bool:
|
|
return await self.registry.register_instance(instance)
|
|
|
|
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
|
return await self.registry.deregister_instance(instance)
|
|
|
|
async def get_all_instances(
|
|
self, model_name: str, healthy_only: bool = False
|
|
) -> List[ModelInstance]:
|
|
logging.info(
|
|
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
|
|
)
|
|
return await self.registry.get_all_instances(model_name, healthy_only)
|
|
|
|
async def get_all_model_instances(self) -> List[ModelInstance]:
|
|
return await self.registry.get_all_model_instances()
|
|
|
|
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
|
return await self.registry.send_heartbeat(instance)
|
|
|
|
async def model_apply(self) -> bool:
|
|
# TODO
|
|
raise NotImplementedError
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
controller = ModelController()
|
|
|
|
|
|
@router.post("/controller/models")
|
|
async def api_register_instance(request: ModelInstance):
|
|
return await controller.register_instance(request)
|
|
|
|
|
|
@router.delete("/controller/models")
|
|
async def api_deregister_instance(request: ModelInstance):
|
|
return await controller.deregister_instance(request)
|
|
|
|
|
|
@router.get("/controller/models")
|
|
async def api_get_all_instances(model_name: str = None, healthy_only: bool = False):
|
|
if not model_name:
|
|
return await controller.get_all_model_instances()
|
|
return await controller.get_all_instances(model_name, healthy_only=healthy_only)
|
|
|
|
|
|
@router.post("/controller/heartbeat")
|
|
async def api_model_heartbeat(request: ModelInstance):
|
|
return await controller.send_heartbeat(request)
|