feat(model): Support database model registry (#1656)

This commit is contained in:
Fangyin Cheng
2024-06-24 19:07:10 +08:00
committed by GitHub
parent c57ee0289b
commit 47d205f676
35 changed files with 2014 additions and 792 deletions

View File

@@ -26,6 +26,7 @@ from dbgpt.util.parameter_utils import (
build_lazy_click_command,
)
# Your can set environment variable CONTROLLER_ADDRESS to set the default address
MODEL_CONTROLLER_ADDRESS = "http://127.0.0.1:8000"
logger = logging.getLogger("dbgpt_cli")

View File

@@ -1,22 +1,11 @@
import importlib.metadata as metadata
import pytest
import pytest_asyncio
from aioresponses import aioresponses
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from httpx import AsyncClient, HTTPError
from httpx import ASGITransport, AsyncClient, HTTPError
from dbgpt.component import SystemApp
from dbgpt.model.cluster.apiserver.api import (
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
DeltaMessage,
ModelList,
UsageInfo,
api_settings,
initialize_apiserver,
)
@@ -56,12 +45,13 @@ async def client(request, system_app: SystemApp):
if api_settings:
# Clear global api keys
api_settings.api_keys = []
async with AsyncClient(app=app, base_url="http://test", headers=headers) as client:
async with AsyncClient(
transport=ASGITransport(app), base_url="http://test", headers=headers
) as client:
async with _new_cluster(**param) as cluster:
worker_manager, model_registry = cluster
system_app.register(_DefaultWorkerManagerFactory, worker_manager)
system_app.register_instance(model_registry)
# print(f"Instances {model_registry.registry}")
initialize_apiserver(None, app, system_app, api_keys=api_keys)
yield client
@@ -113,7 +103,11 @@ async def test_chat_completions(client: AsyncClient, expected_messages):
"Hello world.",
"abc",
),
({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"),
(
{"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]},
"你好,我是张三。",
"abc",
),
],
indirect=["client"],
)
@@ -160,7 +154,11 @@ async def test_chat_completions_with_openai_lib_async_no_stream(
"Hello world.",
"abc",
),
({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"),
(
{"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]},
"你好,我是张三。",
"abc",
),
],
indirect=["client"],
)

View File

@@ -1,6 +1,6 @@
import logging
from abc import ABC, abstractmethod
from typing import List
from typing import List, Literal, Optional
from fastapi import APIRouter
@@ -8,6 +8,7 @@ from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.model.base import ModelInstance
from dbgpt.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
from dbgpt.model.parameter import ModelControllerParameters
from dbgpt.util.api_utils import APIMixin
from dbgpt.util.api_utils import _api_remote as api_remote
from dbgpt.util.api_utils import _sync_api_remote as sync_api_remote
from dbgpt.util.fastapi import create_app
@@ -46,9 +47,7 @@ class BaseModelController(BaseComponent, ABC):
class LocalModelController(BaseModelController):
def __init__(self, registry: ModelRegistry = None) -> None:
if not registry:
registry = EmbeddedModelRegistry()
def __init__(self, registry: ModelRegistry) -> None:
self.registry = registry
self.deployment = None
@@ -75,9 +74,25 @@ class LocalModelController(BaseModelController):
return await self.registry.send_heartbeat(instance)
class _RemoteModelController(BaseModelController):
def __init__(self, base_url: str) -> None:
self.base_url = base_url
class _RemoteModelController(APIMixin, BaseModelController):
def __init__(
self,
urls: str,
health_check_interval_secs: int = 5,
health_check_timeout_secs: int = 30,
check_health: bool = True,
choice_type: Literal["latest_first", "random"] = "latest_first",
) -> None:
APIMixin.__init__(
self,
urls=urls,
health_check_path="/api/health",
health_check_interval_secs=health_check_interval_secs,
health_check_timeout_secs=health_check_timeout_secs,
check_health=check_health,
choice_type=choice_type,
)
BaseModelController.__init__(self)
@api_remote(path="/api/controller/models", method="POST")
async def register_instance(self, instance: ModelInstance) -> bool:
@@ -139,13 +154,19 @@ controller = ModelControllerAdapter()
def initialize_controller(
app=None, remote_controller_addr: str = None, host: str = None, port: int = None
app=None,
remote_controller_addr: str = None,
host: str = None,
port: int = None,
registry: Optional[ModelRegistry] = None,
):
global controller
if remote_controller_addr:
controller.backend = _RemoteModelController(remote_controller_addr)
else:
controller.backend = LocalModelController()
if not registry:
registry = EmbeddedModelRegistry()
controller.backend = LocalModelController(registry=registry)
if app:
app.include_router(router, prefix="/api", tags=["Model"])
@@ -158,6 +179,12 @@ def initialize_controller(
uvicorn.run(app, host=host, port=port, log_level="info")
@router.get("/health")
async def api_health_check():
"""Health check API."""
return {"status": "ok"}
@router.post("/controller/models")
async def api_register_instance(request: ModelInstance):
return await controller.register_instance(request)
@@ -179,6 +206,87 @@ async def api_model_heartbeat(request: ModelInstance):
return await controller.send_heartbeat(request)
def _create_registry(controller_params: ModelControllerParameters) -> ModelRegistry:
"""Create a model registry based on the controller parameters.
Registry will store the metadata of all model instances, it will be a high
availability service for model instances if you use a database registry now. Also,
we can implement more registry types in the future.
"""
registry_type = controller_params.registry_type.strip()
if controller_params.registry_type == "embedded":
return EmbeddedModelRegistry(
heartbeat_interval_secs=controller_params.heartbeat_interval_secs,
heartbeat_timeout_secs=controller_params.heartbeat_timeout_secs,
)
elif controller_params.registry_type == "database":
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
from dbgpt.model.cluster.registry_impl.storage import StorageModelRegistry
try_to_create_db = False
if controller_params.registry_db_type == "mysql":
db_name = controller_params.registry_db_name
db_host = controller_params.registry_db_host
db_port = controller_params.registry_db_port
db_user = controller_params.registry_db_user
db_password = controller_params.registry_db_password
if not db_name:
raise ValueError(
"Registry DB name is required when using MySQL registry."
)
if not db_host:
raise ValueError(
"Registry DB host is required when using MySQL registry."
)
if not db_port:
raise ValueError(
"Registry DB port is required when using MySQL registry."
)
if not db_user:
raise ValueError(
"Registry DB user is required when using MySQL registry."
)
if not db_password:
raise ValueError(
"Registry DB password is required when using MySQL registry."
)
db_url = (
f"mysql+pymysql://{quote(db_user)}:"
f"{urlquote(db_password)}@"
f"{db_host}:"
f"{str(db_port)}/"
f"{db_name}?charset=utf8mb4"
)
elif controller_params.registry_db_type == "sqlite":
db_name = controller_params.registry_db_name
if not db_name:
raise ValueError(
"Registry DB name is required when using SQLite registry."
)
db_url = f"sqlite:///{db_name}"
try_to_create_db = True
else:
raise ValueError(
f"Unsupported registry DB type: {controller_params.registry_db_type}"
)
registry = StorageModelRegistry.from_url(
db_url,
db_name,
pool_size=controller_params.registry_db_pool_size,
max_overflow=controller_params.registry_db_max_overflow,
try_to_create_db=try_to_create_db,
heartbeat_interval_secs=controller_params.heartbeat_interval_secs,
heartbeat_timeout_secs=controller_params.heartbeat_timeout_secs,
)
return registry
else:
raise ValueError(f"Unsupported registry type: {registry_type}")
def run_model_controller():
parser = EnvArgumentParser()
env_prefix = "controller_"
@@ -192,8 +300,11 @@ def run_model_controller():
logging_level=controller_params.log_level,
logger_filename=controller_params.log_file,
)
registry = _create_registry(controller_params)
initialize_controller(host=controller_params.host, port=controller_params.port)
initialize_controller(
host=controller_params.host, port=controller_params.port, registry=registry
)
if __name__ == "__main__":

View File

@@ -0,0 +1,116 @@
from datetime import datetime
from sqlalchemy import (
Boolean,
Column,
DateTime,
Float,
Integer,
String,
UniqueConstraint,
)
from sqlalchemy.orm import Session
from dbgpt.core.interface.storage import ResourceIdentifier, StorageItemAdapter
from dbgpt.storage.metadata import Model
from .storage import ModelInstanceStorageItem
class ModelInstanceEntity(Model):
"""Model instance entity.
Use database as the registry, here is the table schema of the model instance.
"""
__tablename__ = "dbgpt_cluster_registry_instance"
__table_args__ = (
UniqueConstraint(
"model_name",
"host",
"port",
"sys_code",
name="uk_model_instance",
),
)
id = Column(Integer, primary_key=True, comment="Auto increment id")
model_name = Column(String(128), nullable=False, comment="Model name")
host = Column(String(128), nullable=False, comment="Host of the model")
port = Column(Integer, nullable=False, comment="Port of the model")
weight = Column(Float, nullable=True, default=1.0, comment="Weight of the model")
check_healthy = Column(
Boolean,
nullable=True,
default=True,
comment="Whether to check the health of the model",
)
healthy = Column(
Boolean, nullable=True, default=False, comment="Whether the model is healthy"
)
enabled = Column(
Boolean, nullable=True, default=True, comment="Whether the model is enabled"
)
prompt_template = Column(
String(128),
nullable=True,
comment="Prompt template for the model instance",
)
last_heartbeat = Column(
DateTime,
nullable=True,
comment="Last heartbeat time of the model instance",
)
user_name = Column(String(128), nullable=True, comment="User name")
sys_code = Column(String(128), nullable=True, comment="System code")
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
class ModelInstanceItemAdapter(
StorageItemAdapter[ModelInstanceStorageItem, ModelInstanceEntity]
):
def to_storage_format(self, item: ModelInstanceStorageItem) -> ModelInstanceEntity:
return ModelInstanceEntity(
model_name=item.model_name,
host=item.host,
port=item.port,
weight=item.weight,
check_healthy=item.check_healthy,
healthy=item.healthy,
enabled=item.enabled,
prompt_template=item.prompt_template,
last_heartbeat=item.last_heartbeat,
# user_name=item.user_name,
# sys_code=item.sys_code,
)
def from_storage_format(
self, model: ModelInstanceEntity
) -> ModelInstanceStorageItem:
return ModelInstanceStorageItem(
model_name=model.model_name,
host=model.host,
port=model.port,
weight=model.weight,
check_healthy=model.check_healthy,
healthy=model.healthy,
enabled=model.enabled,
prompt_template=model.prompt_template,
last_heartbeat=model.last_heartbeat,
)
def get_query_for_identifier(
self,
storage_format: ModelInstanceEntity,
resource_id: ResourceIdentifier,
**kwargs,
):
session: Session = kwargs.get("session")
if session is None:
raise Exception("session is None")
query_obj = session.query(ModelInstanceEntity)
for key, value in resource_id.to_dict().items():
if value is None:
continue
query_obj = query_obj.filter(getattr(ModelInstanceEntity, key) == value)
return query_obj

View File

@@ -0,0 +1,374 @@
import threading
import time
from concurrent.futures import Executor, ThreadPoolExecutor
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
from dbgpt.component import SystemApp
from dbgpt.core.interface.storage import (
QuerySpec,
ResourceIdentifier,
StorageInterface,
StorageItem,
)
from dbgpt.util.executor_utils import blocking_func_to_async
from ...base import ModelInstance
from ..registry import ModelRegistry
@dataclass
class ModelInstanceIdentifier(ResourceIdentifier):
identifier_split: str = field(default="___$$$$___", init=False)
model_name: str
host: str
port: int
def __post_init__(self):
"""Post init method."""
if self.model_name is None:
raise ValueError("model_name is required.")
if self.host is None:
raise ValueError("host is required.")
if self.port is None:
raise ValueError("port is required.")
if any(
self.identifier_split in key
for key in [self.model_name, self.host, str(self.port)]
if key is not None
):
raise ValueError(
f"identifier_split {self.identifier_split} is not allowed in "
f"model_name, host, port."
)
@property
def str_identifier(self) -> str:
"""Return the string identifier of the identifier."""
return self.identifier_split.join(
key
for key in [
self.model_name,
self.host,
str(self.port),
]
if key is not None
)
def to_dict(self) -> Dict:
"""Convert the identifier to a dict.
Returns:
Dict: The dict of the identifier.
"""
return {
"model_name": self.model_name,
"host": self.host,
"port": self.port,
}
@dataclass
class ModelInstanceStorageItem(StorageItem):
model_name: str
host: str
port: int
weight: Optional[float] = 1.0
check_healthy: Optional[bool] = True
healthy: Optional[bool] = False
enabled: Optional[bool] = True
prompt_template: Optional[str] = None
last_heartbeat: Optional[datetime] = None
_identifier: ModelInstanceIdentifier = field(init=False)
def __post_init__(self):
"""Post init method."""
# Convert last_heartbeat to datetime if it's a timestamp
if isinstance(self.last_heartbeat, (int, float)):
self.last_heartbeat = datetime.fromtimestamp(self.last_heartbeat)
self._identifier = ModelInstanceIdentifier(
model_name=self.model_name,
host=self.host,
port=self.port,
)
@property
def identifier(self) -> ModelInstanceIdentifier:
return self._identifier
def merge(self, other: "StorageItem") -> None:
if not isinstance(other, ModelInstanceStorageItem):
raise ValueError(f"Cannot merge with {type(other)}")
self.from_object(other)
def to_dict(self) -> Dict:
last_heartbeat = self.last_heartbeat.timestamp()
return {
"model_name": self.model_name,
"host": self.host,
"port": self.port,
"weight": self.weight,
"check_healthy": self.check_healthy,
"healthy": self.healthy,
"enabled": self.enabled,
"prompt_template": self.prompt_template,
"last_heartbeat": last_heartbeat,
}
def from_object(self, item: "ModelInstanceStorageItem") -> None:
"""Build the item from another item."""
self.model_name = item.model_name
self.host = item.host
self.port = item.port
self.weight = item.weight
self.check_healthy = item.check_healthy
self.healthy = item.healthy
self.enabled = item.enabled
self.prompt_template = item.prompt_template
self.last_heartbeat = item.last_heartbeat
@classmethod
def from_model_instance(cls, instance: ModelInstance) -> "ModelInstanceStorageItem":
return cls(
model_name=instance.model_name,
host=instance.host,
port=instance.port,
weight=instance.weight,
check_healthy=instance.check_healthy,
healthy=instance.healthy,
enabled=instance.enabled,
prompt_template=instance.prompt_template,
last_heartbeat=instance.last_heartbeat,
)
@classmethod
def to_model_instance(cls, item: "ModelInstanceStorageItem") -> ModelInstance:
return ModelInstance(
model_name=item.model_name,
host=item.host,
port=item.port,
weight=item.weight,
check_healthy=item.check_healthy,
healthy=item.healthy,
enabled=item.enabled,
prompt_template=item.prompt_template,
last_heartbeat=item.last_heartbeat,
)
class StorageModelRegistry(ModelRegistry):
def __init__(
self,
storage: StorageInterface,
system_app: SystemApp | None = None,
executor: Optional[Executor] = None,
heartbeat_interval_secs: float | int = 60,
heartbeat_timeout_secs: int = 120,
):
super().__init__(system_app)
self._storage = storage
self._executor = executor or ThreadPoolExecutor(max_workers=2)
self.heartbeat_interval_secs = heartbeat_interval_secs
self.heartbeat_timeout_secs = heartbeat_timeout_secs
self.heartbeat_thread = threading.Thread(target=self._heartbeat_checker)
self.heartbeat_thread.daemon = True
self.heartbeat_thread.start()
@classmethod
def from_url(
cls,
db_url: str,
db_name: str,
pool_size: int = 5,
max_overflow: int = 10,
try_to_create_db: bool = False,
**kwargs,
) -> "StorageModelRegistry":
from dbgpt.storage.metadata.db_manager import DatabaseManager, initialize_db
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
from dbgpt.util.serialization.json_serialization import JsonSerializer
from .db_storage import ModelInstanceEntity, ModelInstanceItemAdapter
engine_args = {
"pool_size": pool_size,
"max_overflow": max_overflow,
"pool_timeout": 30,
"pool_recycle": 3600,
"pool_pre_ping": True,
}
db: DatabaseManager = initialize_db(
db_url, db_name, engine_args, try_to_create_db=try_to_create_db
)
storage_adapter = ModelInstanceItemAdapter()
serializer = JsonSerializer()
storage = SQLAlchemyStorage(
db,
ModelInstanceEntity,
storage_adapter,
serializer,
)
return cls(storage, **kwargs)
async def _get_instances_by_model(
self, model_name: str, host: str, port: int, healthy_only: bool = False
) -> Tuple[List[ModelInstanceStorageItem], List[ModelInstanceStorageItem]]:
query_spec = QuerySpec(conditions={"model_name": model_name})
# Query all instances of the model
instances = await blocking_func_to_async(
self._executor, self._storage.query, query_spec, ModelInstanceStorageItem
)
if healthy_only:
instances = [ins for ins in instances if ins.healthy is True]
exist_ins = [ins for ins in instances if ins.host == host and ins.port == port]
return instances, exist_ins
def _heartbeat_checker(self):
while True:
all_instances: List[ModelInstanceStorageItem] = self._storage.query(
QuerySpec(conditions={}), ModelInstanceStorageItem
)
for instance in all_instances:
if (
instance.check_healthy
and datetime.now() - instance.last_heartbeat
> timedelta(seconds=self.heartbeat_timeout_secs)
):
instance.healthy = False
self._storage.update(instance)
time.sleep(self.heartbeat_interval_secs)
async def register_instance(self, instance: ModelInstance) -> bool:
model_name = instance.model_name.strip()
host = instance.host.strip()
port = instance.port
_, exist_ins = await self._get_instances_by_model(
model_name, host, port, healthy_only=False
)
if exist_ins:
# Exist instances, just update the instance
# One exist instance at most
ins: ModelInstanceStorageItem = exist_ins[0]
# Update instance
ins.weight = instance.weight
ins.healthy = True
ins.prompt_template = instance.prompt_template
ins.last_heartbeat = datetime.now()
await blocking_func_to_async(self._executor, self._storage.update, ins)
else:
# No exist instance, save the new instance
new_inst = ModelInstanceStorageItem.from_model_instance(instance)
new_inst.healthy = True
new_inst.last_heartbeat = datetime.now()
await blocking_func_to_async(self._executor, self._storage.save, new_inst)
return True
async def deregister_instance(self, instance: ModelInstance) -> bool:
"""Deregister a model instance.
If the instance exists, set the instance as unhealthy, nothing to do if the
instance does not exist.
Args:
instance (ModelInstance): The instance to deregister.
"""
model_name = instance.model_name.strip()
host = instance.host.strip()
port = instance.port
_, exist_ins = await self._get_instances_by_model(
model_name, host, port, healthy_only=False
)
if exist_ins:
ins = exist_ins[0]
ins.healthy = False
await blocking_func_to_async(self._executor, self._storage.update, ins)
return True
async def get_all_instances(
self, model_name: str, healthy_only: bool = False
) -> List[ModelInstance]:
"""Get all instances of a model(Async).
Args:
model_name (str): The model name.
healthy_only (bool): Whether only get healthy instances. Defaults to False.
"""
return await blocking_func_to_async(
self._executor, self.sync_get_all_instances, model_name, healthy_only
)
def sync_get_all_instances(
self, model_name: str, healthy_only: bool = False
) -> List[ModelInstance]:
"""Get all instances of a model.
Args:
model_name (str): The model name.
healthy_only (bool): Whether only get healthy instances. Defaults to False.
Returns:
List[ModelInstance]: The list of instances.
"""
instances = self._storage.query(
QuerySpec(conditions={"model_name": model_name}), ModelInstanceStorageItem
)
if healthy_only:
instances = [ins for ins in instances if ins.healthy is True]
return [ModelInstanceStorageItem.to_model_instance(ins) for ins in instances]
async def get_all_model_instances(
self, healthy_only: bool = False
) -> List[ModelInstance]:
"""Get all model instances.
Args:
healthy_only (bool): Whether only get healthy instances. Defaults to False.
Returns:
List[ModelInstance]: The list of instances.
"""
all_instances = await blocking_func_to_async(
self._executor,
self._storage.query,
QuerySpec(conditions={}),
ModelInstanceStorageItem,
)
if healthy_only:
all_instances = [ins for ins in all_instances if ins.healthy is True]
return [
ModelInstanceStorageItem.to_model_instance(ins) for ins in all_instances
]
async def send_heartbeat(self, instance: ModelInstance) -> bool:
"""Receive heartbeat from model instance.
Update the last heartbeat time of the instance. If the instance does not exist,
register the instance.
Args:
instance (ModelInstance): The instance to send heartbeat.
Returns:
bool: True if the heartbeat is received successfully.
"""
model_name = instance.model_name.strip()
host = instance.host.strip()
port = instance.port
_, exist_ins = await self._get_instances_by_model(
model_name, host, port, healthy_only=False
)
if not exist_ins:
# register new instance from heartbeat
await self.register_instance(instance)
return True
else:
ins = exist_ins[0]
ins.last_heartbeat = datetime.now()
ins.healthy = True
await blocking_func_to_async(self._executor, self._storage.update, ins)
return True

View File

@@ -0,0 +1,221 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
import pytest
from dbgpt.core.interface.storage import InMemoryStorage, QuerySpec
from dbgpt.util.serialization.json_serialization import JsonSerializer
from ...registry_impl.storage import (
ModelInstance,
ModelInstanceStorageItem,
StorageModelRegistry,
)
@pytest.fixture
def in_memory_storage():
return InMemoryStorage(serializer=JsonSerializer())
@pytest.fixture
def thread_pool_executor():
return ThreadPoolExecutor(max_workers=2)
@pytest.fixture
def registry(in_memory_storage, thread_pool_executor):
return StorageModelRegistry(
storage=in_memory_storage,
executor=thread_pool_executor,
heartbeat_interval_secs=1,
heartbeat_timeout_secs=2,
)
@pytest.fixture
def model_instance():
return ModelInstance(
model_name="test_model",
host="localhost",
port=8080,
weight=1.0,
check_healthy=True,
healthy=True,
enabled=True,
prompt_template=None,
last_heartbeat=datetime.now(),
)
@pytest.fixture
def model_instance_2():
return ModelInstance(
model_name="test_model",
host="localhost",
port=8081,
weight=1.0,
check_healthy=True,
healthy=True,
enabled=True,
prompt_template=None,
last_heartbeat=datetime.now(),
)
@pytest.fixture
def model_instance_3():
return ModelInstance(
model_name="test_model_2",
host="localhost",
port=8082,
weight=1.0,
check_healthy=True,
healthy=True,
enabled=True,
prompt_template=None,
last_heartbeat=datetime.now(),
)
@pytest.fixture
def model_instance_storage_item(model_instance):
return ModelInstanceStorageItem.from_model_instance(model_instance)
@pytest.mark.asyncio
async def test_register_instance_new(registry, model_instance):
"""Test registering a new model instance."""
result = await registry.register_instance(model_instance)
assert result is True
instances = await registry.get_all_instances(model_instance.model_name)
assert len(instances) == 1
saved_instance = instances[0]
assert saved_instance.model_name == model_instance.model_name
assert saved_instance.host == model_instance.host
assert saved_instance.port == model_instance.port
assert saved_instance.healthy is True
assert saved_instance.last_heartbeat is not None
@pytest.mark.asyncio
async def test_register_instance_existing(
registry, model_instance, model_instance_storage_item
):
"""Test registering an existing model instance and updating it."""
await registry.register_instance(model_instance)
# Register the instance again with updated heartbeat
result = await registry.register_instance(model_instance)
assert result is True
instances = await registry.get_all_instances(model_instance.model_name)
assert len(instances) == 1
updated_instance = instances[0]
assert updated_instance.model_name == model_instance.model_name
assert updated_instance.host == model_instance.host
assert updated_instance.port == model_instance.port
assert updated_instance.healthy is True
assert updated_instance.last_heartbeat is not None
@pytest.mark.asyncio
async def test_deregister_instance(registry, model_instance):
"""Test deregistering a model instance."""
await registry.register_instance(model_instance)
result = await registry.deregister_instance(model_instance)
assert result is True
instances = await registry.get_all_instances(model_instance.model_name)
assert len(instances) == 1
deregistered_instance = instances[0]
assert deregistered_instance.healthy is False
@pytest.mark.asyncio
async def test_get_all_instances(registry, model_instance):
"""Test retrieving all model instances."""
await registry.register_instance(model_instance)
result = await registry.get_all_instances(
model_instance.model_name, healthy_only=True
)
assert len(result) == 1
assert result[0].model_name == model_instance.model_name
def test_sync_get_all_instances(registry, model_instance):
"""Test synchronously retrieving all model instances."""
registry.sync_get_all_instances(model_instance.model_name, healthy_only=True)
registry._storage.save(ModelInstanceStorageItem.from_model_instance(model_instance))
result = registry.sync_get_all_instances(
model_instance.model_name, healthy_only=True
)
assert len(result) == 1
assert result[0].model_name == model_instance.model_name
@pytest.mark.asyncio
async def test_send_heartbeat_new_instance(registry, model_instance):
"""Test sending a heartbeat for a new instance."""
result = await registry.send_heartbeat(model_instance)
assert result is True
instances = await registry.get_all_instances(model_instance.model_name)
assert len(instances) == 1
saved_instance = instances[0]
assert saved_instance.model_name == model_instance.model_name
@pytest.mark.asyncio
async def test_send_heartbeat_existing_instance(registry, model_instance):
"""Test sending a heartbeat for an existing instance."""
await registry.register_instance(model_instance)
# Send heartbeat to update the instance
result = await registry.send_heartbeat(model_instance)
assert result is True
instances = await registry.get_all_instances(model_instance.model_name)
assert len(instances) == 1
updated_instance = instances[0]
assert updated_instance.last_heartbeat > model_instance.last_heartbeat
@pytest.mark.asyncio
async def test_heartbeat_checker(
in_memory_storage, thread_pool_executor, model_instance
):
"""Test the heartbeat checker mechanism."""
heartbeat_timeout_secs = 1
registry = StorageModelRegistry(
storage=in_memory_storage,
executor=thread_pool_executor,
heartbeat_interval_secs=0.1,
heartbeat_timeout_secs=heartbeat_timeout_secs,
)
async def check_heartbeat(model_name: str, expected_healthy: bool):
instances = await registry.get_all_instances(model_name)
assert len(instances) == 1
updated_instance = instances[0]
assert updated_instance.healthy == expected_healthy
await registry.register_instance(model_instance)
# First heartbeat should be successful
await check_heartbeat(model_instance.model_name, True)
# Wait heartbeat timeout
await asyncio.sleep(heartbeat_timeout_secs + 0.5)
await check_heartbeat(model_instance.model_name, False)
# Send heartbeat again
await registry.send_heartbeat(model_instance)
# Should be healthy again
await check_heartbeat(model_instance.model_name, True)

View File

@@ -1059,6 +1059,10 @@ def initialize_worker_manager_in_client(
if not app:
raise Exception("app can't be None")
if system_app:
logger.info(f"Register WorkerManager {_DefaultWorkerManagerFactory.name}")
system_app.register(_DefaultWorkerManagerFactory, worker_manager)
worker_params: ModelWorkerParameters = _parse_worker_params(
model_name=model_name, model_path=model_path, controller_addr=controller_addr
)
@@ -1104,8 +1108,6 @@ def initialize_worker_manager_in_client(
if include_router and app:
# mount WorkerManager router
app.include_router(router, prefix="/api")
if system_app:
system_app.register(_DefaultWorkerManagerFactory, worker_manager)
def run_worker_manager(

View File

@@ -55,6 +55,84 @@ class ModelControllerParameters(BaseParameters):
port: Optional[int] = field(
default=8000, metadata={"help": "Model Controller deploy port"}
)
registry_type: Optional[str] = field(
default="embedded",
metadata={
"help": "Registry type: embedded, database...",
"valid_values": ["embedded", "database"],
},
)
registry_db_type: Optional[str] = field(
default="mysql",
metadata={
"help": "Registry database type, now only support sqlite and mysql, it is "
"valid when registry_type is database",
"valid_values": ["mysql", "sqlite"],
},
)
registry_db_name: Optional[str] = field(
default="dbgpt",
metadata={
"help": "Registry database name, just for database, it is valid when "
"registry_type is database, please set to full database path for sqlite"
},
)
registry_db_host: Optional[str] = field(
default=None,
metadata={
"help": "Registry database host, just for database, it is valid when "
"registry_type is database"
},
)
registry_db_port: Optional[int] = field(
default=None,
metadata={
"help": "Registry database port, just for database, it is valid when "
"registry_type is database"
},
)
registry_db_user: Optional[str] = field(
default=None,
metadata={
"help": "Registry database user, just for database, it is valid when "
"registry_type is database"
},
)
registry_db_password: Optional[str] = field(
default=None,
metadata={
"help": "Registry database password, just for database, it is valid when "
"registry_type is database. We recommend to use environment variable to "
"store password, you can set it in your environment variable like "
"export CONTROLLER_REGISTRY_DB_PASSWORD='your_password'"
},
)
registry_db_pool_size: Optional[int] = field(
default=5,
metadata={
"help": "Registry database pool size, just for database, it is valid when "
"registry_type is database"
},
)
registry_db_max_overflow: Optional[int] = field(
default=10,
metadata={
"help": "Registry database max overflow, just for database, it is valid "
"when registry_type is database"
},
)
heartbeat_interval_secs: Optional[int] = field(
default=20, metadata={"help": "The interval for checking heartbeats (seconds)"}
)
heartbeat_timeout_secs: Optional[int] = field(
default=60,
metadata={
"help": "The timeout for checking heartbeats (seconds), it will be set "
"unhealthy if the worker is not responding in this time"
},
)
daemon: Optional[bool] = field(
default=False, metadata={"help": "Run Model Controller in background"}
)