DB-GPT/dbgpt/model/cluster/registry.py
2024-01-10 10:39:04 +08:00

225 lines
7.7 KiB
Python

import itertools
import logging
import random
import threading
import time
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.model.base import ModelInstance
logger = logging.getLogger(__name__)
class ModelRegistry(BaseComponent, ABC):
"""
Abstract base class for a model registry. It provides an interface
for registering, deregistering, fetching instances, and sending heartbeats
for instances.
"""
name = ComponentType.MODEL_REGISTRY
def __init__(self, system_app: SystemApp | None = None):
self.system_app = system_app
super().__init__(system_app)
def init_app(self, system_app: SystemApp):
"""Initialize the component with the main application."""
self.system_app = system_app
@abstractmethod
async def register_instance(self, instance: ModelInstance) -> bool:
"""
Register a given model instance.
Args:
- instance (ModelInstance): The instance of the model to register.
Returns:
- bool: True if registration is successful, False otherwise.
"""
pass
@abstractmethod
async def deregister_instance(self, instance: ModelInstance) -> bool:
"""
Deregister a given model instance.
Args:
- instance (ModelInstance): The instance of the model to deregister.
Returns:
- bool: True if deregistration is successful, False otherwise.
"""
@abstractmethod
async def get_all_instances(
self, model_name: str, healthy_only: bool = False
) -> List[ModelInstance]:
"""
Fetch all instances of a given model. Optionally, fetch only the healthy instances.
Args:
- model_name (str): Name of the model to fetch instances for.
- healthy_only (bool, optional): If set to True, fetches only the healthy instances.
Defaults to False.
Returns:
- List[ModelInstance]: A list of instances for the given model.
"""
@abstractmethod
def sync_get_all_instances(
self, model_name: str, healthy_only: bool = False
) -> List[ModelInstance]:
"""Fetch all instances of a given model. Optionally, fetch only the healthy instances."""
@abstractmethod
async def get_all_model_instances(
self, healthy_only: bool = False
) -> List[ModelInstance]:
"""
Fetch all instances of all models, Optionally, fetch only the healthy instances.
Returns:
- List[ModelInstance]: A list of instances for the all models.
"""
async def select_one_health_instance(self, model_name: str) -> ModelInstance:
"""
Selects one healthy and enabled instance for a given model.
Args:
- model_name (str): Name of the model.
Returns:
- ModelInstance: One randomly selected healthy and enabled instance, or None if no such instance exists.
"""
instances = await self.get_all_instances(model_name, healthy_only=True)
instances = [i for i in instances if i.enabled]
if not instances:
return None
return random.choice(instances)
@abstractmethod
async def send_heartbeat(self, instance: ModelInstance) -> bool:
"""
Send a heartbeat for a given model instance. This can be used to
verify if the instance is still alive and functioning.
Args:
- instance (ModelInstance): The instance of the model to send a heartbeat for.
Returns:
- bool: True if heartbeat is successful, False otherwise.
"""
class EmbeddedModelRegistry(ModelRegistry):
def __init__(
self,
system_app: SystemApp | None = None,
heartbeat_interval_secs: int = 60,
heartbeat_timeout_secs: int = 120,
):
super().__init__(system_app)
self.registry: Dict[str, List[ModelInstance]] = defaultdict(list)
self.heartbeat_interval_secs = heartbeat_interval_secs
self.heartbeat_timeout_secs = heartbeat_timeout_secs
self.heartbeat_thread = threading.Thread(target=self._heartbeat_checker)
self.heartbeat_thread.daemon = True
self.heartbeat_thread.start()
def _get_instances(
self, model_name: str, host: str, port: int, healthy_only: bool = False
) -> Tuple[List[ModelInstance], List[ModelInstance]]:
instances = self.registry[model_name]
if healthy_only:
instances = [ins for ins in instances if ins.healthy == True]
exist_ins = [ins for ins in instances if ins.host == host and ins.port == port]
return instances, exist_ins
def _heartbeat_checker(self):
while True:
for instances in self.registry.values():
for instance in instances:
if (
instance.check_healthy
and datetime.now() - instance.last_heartbeat
> timedelta(seconds=self.heartbeat_timeout_secs)
):
instance.healthy = False
time.sleep(self.heartbeat_interval_secs)
async def register_instance(self, instance: ModelInstance) -> bool:
model_name = instance.model_name.strip()
host = instance.host.strip()
port = instance.port
instances, exist_ins = self._get_instances(
model_name, host, port, healthy_only=False
)
if exist_ins:
# One exist instance at most
ins = exist_ins[0]
# Update instance
ins.weight = instance.weight
ins.healthy = True
ins.prompt_template = instance.prompt_template
ins.last_heartbeat = datetime.now()
else:
instance.healthy = True
instance.last_heartbeat = datetime.now()
instances.append(instance)
return True
async def deregister_instance(self, instance: ModelInstance) -> bool:
model_name = instance.model_name.strip()
host = instance.host.strip()
port = instance.port
_, exist_ins = self._get_instances(model_name, host, port, healthy_only=False)
if exist_ins:
ins = exist_ins[0]
ins.healthy = False
return True
async def get_all_instances(
self, model_name: str, healthy_only: bool = False
) -> List[ModelInstance]:
return self.sync_get_all_instances(model_name, healthy_only)
def sync_get_all_instances(
self, model_name: str, healthy_only: bool = False
) -> List[ModelInstance]:
instances = self.registry[model_name]
if healthy_only:
instances = [ins for ins in instances if ins.healthy == True]
return instances
async def get_all_model_instances(
self, healthy_only: bool = False
) -> List[ModelInstance]:
logger.debug("Current registry metadata:\n{self.registry}")
instances = list(itertools.chain(*self.registry.values()))
if healthy_only:
instances = [ins for ins in instances if ins.healthy == True]
return instances
async def send_heartbeat(self, instance: ModelInstance) -> bool:
_, exist_ins = self._get_instances(
instance.model_name, instance.host, instance.port, healthy_only=False
)
if not exist_ins:
# register new install from heartbeat
await self.register_instance(instance)
return True
ins = exist_ins[0]
ins.last_heartbeat = datetime.now()
ins.healthy = True
return True