mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-28 22:37:31 +00:00
225 lines
7.7 KiB
Python
225 lines
7.7 KiB
Python
import itertools
|
|
import logging
|
|
import random
|
|
import threading
|
|
import time
|
|
from abc import ABC, abstractmethod
|
|
from collections import defaultdict
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
|
from dbgpt.model.base import ModelInstance
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ModelRegistry(BaseComponent, ABC):
|
|
"""
|
|
Abstract base class for a model registry. It provides an interface
|
|
for registering, deregistering, fetching instances, and sending heartbeats
|
|
for instances.
|
|
"""
|
|
|
|
name = ComponentType.MODEL_REGISTRY
|
|
|
|
def __init__(self, system_app: SystemApp | None = None):
|
|
self.system_app = system_app
|
|
super().__init__(system_app)
|
|
|
|
def init_app(self, system_app: SystemApp):
|
|
"""Initialize the component with the main application."""
|
|
self.system_app = system_app
|
|
|
|
@abstractmethod
|
|
async def register_instance(self, instance: ModelInstance) -> bool:
|
|
"""
|
|
Register a given model instance.
|
|
|
|
Args:
|
|
- instance (ModelInstance): The instance of the model to register.
|
|
|
|
Returns:
|
|
- bool: True if registration is successful, False otherwise.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
|
"""
|
|
Deregister a given model instance.
|
|
|
|
Args:
|
|
- instance (ModelInstance): The instance of the model to deregister.
|
|
|
|
Returns:
|
|
- bool: True if deregistration is successful, False otherwise.
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def get_all_instances(
|
|
self, model_name: str, healthy_only: bool = False
|
|
) -> List[ModelInstance]:
|
|
"""
|
|
Fetch all instances of a given model. Optionally, fetch only the healthy instances.
|
|
|
|
Args:
|
|
- model_name (str): Name of the model to fetch instances for.
|
|
- healthy_only (bool, optional): If set to True, fetches only the healthy instances.
|
|
Defaults to False.
|
|
|
|
Returns:
|
|
- List[ModelInstance]: A list of instances for the given model.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def sync_get_all_instances(
|
|
self, model_name: str, healthy_only: bool = False
|
|
) -> List[ModelInstance]:
|
|
"""Fetch all instances of a given model. Optionally, fetch only the healthy instances."""
|
|
|
|
@abstractmethod
|
|
async def get_all_model_instances(
|
|
self, healthy_only: bool = False
|
|
) -> List[ModelInstance]:
|
|
"""
|
|
Fetch all instances of all models, Optionally, fetch only the healthy instances.
|
|
|
|
Returns:
|
|
- List[ModelInstance]: A list of instances for the all models.
|
|
"""
|
|
|
|
async def select_one_health_instance(self, model_name: str) -> ModelInstance:
|
|
"""
|
|
Selects one healthy and enabled instance for a given model.
|
|
|
|
Args:
|
|
- model_name (str): Name of the model.
|
|
|
|
Returns:
|
|
- ModelInstance: One randomly selected healthy and enabled instance, or None if no such instance exists.
|
|
"""
|
|
instances = await self.get_all_instances(model_name, healthy_only=True)
|
|
instances = [i for i in instances if i.enabled]
|
|
if not instances:
|
|
return None
|
|
return random.choice(instances)
|
|
|
|
@abstractmethod
|
|
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
|
"""
|
|
Send a heartbeat for a given model instance. This can be used to
|
|
verify if the instance is still alive and functioning.
|
|
|
|
Args:
|
|
- instance (ModelInstance): The instance of the model to send a heartbeat for.
|
|
|
|
Returns:
|
|
- bool: True if heartbeat is successful, False otherwise.
|
|
"""
|
|
|
|
|
|
class EmbeddedModelRegistry(ModelRegistry):
|
|
def __init__(
|
|
self,
|
|
system_app: SystemApp | None = None,
|
|
heartbeat_interval_secs: int = 60,
|
|
heartbeat_timeout_secs: int = 120,
|
|
):
|
|
super().__init__(system_app)
|
|
self.registry: Dict[str, List[ModelInstance]] = defaultdict(list)
|
|
self.heartbeat_interval_secs = heartbeat_interval_secs
|
|
self.heartbeat_timeout_secs = heartbeat_timeout_secs
|
|
self.heartbeat_thread = threading.Thread(target=self._heartbeat_checker)
|
|
self.heartbeat_thread.daemon = True
|
|
self.heartbeat_thread.start()
|
|
|
|
def _get_instances(
|
|
self, model_name: str, host: str, port: int, healthy_only: bool = False
|
|
) -> Tuple[List[ModelInstance], List[ModelInstance]]:
|
|
instances = self.registry[model_name]
|
|
if healthy_only:
|
|
instances = [ins for ins in instances if ins.healthy == True]
|
|
exist_ins = [ins for ins in instances if ins.host == host and ins.port == port]
|
|
return instances, exist_ins
|
|
|
|
def _heartbeat_checker(self):
|
|
while True:
|
|
for instances in self.registry.values():
|
|
for instance in instances:
|
|
if (
|
|
instance.check_healthy
|
|
and datetime.now() - instance.last_heartbeat
|
|
> timedelta(seconds=self.heartbeat_timeout_secs)
|
|
):
|
|
instance.healthy = False
|
|
time.sleep(self.heartbeat_interval_secs)
|
|
|
|
async def register_instance(self, instance: ModelInstance) -> bool:
|
|
model_name = instance.model_name.strip()
|
|
host = instance.host.strip()
|
|
port = instance.port
|
|
|
|
instances, exist_ins = self._get_instances(
|
|
model_name, host, port, healthy_only=False
|
|
)
|
|
if exist_ins:
|
|
# One exist instance at most
|
|
ins = exist_ins[0]
|
|
# Update instance
|
|
ins.weight = instance.weight
|
|
ins.healthy = True
|
|
ins.prompt_template = instance.prompt_template
|
|
ins.last_heartbeat = datetime.now()
|
|
else:
|
|
instance.healthy = True
|
|
instance.last_heartbeat = datetime.now()
|
|
instances.append(instance)
|
|
return True
|
|
|
|
async def deregister_instance(self, instance: ModelInstance) -> bool:
|
|
model_name = instance.model_name.strip()
|
|
host = instance.host.strip()
|
|
port = instance.port
|
|
_, exist_ins = self._get_instances(model_name, host, port, healthy_only=False)
|
|
if exist_ins:
|
|
ins = exist_ins[0]
|
|
ins.healthy = False
|
|
return True
|
|
|
|
async def get_all_instances(
|
|
self, model_name: str, healthy_only: bool = False
|
|
) -> List[ModelInstance]:
|
|
return self.sync_get_all_instances(model_name, healthy_only)
|
|
|
|
def sync_get_all_instances(
|
|
self, model_name: str, healthy_only: bool = False
|
|
) -> List[ModelInstance]:
|
|
instances = self.registry[model_name]
|
|
if healthy_only:
|
|
instances = [ins for ins in instances if ins.healthy == True]
|
|
return instances
|
|
|
|
async def get_all_model_instances(
|
|
self, healthy_only: bool = False
|
|
) -> List[ModelInstance]:
|
|
logger.debug("Current registry metadata:\n{self.registry}")
|
|
instances = list(itertools.chain(*self.registry.values()))
|
|
if healthy_only:
|
|
instances = [ins for ins in instances if ins.healthy == True]
|
|
return instances
|
|
|
|
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
|
_, exist_ins = self._get_instances(
|
|
instance.model_name, instance.host, instance.port, healthy_only=False
|
|
)
|
|
if not exist_ins:
|
|
# register new install from heartbeat
|
|
await self.register_instance(instance)
|
|
return True
|
|
|
|
ins = exist_ins[0]
|
|
ins.last_heartbeat = datetime.now()
|
|
ins.healthy = True
|
|
return True
|