DB-GPT/dbgpt/model/cluster/registry_impl/storage.py
明天 b124ecc10b
feat: (0.6)New UI (#1855)
Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com>
Co-authored-by: csunny <cfqsunny@163.com>
2024-08-21 17:37:45 +08:00

374 lines
13 KiB
Python

import threading
import time
from concurrent.futures import Executor, ThreadPoolExecutor
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
from dbgpt.component import SystemApp
from dbgpt.core.interface.storage import (
QuerySpec,
ResourceIdentifier,
StorageInterface,
StorageItem,
)
from dbgpt.util.executor_utils import blocking_func_to_async
from ...base import ModelInstance
from ..registry import ModelRegistry
@dataclass
class ModelInstanceIdentifier(ResourceIdentifier):
identifier_split: str = field(default="___$$$$___", init=False)
model_name: str
host: str
port: int
def __post_init__(self):
"""Post init method."""
if self.model_name is None:
raise ValueError("model_name is required.")
if self.host is None:
raise ValueError("host is required.")
if self.port is None:
raise ValueError("port is required.")
if any(
self.identifier_split in key
for key in [self.model_name, self.host, str(self.port)]
if key is not None
):
raise ValueError(
f"identifier_split {self.identifier_split} is not allowed in "
f"model_name, host, port."
)
@property
def str_identifier(self) -> str:
"""Return the string identifier of the identifier."""
return self.identifier_split.join(
key
for key in [
self.model_name,
self.host,
str(self.port),
]
if key is not None
)
def to_dict(self) -> Dict:
"""Convert the identifier to a dict.
Returns:
Dict: The dict of the identifier.
"""
return {
"model_name": self.model_name,
"host": self.host,
"port": self.port,
}
@dataclass
class ModelInstanceStorageItem(StorageItem):
model_name: str
host: str
port: int
weight: Optional[float] = 1.0
check_healthy: Optional[bool] = True
healthy: Optional[bool] = False
enabled: Optional[bool] = True
prompt_template: Optional[str] = None
last_heartbeat: Optional[datetime] = None
_identifier: ModelInstanceIdentifier = field(init=False)
def __post_init__(self):
"""Post init method."""
# Convert last_heartbeat to datetime if it's a timestamp
if isinstance(self.last_heartbeat, (int, float)):
self.last_heartbeat = datetime.fromtimestamp(self.last_heartbeat)
self._identifier = ModelInstanceIdentifier(
model_name=self.model_name,
host=self.host,
port=self.port,
)
@property
def identifier(self) -> ModelInstanceIdentifier:
return self._identifier
def merge(self, other: "StorageItem") -> None:
if not isinstance(other, ModelInstanceStorageItem):
raise ValueError(f"Cannot merge with {type(other)}")
self.from_object(other)
def to_dict(self) -> Dict:
last_heartbeat = self.last_heartbeat.timestamp()
return {
"model_name": self.model_name,
"host": self.host,
"port": self.port,
"weight": self.weight,
"check_healthy": self.check_healthy,
"healthy": self.healthy,
"enabled": self.enabled,
"prompt_template": self.prompt_template,
"last_heartbeat": last_heartbeat,
}
def from_object(self, item: "ModelInstanceStorageItem") -> None:
"""Build the item from another item."""
self.model_name = item.model_name
self.host = item.host
self.port = item.port
self.weight = item.weight
self.check_healthy = item.check_healthy
self.healthy = item.healthy
self.enabled = item.enabled
self.prompt_template = item.prompt_template
self.last_heartbeat = item.last_heartbeat
@classmethod
def from_model_instance(cls, instance: ModelInstance) -> "ModelInstanceStorageItem":
return cls(
model_name=instance.model_name,
host=instance.host,
port=instance.port,
weight=instance.weight,
check_healthy=instance.check_healthy,
healthy=instance.healthy,
enabled=instance.enabled,
prompt_template=instance.prompt_template,
last_heartbeat=instance.last_heartbeat,
)
@classmethod
def to_model_instance(cls, item: "ModelInstanceStorageItem") -> ModelInstance:
return ModelInstance(
model_name=item.model_name,
host=item.host,
port=item.port,
weight=item.weight,
check_healthy=item.check_healthy,
healthy=item.healthy,
enabled=item.enabled,
prompt_template=item.prompt_template,
last_heartbeat=item.last_heartbeat,
)
class StorageModelRegistry(ModelRegistry):
def __init__(
self,
storage: StorageInterface,
system_app: SystemApp | None = None,
executor: Optional[Executor] = None,
heartbeat_interval_secs: float | int = 60,
heartbeat_timeout_secs: int = 120,
):
super().__init__(system_app)
self._storage = storage
self._executor = executor or ThreadPoolExecutor(max_workers=2)
self.heartbeat_interval_secs = heartbeat_interval_secs
self.heartbeat_timeout_secs = heartbeat_timeout_secs
self.heartbeat_thread = threading.Thread(target=self._heartbeat_checker)
self.heartbeat_thread.daemon = True
self.heartbeat_thread.start()
@classmethod
def from_url(
cls,
db_url: str,
db_name: str,
pool_size: int = 5,
max_overflow: int = 10,
try_to_create_db: bool = False,
**kwargs,
) -> "StorageModelRegistry":
from dbgpt.storage.metadata.db_manager import DatabaseManager, initialize_db
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
from dbgpt.util.serialization.json_serialization import JsonSerializer
from .db_storage import ModelInstanceEntity, ModelInstanceItemAdapter
engine_args = {
"pool_size": pool_size,
"max_overflow": max_overflow,
"pool_timeout": 30,
"pool_recycle": 3600,
"pool_pre_ping": True,
}
db: DatabaseManager = initialize_db(
db_url, db_name, engine_args, try_to_create_db=try_to_create_db
)
storage_adapter = ModelInstanceItemAdapter()
serializer = JsonSerializer()
storage = SQLAlchemyStorage(
db,
ModelInstanceEntity,
storage_adapter,
serializer,
)
return cls(storage, **kwargs)
async def _get_instances_by_model(
self, model_name: str, host: str, port: int, healthy_only: bool = False
) -> Tuple[List[ModelInstanceStorageItem], List[ModelInstanceStorageItem]]:
query_spec = QuerySpec(conditions={"model_name": model_name})
# Query all instances of the model
instances = await blocking_func_to_async(
self._executor, self._storage.query, query_spec, ModelInstanceStorageItem
)
if healthy_only:
instances = [ins for ins in instances if ins.healthy is True]
exist_ins = [ins for ins in instances if ins.host == host and ins.port == port]
return instances, exist_ins
def _heartbeat_checker(self):
while True:
all_instances: List[ModelInstanceStorageItem] = self._storage.query(
QuerySpec(conditions={}), ModelInstanceStorageItem
)
for instance in all_instances:
if (
instance.check_healthy
and datetime.now() - instance.last_heartbeat
> timedelta(seconds=self.heartbeat_timeout_secs)
):
instance.healthy = False
self._storage.update(instance)
time.sleep(self.heartbeat_interval_secs)
async def register_instance(self, instance: ModelInstance) -> bool:
model_name = instance.model_name.strip()
host = instance.host.strip()
port = instance.port
_, exist_ins = await self._get_instances_by_model(
model_name, host, port, healthy_only=False
)
if exist_ins:
# Exist instances, just update the instance
# One exist instance at most
ins: ModelInstanceStorageItem = exist_ins[0]
# Update instance
ins.weight = instance.weight
ins.healthy = True
ins.prompt_template = instance.prompt_template
ins.last_heartbeat = datetime.now()
await blocking_func_to_async(self._executor, self._storage.update, ins)
else:
# No exist instance, save the new instance
new_inst = ModelInstanceStorageItem.from_model_instance(instance)
new_inst.healthy = True
new_inst.last_heartbeat = datetime.now()
await blocking_func_to_async(self._executor, self._storage.save, new_inst)
return True
async def deregister_instance(self, instance: ModelInstance) -> bool:
"""Deregister a model instance.
If the instance exists, set the instance as unhealthy, nothing to do if the
instance does not exist.
Args:
instance (ModelInstance): The instance to deregister.
"""
model_name = instance.model_name.strip()
host = instance.host.strip()
port = instance.port
_, exist_ins = await self._get_instances_by_model(
model_name, host, port, healthy_only=False
)
if exist_ins:
ins = exist_ins[0]
ins.healthy = False
await blocking_func_to_async(self._executor, self._storage.update, ins)
return True
async def get_all_instances(
self, model_name: str, healthy_only: bool = False
) -> List[ModelInstance]:
"""Get all instances of a model(Async).
Args:
model_name (str): The model name.
healthy_only (bool): Whether only get healthy instances. Defaults to False.
"""
return await blocking_func_to_async(
self._executor, self.sync_get_all_instances, model_name, healthy_only
)
def sync_get_all_instances(
self, model_name: str, healthy_only: bool = False
) -> List[ModelInstance]:
"""Get all instances of a model.
Args:
model_name (str): The model name.
healthy_only (bool): Whether only get healthy instances. Defaults to False.
Returns:
List[ModelInstance]: The list of instances.
"""
instances = self._storage.query(
QuerySpec(conditions={"model_name": model_name}), ModelInstanceStorageItem
)
if healthy_only:
instances = [ins for ins in instances if ins.healthy is True]
return [ModelInstanceStorageItem.to_model_instance(ins) for ins in instances]
async def get_all_model_instances(
self, healthy_only: bool = False
) -> List[ModelInstance]:
"""Get all model instances.
Args:
healthy_only (bool): Whether only get healthy instances. Defaults to False.
Returns:
List[ModelInstance]: The list of instances.
"""
all_instances = await blocking_func_to_async(
self._executor,
self._storage.query,
QuerySpec(conditions={}),
ModelInstanceStorageItem,
)
if healthy_only:
all_instances = [ins for ins in all_instances if ins.healthy is True]
return [
ModelInstanceStorageItem.to_model_instance(ins) for ins in all_instances
]
async def send_heartbeat(self, instance: ModelInstance) -> bool:
"""Receive heartbeat from model instance.
Update the last heartbeat time of the instance. If the instance does not exist,
register the instance.
Args:
instance (ModelInstance): The instance to send heartbeat.
Returns:
bool: True if the heartbeat is received successfully.
"""
model_name = instance.model_name.strip()
host = instance.host.strip()
port = instance.port
_, exist_ins = await self._get_instances_by_model(
model_name, host, port, healthy_only=False
)
if not exist_ins:
# register new instance from heartbeat
await self.register_instance(instance)
return True
else:
ins = exist_ins[0]
ins.last_heartbeat = datetime.now()
ins.healthy = True
await blocking_func_to_async(self._executor, self._storage.update, ins)
return True