mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 05:01:25 +00:00
feat(model): Support database model registry (#1656)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
@@ -8,6 +8,7 @@ from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt.model.base import ModelInstance
|
||||
from dbgpt.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
|
||||
from dbgpt.model.parameter import ModelControllerParameters
|
||||
from dbgpt.util.api_utils import APIMixin
|
||||
from dbgpt.util.api_utils import _api_remote as api_remote
|
||||
from dbgpt.util.api_utils import _sync_api_remote as sync_api_remote
|
||||
from dbgpt.util.fastapi import create_app
|
||||
@@ -46,9 +47,7 @@ class BaseModelController(BaseComponent, ABC):
|
||||
|
||||
|
||||
class LocalModelController(BaseModelController):
|
||||
def __init__(self, registry: ModelRegistry = None) -> None:
|
||||
if not registry:
|
||||
registry = EmbeddedModelRegistry()
|
||||
def __init__(self, registry: ModelRegistry) -> None:
|
||||
self.registry = registry
|
||||
self.deployment = None
|
||||
|
||||
@@ -75,9 +74,25 @@ class LocalModelController(BaseModelController):
|
||||
return await self.registry.send_heartbeat(instance)
|
||||
|
||||
|
||||
class _RemoteModelController(BaseModelController):
|
||||
def __init__(self, base_url: str) -> None:
|
||||
self.base_url = base_url
|
||||
class _RemoteModelController(APIMixin, BaseModelController):
|
||||
def __init__(
|
||||
self,
|
||||
urls: str,
|
||||
health_check_interval_secs: int = 5,
|
||||
health_check_timeout_secs: int = 30,
|
||||
check_health: bool = True,
|
||||
choice_type: Literal["latest_first", "random"] = "latest_first",
|
||||
) -> None:
|
||||
APIMixin.__init__(
|
||||
self,
|
||||
urls=urls,
|
||||
health_check_path="/api/health",
|
||||
health_check_interval_secs=health_check_interval_secs,
|
||||
health_check_timeout_secs=health_check_timeout_secs,
|
||||
check_health=check_health,
|
||||
choice_type=choice_type,
|
||||
)
|
||||
BaseModelController.__init__(self)
|
||||
|
||||
@api_remote(path="/api/controller/models", method="POST")
|
||||
async def register_instance(self, instance: ModelInstance) -> bool:
|
||||
@@ -139,13 +154,19 @@ controller = ModelControllerAdapter()
|
||||
|
||||
|
||||
def initialize_controller(
|
||||
app=None, remote_controller_addr: str = None, host: str = None, port: int = None
|
||||
app=None,
|
||||
remote_controller_addr: str = None,
|
||||
host: str = None,
|
||||
port: int = None,
|
||||
registry: Optional[ModelRegistry] = None,
|
||||
):
|
||||
global controller
|
||||
if remote_controller_addr:
|
||||
controller.backend = _RemoteModelController(remote_controller_addr)
|
||||
else:
|
||||
controller.backend = LocalModelController()
|
||||
if not registry:
|
||||
registry = EmbeddedModelRegistry()
|
||||
controller.backend = LocalModelController(registry=registry)
|
||||
|
||||
if app:
|
||||
app.include_router(router, prefix="/api", tags=["Model"])
|
||||
@@ -158,6 +179,12 @@ def initialize_controller(
|
||||
uvicorn.run(app, host=host, port=port, log_level="info")
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def api_health_check():
|
||||
"""Health check API."""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.post("/controller/models")
|
||||
async def api_register_instance(request: ModelInstance):
|
||||
return await controller.register_instance(request)
|
||||
@@ -179,6 +206,87 @@ async def api_model_heartbeat(request: ModelInstance):
|
||||
return await controller.send_heartbeat(request)
|
||||
|
||||
|
||||
def _create_registry(controller_params: ModelControllerParameters) -> ModelRegistry:
|
||||
"""Create a model registry based on the controller parameters.
|
||||
|
||||
Registry will store the metadata of all model instances, it will be a high
|
||||
availability service for model instances if you use a database registry now. Also,
|
||||
we can implement more registry types in the future.
|
||||
"""
|
||||
registry_type = controller_params.registry_type.strip()
|
||||
if controller_params.registry_type == "embedded":
|
||||
return EmbeddedModelRegistry(
|
||||
heartbeat_interval_secs=controller_params.heartbeat_interval_secs,
|
||||
heartbeat_timeout_secs=controller_params.heartbeat_timeout_secs,
|
||||
)
|
||||
elif controller_params.registry_type == "database":
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import quote_plus as urlquote
|
||||
|
||||
from dbgpt.model.cluster.registry_impl.storage import StorageModelRegistry
|
||||
|
||||
try_to_create_db = False
|
||||
|
||||
if controller_params.registry_db_type == "mysql":
|
||||
db_name = controller_params.registry_db_name
|
||||
db_host = controller_params.registry_db_host
|
||||
db_port = controller_params.registry_db_port
|
||||
db_user = controller_params.registry_db_user
|
||||
db_password = controller_params.registry_db_password
|
||||
if not db_name:
|
||||
raise ValueError(
|
||||
"Registry DB name is required when using MySQL registry."
|
||||
)
|
||||
if not db_host:
|
||||
raise ValueError(
|
||||
"Registry DB host is required when using MySQL registry."
|
||||
)
|
||||
if not db_port:
|
||||
raise ValueError(
|
||||
"Registry DB port is required when using MySQL registry."
|
||||
)
|
||||
if not db_user:
|
||||
raise ValueError(
|
||||
"Registry DB user is required when using MySQL registry."
|
||||
)
|
||||
if not db_password:
|
||||
raise ValueError(
|
||||
"Registry DB password is required when using MySQL registry."
|
||||
)
|
||||
db_url = (
|
||||
f"mysql+pymysql://{quote(db_user)}:"
|
||||
f"{urlquote(db_password)}@"
|
||||
f"{db_host}:"
|
||||
f"{str(db_port)}/"
|
||||
f"{db_name}?charset=utf8mb4"
|
||||
)
|
||||
elif controller_params.registry_db_type == "sqlite":
|
||||
db_name = controller_params.registry_db_name
|
||||
if not db_name:
|
||||
raise ValueError(
|
||||
"Registry DB name is required when using SQLite registry."
|
||||
)
|
||||
db_url = f"sqlite:///{db_name}"
|
||||
try_to_create_db = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported registry DB type: {controller_params.registry_db_type}"
|
||||
)
|
||||
|
||||
registry = StorageModelRegistry.from_url(
|
||||
db_url,
|
||||
db_name,
|
||||
pool_size=controller_params.registry_db_pool_size,
|
||||
max_overflow=controller_params.registry_db_max_overflow,
|
||||
try_to_create_db=try_to_create_db,
|
||||
heartbeat_interval_secs=controller_params.heartbeat_interval_secs,
|
||||
heartbeat_timeout_secs=controller_params.heartbeat_timeout_secs,
|
||||
)
|
||||
return registry
|
||||
else:
|
||||
raise ValueError(f"Unsupported registry type: {registry_type}")
|
||||
|
||||
|
||||
def run_model_controller():
|
||||
parser = EnvArgumentParser()
|
||||
env_prefix = "controller_"
|
||||
@@ -192,8 +300,11 @@ def run_model_controller():
|
||||
logging_level=controller_params.log_level,
|
||||
logger_filename=controller_params.log_file,
|
||||
)
|
||||
registry = _create_registry(controller_params)
|
||||
|
||||
initialize_controller(host=controller_params.host, port=controller_params.port)
|
||||
initialize_controller(
|
||||
host=controller_params.host, port=controller_params.port, registry=registry
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user