from contextlib import asynccontextmanager, contextmanager from typing import Dict, Iterator, List, Tuple import pytest import pytest_asyncio from dbgpt.core import ModelMetadata, ModelOutput from dbgpt.model.base import ModelInstance from dbgpt.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry from dbgpt.model.cluster.worker.manager import ( ApplyFunction, DeregisterFunc, LocalWorkerManager, RegisterFunc, SendHeartbeatFunc, WorkerManager, ) from dbgpt.model.cluster.worker_base import ModelWorker from dbgpt.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType @pytest.fixture def model_registry(request): return EmbeddedModelRegistry() @pytest.fixture def model_instance(): return ModelInstance( model_name="test_model", host="192.168.1.1", port=5000, ) class MockModelWorker(ModelWorker): def __init__( self, model_parameters: ModelParameters, error_worker: bool = False, stop_error: bool = False, stream_messags: List[str] = None, embeddings: List[List[float]] = None, ) -> None: super().__init__() if not stream_messags: stream_messags = [] if not embeddings: embeddings = [] self.model_parameters = model_parameters self.error_worker = error_worker self.stop_error = stop_error self.stream_messags = stream_messags self._embeddings = embeddings def parse_parameters(self, command_args: List[str] = None) -> ModelParameters: return self.model_parameters def load_worker(self, model_name: str, model_path: str, **kwargs) -> None: pass def start( self, model_params: ModelParameters = None, command_args: List[str] = None ) -> None: if self.error_worker: raise Exception("Start worker error for mock") def stop(self) -> None: if self.stop_error: raise Exception("Stop worker error for mock") def generate_stream(self, params: Dict) -> Iterator[ModelOutput]: full_text = "" for msg in self.stream_messags: full_text += msg yield ModelOutput(text=full_text, error_code=0) def generate(self, params: Dict) -> ModelOutput: output = None for out in self.generate_stream(params): output = out return output def count_token(self, prompt: str) -> int: return len(prompt) def get_model_metadata(self, params: Dict) -> ModelMetadata: return ModelMetadata( model=self.model_parameters.model_name, ) def embeddings(self, params: Dict) -> List[List[float]]: return self._embeddings _TEST_MODEL_NAME = "vicuna-13b-v1.5" _TEST_MODEL_PATH = "/app/models/vicuna-13b-v1.5" ClusterType = Tuple[WorkerManager, ModelRegistry] def _new_worker_params( model_name: str = _TEST_MODEL_NAME, model_path: str = _TEST_MODEL_PATH, worker_type: str = WorkerType.LLM.value, ) -> ModelWorkerParameters: return ModelWorkerParameters( model_name=model_name, model_path=model_path, worker_type=worker_type ) def _create_workers( num_workers: int, error_worker: bool = False, stop_error: bool = False, worker_type: str = WorkerType.LLM.value, stream_messags: List[str] = None, embeddings: List[List[float]] = None, host: str = "127.0.0.1", start_port=8001, ) -> List[Tuple[ModelWorker, ModelWorkerParameters, ModelInstance]]: workers = [] for i in range(num_workers): model_name = f"test-model-name-{i}" model_path = f"test-model-path-{i}" model_parameters = ModelParameters(model_name=model_name, model_path=model_path) worker = MockModelWorker( model_parameters, error_worker=error_worker, stop_error=stop_error, stream_messags=stream_messags, embeddings=embeddings, ) model_instance = ModelInstance( model_name=WorkerType.to_worker_key(model_name, worker_type), host=host, port=start_port + i, healthy=True, ) worker_params = _new_worker_params( model_name, model_path, worker_type=worker_type ) workers.append((worker, worker_params, model_instance)) return workers @asynccontextmanager async def _start_worker_manager(**kwargs): register_func = kwargs.get("register_func") deregister_func = kwargs.get("deregister_func") send_heartbeat_func = kwargs.get("send_heartbeat_func") model_registry = kwargs.get("model_registry") workers = kwargs.get("workers") num_workers = int(kwargs.get("num_workers", 0)) start = kwargs.get("start", True) stop = kwargs.get("stop", True) error_worker = kwargs.get("error_worker", False) stop_error = kwargs.get("stop_error", False) stream_messags = kwargs.get("stream_messags", []) embeddings = kwargs.get("embeddings", []) worker_manager = LocalWorkerManager( register_func=register_func, deregister_func=deregister_func, send_heartbeat_func=send_heartbeat_func, model_registry=model_registry, ) for worker, worker_params, model_instance in _create_workers( num_workers, error_worker, stop_error, stream_messags, embeddings ): worker_manager.add_worker(worker, worker_params) if workers: for worker, worker_params, model_instance in workers: worker_manager.add_worker(worker, worker_params) if start: await worker_manager.start() yield worker_manager if stop: await worker_manager.stop() async def _create_model_registry( workers: List[Tuple[ModelWorker, ModelWorkerParameters, ModelInstance]] ) -> ModelRegistry: registry = EmbeddedModelRegistry() for _, _, inst in workers: assert await registry.register_instance(inst) == True return registry @pytest_asyncio.fixture async def manager_2_workers(request): param = getattr(request, "param", {}) async with _start_worker_manager(num_workers=2, **param) as worker_manager: yield worker_manager @pytest_asyncio.fixture async def manager_with_2_workers(request): param = getattr(request, "param", {}) workers = _create_workers(2, stream_messags=param.get("stream_messags", [])) async with _start_worker_manager(workers=workers, **param) as worker_manager: yield (worker_manager, workers) @pytest_asyncio.fixture async def manager_2_embedding_workers(request): param = getattr(request, "param", {}) workers = _create_workers( 2, worker_type=WorkerType.TEXT2VEC.value, embeddings=param.get("embeddings", []) ) async with _start_worker_manager(workers=workers, **param) as worker_manager: yield (worker_manager, workers) @asynccontextmanager async def _new_cluster(**kwargs) -> ClusterType: num_workers = kwargs.get("num_workers", 0) workers = _create_workers( num_workers, stream_messags=kwargs.get("stream_messags", []) ) if "num_workers" in kwargs: del kwargs["num_workers"] registry = await _create_model_registry( workers, ) async with _start_worker_manager(workers=workers, **kwargs) as worker_manager: yield (worker_manager, registry) @pytest_asyncio.fixture async def cluster_2_workers(request): param = getattr(request, "param", {}) workers = _create_workers(2) registry = await _create_model_registry(workers) async with _start_worker_manager(workers=workers, **param) as worker_manager: yield (worker_manager, registry)