mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 05:01:25 +00:00
refactor: The first refactored version for sdk release (#907)
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
0
dbgpt/model/cluster/controller/__init__.py
Normal file
0
dbgpt/model/cluster/controller/__init__.py
Normal file
201
dbgpt/model/cluster/controller/controller.py
Normal file
201
dbgpt/model/cluster/controller/controller.py
Normal file
@@ -0,0 +1,201 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt.model.base import ModelInstance
|
||||
from dbgpt.model.parameter import ModelControllerParameters
|
||||
from dbgpt.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
|
||||
from dbgpt.util.parameter_utils import EnvArgumentParser
|
||||
from dbgpt.util.api_utils import (
|
||||
_api_remote as api_remote,
|
||||
_sync_api_remote as sync_api_remote,
|
||||
)
|
||||
from dbgpt.util.utils import setup_logging, setup_http_service_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseModelController(BaseComponent, ABC):
|
||||
name = ComponentType.MODEL_CONTROLLER
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def register_instance(self, instance: ModelInstance) -> bool:
|
||||
"""Register a given model instance"""
|
||||
|
||||
@abstractmethod
|
||||
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
||||
"""Deregister a given model instance."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_all_instances(
|
||||
self, model_name: str = None, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
"""Fetch all instances of a given model. Optionally, fetch only the healthy instances."""
|
||||
|
||||
@abstractmethod
|
||||
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
||||
"""Send a heartbeat for a given model instance. This can be used to verify if the instance is still alive and functioning."""
|
||||
|
||||
async def model_apply(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LocalModelController(BaseModelController):
|
||||
def __init__(self, registry: ModelRegistry = None) -> None:
|
||||
if not registry:
|
||||
registry = EmbeddedModelRegistry()
|
||||
self.registry = registry
|
||||
self.deployment = None
|
||||
|
||||
async def register_instance(self, instance: ModelInstance) -> bool:
|
||||
return await self.registry.register_instance(instance)
|
||||
|
||||
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
||||
return await self.registry.deregister_instance(instance)
|
||||
|
||||
async def get_all_instances(
|
||||
self, model_name: str = None, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
logger.info(
|
||||
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
|
||||
)
|
||||
if not model_name:
|
||||
return await self.registry.get_all_model_instances(
|
||||
healthy_only=healthy_only
|
||||
)
|
||||
else:
|
||||
return await self.registry.get_all_instances(model_name, healthy_only)
|
||||
|
||||
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
||||
return await self.registry.send_heartbeat(instance)
|
||||
|
||||
|
||||
class _RemoteModelController(BaseModelController):
|
||||
def __init__(self, base_url: str) -> None:
|
||||
self.base_url = base_url
|
||||
|
||||
@api_remote(path="/api/controller/models", method="POST")
|
||||
async def register_instance(self, instance: ModelInstance) -> bool:
|
||||
pass
|
||||
|
||||
@api_remote(path="/api/controller/models", method="DELETE")
|
||||
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
||||
pass
|
||||
|
||||
@api_remote(path="/api/controller/models")
|
||||
async def get_all_instances(
|
||||
self, model_name: str = None, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
pass
|
||||
|
||||
@api_remote(path="/api/controller/heartbeat", method="POST")
|
||||
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class ModelRegistryClient(_RemoteModelController, ModelRegistry):
|
||||
async def get_all_model_instances(
|
||||
self, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
return await self.get_all_instances(healthy_only=healthy_only)
|
||||
|
||||
@sync_api_remote(path="/api/controller/models")
|
||||
def sync_get_all_instances(
|
||||
self, model_name: str = None, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
pass
|
||||
|
||||
|
||||
class ModelControllerAdapter(BaseModelController):
|
||||
def __init__(self, backend: BaseModelController = None) -> None:
|
||||
self.backend = backend
|
||||
|
||||
async def register_instance(self, instance: ModelInstance) -> bool:
|
||||
return await self.backend.register_instance(instance)
|
||||
|
||||
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
||||
return await self.backend.deregister_instance(instance)
|
||||
|
||||
async def get_all_instances(
|
||||
self, model_name: str = None, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
return await self.backend.get_all_instances(model_name, healthy_only)
|
||||
|
||||
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
||||
return await self.backend.send_heartbeat(instance)
|
||||
|
||||
async def model_apply(self) -> bool:
|
||||
return await self.backend.model_apply()
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
controller = ModelControllerAdapter()
|
||||
|
||||
|
||||
def initialize_controller(
|
||||
app=None, remote_controller_addr: str = None, host: str = None, port: int = None
|
||||
):
|
||||
global controller
|
||||
if remote_controller_addr:
|
||||
controller.backend = _RemoteModelController(remote_controller_addr)
|
||||
else:
|
||||
controller.backend = LocalModelController()
|
||||
|
||||
if app:
|
||||
app.include_router(router, prefix="/api", tags=["Model"])
|
||||
else:
|
||||
import uvicorn
|
||||
|
||||
setup_http_service_logging()
|
||||
app = FastAPI()
|
||||
app.include_router(router, prefix="/api", tags=["Model"])
|
||||
uvicorn.run(app, host=host, port=port, log_level="info")
|
||||
|
||||
|
||||
@router.post("/controller/models")
|
||||
async def api_register_instance(request: ModelInstance):
|
||||
return await controller.register_instance(request)
|
||||
|
||||
|
||||
@router.delete("/controller/models")
|
||||
async def api_deregister_instance(model_name: str, host: str, port: int):
|
||||
instance = ModelInstance(model_name=model_name, host=host, port=port)
|
||||
return await controller.deregister_instance(instance)
|
||||
|
||||
|
||||
@router.get("/controller/models")
|
||||
async def api_get_all_instances(model_name: str = None, healthy_only: bool = False):
|
||||
return await controller.get_all_instances(model_name, healthy_only=healthy_only)
|
||||
|
||||
|
||||
@router.post("/controller/heartbeat")
|
||||
async def api_model_heartbeat(request: ModelInstance):
|
||||
return await controller.send_heartbeat(request)
|
||||
|
||||
|
||||
def run_model_controller():
|
||||
parser = EnvArgumentParser()
|
||||
env_prefix = "controller_"
|
||||
controller_params: ModelControllerParameters = parser.parse_args_into_dataclass(
|
||||
ModelControllerParameters,
|
||||
env_prefixes=[env_prefix],
|
||||
)
|
||||
|
||||
setup_logging(
|
||||
"dbgpt",
|
||||
logging_level=controller_params.log_level,
|
||||
logger_filename=controller_params.log_file,
|
||||
)
|
||||
|
||||
initialize_controller(host=controller_params.host, port=controller_params.port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_model_controller()
|
0
dbgpt/model/cluster/controller/ray_controller.py
Normal file
0
dbgpt/model/cluster/controller/ray_controller.py
Normal file
0
dbgpt/model/cluster/controller/tests/__init__.py
Normal file
0
dbgpt/model/cluster/controller/tests/__init__.py
Normal file
142
dbgpt/model/cluster/controller/tests/test_registry.py
Normal file
142
dbgpt/model/cluster/controller/tests/test_registry.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import asyncio
|
||||
from dbgpt.model.base import ModelInstance
|
||||
from dbgpt.model.cluster.registry import EmbeddedModelRegistry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_registry():
|
||||
return EmbeddedModelRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_instance():
|
||||
return ModelInstance(
|
||||
model_name="test_model",
|
||||
host="192.168.1.1",
|
||||
port=5000,
|
||||
)
|
||||
|
||||
|
||||
# Async function to test the registry
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_instance(model_registry, model_instance):
|
||||
"""
|
||||
Test if an instance can be registered correctly
|
||||
"""
|
||||
assert await model_registry.register_instance(model_instance) == True
|
||||
assert len(model_registry.registry[model_instance.model_name]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deregister_instance(model_registry, model_instance):
|
||||
"""
|
||||
Test if an instance can be deregistered correctly
|
||||
"""
|
||||
await model_registry.register_instance(model_instance)
|
||||
assert await model_registry.deregister_instance(model_instance) == True
|
||||
assert not model_registry.registry[model_instance.model_name][0].healthy
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_instances(model_registry, model_instance):
|
||||
"""
|
||||
Test if all instances can be retrieved, with and without the healthy_only filter
|
||||
"""
|
||||
await model_registry.register_instance(model_instance)
|
||||
assert len(await model_registry.get_all_instances(model_instance.model_name)) == 1
|
||||
assert (
|
||||
len(
|
||||
await model_registry.get_all_instances(
|
||||
model_instance.model_name, healthy_only=True
|
||||
)
|
||||
)
|
||||
== 1
|
||||
)
|
||||
model_instance.healthy = False
|
||||
assert (
|
||||
len(
|
||||
await model_registry.get_all_instances(
|
||||
model_instance.model_name, healthy_only=True
|
||||
)
|
||||
)
|
||||
== 0
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_select_one_health_instance(model_registry, model_instance):
|
||||
"""
|
||||
Test if a single healthy instance can be selected
|
||||
"""
|
||||
await model_registry.register_instance(model_instance)
|
||||
selected_instance = await model_registry.select_one_health_instance(
|
||||
model_instance.model_name
|
||||
)
|
||||
assert selected_instance is not None
|
||||
assert selected_instance.healthy
|
||||
assert selected_instance.enabled
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_heartbeat(model_registry, model_instance):
|
||||
"""
|
||||
Test if a heartbeat can be sent and that it correctly updates the last_heartbeat timestamp
|
||||
"""
|
||||
await model_registry.register_instance(model_instance)
|
||||
last_heartbeat = datetime.now() - timedelta(seconds=10)
|
||||
model_instance.last_heartbeat = last_heartbeat
|
||||
assert await model_registry.send_heartbeat(model_instance) == True
|
||||
assert (
|
||||
model_registry.registry[model_instance.model_name][0].last_heartbeat
|
||||
> last_heartbeat
|
||||
)
|
||||
assert model_registry.registry[model_instance.model_name][0].healthy == True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heartbeat_timeout(model_registry, model_instance):
|
||||
"""
|
||||
Test if an instance is marked as unhealthy when the heartbeat is not sent within the timeout
|
||||
"""
|
||||
model_registry = EmbeddedModelRegistry(1, 1)
|
||||
await model_registry.register_instance(model_instance)
|
||||
model_registry.registry[model_instance.model_name][
|
||||
0
|
||||
].last_heartbeat = datetime.now() - timedelta(
|
||||
seconds=model_registry.heartbeat_timeout_secs + 1
|
||||
)
|
||||
await asyncio.sleep(model_registry.heartbeat_interval_secs + 1)
|
||||
assert not model_registry.registry[model_instance.model_name][0].healthy
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_instances(model_registry, model_instance):
|
||||
"""
|
||||
Test if multiple instances of the same model are handled correctly
|
||||
"""
|
||||
model_instance2 = ModelInstance(
|
||||
model_name="test_model",
|
||||
host="192.168.1.2",
|
||||
port=5000,
|
||||
)
|
||||
await model_registry.register_instance(model_instance)
|
||||
await model_registry.register_instance(model_instance2)
|
||||
assert len(await model_registry.get_all_instances(model_instance.model_name)) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_model_name_different_ip_port(model_registry):
|
||||
"""
|
||||
Test if instances with the same model name but different IP and port are handled correctly
|
||||
"""
|
||||
instance1 = ModelInstance(model_name="test_model", host="192.168.1.1", port=5000)
|
||||
instance2 = ModelInstance(model_name="test_model", host="192.168.1.2", port=6000)
|
||||
await model_registry.register_instance(instance1)
|
||||
await model_registry.register_instance(instance2)
|
||||
instances = await model_registry.get_all_instances("test_model")
|
||||
assert len(instances) == 2
|
||||
assert instances[0].host != instances[1].host
|
||||
assert instances[0].port != instances[1].port
|
Reference in New Issue
Block a user