mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-16 22:51:24 +00:00
feat(model): Support database model registry (#1656)
This commit is contained in:
@@ -26,6 +26,7 @@ from dbgpt.util.parameter_utils import (
|
||||
build_lazy_click_command,
|
||||
)
|
||||
|
||||
# Your can set environment variable CONTROLLER_ADDRESS to set the default address
|
||||
MODEL_CONTROLLER_ADDRESS = "http://127.0.0.1:8000"
|
||||
|
||||
logger = logging.getLogger("dbgpt_cli")
|
||||
|
@@ -1,22 +1,11 @@
|
||||
import importlib.metadata as metadata
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from aioresponses import aioresponses
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from httpx import AsyncClient, HTTPError
|
||||
from httpx import ASGITransport, AsyncClient, HTTPError
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.model.cluster.apiserver.api import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
DeltaMessage,
|
||||
ModelList,
|
||||
UsageInfo,
|
||||
api_settings,
|
||||
initialize_apiserver,
|
||||
)
|
||||
@@ -56,12 +45,13 @@ async def client(request, system_app: SystemApp):
|
||||
if api_settings:
|
||||
# Clear global api keys
|
||||
api_settings.api_keys = []
|
||||
async with AsyncClient(app=app, base_url="http://test", headers=headers) as client:
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app), base_url="http://test", headers=headers
|
||||
) as client:
|
||||
async with _new_cluster(**param) as cluster:
|
||||
worker_manager, model_registry = cluster
|
||||
system_app.register(_DefaultWorkerManagerFactory, worker_manager)
|
||||
system_app.register_instance(model_registry)
|
||||
# print(f"Instances {model_registry.registry}")
|
||||
initialize_apiserver(None, app, system_app, api_keys=api_keys)
|
||||
yield client
|
||||
|
||||
@@ -113,7 +103,11 @@ async def test_chat_completions(client: AsyncClient, expected_messages):
|
||||
"Hello world.",
|
||||
"abc",
|
||||
),
|
||||
({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"),
|
||||
(
|
||||
{"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]},
|
||||
"你好,我是张三。",
|
||||
"abc",
|
||||
),
|
||||
],
|
||||
indirect=["client"],
|
||||
)
|
||||
@@ -160,7 +154,11 @@ async def test_chat_completions_with_openai_lib_async_no_stream(
|
||||
"Hello world.",
|
||||
"abc",
|
||||
),
|
||||
({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"),
|
||||
(
|
||||
{"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]},
|
||||
"你好,我是张三。",
|
||||
"abc",
|
||||
),
|
||||
],
|
||||
indirect=["client"],
|
||||
)
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
@@ -8,6 +8,7 @@ from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt.model.base import ModelInstance
|
||||
from dbgpt.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
|
||||
from dbgpt.model.parameter import ModelControllerParameters
|
||||
from dbgpt.util.api_utils import APIMixin
|
||||
from dbgpt.util.api_utils import _api_remote as api_remote
|
||||
from dbgpt.util.api_utils import _sync_api_remote as sync_api_remote
|
||||
from dbgpt.util.fastapi import create_app
|
||||
@@ -46,9 +47,7 @@ class BaseModelController(BaseComponent, ABC):
|
||||
|
||||
|
||||
class LocalModelController(BaseModelController):
|
||||
def __init__(self, registry: ModelRegistry = None) -> None:
|
||||
if not registry:
|
||||
registry = EmbeddedModelRegistry()
|
||||
def __init__(self, registry: ModelRegistry) -> None:
|
||||
self.registry = registry
|
||||
self.deployment = None
|
||||
|
||||
@@ -75,9 +74,25 @@ class LocalModelController(BaseModelController):
|
||||
return await self.registry.send_heartbeat(instance)
|
||||
|
||||
|
||||
class _RemoteModelController(BaseModelController):
|
||||
def __init__(self, base_url: str) -> None:
|
||||
self.base_url = base_url
|
||||
class _RemoteModelController(APIMixin, BaseModelController):
|
||||
def __init__(
|
||||
self,
|
||||
urls: str,
|
||||
health_check_interval_secs: int = 5,
|
||||
health_check_timeout_secs: int = 30,
|
||||
check_health: bool = True,
|
||||
choice_type: Literal["latest_first", "random"] = "latest_first",
|
||||
) -> None:
|
||||
APIMixin.__init__(
|
||||
self,
|
||||
urls=urls,
|
||||
health_check_path="/api/health",
|
||||
health_check_interval_secs=health_check_interval_secs,
|
||||
health_check_timeout_secs=health_check_timeout_secs,
|
||||
check_health=check_health,
|
||||
choice_type=choice_type,
|
||||
)
|
||||
BaseModelController.__init__(self)
|
||||
|
||||
@api_remote(path="/api/controller/models", method="POST")
|
||||
async def register_instance(self, instance: ModelInstance) -> bool:
|
||||
@@ -139,13 +154,19 @@ controller = ModelControllerAdapter()
|
||||
|
||||
|
||||
def initialize_controller(
|
||||
app=None, remote_controller_addr: str = None, host: str = None, port: int = None
|
||||
app=None,
|
||||
remote_controller_addr: str = None,
|
||||
host: str = None,
|
||||
port: int = None,
|
||||
registry: Optional[ModelRegistry] = None,
|
||||
):
|
||||
global controller
|
||||
if remote_controller_addr:
|
||||
controller.backend = _RemoteModelController(remote_controller_addr)
|
||||
else:
|
||||
controller.backend = LocalModelController()
|
||||
if not registry:
|
||||
registry = EmbeddedModelRegistry()
|
||||
controller.backend = LocalModelController(registry=registry)
|
||||
|
||||
if app:
|
||||
app.include_router(router, prefix="/api", tags=["Model"])
|
||||
@@ -158,6 +179,12 @@ def initialize_controller(
|
||||
uvicorn.run(app, host=host, port=port, log_level="info")
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def api_health_check():
|
||||
"""Health check API."""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.post("/controller/models")
|
||||
async def api_register_instance(request: ModelInstance):
|
||||
return await controller.register_instance(request)
|
||||
@@ -179,6 +206,87 @@ async def api_model_heartbeat(request: ModelInstance):
|
||||
return await controller.send_heartbeat(request)
|
||||
|
||||
|
||||
def _create_registry(controller_params: ModelControllerParameters) -> ModelRegistry:
|
||||
"""Create a model registry based on the controller parameters.
|
||||
|
||||
Registry will store the metadata of all model instances, it will be a high
|
||||
availability service for model instances if you use a database registry now. Also,
|
||||
we can implement more registry types in the future.
|
||||
"""
|
||||
registry_type = controller_params.registry_type.strip()
|
||||
if controller_params.registry_type == "embedded":
|
||||
return EmbeddedModelRegistry(
|
||||
heartbeat_interval_secs=controller_params.heartbeat_interval_secs,
|
||||
heartbeat_timeout_secs=controller_params.heartbeat_timeout_secs,
|
||||
)
|
||||
elif controller_params.registry_type == "database":
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import quote_plus as urlquote
|
||||
|
||||
from dbgpt.model.cluster.registry_impl.storage import StorageModelRegistry
|
||||
|
||||
try_to_create_db = False
|
||||
|
||||
if controller_params.registry_db_type == "mysql":
|
||||
db_name = controller_params.registry_db_name
|
||||
db_host = controller_params.registry_db_host
|
||||
db_port = controller_params.registry_db_port
|
||||
db_user = controller_params.registry_db_user
|
||||
db_password = controller_params.registry_db_password
|
||||
if not db_name:
|
||||
raise ValueError(
|
||||
"Registry DB name is required when using MySQL registry."
|
||||
)
|
||||
if not db_host:
|
||||
raise ValueError(
|
||||
"Registry DB host is required when using MySQL registry."
|
||||
)
|
||||
if not db_port:
|
||||
raise ValueError(
|
||||
"Registry DB port is required when using MySQL registry."
|
||||
)
|
||||
if not db_user:
|
||||
raise ValueError(
|
||||
"Registry DB user is required when using MySQL registry."
|
||||
)
|
||||
if not db_password:
|
||||
raise ValueError(
|
||||
"Registry DB password is required when using MySQL registry."
|
||||
)
|
||||
db_url = (
|
||||
f"mysql+pymysql://{quote(db_user)}:"
|
||||
f"{urlquote(db_password)}@"
|
||||
f"{db_host}:"
|
||||
f"{str(db_port)}/"
|
||||
f"{db_name}?charset=utf8mb4"
|
||||
)
|
||||
elif controller_params.registry_db_type == "sqlite":
|
||||
db_name = controller_params.registry_db_name
|
||||
if not db_name:
|
||||
raise ValueError(
|
||||
"Registry DB name is required when using SQLite registry."
|
||||
)
|
||||
db_url = f"sqlite:///{db_name}"
|
||||
try_to_create_db = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported registry DB type: {controller_params.registry_db_type}"
|
||||
)
|
||||
|
||||
registry = StorageModelRegistry.from_url(
|
||||
db_url,
|
||||
db_name,
|
||||
pool_size=controller_params.registry_db_pool_size,
|
||||
max_overflow=controller_params.registry_db_max_overflow,
|
||||
try_to_create_db=try_to_create_db,
|
||||
heartbeat_interval_secs=controller_params.heartbeat_interval_secs,
|
||||
heartbeat_timeout_secs=controller_params.heartbeat_timeout_secs,
|
||||
)
|
||||
return registry
|
||||
else:
|
||||
raise ValueError(f"Unsupported registry type: {registry_type}")
|
||||
|
||||
|
||||
def run_model_controller():
|
||||
parser = EnvArgumentParser()
|
||||
env_prefix = "controller_"
|
||||
@@ -192,8 +300,11 @@ def run_model_controller():
|
||||
logging_level=controller_params.log_level,
|
||||
logger_filename=controller_params.log_file,
|
||||
)
|
||||
registry = _create_registry(controller_params)
|
||||
|
||||
initialize_controller(host=controller_params.host, port=controller_params.port)
|
||||
initialize_controller(
|
||||
host=controller_params.host, port=controller_params.port, registry=registry
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
0
dbgpt/model/cluster/registry_impl/__init__.py
Normal file
0
dbgpt/model/cluster/registry_impl/__init__.py
Normal file
116
dbgpt/model/cluster/registry_impl/db_storage.py
Normal file
116
dbgpt/model/cluster/registry_impl/db_storage.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Float,
|
||||
Integer,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dbgpt.core.interface.storage import ResourceIdentifier, StorageItemAdapter
|
||||
from dbgpt.storage.metadata import Model
|
||||
|
||||
from .storage import ModelInstanceStorageItem
|
||||
|
||||
|
||||
class ModelInstanceEntity(Model):
|
||||
"""Model instance entity.
|
||||
|
||||
Use database as the registry, here is the table schema of the model instance.
|
||||
"""
|
||||
|
||||
__tablename__ = "dbgpt_cluster_registry_instance"
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"model_name",
|
||||
"host",
|
||||
"port",
|
||||
"sys_code",
|
||||
name="uk_model_instance",
|
||||
),
|
||||
)
|
||||
id = Column(Integer, primary_key=True, comment="Auto increment id")
|
||||
model_name = Column(String(128), nullable=False, comment="Model name")
|
||||
host = Column(String(128), nullable=False, comment="Host of the model")
|
||||
port = Column(Integer, nullable=False, comment="Port of the model")
|
||||
weight = Column(Float, nullable=True, default=1.0, comment="Weight of the model")
|
||||
check_healthy = Column(
|
||||
Boolean,
|
||||
nullable=True,
|
||||
default=True,
|
||||
comment="Whether to check the health of the model",
|
||||
)
|
||||
healthy = Column(
|
||||
Boolean, nullable=True, default=False, comment="Whether the model is healthy"
|
||||
)
|
||||
enabled = Column(
|
||||
Boolean, nullable=True, default=True, comment="Whether the model is enabled"
|
||||
)
|
||||
prompt_template = Column(
|
||||
String(128),
|
||||
nullable=True,
|
||||
comment="Prompt template for the model instance",
|
||||
)
|
||||
last_heartbeat = Column(
|
||||
DateTime,
|
||||
nullable=True,
|
||||
comment="Last heartbeat time of the model instance",
|
||||
)
|
||||
user_name = Column(String(128), nullable=True, comment="User name")
|
||||
sys_code = Column(String(128), nullable=True, comment="System code")
|
||||
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
|
||||
gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
|
||||
|
||||
|
||||
class ModelInstanceItemAdapter(
|
||||
StorageItemAdapter[ModelInstanceStorageItem, ModelInstanceEntity]
|
||||
):
|
||||
def to_storage_format(self, item: ModelInstanceStorageItem) -> ModelInstanceEntity:
|
||||
return ModelInstanceEntity(
|
||||
model_name=item.model_name,
|
||||
host=item.host,
|
||||
port=item.port,
|
||||
weight=item.weight,
|
||||
check_healthy=item.check_healthy,
|
||||
healthy=item.healthy,
|
||||
enabled=item.enabled,
|
||||
prompt_template=item.prompt_template,
|
||||
last_heartbeat=item.last_heartbeat,
|
||||
# user_name=item.user_name,
|
||||
# sys_code=item.sys_code,
|
||||
)
|
||||
|
||||
def from_storage_format(
|
||||
self, model: ModelInstanceEntity
|
||||
) -> ModelInstanceStorageItem:
|
||||
return ModelInstanceStorageItem(
|
||||
model_name=model.model_name,
|
||||
host=model.host,
|
||||
port=model.port,
|
||||
weight=model.weight,
|
||||
check_healthy=model.check_healthy,
|
||||
healthy=model.healthy,
|
||||
enabled=model.enabled,
|
||||
prompt_template=model.prompt_template,
|
||||
last_heartbeat=model.last_heartbeat,
|
||||
)
|
||||
|
||||
def get_query_for_identifier(
|
||||
self,
|
||||
storage_format: ModelInstanceEntity,
|
||||
resource_id: ResourceIdentifier,
|
||||
**kwargs,
|
||||
):
|
||||
session: Session = kwargs.get("session")
|
||||
if session is None:
|
||||
raise Exception("session is None")
|
||||
query_obj = session.query(ModelInstanceEntity)
|
||||
for key, value in resource_id.to_dict().items():
|
||||
if value is None:
|
||||
continue
|
||||
query_obj = query_obj.filter(getattr(ModelInstanceEntity, key) == value)
|
||||
return query_obj
|
374
dbgpt/model/cluster/registry_impl/storage.py
Normal file
374
dbgpt/model/cluster/registry_impl/storage.py
Normal file
@@ -0,0 +1,374 @@
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.core.interface.storage import (
|
||||
QuerySpec,
|
||||
ResourceIdentifier,
|
||||
StorageInterface,
|
||||
StorageItem,
|
||||
)
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async
|
||||
|
||||
from ...base import ModelInstance
|
||||
from ..registry import ModelRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInstanceIdentifier(ResourceIdentifier):
|
||||
identifier_split: str = field(default="___$$$$___", init=False)
|
||||
model_name: str
|
||||
host: str
|
||||
port: int
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post init method."""
|
||||
if self.model_name is None:
|
||||
raise ValueError("model_name is required.")
|
||||
if self.host is None:
|
||||
raise ValueError("host is required.")
|
||||
if self.port is None:
|
||||
raise ValueError("port is required.")
|
||||
|
||||
if any(
|
||||
self.identifier_split in key
|
||||
for key in [self.model_name, self.host, str(self.port)]
|
||||
if key is not None
|
||||
):
|
||||
raise ValueError(
|
||||
f"identifier_split {self.identifier_split} is not allowed in "
|
||||
f"model_name, host, port."
|
||||
)
|
||||
|
||||
@property
|
||||
def str_identifier(self) -> str:
|
||||
"""Return the string identifier of the identifier."""
|
||||
return self.identifier_split.join(
|
||||
key
|
||||
for key in [
|
||||
self.model_name,
|
||||
self.host,
|
||||
str(self.port),
|
||||
]
|
||||
if key is not None
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert the identifier to a dict.
|
||||
|
||||
Returns:
|
||||
Dict: The dict of the identifier.
|
||||
"""
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInstanceStorageItem(StorageItem):
|
||||
|
||||
model_name: str
|
||||
host: str
|
||||
port: int
|
||||
weight: Optional[float] = 1.0
|
||||
check_healthy: Optional[bool] = True
|
||||
healthy: Optional[bool] = False
|
||||
enabled: Optional[bool] = True
|
||||
prompt_template: Optional[str] = None
|
||||
last_heartbeat: Optional[datetime] = None
|
||||
_identifier: ModelInstanceIdentifier = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post init method."""
|
||||
# Convert last_heartbeat to datetime if it's a timestamp
|
||||
if isinstance(self.last_heartbeat, (int, float)):
|
||||
self.last_heartbeat = datetime.fromtimestamp(self.last_heartbeat)
|
||||
|
||||
self._identifier = ModelInstanceIdentifier(
|
||||
model_name=self.model_name,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
)
|
||||
|
||||
@property
|
||||
def identifier(self) -> ModelInstanceIdentifier:
|
||||
return self._identifier
|
||||
|
||||
def merge(self, other: "StorageItem") -> None:
|
||||
if not isinstance(other, ModelInstanceStorageItem):
|
||||
raise ValueError(f"Cannot merge with {type(other)}")
|
||||
self.from_object(other)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
last_heartbeat = self.last_heartbeat.timestamp()
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"weight": self.weight,
|
||||
"check_healthy": self.check_healthy,
|
||||
"healthy": self.healthy,
|
||||
"enabled": self.enabled,
|
||||
"prompt_template": self.prompt_template,
|
||||
"last_heartbeat": last_heartbeat,
|
||||
}
|
||||
|
||||
def from_object(self, item: "ModelInstanceStorageItem") -> None:
|
||||
"""Build the item from another item."""
|
||||
self.model_name = item.model_name
|
||||
self.host = item.host
|
||||
self.port = item.port
|
||||
self.weight = item.weight
|
||||
self.check_healthy = item.check_healthy
|
||||
self.healthy = item.healthy
|
||||
self.enabled = item.enabled
|
||||
self.prompt_template = item.prompt_template
|
||||
self.last_heartbeat = item.last_heartbeat
|
||||
|
||||
@classmethod
|
||||
def from_model_instance(cls, instance: ModelInstance) -> "ModelInstanceStorageItem":
|
||||
return cls(
|
||||
model_name=instance.model_name,
|
||||
host=instance.host,
|
||||
port=instance.port,
|
||||
weight=instance.weight,
|
||||
check_healthy=instance.check_healthy,
|
||||
healthy=instance.healthy,
|
||||
enabled=instance.enabled,
|
||||
prompt_template=instance.prompt_template,
|
||||
last_heartbeat=instance.last_heartbeat,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def to_model_instance(cls, item: "ModelInstanceStorageItem") -> ModelInstance:
|
||||
return ModelInstance(
|
||||
model_name=item.model_name,
|
||||
host=item.host,
|
||||
port=item.port,
|
||||
weight=item.weight,
|
||||
check_healthy=item.check_healthy,
|
||||
healthy=item.healthy,
|
||||
enabled=item.enabled,
|
||||
prompt_template=item.prompt_template,
|
||||
last_heartbeat=item.last_heartbeat,
|
||||
)
|
||||
|
||||
|
||||
class StorageModelRegistry(ModelRegistry):
|
||||
def __init__(
|
||||
self,
|
||||
storage: StorageInterface,
|
||||
system_app: SystemApp | None = None,
|
||||
executor: Optional[Executor] = None,
|
||||
heartbeat_interval_secs: float | int = 60,
|
||||
heartbeat_timeout_secs: int = 120,
|
||||
):
|
||||
super().__init__(system_app)
|
||||
self._storage = storage
|
||||
self._executor = executor or ThreadPoolExecutor(max_workers=2)
|
||||
self.heartbeat_interval_secs = heartbeat_interval_secs
|
||||
self.heartbeat_timeout_secs = heartbeat_timeout_secs
|
||||
self.heartbeat_thread = threading.Thread(target=self._heartbeat_checker)
|
||||
self.heartbeat_thread.daemon = True
|
||||
self.heartbeat_thread.start()
|
||||
|
||||
@classmethod
|
||||
def from_url(
|
||||
cls,
|
||||
db_url: str,
|
||||
db_name: str,
|
||||
pool_size: int = 5,
|
||||
max_overflow: int = 10,
|
||||
try_to_create_db: bool = False,
|
||||
**kwargs,
|
||||
) -> "StorageModelRegistry":
|
||||
from dbgpt.storage.metadata.db_manager import DatabaseManager, initialize_db
|
||||
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
|
||||
from .db_storage import ModelInstanceEntity, ModelInstanceItemAdapter
|
||||
|
||||
engine_args = {
|
||||
"pool_size": pool_size,
|
||||
"max_overflow": max_overflow,
|
||||
"pool_timeout": 30,
|
||||
"pool_recycle": 3600,
|
||||
"pool_pre_ping": True,
|
||||
}
|
||||
|
||||
db: DatabaseManager = initialize_db(
|
||||
db_url, db_name, engine_args, try_to_create_db=try_to_create_db
|
||||
)
|
||||
storage_adapter = ModelInstanceItemAdapter()
|
||||
serializer = JsonSerializer()
|
||||
storage = SQLAlchemyStorage(
|
||||
db,
|
||||
ModelInstanceEntity,
|
||||
storage_adapter,
|
||||
serializer,
|
||||
)
|
||||
return cls(storage, **kwargs)
|
||||
|
||||
async def _get_instances_by_model(
|
||||
self, model_name: str, host: str, port: int, healthy_only: bool = False
|
||||
) -> Tuple[List[ModelInstanceStorageItem], List[ModelInstanceStorageItem]]:
|
||||
query_spec = QuerySpec(conditions={"model_name": model_name})
|
||||
# Query all instances of the model
|
||||
instances = await blocking_func_to_async(
|
||||
self._executor, self._storage.query, query_spec, ModelInstanceStorageItem
|
||||
)
|
||||
if healthy_only:
|
||||
instances = [ins for ins in instances if ins.healthy is True]
|
||||
exist_ins = [ins for ins in instances if ins.host == host and ins.port == port]
|
||||
return instances, exist_ins
|
||||
|
||||
def _heartbeat_checker(self):
|
||||
while True:
|
||||
all_instances: List[ModelInstanceStorageItem] = self._storage.query(
|
||||
QuerySpec(conditions={}), ModelInstanceStorageItem
|
||||
)
|
||||
for instance in all_instances:
|
||||
if (
|
||||
instance.check_healthy
|
||||
and datetime.now() - instance.last_heartbeat
|
||||
> timedelta(seconds=self.heartbeat_timeout_secs)
|
||||
):
|
||||
instance.healthy = False
|
||||
self._storage.update(instance)
|
||||
time.sleep(self.heartbeat_interval_secs)
|
||||
|
||||
async def register_instance(self, instance: ModelInstance) -> bool:
|
||||
model_name = instance.model_name.strip()
|
||||
host = instance.host.strip()
|
||||
port = instance.port
|
||||
_, exist_ins = await self._get_instances_by_model(
|
||||
model_name, host, port, healthy_only=False
|
||||
)
|
||||
if exist_ins:
|
||||
# Exist instances, just update the instance
|
||||
# One exist instance at most
|
||||
ins: ModelInstanceStorageItem = exist_ins[0]
|
||||
# Update instance
|
||||
ins.weight = instance.weight
|
||||
ins.healthy = True
|
||||
ins.prompt_template = instance.prompt_template
|
||||
ins.last_heartbeat = datetime.now()
|
||||
await blocking_func_to_async(self._executor, self._storage.update, ins)
|
||||
else:
|
||||
# No exist instance, save the new instance
|
||||
new_inst = ModelInstanceStorageItem.from_model_instance(instance)
|
||||
new_inst.healthy = True
|
||||
new_inst.last_heartbeat = datetime.now()
|
||||
await blocking_func_to_async(self._executor, self._storage.save, new_inst)
|
||||
return True
|
||||
|
||||
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
||||
"""Deregister a model instance.
|
||||
|
||||
If the instance exists, set the instance as unhealthy, nothing to do if the
|
||||
instance does not exist.
|
||||
|
||||
Args:
|
||||
instance (ModelInstance): The instance to deregister.
|
||||
"""
|
||||
model_name = instance.model_name.strip()
|
||||
host = instance.host.strip()
|
||||
port = instance.port
|
||||
_, exist_ins = await self._get_instances_by_model(
|
||||
model_name, host, port, healthy_only=False
|
||||
)
|
||||
if exist_ins:
|
||||
ins = exist_ins[0]
|
||||
ins.healthy = False
|
||||
await blocking_func_to_async(self._executor, self._storage.update, ins)
|
||||
return True
|
||||
|
||||
async def get_all_instances(
|
||||
self, model_name: str, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
"""Get all instances of a model(Async).
|
||||
|
||||
Args:
|
||||
model_name (str): The model name.
|
||||
healthy_only (bool): Whether only get healthy instances. Defaults to False.
|
||||
"""
|
||||
return await blocking_func_to_async(
|
||||
self._executor, self.sync_get_all_instances, model_name, healthy_only
|
||||
)
|
||||
|
||||
def sync_get_all_instances(
|
||||
self, model_name: str, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
"""Get all instances of a model.
|
||||
|
||||
Args:
|
||||
model_name (str): The model name.
|
||||
healthy_only (bool): Whether only get healthy instances. Defaults to False.
|
||||
|
||||
Returns:
|
||||
List[ModelInstance]: The list of instances.
|
||||
"""
|
||||
instances = self._storage.query(
|
||||
QuerySpec(conditions={"model_name": model_name}), ModelInstanceStorageItem
|
||||
)
|
||||
if healthy_only:
|
||||
instances = [ins for ins in instances if ins.healthy is True]
|
||||
return [ModelInstanceStorageItem.to_model_instance(ins) for ins in instances]
|
||||
|
||||
async def get_all_model_instances(
|
||||
self, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
"""Get all model instances.
|
||||
|
||||
Args:
|
||||
healthy_only (bool): Whether only get healthy instances. Defaults to False.
|
||||
|
||||
Returns:
|
||||
List[ModelInstance]: The list of instances.
|
||||
"""
|
||||
all_instances = await blocking_func_to_async(
|
||||
self._executor,
|
||||
self._storage.query,
|
||||
QuerySpec(conditions={}),
|
||||
ModelInstanceStorageItem,
|
||||
)
|
||||
if healthy_only:
|
||||
all_instances = [ins for ins in all_instances if ins.healthy is True]
|
||||
return [
|
||||
ModelInstanceStorageItem.to_model_instance(ins) for ins in all_instances
|
||||
]
|
||||
|
||||
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
||||
"""Receive heartbeat from model instance.
|
||||
|
||||
Update the last heartbeat time of the instance. If the instance does not exist,
|
||||
register the instance.
|
||||
|
||||
Args:
|
||||
instance (ModelInstance): The instance to send heartbeat.
|
||||
|
||||
Returns:
|
||||
bool: True if the heartbeat is received successfully.
|
||||
"""
|
||||
model_name = instance.model_name.strip()
|
||||
host = instance.host.strip()
|
||||
port = instance.port
|
||||
_, exist_ins = await self._get_instances_by_model(
|
||||
model_name, host, port, healthy_only=False
|
||||
)
|
||||
if not exist_ins:
|
||||
# register new instance from heartbeat
|
||||
await self.register_instance(instance)
|
||||
return True
|
||||
else:
|
||||
ins = exist_ins[0]
|
||||
ins.last_heartbeat = datetime.now()
|
||||
ins.healthy = True
|
||||
await blocking_func_to_async(self._executor, self._storage.update, ins)
|
||||
return True
|
0
dbgpt/model/cluster/tests/registry_impl/__init__.py
Normal file
0
dbgpt/model/cluster/tests/registry_impl/__init__.py
Normal file
221
dbgpt/model/cluster/tests/registry_impl/test_storage_registry.py
Normal file
221
dbgpt/model/cluster/tests/registry_impl/test_storage_registry.py
Normal file
@@ -0,0 +1,221 @@
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.core.interface.storage import InMemoryStorage, QuerySpec
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
|
||||
from ...registry_impl.storage import (
|
||||
ModelInstance,
|
||||
ModelInstanceStorageItem,
|
||||
StorageModelRegistry,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def in_memory_storage():
|
||||
return InMemoryStorage(serializer=JsonSerializer())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def thread_pool_executor():
|
||||
return ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def registry(in_memory_storage, thread_pool_executor):
|
||||
return StorageModelRegistry(
|
||||
storage=in_memory_storage,
|
||||
executor=thread_pool_executor,
|
||||
heartbeat_interval_secs=1,
|
||||
heartbeat_timeout_secs=2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_instance():
|
||||
return ModelInstance(
|
||||
model_name="test_model",
|
||||
host="localhost",
|
||||
port=8080,
|
||||
weight=1.0,
|
||||
check_healthy=True,
|
||||
healthy=True,
|
||||
enabled=True,
|
||||
prompt_template=None,
|
||||
last_heartbeat=datetime.now(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_instance_2():
|
||||
return ModelInstance(
|
||||
model_name="test_model",
|
||||
host="localhost",
|
||||
port=8081,
|
||||
weight=1.0,
|
||||
check_healthy=True,
|
||||
healthy=True,
|
||||
enabled=True,
|
||||
prompt_template=None,
|
||||
last_heartbeat=datetime.now(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_instance_3():
|
||||
return ModelInstance(
|
||||
model_name="test_model_2",
|
||||
host="localhost",
|
||||
port=8082,
|
||||
weight=1.0,
|
||||
check_healthy=True,
|
||||
healthy=True,
|
||||
enabled=True,
|
||||
prompt_template=None,
|
||||
last_heartbeat=datetime.now(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_instance_storage_item(model_instance):
|
||||
return ModelInstanceStorageItem.from_model_instance(model_instance)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_instance_new(registry, model_instance):
|
||||
"""Test registering a new model instance."""
|
||||
result = await registry.register_instance(model_instance)
|
||||
|
||||
assert result is True
|
||||
instances = await registry.get_all_instances(model_instance.model_name)
|
||||
assert len(instances) == 1
|
||||
saved_instance = instances[0]
|
||||
assert saved_instance.model_name == model_instance.model_name
|
||||
assert saved_instance.host == model_instance.host
|
||||
assert saved_instance.port == model_instance.port
|
||||
assert saved_instance.healthy is True
|
||||
assert saved_instance.last_heartbeat is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_instance_existing(
|
||||
registry, model_instance, model_instance_storage_item
|
||||
):
|
||||
"""Test registering an existing model instance and updating it."""
|
||||
await registry.register_instance(model_instance)
|
||||
|
||||
# Register the instance again with updated heartbeat
|
||||
result = await registry.register_instance(model_instance)
|
||||
|
||||
assert result is True
|
||||
instances = await registry.get_all_instances(model_instance.model_name)
|
||||
assert len(instances) == 1
|
||||
updated_instance = instances[0]
|
||||
assert updated_instance.model_name == model_instance.model_name
|
||||
assert updated_instance.host == model_instance.host
|
||||
assert updated_instance.port == model_instance.port
|
||||
assert updated_instance.healthy is True
|
||||
assert updated_instance.last_heartbeat is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deregister_instance(registry, model_instance):
|
||||
"""Test deregistering a model instance."""
|
||||
await registry.register_instance(model_instance)
|
||||
|
||||
result = await registry.deregister_instance(model_instance)
|
||||
|
||||
assert result is True
|
||||
instances = await registry.get_all_instances(model_instance.model_name)
|
||||
assert len(instances) == 1
|
||||
deregistered_instance = instances[0]
|
||||
assert deregistered_instance.healthy is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_instances(registry, model_instance):
|
||||
"""Test retrieving all model instances."""
|
||||
await registry.register_instance(model_instance)
|
||||
|
||||
result = await registry.get_all_instances(
|
||||
model_instance.model_name, healthy_only=True
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].model_name == model_instance.model_name
|
||||
|
||||
|
||||
def test_sync_get_all_instances(registry, model_instance):
|
||||
"""Test synchronously retrieving all model instances."""
|
||||
registry.sync_get_all_instances(model_instance.model_name, healthy_only=True)
|
||||
registry._storage.save(ModelInstanceStorageItem.from_model_instance(model_instance))
|
||||
|
||||
result = registry.sync_get_all_instances(
|
||||
model_instance.model_name, healthy_only=True
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].model_name == model_instance.model_name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_heartbeat_new_instance(registry, model_instance):
|
||||
"""Test sending a heartbeat for a new instance."""
|
||||
result = await registry.send_heartbeat(model_instance)
|
||||
|
||||
assert result is True
|
||||
instances = await registry.get_all_instances(model_instance.model_name)
|
||||
assert len(instances) == 1
|
||||
saved_instance = instances[0]
|
||||
assert saved_instance.model_name == model_instance.model_name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_heartbeat_existing_instance(registry, model_instance):
|
||||
"""Test sending a heartbeat for an existing instance."""
|
||||
await registry.register_instance(model_instance)
|
||||
|
||||
# Send heartbeat to update the instance
|
||||
result = await registry.send_heartbeat(model_instance)
|
||||
|
||||
assert result is True
|
||||
instances = await registry.get_all_instances(model_instance.model_name)
|
||||
assert len(instances) == 1
|
||||
updated_instance = instances[0]
|
||||
assert updated_instance.last_heartbeat > model_instance.last_heartbeat
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heartbeat_checker(
|
||||
in_memory_storage, thread_pool_executor, model_instance
|
||||
):
|
||||
"""Test the heartbeat checker mechanism."""
|
||||
heartbeat_timeout_secs = 1
|
||||
registry = StorageModelRegistry(
|
||||
storage=in_memory_storage,
|
||||
executor=thread_pool_executor,
|
||||
heartbeat_interval_secs=0.1,
|
||||
heartbeat_timeout_secs=heartbeat_timeout_secs,
|
||||
)
|
||||
|
||||
async def check_heartbeat(model_name: str, expected_healthy: bool):
|
||||
instances = await registry.get_all_instances(model_name)
|
||||
assert len(instances) == 1
|
||||
updated_instance = instances[0]
|
||||
assert updated_instance.healthy == expected_healthy
|
||||
|
||||
await registry.register_instance(model_instance)
|
||||
# First heartbeat should be successful
|
||||
await check_heartbeat(model_instance.model_name, True)
|
||||
# Wait heartbeat timeout
|
||||
await asyncio.sleep(heartbeat_timeout_secs + 0.5)
|
||||
await check_heartbeat(model_instance.model_name, False)
|
||||
|
||||
# Send heartbeat again
|
||||
await registry.send_heartbeat(model_instance)
|
||||
# Should be healthy again
|
||||
await check_heartbeat(model_instance.model_name, True)
|
@@ -1059,6 +1059,10 @@ def initialize_worker_manager_in_client(
|
||||
if not app:
|
||||
raise Exception("app can't be None")
|
||||
|
||||
if system_app:
|
||||
logger.info(f"Register WorkerManager {_DefaultWorkerManagerFactory.name}")
|
||||
system_app.register(_DefaultWorkerManagerFactory, worker_manager)
|
||||
|
||||
worker_params: ModelWorkerParameters = _parse_worker_params(
|
||||
model_name=model_name, model_path=model_path, controller_addr=controller_addr
|
||||
)
|
||||
@@ -1104,8 +1108,6 @@ def initialize_worker_manager_in_client(
|
||||
if include_router and app:
|
||||
# mount WorkerManager router
|
||||
app.include_router(router, prefix="/api")
|
||||
if system_app:
|
||||
system_app.register(_DefaultWorkerManagerFactory, worker_manager)
|
||||
|
||||
|
||||
def run_worker_manager(
|
||||
|
@@ -55,6 +55,84 @@ class ModelControllerParameters(BaseParameters):
|
||||
port: Optional[int] = field(
|
||||
default=8000, metadata={"help": "Model Controller deploy port"}
|
||||
)
|
||||
registry_type: Optional[str] = field(
|
||||
default="embedded",
|
||||
metadata={
|
||||
"help": "Registry type: embedded, database...",
|
||||
"valid_values": ["embedded", "database"],
|
||||
},
|
||||
)
|
||||
registry_db_type: Optional[str] = field(
|
||||
default="mysql",
|
||||
metadata={
|
||||
"help": "Registry database type, now only support sqlite and mysql, it is "
|
||||
"valid when registry_type is database",
|
||||
"valid_values": ["mysql", "sqlite"],
|
||||
},
|
||||
)
|
||||
registry_db_name: Optional[str] = field(
|
||||
default="dbgpt",
|
||||
metadata={
|
||||
"help": "Registry database name, just for database, it is valid when "
|
||||
"registry_type is database, please set to full database path for sqlite"
|
||||
},
|
||||
)
|
||||
registry_db_host: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Registry database host, just for database, it is valid when "
|
||||
"registry_type is database"
|
||||
},
|
||||
)
|
||||
registry_db_port: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Registry database port, just for database, it is valid when "
|
||||
"registry_type is database"
|
||||
},
|
||||
)
|
||||
registry_db_user: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Registry database user, just for database, it is valid when "
|
||||
"registry_type is database"
|
||||
},
|
||||
)
|
||||
registry_db_password: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Registry database password, just for database, it is valid when "
|
||||
"registry_type is database. We recommend to use environment variable to "
|
||||
"store password, you can set it in your environment variable like "
|
||||
"export CONTROLLER_REGISTRY_DB_PASSWORD='your_password'"
|
||||
},
|
||||
)
|
||||
registry_db_pool_size: Optional[int] = field(
|
||||
default=5,
|
||||
metadata={
|
||||
"help": "Registry database pool size, just for database, it is valid when "
|
||||
"registry_type is database"
|
||||
},
|
||||
)
|
||||
registry_db_max_overflow: Optional[int] = field(
|
||||
default=10,
|
||||
metadata={
|
||||
"help": "Registry database max overflow, just for database, it is valid "
|
||||
"when registry_type is database"
|
||||
},
|
||||
)
|
||||
|
||||
heartbeat_interval_secs: Optional[int] = field(
|
||||
default=20, metadata={"help": "The interval for checking heartbeats (seconds)"}
|
||||
)
|
||||
heartbeat_timeout_secs: Optional[int] = field(
|
||||
default=60,
|
||||
metadata={
|
||||
"help": "The timeout for checking heartbeats (seconds), it will be set "
|
||||
"unhealthy if the worker is not responding in this time"
|
||||
},
|
||||
)
|
||||
|
||||
daemon: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Run Model Controller in background"}
|
||||
)
|
||||
|
Reference in New Issue
Block a user