mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-02 08:40:36 +00:00
Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com> Co-authored-by: csunny <cfqsunny@163.com>
374 lines
13 KiB
Python
374 lines
13 KiB
Python
import threading
|
|
import time
|
|
from concurrent.futures import Executor, ThreadPoolExecutor
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
from dbgpt.component import SystemApp
|
|
from dbgpt.core.interface.storage import (
|
|
QuerySpec,
|
|
ResourceIdentifier,
|
|
StorageInterface,
|
|
StorageItem,
|
|
)
|
|
from dbgpt.util.executor_utils import blocking_func_to_async
|
|
|
|
from ...base import ModelInstance
|
|
from ..registry import ModelRegistry
|
|
|
|
|
|
@dataclass
|
|
class ModelInstanceIdentifier(ResourceIdentifier):
|
|
identifier_split: str = field(default="___$$$$___", init=False)
|
|
model_name: str
|
|
host: str
|
|
port: int
|
|
|
|
def __post_init__(self):
|
|
"""Post init method."""
|
|
if self.model_name is None:
|
|
raise ValueError("model_name is required.")
|
|
if self.host is None:
|
|
raise ValueError("host is required.")
|
|
if self.port is None:
|
|
raise ValueError("port is required.")
|
|
|
|
if any(
|
|
self.identifier_split in key
|
|
for key in [self.model_name, self.host, str(self.port)]
|
|
if key is not None
|
|
):
|
|
raise ValueError(
|
|
f"identifier_split {self.identifier_split} is not allowed in "
|
|
f"model_name, host, port."
|
|
)
|
|
|
|
@property
|
|
def str_identifier(self) -> str:
|
|
"""Return the string identifier of the identifier."""
|
|
return self.identifier_split.join(
|
|
key
|
|
for key in [
|
|
self.model_name,
|
|
self.host,
|
|
str(self.port),
|
|
]
|
|
if key is not None
|
|
)
|
|
|
|
def to_dict(self) -> Dict:
|
|
"""Convert the identifier to a dict.
|
|
|
|
Returns:
|
|
Dict: The dict of the identifier.
|
|
"""
|
|
return {
|
|
"model_name": self.model_name,
|
|
"host": self.host,
|
|
"port": self.port,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class ModelInstanceStorageItem(StorageItem):
|
|
model_name: str
|
|
host: str
|
|
port: int
|
|
weight: Optional[float] = 1.0
|
|
check_healthy: Optional[bool] = True
|
|
healthy: Optional[bool] = False
|
|
enabled: Optional[bool] = True
|
|
prompt_template: Optional[str] = None
|
|
last_heartbeat: Optional[datetime] = None
|
|
_identifier: ModelInstanceIdentifier = field(init=False)
|
|
|
|
def __post_init__(self):
|
|
"""Post init method."""
|
|
# Convert last_heartbeat to datetime if it's a timestamp
|
|
if isinstance(self.last_heartbeat, (int, float)):
|
|
self.last_heartbeat = datetime.fromtimestamp(self.last_heartbeat)
|
|
|
|
self._identifier = ModelInstanceIdentifier(
|
|
model_name=self.model_name,
|
|
host=self.host,
|
|
port=self.port,
|
|
)
|
|
|
|
@property
|
|
def identifier(self) -> ModelInstanceIdentifier:
|
|
return self._identifier
|
|
|
|
def merge(self, other: "StorageItem") -> None:
|
|
if not isinstance(other, ModelInstanceStorageItem):
|
|
raise ValueError(f"Cannot merge with {type(other)}")
|
|
self.from_object(other)
|
|
|
|
def to_dict(self) -> Dict:
|
|
last_heartbeat = self.last_heartbeat.timestamp()
|
|
return {
|
|
"model_name": self.model_name,
|
|
"host": self.host,
|
|
"port": self.port,
|
|
"weight": self.weight,
|
|
"check_healthy": self.check_healthy,
|
|
"healthy": self.healthy,
|
|
"enabled": self.enabled,
|
|
"prompt_template": self.prompt_template,
|
|
"last_heartbeat": last_heartbeat,
|
|
}
|
|
|
|
def from_object(self, item: "ModelInstanceStorageItem") -> None:
|
|
"""Build the item from another item."""
|
|
self.model_name = item.model_name
|
|
self.host = item.host
|
|
self.port = item.port
|
|
self.weight = item.weight
|
|
self.check_healthy = item.check_healthy
|
|
self.healthy = item.healthy
|
|
self.enabled = item.enabled
|
|
self.prompt_template = item.prompt_template
|
|
self.last_heartbeat = item.last_heartbeat
|
|
|
|
@classmethod
|
|
def from_model_instance(cls, instance: ModelInstance) -> "ModelInstanceStorageItem":
|
|
return cls(
|
|
model_name=instance.model_name,
|
|
host=instance.host,
|
|
port=instance.port,
|
|
weight=instance.weight,
|
|
check_healthy=instance.check_healthy,
|
|
healthy=instance.healthy,
|
|
enabled=instance.enabled,
|
|
prompt_template=instance.prompt_template,
|
|
last_heartbeat=instance.last_heartbeat,
|
|
)
|
|
|
|
@classmethod
|
|
def to_model_instance(cls, item: "ModelInstanceStorageItem") -> ModelInstance:
|
|
return ModelInstance(
|
|
model_name=item.model_name,
|
|
host=item.host,
|
|
port=item.port,
|
|
weight=item.weight,
|
|
check_healthy=item.check_healthy,
|
|
healthy=item.healthy,
|
|
enabled=item.enabled,
|
|
prompt_template=item.prompt_template,
|
|
last_heartbeat=item.last_heartbeat,
|
|
)
|
|
|
|
|
|
class StorageModelRegistry(ModelRegistry):
|
|
def __init__(
|
|
self,
|
|
storage: StorageInterface,
|
|
system_app: SystemApp | None = None,
|
|
executor: Optional[Executor] = None,
|
|
heartbeat_interval_secs: float | int = 60,
|
|
heartbeat_timeout_secs: int = 120,
|
|
):
|
|
super().__init__(system_app)
|
|
self._storage = storage
|
|
self._executor = executor or ThreadPoolExecutor(max_workers=2)
|
|
self.heartbeat_interval_secs = heartbeat_interval_secs
|
|
self.heartbeat_timeout_secs = heartbeat_timeout_secs
|
|
self.heartbeat_thread = threading.Thread(target=self._heartbeat_checker)
|
|
self.heartbeat_thread.daemon = True
|
|
self.heartbeat_thread.start()
|
|
|
|
@classmethod
|
|
def from_url(
|
|
cls,
|
|
db_url: str,
|
|
db_name: str,
|
|
pool_size: int = 5,
|
|
max_overflow: int = 10,
|
|
try_to_create_db: bool = False,
|
|
**kwargs,
|
|
) -> "StorageModelRegistry":
|
|
from dbgpt.storage.metadata.db_manager import DatabaseManager, initialize_db
|
|
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
|
|
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
|
|
|
from .db_storage import ModelInstanceEntity, ModelInstanceItemAdapter
|
|
|
|
engine_args = {
|
|
"pool_size": pool_size,
|
|
"max_overflow": max_overflow,
|
|
"pool_timeout": 30,
|
|
"pool_recycle": 3600,
|
|
"pool_pre_ping": True,
|
|
}
|
|
|
|
db: DatabaseManager = initialize_db(
|
|
db_url, db_name, engine_args, try_to_create_db=try_to_create_db
|
|
)
|
|
storage_adapter = ModelInstanceItemAdapter()
|
|
serializer = JsonSerializer()
|
|
storage = SQLAlchemyStorage(
|
|
db,
|
|
ModelInstanceEntity,
|
|
storage_adapter,
|
|
serializer,
|
|
)
|
|
return cls(storage, **kwargs)
|
|
|
|
async def _get_instances_by_model(
|
|
self, model_name: str, host: str, port: int, healthy_only: bool = False
|
|
) -> Tuple[List[ModelInstanceStorageItem], List[ModelInstanceStorageItem]]:
|
|
query_spec = QuerySpec(conditions={"model_name": model_name})
|
|
# Query all instances of the model
|
|
instances = await blocking_func_to_async(
|
|
self._executor, self._storage.query, query_spec, ModelInstanceStorageItem
|
|
)
|
|
if healthy_only:
|
|
instances = [ins for ins in instances if ins.healthy is True]
|
|
exist_ins = [ins for ins in instances if ins.host == host and ins.port == port]
|
|
return instances, exist_ins
|
|
|
|
def _heartbeat_checker(self):
|
|
while True:
|
|
all_instances: List[ModelInstanceStorageItem] = self._storage.query(
|
|
QuerySpec(conditions={}), ModelInstanceStorageItem
|
|
)
|
|
for instance in all_instances:
|
|
if (
|
|
instance.check_healthy
|
|
and datetime.now() - instance.last_heartbeat
|
|
> timedelta(seconds=self.heartbeat_timeout_secs)
|
|
):
|
|
instance.healthy = False
|
|
self._storage.update(instance)
|
|
time.sleep(self.heartbeat_interval_secs)
|
|
|
|
async def register_instance(self, instance: ModelInstance) -> bool:
|
|
model_name = instance.model_name.strip()
|
|
host = instance.host.strip()
|
|
port = instance.port
|
|
_, exist_ins = await self._get_instances_by_model(
|
|
model_name, host, port, healthy_only=False
|
|
)
|
|
if exist_ins:
|
|
# Exist instances, just update the instance
|
|
# One exist instance at most
|
|
ins: ModelInstanceStorageItem = exist_ins[0]
|
|
# Update instance
|
|
ins.weight = instance.weight
|
|
ins.healthy = True
|
|
ins.prompt_template = instance.prompt_template
|
|
ins.last_heartbeat = datetime.now()
|
|
await blocking_func_to_async(self._executor, self._storage.update, ins)
|
|
else:
|
|
# No exist instance, save the new instance
|
|
new_inst = ModelInstanceStorageItem.from_model_instance(instance)
|
|
new_inst.healthy = True
|
|
new_inst.last_heartbeat = datetime.now()
|
|
await blocking_func_to_async(self._executor, self._storage.save, new_inst)
|
|
return True
|
|
|
|
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
|
"""Deregister a model instance.
|
|
|
|
If the instance exists, set the instance as unhealthy, nothing to do if the
|
|
instance does not exist.
|
|
|
|
Args:
|
|
instance (ModelInstance): The instance to deregister.
|
|
"""
|
|
model_name = instance.model_name.strip()
|
|
host = instance.host.strip()
|
|
port = instance.port
|
|
_, exist_ins = await self._get_instances_by_model(
|
|
model_name, host, port, healthy_only=False
|
|
)
|
|
if exist_ins:
|
|
ins = exist_ins[0]
|
|
ins.healthy = False
|
|
await blocking_func_to_async(self._executor, self._storage.update, ins)
|
|
return True
|
|
|
|
async def get_all_instances(
|
|
self, model_name: str, healthy_only: bool = False
|
|
) -> List[ModelInstance]:
|
|
"""Get all instances of a model(Async).
|
|
|
|
Args:
|
|
model_name (str): The model name.
|
|
healthy_only (bool): Whether only get healthy instances. Defaults to False.
|
|
"""
|
|
return await blocking_func_to_async(
|
|
self._executor, self.sync_get_all_instances, model_name, healthy_only
|
|
)
|
|
|
|
def sync_get_all_instances(
|
|
self, model_name: str, healthy_only: bool = False
|
|
) -> List[ModelInstance]:
|
|
"""Get all instances of a model.
|
|
|
|
Args:
|
|
model_name (str): The model name.
|
|
healthy_only (bool): Whether only get healthy instances. Defaults to False.
|
|
|
|
Returns:
|
|
List[ModelInstance]: The list of instances.
|
|
"""
|
|
instances = self._storage.query(
|
|
QuerySpec(conditions={"model_name": model_name}), ModelInstanceStorageItem
|
|
)
|
|
if healthy_only:
|
|
instances = [ins for ins in instances if ins.healthy is True]
|
|
return [ModelInstanceStorageItem.to_model_instance(ins) for ins in instances]
|
|
|
|
async def get_all_model_instances(
|
|
self, healthy_only: bool = False
|
|
) -> List[ModelInstance]:
|
|
"""Get all model instances.
|
|
|
|
Args:
|
|
healthy_only (bool): Whether only get healthy instances. Defaults to False.
|
|
|
|
Returns:
|
|
List[ModelInstance]: The list of instances.
|
|
"""
|
|
all_instances = await blocking_func_to_async(
|
|
self._executor,
|
|
self._storage.query,
|
|
QuerySpec(conditions={}),
|
|
ModelInstanceStorageItem,
|
|
)
|
|
if healthy_only:
|
|
all_instances = [ins for ins in all_instances if ins.healthy is True]
|
|
return [
|
|
ModelInstanceStorageItem.to_model_instance(ins) for ins in all_instances
|
|
]
|
|
|
|
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
|
"""Receive heartbeat from model instance.
|
|
|
|
Update the last heartbeat time of the instance. If the instance does not exist,
|
|
register the instance.
|
|
|
|
Args:
|
|
instance (ModelInstance): The instance to send heartbeat.
|
|
|
|
Returns:
|
|
bool: True if the heartbeat is received successfully.
|
|
"""
|
|
model_name = instance.model_name.strip()
|
|
host = instance.host.strip()
|
|
port = instance.port
|
|
_, exist_ins = await self._get_instances_by_model(
|
|
model_name, host, port, healthy_only=False
|
|
)
|
|
if not exist_ins:
|
|
# register new instance from heartbeat
|
|
await self.register_instance(instance)
|
|
return True
|
|
else:
|
|
ins = exist_ins[0]
|
|
ins.last_heartbeat = datetime.now()
|
|
ins.healthy = True
|
|
await blocking_func_to_async(self._executor, self._storage.update, ins)
|
|
return True
|