DB-GPT/dbgpt/model/cluster/manager_base.py
2024-01-10 10:39:04 +08:00

174 lines
5.3 KiB
Python

import asyncio
from abc import ABC, abstractmethod
from concurrent.futures import Future
from dataclasses import dataclass
from datetime import datetime
from typing import Callable, Dict, Iterator, List, Optional
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.core import ModelMetadata, ModelOutput
from dbgpt.model.base import WorkerApplyOutput, WorkerSupportedModel
from dbgpt.model.cluster.base import WorkerApplyRequest, WorkerStartupRequest
from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.parameter import ModelParameters, ModelWorkerParameters
from dbgpt.util.parameter_utils import ParameterDescription
@dataclass
class WorkerRunData:
host: str
port: int
worker_key: str
worker: ModelWorker
worker_params: ModelWorkerParameters
model_params: ModelParameters
stop_event: asyncio.Event
semaphore: asyncio.Semaphore = None
command_args: List[str] = None
_heartbeat_future: Optional[Future] = None
_last_heartbeat: Optional[datetime] = None
def _to_print_key(self):
model_name = self.model_params.model_name
model_type = (
self.model_params.model_type
if hasattr(self.model_params, "model_type")
else "text2vec"
)
host = self.host
port = self.port
return f"model {model_name}@{model_type}({host}:{port})"
@property
def stopped(self):
"""Check if the worker is stopped""" ""
return self.stop_event.is_set()
class WorkerManager(ABC):
@abstractmethod
async def start(self):
"""Start worker manager
Raises:
Exception: if start worker manager not successfully
"""
@abstractmethod
async def stop(self, ignore_exception: bool = False):
"""Stop worker manager"""
@abstractmethod
def after_start(self, listener: Callable[["WorkerManager"], None]):
"""Add a listener after WorkerManager startup"""
@abstractmethod
async def get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
"""Asynchronous get model instances by worker type and model name"""
@abstractmethod
async def get_all_model_instances(
self, worker_type: str, healthy_only: bool = True
) -> List[WorkerRunData]:
"""Asynchronous get all model instances
Args:
worker_type (str): worker type
healthy_only (bool, optional): only return healthy instances. Defaults to True.
Returns:
List[WorkerRunData]: worker run data list
"""
@abstractmethod
def sync_get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
"""Get model instances by worker type and model name"""
@abstractmethod
async def select_one_instance(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> WorkerRunData:
"""Asynchronous select one instance"""
@abstractmethod
def sync_select_one_instance(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> WorkerRunData:
"""Select one instance"""
@abstractmethod
async def supported_models(self) -> List[WorkerSupportedModel]:
"""List supported models"""
@abstractmethod
async def model_startup(self, startup_req: WorkerStartupRequest):
"""Create and start a model instance"""
@abstractmethod
async def model_shutdown(self, shutdown_req: WorkerStartupRequest):
"""Shutdown model instance"""
@abstractmethod
async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]:
"""Generate stream result, chat scene"""
@abstractmethod
async def generate(self, params: Dict) -> ModelOutput:
"""Generate non stream result"""
@abstractmethod
async def embeddings(self, params: Dict) -> List[List[float]]:
"""Asynchronous embed input"""
@abstractmethod
def sync_embeddings(self, params: Dict) -> List[List[float]]:
"""Embed input
This function may be passed to a third-party system call for synchronous calls.
We must provide a synchronous version.
"""
@abstractmethod
async def count_token(self, params: Dict) -> int:
"""Count token of prompt
Args:
params (Dict): parameters, eg. {"prompt": "hello", "model": "vicuna-13b-v1.5"}
Returns:
int: token count
"""
@abstractmethod
async def get_model_metadata(self, params: Dict) -> ModelMetadata:
"""Get model metadata
Args:
params (Dict): parameters, eg. {"model": "vicuna-13b-v1.5"}
"""
@abstractmethod
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
"""Worker apply"""
@abstractmethod
async def parameter_descriptions(
self, worker_type: str, model_name: str
) -> List[ParameterDescription]:
"""Get parameter descriptions of model"""
class WorkerManagerFactory(BaseComponent, ABC):
name = ComponentType.WORKER_MANAGER_FACTORY.value
def init_app(self, system_app: SystemApp):
pass
@abstractmethod
def create(self) -> WorkerManager:
"""Create worker manager"""