feat(model): Support database model registry (#1656)

This commit is contained in:
Fangyin Cheng
2024-06-24 19:07:10 +08:00
committed by GitHub
parent c57ee0289b
commit 47d205f676
35 changed files with 2014 additions and 792 deletions

View File

@@ -1,6 +1,6 @@
import logging
from abc import ABC, abstractmethod
from typing import List
from typing import List, Literal, Optional
from fastapi import APIRouter
@@ -8,6 +8,7 @@ from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.model.base import ModelInstance
from dbgpt.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
from dbgpt.model.parameter import ModelControllerParameters
from dbgpt.util.api_utils import APIMixin
from dbgpt.util.api_utils import _api_remote as api_remote
from dbgpt.util.api_utils import _sync_api_remote as sync_api_remote
from dbgpt.util.fastapi import create_app
@@ -46,9 +47,7 @@ class BaseModelController(BaseComponent, ABC):
class LocalModelController(BaseModelController):
def __init__(self, registry: ModelRegistry = None) -> None:
if not registry:
registry = EmbeddedModelRegistry()
def __init__(self, registry: ModelRegistry) -> None:
self.registry = registry
self.deployment = None
@@ -75,9 +74,25 @@ class LocalModelController(BaseModelController):
return await self.registry.send_heartbeat(instance)
class _RemoteModelController(BaseModelController):
def __init__(self, base_url: str) -> None:
self.base_url = base_url
class _RemoteModelController(APIMixin, BaseModelController):
def __init__(
self,
urls: str,
health_check_interval_secs: int = 5,
health_check_timeout_secs: int = 30,
check_health: bool = True,
choice_type: Literal["latest_first", "random"] = "latest_first",
) -> None:
APIMixin.__init__(
self,
urls=urls,
health_check_path="/api/health",
health_check_interval_secs=health_check_interval_secs,
health_check_timeout_secs=health_check_timeout_secs,
check_health=check_health,
choice_type=choice_type,
)
BaseModelController.__init__(self)
@api_remote(path="/api/controller/models", method="POST")
async def register_instance(self, instance: ModelInstance) -> bool:
@@ -139,13 +154,19 @@ controller = ModelControllerAdapter()
def initialize_controller(
app=None, remote_controller_addr: str = None, host: str = None, port: int = None
app=None,
remote_controller_addr: str = None,
host: str = None,
port: int = None,
registry: Optional[ModelRegistry] = None,
):
global controller
if remote_controller_addr:
controller.backend = _RemoteModelController(remote_controller_addr)
else:
controller.backend = LocalModelController()
if not registry:
registry = EmbeddedModelRegistry()
controller.backend = LocalModelController(registry=registry)
if app:
app.include_router(router, prefix="/api", tags=["Model"])
@@ -158,6 +179,12 @@ def initialize_controller(
uvicorn.run(app, host=host, port=port, log_level="info")
@router.get("/health")
async def api_health_check():
"""Health check API."""
return {"status": "ok"}
@router.post("/controller/models")
async def api_register_instance(request: ModelInstance):
return await controller.register_instance(request)
@@ -179,6 +206,87 @@ async def api_model_heartbeat(request: ModelInstance):
return await controller.send_heartbeat(request)
def _create_registry(controller_params: ModelControllerParameters) -> ModelRegistry:
"""Create a model registry based on the controller parameters.
Registry will store the metadata of all model instances, it will be a high
availability service for model instances if you use a database registry now. Also,
we can implement more registry types in the future.
"""
registry_type = controller_params.registry_type.strip()
if controller_params.registry_type == "embedded":
return EmbeddedModelRegistry(
heartbeat_interval_secs=controller_params.heartbeat_interval_secs,
heartbeat_timeout_secs=controller_params.heartbeat_timeout_secs,
)
elif controller_params.registry_type == "database":
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
from dbgpt.model.cluster.registry_impl.storage import StorageModelRegistry
try_to_create_db = False
if controller_params.registry_db_type == "mysql":
db_name = controller_params.registry_db_name
db_host = controller_params.registry_db_host
db_port = controller_params.registry_db_port
db_user = controller_params.registry_db_user
db_password = controller_params.registry_db_password
if not db_name:
raise ValueError(
"Registry DB name is required when using MySQL registry."
)
if not db_host:
raise ValueError(
"Registry DB host is required when using MySQL registry."
)
if not db_port:
raise ValueError(
"Registry DB port is required when using MySQL registry."
)
if not db_user:
raise ValueError(
"Registry DB user is required when using MySQL registry."
)
if not db_password:
raise ValueError(
"Registry DB password is required when using MySQL registry."
)
db_url = (
f"mysql+pymysql://{quote(db_user)}:"
f"{urlquote(db_password)}@"
f"{db_host}:"
f"{str(db_port)}/"
f"{db_name}?charset=utf8mb4"
)
elif controller_params.registry_db_type == "sqlite":
db_name = controller_params.registry_db_name
if not db_name:
raise ValueError(
"Registry DB name is required when using SQLite registry."
)
db_url = f"sqlite:///{db_name}"
try_to_create_db = True
else:
raise ValueError(
f"Unsupported registry DB type: {controller_params.registry_db_type}"
)
registry = StorageModelRegistry.from_url(
db_url,
db_name,
pool_size=controller_params.registry_db_pool_size,
max_overflow=controller_params.registry_db_max_overflow,
try_to_create_db=try_to_create_db,
heartbeat_interval_secs=controller_params.heartbeat_interval_secs,
heartbeat_timeout_secs=controller_params.heartbeat_timeout_secs,
)
return registry
else:
raise ValueError(f"Unsupported registry type: {registry_type}")
def run_model_controller():
parser = EnvArgumentParser()
env_prefix = "controller_"
@@ -192,8 +300,11 @@ def run_model_controller():
logging_level=controller_params.log_level,
logger_filename=controller_params.log_file,
)
registry = _create_registry(controller_params)
initialize_controller(host=controller_params.host, port=controller_params.port)
initialize_controller(
host=controller_params.host, port=controller_params.port, registry=registry
)
if __name__ == "__main__":