mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-27 22:07:48 +00:00
Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com> Co-authored-by: csunny <cfqsunny@163.com>
342 lines
12 KiB
Python
342 lines
12 KiB
Python
import logging
|
|
import os
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Literal, Optional
|
|
|
|
from fastapi import APIRouter
|
|
|
|
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
|
from dbgpt.model.base import ModelInstance
|
|
from dbgpt.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
|
|
from dbgpt.model.parameter import ModelControllerParameters
|
|
from dbgpt.util.api_utils import APIMixin
|
|
from dbgpt.util.api_utils import _api_remote as api_remote
|
|
from dbgpt.util.api_utils import _sync_api_remote as sync_api_remote
|
|
from dbgpt.util.fastapi import create_app
|
|
from dbgpt.util.parameter_utils import EnvArgumentParser
|
|
from dbgpt.util.tracer.tracer_impl import initialize_tracer, root_tracer
|
|
from dbgpt.util.utils import setup_http_service_logging, setup_logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BaseModelController(BaseComponent, ABC):
|
|
name = ComponentType.MODEL_CONTROLLER
|
|
|
|
def init_app(self, system_app: SystemApp):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def register_instance(self, instance: ModelInstance) -> bool:
|
|
"""Register a given model instance"""
|
|
|
|
@abstractmethod
|
|
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
|
"""Deregister a given model instance."""
|
|
|
|
@abstractmethod
|
|
async def get_all_instances(
|
|
self, model_name: str = None, healthy_only: bool = False
|
|
) -> List[ModelInstance]:
|
|
"""Fetch all instances of a given model. Optionally, fetch only the healthy instances."""
|
|
|
|
@abstractmethod
|
|
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
|
"""Send a heartbeat for a given model instance. This can be used to verify if the instance is still alive and functioning."""
|
|
|
|
async def model_apply(self) -> bool:
|
|
raise NotImplementedError
|
|
|
|
|
|
class LocalModelController(BaseModelController):
|
|
def __init__(self, registry: ModelRegistry) -> None:
|
|
self.registry = registry
|
|
self.deployment = None
|
|
|
|
async def register_instance(self, instance: ModelInstance) -> bool:
|
|
return await self.registry.register_instance(instance)
|
|
|
|
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
|
return await self.registry.deregister_instance(instance)
|
|
|
|
async def get_all_instances(
|
|
self, model_name: str = None, healthy_only: bool = False
|
|
) -> List[ModelInstance]:
|
|
logger.info(
|
|
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
|
|
)
|
|
if not model_name:
|
|
return await self.registry.get_all_model_instances(
|
|
healthy_only=healthy_only
|
|
)
|
|
else:
|
|
return await self.registry.get_all_instances(model_name, healthy_only)
|
|
|
|
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
|
return await self.registry.send_heartbeat(instance)
|
|
|
|
|
|
class _RemoteModelController(APIMixin, BaseModelController):
|
|
def __init__(
|
|
self,
|
|
urls: str,
|
|
health_check_interval_secs: int = 5,
|
|
health_check_timeout_secs: int = 30,
|
|
check_health: bool = True,
|
|
choice_type: Literal["latest_first", "random"] = "latest_first",
|
|
) -> None:
|
|
APIMixin.__init__(
|
|
self,
|
|
urls=urls,
|
|
health_check_path="/api/health",
|
|
health_check_interval_secs=health_check_interval_secs,
|
|
health_check_timeout_secs=health_check_timeout_secs,
|
|
check_health=check_health,
|
|
choice_type=choice_type,
|
|
)
|
|
BaseModelController.__init__(self)
|
|
|
|
@api_remote(path="/api/controller/models", method="POST")
|
|
async def register_instance(self, instance: ModelInstance) -> bool:
|
|
pass
|
|
|
|
@api_remote(path="/api/controller/models", method="DELETE")
|
|
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
|
pass
|
|
|
|
@api_remote(path="/api/controller/models")
|
|
async def get_all_instances(
|
|
self, model_name: str = None, healthy_only: bool = False
|
|
) -> List[ModelInstance]:
|
|
pass
|
|
|
|
@api_remote(path="/api/controller/heartbeat", method="POST")
|
|
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
|
pass
|
|
|
|
|
|
class ModelRegistryClient(_RemoteModelController, ModelRegistry):
|
|
async def get_all_model_instances(
|
|
self, healthy_only: bool = False
|
|
) -> List[ModelInstance]:
|
|
return await self.get_all_instances(healthy_only=healthy_only)
|
|
|
|
@sync_api_remote(path="/api/controller/models")
|
|
def sync_get_all_instances(
|
|
self, model_name: str = None, healthy_only: bool = False
|
|
) -> List[ModelInstance]:
|
|
pass
|
|
|
|
|
|
class ModelControllerAdapter(BaseModelController):
|
|
def __init__(self, backend: BaseModelController = None) -> None:
|
|
self.backend = backend
|
|
|
|
async def register_instance(self, instance: ModelInstance) -> bool:
|
|
return await self.backend.register_instance(instance)
|
|
|
|
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
|
return await self.backend.deregister_instance(instance)
|
|
|
|
async def get_all_instances(
|
|
self, model_name: str = None, healthy_only: bool = False
|
|
) -> List[ModelInstance]:
|
|
return await self.backend.get_all_instances(model_name, healthy_only)
|
|
|
|
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
|
return await self.backend.send_heartbeat(instance)
|
|
|
|
async def model_apply(self) -> bool:
|
|
return await self.backend.model_apply()
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
controller = ModelControllerAdapter()
|
|
|
|
|
|
def initialize_controller(
|
|
app=None,
|
|
remote_controller_addr: str = None,
|
|
host: str = None,
|
|
port: int = None,
|
|
registry: Optional[ModelRegistry] = None,
|
|
controller_params: Optional[ModelControllerParameters] = None,
|
|
system_app: Optional[SystemApp] = None,
|
|
):
|
|
global controller
|
|
if remote_controller_addr:
|
|
controller.backend = _RemoteModelController(remote_controller_addr)
|
|
else:
|
|
if not registry:
|
|
registry = EmbeddedModelRegistry()
|
|
controller.backend = LocalModelController(registry=registry)
|
|
|
|
if app:
|
|
app.include_router(router, prefix="/api", tags=["Model"])
|
|
else:
|
|
import uvicorn
|
|
|
|
from dbgpt.configs.model_config import LOGDIR
|
|
|
|
setup_http_service_logging()
|
|
app = create_app()
|
|
if not system_app:
|
|
system_app = SystemApp(app)
|
|
if not controller_params:
|
|
raise ValueError("Controller parameters are required.")
|
|
initialize_tracer(
|
|
os.path.join(LOGDIR, controller_params.tracer_file),
|
|
root_operation_name="DB-GPT-ModelController",
|
|
system_app=system_app,
|
|
tracer_storage_cls=controller_params.tracer_storage_cls,
|
|
enable_open_telemetry=controller_params.tracer_to_open_telemetry,
|
|
otlp_endpoint=controller_params.otel_exporter_otlp_traces_endpoint,
|
|
otlp_insecure=controller_params.otel_exporter_otlp_traces_insecure,
|
|
otlp_timeout=controller_params.otel_exporter_otlp_traces_timeout,
|
|
)
|
|
|
|
app.include_router(router, prefix="/api", tags=["Model"])
|
|
uvicorn.run(app, host=host, port=port, log_level="info")
|
|
|
|
|
|
@router.get("/health")
|
|
async def api_health_check():
|
|
"""Health check API."""
|
|
return {"status": "ok"}
|
|
|
|
|
|
@router.post("/controller/models")
|
|
async def api_register_instance(request: ModelInstance):
|
|
with root_tracer.start_span(
|
|
"dbgpt.model.controller.register_instance", metadata=request.to_dict()
|
|
):
|
|
return await controller.register_instance(request)
|
|
|
|
|
|
@router.delete("/controller/models")
|
|
async def api_deregister_instance(model_name: str, host: str, port: int):
|
|
instance = ModelInstance(model_name=model_name, host=host, port=port)
|
|
with root_tracer.start_span(
|
|
"dbgpt.model.controller.deregister_instance", metadata=instance.to_dict()
|
|
):
|
|
return await controller.deregister_instance(instance)
|
|
|
|
|
|
@router.get("/controller/models")
|
|
async def api_get_all_instances(model_name: str = None, healthy_only: bool = False):
|
|
return await controller.get_all_instances(model_name, healthy_only=healthy_only)
|
|
|
|
|
|
@router.post("/controller/heartbeat")
|
|
async def api_model_heartbeat(request: ModelInstance):
|
|
return await controller.send_heartbeat(request)
|
|
|
|
|
|
def _create_registry(controller_params: ModelControllerParameters) -> ModelRegistry:
|
|
"""Create a model registry based on the controller parameters.
|
|
|
|
Registry will store the metadata of all model instances, it will be a high
|
|
availability service for model instances if you use a database registry now. Also,
|
|
we can implement more registry types in the future.
|
|
"""
|
|
registry_type = controller_params.registry_type.strip()
|
|
if controller_params.registry_type == "embedded":
|
|
return EmbeddedModelRegistry(
|
|
heartbeat_interval_secs=controller_params.heartbeat_interval_secs,
|
|
heartbeat_timeout_secs=controller_params.heartbeat_timeout_secs,
|
|
)
|
|
elif controller_params.registry_type == "database":
|
|
from urllib.parse import quote
|
|
from urllib.parse import quote_plus as urlquote
|
|
|
|
from dbgpt.model.cluster.registry_impl.storage import StorageModelRegistry
|
|
|
|
try_to_create_db = False
|
|
|
|
if controller_params.registry_db_type == "mysql":
|
|
db_name = controller_params.registry_db_name
|
|
db_host = controller_params.registry_db_host
|
|
db_port = controller_params.registry_db_port
|
|
db_user = controller_params.registry_db_user
|
|
db_password = controller_params.registry_db_password
|
|
if not db_name:
|
|
raise ValueError(
|
|
"Registry DB name is required when using MySQL registry."
|
|
)
|
|
if not db_host:
|
|
raise ValueError(
|
|
"Registry DB host is required when using MySQL registry."
|
|
)
|
|
if not db_port:
|
|
raise ValueError(
|
|
"Registry DB port is required when using MySQL registry."
|
|
)
|
|
if not db_user:
|
|
raise ValueError(
|
|
"Registry DB user is required when using MySQL registry."
|
|
)
|
|
if not db_password:
|
|
raise ValueError(
|
|
"Registry DB password is required when using MySQL registry."
|
|
)
|
|
db_url = (
|
|
f"mysql+pymysql://{quote(db_user)}:"
|
|
f"{urlquote(db_password)}@"
|
|
f"{db_host}:"
|
|
f"{str(db_port)}/"
|
|
f"{db_name}?charset=utf8mb4"
|
|
)
|
|
elif controller_params.registry_db_type == "sqlite":
|
|
db_name = controller_params.registry_db_name
|
|
if not db_name:
|
|
raise ValueError(
|
|
"Registry DB name is required when using SQLite registry."
|
|
)
|
|
db_url = f"sqlite:///{db_name}"
|
|
try_to_create_db = True
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported registry DB type: {controller_params.registry_db_type}"
|
|
)
|
|
|
|
registry = StorageModelRegistry.from_url(
|
|
db_url,
|
|
db_name,
|
|
pool_size=controller_params.registry_db_pool_size,
|
|
max_overflow=controller_params.registry_db_max_overflow,
|
|
try_to_create_db=try_to_create_db,
|
|
heartbeat_interval_secs=controller_params.heartbeat_interval_secs,
|
|
heartbeat_timeout_secs=controller_params.heartbeat_timeout_secs,
|
|
)
|
|
return registry
|
|
else:
|
|
raise ValueError(f"Unsupported registry type: {registry_type}")
|
|
|
|
|
|
def run_model_controller():
|
|
parser = EnvArgumentParser()
|
|
env_prefix = "controller_"
|
|
controller_params: ModelControllerParameters = parser.parse_args_into_dataclass(
|
|
ModelControllerParameters,
|
|
env_prefixes=[env_prefix],
|
|
)
|
|
|
|
setup_logging(
|
|
"dbgpt",
|
|
logging_level=controller_params.log_level,
|
|
logger_filename=controller_params.log_file,
|
|
)
|
|
registry = _create_registry(controller_params)
|
|
|
|
initialize_controller(
|
|
host=controller_params.host,
|
|
port=controller_params.port,
|
|
registry=registry,
|
|
controller_params=controller_params,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_model_controller()
|