diff --git a/pilot/model/base.py b/pilot/model/base.py index 035cee044..e89b243c9 100644 --- a/pilot/model/base.py +++ b/pilot/model/base.py @@ -64,6 +64,20 @@ class WorkerApplyOutput: # The seconds cost to apply some action to worker instances timecost: Optional[int] = -1 + @staticmethod + def reduce(outs: List["WorkerApplyOutput"]) -> "WorkerApplyOutput": + """Merge all outputs + + Args: + outs (List["WorkerApplyOutput"]): The list of WorkerApplyOutput + """ + if not outs: + return WorkerApplyOutput("Not outputs") + combined_success = all(out.success for out in outs) + max_timecost = max(out.timecost for out in outs) + combined_message = ", ".join(out.message for out in outs) + return WorkerApplyOutput(combined_message, combined_success, max_timecost) + @dataclass class SupportedModel: diff --git a/pilot/model/cluster/controller/controller.py b/pilot/model/cluster/controller/controller.py index 826ffef03..1ec3965dc 100644 --- a/pilot/model/cluster/controller/controller.py +++ b/pilot/model/cluster/controller/controller.py @@ -185,7 +185,7 @@ def run_model_controller(): setup_logging( "pilot", logging_level=controller_params.log_level, - logger_filename="dbgpt_model_controller.log", + logger_filename=controller_params.log_file, ) initialize_controller(host=controller_params.host, port=controller_params.port) diff --git a/pilot/model/cluster/manager_base.py b/pilot/model/cluster/manager_base.py index 10c351fa6..80170ce2f 100644 --- a/pilot/model/cluster/manager_base.py +++ b/pilot/model/cluster/manager_base.py @@ -26,11 +26,22 @@ class WorkerRunData: _heartbeat_future: Optional[Future] = None _last_heartbeat: Optional[datetime] = None + def _to_print_key(self): + model_name = self.model_params.model_name + model_type = self.model_params.model_type + host = self.host + port = self.port + return f"model {model_name}@{model_type}({host}:{port})" + class WorkerManager(ABC): @abstractmethod async def start(self): - """Start worker manager""" + """Start worker manager + + Raises: + Exception: if start worker manager not successfully + """ @abstractmethod async def stop(self, ignore_exception: bool = False): @@ -69,11 +80,11 @@ class WorkerManager(ABC): """List supported models""" @abstractmethod - async def model_startup(self, startup_req: WorkerStartupRequest) -> bool: + async def model_startup(self, startup_req: WorkerStartupRequest): """Create and start a model instance""" @abstractmethod - async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool: + async def model_shutdown(self, shutdown_req: WorkerStartupRequest): """Shutdown model instance""" @abstractmethod diff --git a/pilot/model/cluster/worker/manager.py b/pilot/model/cluster/worker/manager.py index a85ee0ed7..cc5ef97d6 100644 --- a/pilot/model/cluster/worker/manager.py +++ b/pilot/model/cluster/worker/manager.py @@ -104,12 +104,16 @@ class LocalWorkerManager(WorkerManager): return f"{model_name}@{worker_type}" async def run_blocking_func(self, func, *args): + if asyncio.iscoroutinefunction(func): + raise ValueError(f"The function {func} is not blocking function") loop = asyncio.get_event_loop() return await loop.run_in_executor(self.executor, func, *args) async def start(self): if len(self.workers) > 0: - await self._start_all_worker(apply_req=None) + out = await self._start_all_worker(apply_req=None) + if not out.success: + raise Exception(out.message) if self.register_func: await self.register_func(self.run_data) if self.send_heartbeat_func: @@ -143,7 +147,9 @@ class LocalWorkerManager(WorkerManager): else: stop_tasks.append(self.deregister_func(self.run_data)) - await asyncio.gather(*stop_tasks) + results = await asyncio.gather(*stop_tasks) + if not results[0].success and not ignore_exception: + raise Exception(results[0].message) def after_start(self, listener: Callable[["WorkerManager"], None]): self.start_listeners.append(listener) @@ -193,7 +199,15 @@ class LocalWorkerManager(WorkerManager): logger.warn(f"Instance {worker_key} exist") return False - async def model_startup(self, startup_req: WorkerStartupRequest) -> bool: + def _remove_worker(self, worker_params: ModelWorkerParameters) -> None: + worker_key = self._worker_key( + worker_params.worker_type, worker_params.model_name + ) + instances = self.workers.get(worker_key) + if instances: + del self.workers[worker_key] + + async def model_startup(self, startup_req: WorkerStartupRequest): """Start model""" model_name = startup_req.model worker_type = startup_req.worker_type @@ -213,22 +227,30 @@ class LocalWorkerManager(WorkerManager): self.add_worker, worker, worker_params, command_args ) if not success: - logger.warn( - f"Add worker failed, worker instances is exist, worker_params: {worker_params}" - ) - return False + msg = f"Add worker {model_name}@{worker_type}, worker instances is exist" + logger.warn(f"{msg}, worker_params: {worker_params}") + self._remove_worker(worker_params) + raise Exception(msg) supported_types = WorkerType.values() if worker_type not in supported_types: + self._remove_worker(worker_params) raise ValueError( f"Unsupported worker type: {worker_type}, now supported worker type: {supported_types}" ) start_apply_req = WorkerApplyRequest( model=model_name, apply_type=WorkerApplyType.START, worker_type=worker_type ) - await self.worker_apply(start_apply_req) - return True + out: WorkerApplyOutput = None + try: + out = await self.worker_apply(start_apply_req) + except Exception as e: + self._remove_worker(worker_params) + raise e + if not out.success: + self._remove_worker(worker_params) + raise Exception(out.message) - async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool: + async def model_shutdown(self, shutdown_req: WorkerStartupRequest): logger.info(f"Begin shutdown model, shutdown_req: {shutdown_req}") apply_req = WorkerApplyRequest( model=shutdown_req.model, @@ -236,9 +258,8 @@ class LocalWorkerManager(WorkerManager): worker_type=shutdown_req.worker_type, ) out = await self._stop_all_worker(apply_req) - if out.success: - return True - raise Exception(out.message) + if not out.success: + raise Exception(out.message) async def supported_models(self) -> List[WorkerSupportedModel]: models = await self.run_blocking_func(list_supported_models) @@ -253,7 +274,7 @@ class LocalWorkerManager(WorkerManager): self, worker_type: str, model_name: str, healthy_only: bool = True ) -> List[WorkerRunData]: worker_key = self._worker_key(worker_type, model_name) - return self.workers.get(worker_key) + return self.workers.get(worker_key, []) def _simple_select( self, worker_type: str, model_name: str, worker_instances: List[WorkerRunData] @@ -424,36 +445,47 @@ class LocalWorkerManager(WorkerManager): async def _start_all_worker( self, apply_req: WorkerApplyRequest ) -> WorkerApplyOutput: + # TODO avoid start twice start_time = time.time() logger.info(f"Begin start all worker, apply_req: {apply_req}") async def _start_worker(worker_run_data: WorkerRunData): - await self.run_blocking_func( - worker_run_data.worker.start, - worker_run_data.model_params, - worker_run_data.command_args, - ) - worker_run_data.stop_event.clear() - if worker_run_data.worker_params.register and self.register_func: - # Register worker to controller - await self.register_func(worker_run_data) - if ( - worker_run_data.worker_params.send_heartbeat - and self.send_heartbeat_func - ): - asyncio.create_task( - _async_heartbeat_sender( - worker_run_data, - worker_run_data.worker_params.heartbeat_interval, - self.send_heartbeat_func, + _start_time = time.time() + info = worker_run_data._to_print_key() + out = WorkerApplyOutput("") + try: + await self.run_blocking_func( + worker_run_data.worker.start, + worker_run_data.model_params, + worker_run_data.command_args, + ) + worker_run_data.stop_event.clear() + if worker_run_data.worker_params.register and self.register_func: + # Register worker to controller + await self.register_func(worker_run_data) + if ( + worker_run_data.worker_params.send_heartbeat + and self.send_heartbeat_func + ): + asyncio.create_task( + _async_heartbeat_sender( + worker_run_data, + worker_run_data.worker_params.heartbeat_interval, + self.send_heartbeat_func, + ) ) - ) + out.message = f"{info} start successfully" + except Exception as e: + out.success = False + out.message = f"{info} start failed, {str(e)}" + finally: + out.timecost = time.time() - _start_time + return out - await self._apply_worker(apply_req, _start_worker) - timecost = time.time() - start_time - return WorkerApplyOutput( - message=f"Worker started successfully", timecost=timecost - ) + outs = await self._apply_worker(apply_req, _start_worker) + out = WorkerApplyOutput.reduce(outs) + out.timecost = time.time() - start_time + return out async def _stop_all_worker( self, apply_req: WorkerApplyRequest, ignore_exception: bool = False @@ -461,42 +493,56 @@ class LocalWorkerManager(WorkerManager): start_time = time.time() async def _stop_worker(worker_run_data: WorkerRunData): - await self.run_blocking_func(worker_run_data.worker.stop) - # Set stop event - worker_run_data.stop_event.set() - if worker_run_data._heartbeat_future: - # Wait thread finish - worker_run_data._heartbeat_future.result() - worker_run_data._heartbeat_future = None - if ( - worker_run_data.worker_params.register - and self.register_func - and self.deregister_func - ): - _deregister_func = self.deregister_func - if ignore_exception: + _start_time = time.time() + info = worker_run_data._to_print_key() + out = WorkerApplyOutput("") + try: + await self.run_blocking_func(worker_run_data.worker.stop) + # Set stop event + worker_run_data.stop_event.set() + if worker_run_data._heartbeat_future: + # Wait thread finish + worker_run_data._heartbeat_future.result() + worker_run_data._heartbeat_future = None + if ( + worker_run_data.worker_params.register + and self.register_func + and self.deregister_func + ): + _deregister_func = self.deregister_func + if ignore_exception: - async def safe_deregister_func(run_data): - try: - await self.deregister_func(run_data) - except Exception as e: - logger.warning( - f"Stop worker, ignored exception from deregister_func: {e}" - ) + async def safe_deregister_func(run_data): + try: + await self.deregister_func(run_data) + except Exception as e: + logger.warning( + f"Stop worker, ignored exception from deregister_func: {e}" + ) - _deregister_func = safe_deregister_func - await _deregister_func(worker_run_data) + _deregister_func = safe_deregister_func + await _deregister_func(worker_run_data) + # Remove metadata + self._remove_worker(worker_run_data.worker_params) + out.message = f"{info} stop successfully" + except Exception as e: + out.success = False + out.message = f"{info} stop failed, {str(e)}" + finally: + out.timecost = time.time() - _start_time + return out - await self._apply_worker(apply_req, _stop_worker) - timecost = time.time() - start_time - return WorkerApplyOutput( - message=f"Worker stopped successfully", timecost=timecost - ) + outs = await self._apply_worker(apply_req, _stop_worker) + out = WorkerApplyOutput.reduce(outs) + out.timecost = time.time() - start_time + return out async def _restart_all_worker( self, apply_req: WorkerApplyRequest ) -> WorkerApplyOutput: - await self._stop_all_worker(apply_req) + out = await self._stop_all_worker(apply_req, ignore_exception=True) + if not out.success: + return out return await self._start_all_worker(apply_req) async def _update_all_worker_params( @@ -541,10 +587,10 @@ class WorkerManagerAdapter(WorkerManager): async def supported_models(self) -> List[WorkerSupportedModel]: return await self.worker_manager.supported_models() - async def model_startup(self, startup_req: WorkerStartupRequest) -> bool: + async def model_startup(self, startup_req: WorkerStartupRequest): return await self.worker_manager.model_startup(startup_req) - async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool: + async def model_shutdown(self, shutdown_req: WorkerStartupRequest): return await self.worker_manager.model_shutdown(shutdown_req) async def get_model_instances( @@ -960,7 +1006,7 @@ def run_worker_manager( setup_logging( "pilot", logging_level=worker_params.log_level, - logger_filename="dbgpt_model_worker_manager.log", + logger_filename=worker_params.log_file, ) embedded_mod = True @@ -973,7 +1019,7 @@ def run_worker_manager( system_app = SystemApp(app) initialize_tracer( system_app, - os.path.join(LOGDIR, "dbgpt_model_worker_manager_tracer.jsonl"), + os.path.join(LOGDIR, worker_params.tracer_file), root_operation_name="DB-GPT-WorkerManager-Entry", ) diff --git a/pilot/model/cluster/worker/remote_manager.py b/pilot/model/cluster/worker/remote_manager.py index 3aa9673bb..61b608cc7 100644 --- a/pilot/model/cluster/worker/remote_manager.py +++ b/pilot/model/cluster/worker/remote_manager.py @@ -96,7 +96,7 @@ class RemoteWorkerManager(LocalWorkerManager): raise Exception(error_msg) return worker_instances - async def model_startup(self, startup_req: WorkerStartupRequest) -> bool: + async def model_startup(self, startup_req: WorkerStartupRequest): worker_instances = await self._get_worker_service_instance( startup_req.host, startup_req.port ) @@ -107,10 +107,10 @@ class RemoteWorkerManager(LocalWorkerManager): "/models/startup", method="POST", json=startup_req.dict(), - success_handler=lambda x: True, + success_handler=lambda x: None, ) - async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool: + async def model_shutdown(self, shutdown_req: WorkerStartupRequest): worker_instances = await self._get_worker_service_instance( shutdown_req.host, shutdown_req.port ) @@ -121,7 +121,7 @@ class RemoteWorkerManager(LocalWorkerManager): "/models/shutdown", method="POST", json=shutdown_req.dict(), - success_handler=lambda x: True, + success_handler=lambda x: None, ) def _build_worker_instances( diff --git a/pilot/model/cluster/worker/tests/__init__.py b/pilot/model/cluster/worker/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/model/cluster/worker/tests/base_tests.py b/pilot/model/cluster/worker/tests/base_tests.py new file mode 100644 index 000000000..21821d9f9 --- /dev/null +++ b/pilot/model/cluster/worker/tests/base_tests.py @@ -0,0 +1,168 @@ +import pytest +import pytest_asyncio +from contextlib import contextmanager, asynccontextmanager +from typing import List, Iterator, Dict, Tuple +from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType +from pilot.model.base import ModelOutput +from pilot.model.cluster.worker_base import ModelWorker +from pilot.model.cluster.worker.manager import ( + LocalWorkerManager, + RegisterFunc, + DeregisterFunc, + SendHeartbeatFunc, + ApplyFunction, +) + + +class MockModelWorker(ModelWorker): + def __init__( + self, + model_parameters: ModelParameters, + error_worker: bool = False, + stop_error: bool = False, + stream_messags: List[str] = None, + embeddings: List[List[float]] = None, + ) -> None: + super().__init__() + if not stream_messags: + stream_messags = [] + if not embeddings: + embeddings = [] + self.model_parameters = model_parameters + self.error_worker = error_worker + self.stop_error = stop_error + self.stream_messags = stream_messags + self._embeddings = embeddings + + def parse_parameters(self, command_args: List[str] = None) -> ModelParameters: + return self.model_parameters + + def load_worker(self, model_name: str, model_path: str, **kwargs) -> None: + pass + + def start( + self, model_params: ModelParameters = None, command_args: List[str] = None + ) -> None: + if self.error_worker: + raise Exception("Start worker error for mock") + + def stop(self) -> None: + if self.stop_error: + raise Exception("Stop worker error for mock") + + def generate_stream(self, params: Dict) -> Iterator[ModelOutput]: + for msg in self.stream_messags: + yield ModelOutput(text=msg, error_code=0) + + def generate(self, params: Dict) -> ModelOutput: + output = None + for out in self.generate_stream(params): + output = out + return output + + def embeddings(self, params: Dict) -> List[List[float]]: + return self._embeddings + + +_TEST_MODEL_NAME = "vicuna-13b-v1.5" +_TEST_MODEL_PATH = "/app/models/vicuna-13b-v1.5" + + +def _new_worker_params( + model_name: str = _TEST_MODEL_NAME, + model_path: str = _TEST_MODEL_PATH, + worker_type: str = WorkerType.LLM.value, +) -> ModelWorkerParameters: + return ModelWorkerParameters( + model_name=model_name, model_path=model_path, worker_type=worker_type + ) + + +def _create_workers( + num_workers: int, + error_worker: bool = False, + stop_error: bool = False, + worker_type: str = WorkerType.LLM.value, + stream_messags: List[str] = None, + embeddings: List[List[float]] = None, +) -> List[Tuple[ModelWorker, ModelWorkerParameters]]: + workers = [] + for i in range(num_workers): + model_name = f"test-model-name-{i}" + model_path = f"test-model-path-{i}" + model_parameters = ModelParameters(model_name=model_name, model_path=model_path) + worker = MockModelWorker( + model_parameters, + error_worker=error_worker, + stop_error=stop_error, + stream_messags=stream_messags, + embeddings=embeddings, + ) + worker_params = _new_worker_params( + model_name, model_path, worker_type=worker_type + ) + workers.append((worker, worker_params)) + return workers + + +@asynccontextmanager +async def _start_worker_manager(**kwargs): + register_func = kwargs.get("register_func") + deregister_func = kwargs.get("deregister_func") + send_heartbeat_func = kwargs.get("send_heartbeat_func") + model_registry = kwargs.get("model_registry") + workers = kwargs.get("workers") + num_workers = int(kwargs.get("num_workers", 0)) + start = kwargs.get("start", True) + stop = kwargs.get("stop", True) + error_worker = kwargs.get("error_worker", False) + stop_error = kwargs.get("stop_error", False) + stream_messags = kwargs.get("stream_messags", []) + embeddings = kwargs.get("embeddings", []) + + worker_manager = LocalWorkerManager( + register_func=register_func, + deregister_func=deregister_func, + send_heartbeat_func=send_heartbeat_func, + model_registry=model_registry, + ) + + for worker, worker_params in _create_workers( + num_workers, error_worker, stop_error, stream_messags, embeddings + ): + worker_manager.add_worker(worker, worker_params) + if workers: + for worker, worker_params in workers: + worker_manager.add_worker(worker, worker_params) + + if start: + await worker_manager.start() + + yield worker_manager + if stop: + await worker_manager.stop() + + +@pytest_asyncio.fixture +async def manager_2_workers(request): + param = getattr(request, "param", {}) + async with _start_worker_manager(num_workers=2, **param) as worker_manager: + yield worker_manager + + +@pytest_asyncio.fixture +async def manager_with_2_workers(request): + param = getattr(request, "param", {}) + workers = _create_workers(2, stream_messags=param.get("stream_messags", [])) + async with _start_worker_manager(workers=workers, **param) as worker_manager: + yield (worker_manager, workers) + + +@pytest_asyncio.fixture +async def manager_2_embedding_workers(request): + param = getattr(request, "param", {}) + workers = _create_workers( + 2, worker_type=WorkerType.TEXT2VEC.value, embeddings=param.get("embeddings", []) + ) + async with _start_worker_manager(workers=workers, **param) as worker_manager: + yield (worker_manager, workers) diff --git a/pilot/model/cluster/worker/tests/test_manager.py b/pilot/model/cluster/worker/tests/test_manager.py new file mode 100644 index 000000000..919e64f99 --- /dev/null +++ b/pilot/model/cluster/worker/tests/test_manager.py @@ -0,0 +1,486 @@ +from unittest.mock import patch, AsyncMock +import pytest +from typing import List, Iterator, Dict, Tuple +from dataclasses import asdict +from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType +from pilot.model.base import ModelOutput, WorkerApplyType +from pilot.model.cluster.base import WorkerApplyRequest, WorkerStartupRequest +from pilot.model.cluster.worker_base import ModelWorker +from pilot.model.cluster.manager_base import WorkerRunData +from pilot.model.cluster.worker.manager import ( + LocalWorkerManager, + RegisterFunc, + DeregisterFunc, + SendHeartbeatFunc, + ApplyFunction, +) +from pilot.model.cluster.worker.tests.base_tests import ( + MockModelWorker, + manager_2_workers, + manager_with_2_workers, + manager_2_embedding_workers, + _create_workers, + _start_worker_manager, + _new_worker_params, +) + + +_TEST_MODEL_NAME = "vicuna-13b-v1.5" +_TEST_MODEL_PATH = "/app/models/vicuna-13b-v1.5" + + +@pytest.fixture +def worker(): + mock_worker = _create_workers(1) + yield mock_worker[0][0] + + +@pytest.fixture +def worker_param(): + return _new_worker_params() + + +@pytest.fixture +def manager(request): + if not request or not hasattr(request, "param"): + register_func = None + deregister_func = None + send_heartbeat_func = None + model_registry = None + workers = [] + else: + register_func = request.param.get("register_func") + deregister_func = request.param.get("deregister_func") + send_heartbeat_func = request.param.get("send_heartbeat_func") + model_registry = request.param.get("model_registry") + workers = request.param.get("model_registry") + + worker_manager = LocalWorkerManager( + register_func=register_func, + deregister_func=deregister_func, + send_heartbeat_func=send_heartbeat_func, + model_registry=model_registry, + ) + yield worker_manager + + +@pytest.mark.asyncio +async def test_run_blocking_func(manager: LocalWorkerManager): + def f1() -> int: + return 0 + + def f2(a: int, b: int) -> int: + return a + b + + async def error_f3() -> None: + return 0 + + assert await manager.run_blocking_func(f1) == 0 + assert await manager.run_blocking_func(f2, 1, 2) == 3 + with pytest.raises(ValueError): + await manager.run_blocking_func(error_f3) + + +@pytest.mark.asyncio +async def test_add_worker( + manager: LocalWorkerManager, + worker: ModelWorker, + worker_param: ModelWorkerParameters, +): + # TODO test with register function + assert manager.add_worker(worker, worker_param) + # Add again + assert manager.add_worker(worker, worker_param) == False + key = manager._worker_key(worker_param.worker_type, worker_param.model_name) + assert len(manager.workers) == 1 + assert len(manager.workers[key]) == 1 + assert manager.workers[key][0].worker == worker + + assert manager.add_worker( + worker, + _new_worker_params( + model_name="chatglm2-6b", model_path="/app/models/chatglm2-6b" + ), + ) + assert ( + manager.add_worker( + worker, + _new_worker_params( + model_name="chatglm2-6b", model_path="/app/models/chatglm2-6b" + ), + ) + == False + ) + assert len(manager.workers) == 2 + + +@pytest.mark.asyncio +async def test__apply_worker(manager_2_workers: LocalWorkerManager): + manager = manager_2_workers + + async def f1(wr: WorkerRunData) -> int: + return 0 + + # Apply to all workers + assert await manager._apply_worker(None, apply_func=f1) == [0, 0] + + workers = _create_workers(4) + async with _start_worker_manager(workers=workers) as manager: + # Apply to single model + req = WorkerApplyRequest( + model=workers[0][1].model_name, + apply_type=WorkerApplyType.START, + worker_type=WorkerType.LLM, + ) + assert await manager._apply_worker(req, apply_func=f1) == [0] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("manager_2_workers", [{"start": False}], indirect=True) +async def test__start_all_worker(manager_2_workers: LocalWorkerManager): + manager = manager_2_workers + out = await manager._start_all_worker(None) + assert out.success + assert len(manager.workers) == 2 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "manager_2_workers, is_error_worker", + [ + ({"start": False, "error_worker": False}, False), + ({"start": False, "error_worker": True}, True), + ], + indirect=["manager_2_workers"], +) +async def test_start_worker_manager( + manager_2_workers: LocalWorkerManager, is_error_worker: bool +): + manager = manager_2_workers + if is_error_worker: + with pytest.raises(Exception): + await manager.start() + else: + await manager.start() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "manager_2_workers, is_stop_error", + [ + ({"stop": False, "stop_error": False}, False), + ({"stop": False, "stop_error": True}, True), + ], + indirect=["manager_2_workers"], +) +async def test__stop_all_worker( + manager_2_workers: LocalWorkerManager, is_stop_error: bool +): + manager = manager_2_workers + out = await manager._stop_all_worker(None) + if is_stop_error: + assert not out.success + else: + assert out.success + + +@pytest.mark.asyncio +async def test__restart_all_worker(manager_2_workers: LocalWorkerManager): + manager = manager_2_workers + out = await manager._restart_all_worker(None) + assert out.success + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "manager_2_workers, is_stop_error", + [ + ({"stop": False, "stop_error": False}, False), + ({"stop": False, "stop_error": True}, True), + ], + indirect=["manager_2_workers"], +) +async def test_stop_worker_manager( + manager_2_workers: LocalWorkerManager, is_stop_error: bool +): + manager = manager_2_workers + if is_stop_error: + with pytest.raises(Exception): + await manager.stop() + else: + await manager.stop() + + +@pytest.mark.asyncio +async def test__remove_worker(): + workers = _create_workers(3) + async with _start_worker_manager(workers=workers, stop=False) as manager: + assert len(manager.workers) == 3 + for _, worker_params in workers: + manager._remove_worker(worker_params) + not_exist_parmas = _new_worker_params( + model_name="this is a not exist worker params" + ) + manager._remove_worker(not_exist_parmas) + + +@pytest.mark.asyncio +@patch("pilot.model.cluster.worker.manager._build_worker") +async def test_model_startup(mock_build_worker): + async with _start_worker_manager() as manager: + workers = _create_workers(1) + worker, worker_params = workers[0] + mock_build_worker.return_value = worker + + req = WorkerStartupRequest( + host="127.0.0.1", + port=8001, + model=worker_params.model_name, + worker_type=WorkerType.LLM, + params=asdict(worker_params), + ) + await manager.model_startup(req) + with pytest.raises(Exception): + await manager.model_startup(req) + + async with _start_worker_manager() as manager: + workers = _create_workers(1, error_worker=True) + worker, worker_params = workers[0] + mock_build_worker.return_value = worker + req = WorkerStartupRequest( + host="127.0.0.1", + port=8001, + model=worker_params.model_name, + worker_type=WorkerType.LLM, + params=asdict(worker_params), + ) + with pytest.raises(Exception): + await manager.model_startup(req) + + +@pytest.mark.asyncio +@patch("pilot.model.cluster.worker.manager._build_worker") +async def test_model_shutdown(mock_build_worker): + async with _start_worker_manager(start=False, stop=False) as manager: + workers = _create_workers(1) + worker, worker_params = workers[0] + mock_build_worker.return_value = worker + + req = WorkerStartupRequest( + host="127.0.0.1", + port=8001, + model=worker_params.model_name, + worker_type=WorkerType.LLM, + params=asdict(worker_params), + ) + await manager.model_startup(req) + await manager.model_shutdown(req) + + +@pytest.mark.asyncio +async def test_supported_models(manager_2_workers: LocalWorkerManager): + manager = manager_2_workers + models = await manager.supported_models() + assert len(models) == 1 + models = models[0].models + assert len(models) > 10 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "is_async", + [ + True, + False, + ], +) +async def test_get_model_instances(is_async): + workers = _create_workers(3) + async with _start_worker_manager(workers=workers, stop=False) as manager: + assert len(manager.workers) == 3 + for _, worker_params in workers: + model_name = worker_params.model_name + worker_type = worker_params.worker_type + if is_async: + assert ( + len(await manager.get_model_instances(worker_type, model_name)) == 1 + ) + else: + assert ( + len(manager.sync_get_model_instances(worker_type, model_name)) == 1 + ) + if is_async: + assert not await manager.get_model_instances( + worker_type, "this is not exist model instances" + ) + else: + assert not manager.sync_get_model_instances( + worker_type, "this is not exist model instances" + ) + + +@pytest.mark.asyncio +async def test__simple_select( + manager_with_2_workers: Tuple[ + LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]] + ] +): + manager, workers = manager_with_2_workers + for _, worker_params in workers: + model_name = worker_params.model_name + worker_type = worker_params.worker_type + instances = await manager.get_model_instances(worker_type, model_name) + assert instances + inst = manager._simple_select(worker_params.worker_type, model_name, instances) + assert inst is not None + assert inst.worker_params == worker_params + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "is_async", + [ + True, + False, + ], +) +async def test_select_one_instance( + is_async: bool, + manager_with_2_workers: Tuple[ + LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]] + ], +): + manager, workers = manager_with_2_workers + for _, worker_params in workers: + model_name = worker_params.model_name + worker_type = worker_params.worker_type + if is_async: + inst = await manager.select_one_instance(worker_type, model_name) + else: + inst = manager.sync_select_one_instance(worker_type, model_name) + assert inst is not None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "is_async", + [ + True, + False, + ], +) +async def test__get_model( + is_async: bool, + manager_with_2_workers: Tuple[ + LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]] + ], +): + manager, workers = manager_with_2_workers + for _, worker_params in workers: + model_name = worker_params.model_name + worker_type = worker_params.worker_type + params = {"model": model_name} + if is_async: + wr = await manager._get_model(params, worker_type=worker_type) + else: + wr = manager._sync_get_model(params, worker_type=worker_type) + assert wr is not None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "manager_with_2_workers, expected_messages", + [ + ({"stream_messags": ["Hello", " world."]}, "Hello world."), + ({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。"), + ], + indirect=["manager_with_2_workers"], +) +async def test_generate_stream( + manager_with_2_workers: Tuple[ + LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]] + ], + expected_messages: str, +): + manager, workers = manager_with_2_workers + for _, worker_params in workers: + model_name = worker_params.model_name + worker_type = worker_params.worker_type + params = {"model": model_name} + text = "" + async for out in manager.generate_stream(params): + text += out.text + assert text == expected_messages + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "manager_with_2_workers, expected_messages", + [ + ({"stream_messags": ["Hello", " world."]}, " world."), + ({"stream_messags": ["你好,我是", "张三。"]}, "张三。"), + ], + indirect=["manager_with_2_workers"], +) +async def test_generate( + manager_with_2_workers: Tuple[ + LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]] + ], + expected_messages: str, +): + manager, workers = manager_with_2_workers + for _, worker_params in workers: + model_name = worker_params.model_name + worker_type = worker_params.worker_type + params = {"model": model_name} + out = await manager.generate(params) + assert out.text == expected_messages + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "manager_2_embedding_workers, expected_embedding, is_async", + [ + ({"embeddings": [[1, 2, 3], [4, 5, 6]]}, [[1, 2, 3], [4, 5, 6]], True), + ({"embeddings": [[0, 0, 0], [1, 1, 1]]}, [[0, 0, 0], [1, 1, 1]], False), + ], + indirect=["manager_2_embedding_workers"], +) +async def test_embeddings( + manager_2_embedding_workers: Tuple[ + LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]] + ], + expected_embedding: List[List[int]], + is_async: bool, +): + manager, workers = manager_2_embedding_workers + for _, worker_params in workers: + model_name = worker_params.model_name + worker_type = worker_params.worker_type + params = {"model": model_name, "input": ["hello", "world"]} + if is_async: + out = await manager.embeddings(params) + else: + out = manager.sync_embeddings(params) + assert out == expected_embedding + + +@pytest.mark.asyncio +async def test_parameter_descriptions( + manager_with_2_workers: Tuple[ + LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]] + ] +): + manager, workers = manager_with_2_workers + for _, worker_params in workers: + model_name = worker_params.model_name + worker_type = worker_params.worker_type + params = await manager.parameter_descriptions(worker_type, model_name) + assert params is not None + assert len(params) > 5 + + +@pytest.mark.asyncio +async def test__update_all_worker_params(): + # TODO + pass diff --git a/pilot/model/parameter.py b/pilot/model/parameter.py index 0ad048c24..ba0000435 100644 --- a/pilot/model/parameter.py +++ b/pilot/model/parameter.py @@ -46,6 +46,18 @@ class ModelControllerParameters(BaseParameters): ], }, ) + log_file: Optional[str] = field( + default="dbgpt_model_controller.log", + metadata={ + "help": "The filename to store log", + }, + ) + tracer_file: Optional[str] = field( + default="dbgpt_model_controller_tracer.jsonl", + metadata={ + "help": "The filename to store tracer span records", + }, + ) @dataclass @@ -122,6 +134,18 @@ class ModelWorkerParameters(BaseModelParameters): ], }, ) + log_file: Optional[str] = field( + default="dbgpt_model_worker_manager.log", + metadata={ + "help": "The filename to store log", + }, + ) + tracer_file: Optional[str] = field( + default="dbgpt_model_worker_manager_tracer.jsonl", + metadata={ + "help": "The filename to store tracer span records", + }, + ) @dataclass diff --git a/pilot/server/base.py b/pilot/server/base.py index 3b2d7010b..d34b14b28 100644 --- a/pilot/server/base.py +++ b/pilot/server/base.py @@ -95,6 +95,19 @@ class WebWerverParameters(BaseParameters): daemon: Optional[bool] = field( default=False, metadata={"help": "Run Webserver in background"} ) + controller_addr: Optional[str] = field( + default=None, + metadata={ + "help": "The Model controller address to connect. If None, read model controller address from environment key `MODEL_SERVER`." + }, + ) + model_name: str = field( + default=None, + metadata={ + "help": "The default model name to use. If None, read model name from environment key `LLM_MODEL`.", + "tags": "fixed", + }, + ) share: Optional[bool] = field( default=False, metadata={ @@ -123,3 +136,15 @@ class WebWerverParameters(BaseParameters): }, ) light: Optional[bool] = field(default=False, metadata={"help": "enable light mode"}) + log_file: Optional[str] = field( + default="dbgpt_webserver.log", + metadata={ + "help": "The filename to store log", + }, + ) + tracer_file: Optional[str] = field( + default="dbgpt_webserver_tracer.jsonl", + metadata={ + "help": "The filename to store tracer span records", + }, + ) diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index 2b35eaf10..6762fd32a 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -119,7 +119,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None): if not param.log_level: param.log_level = _get_logging_level() setup_logging( - "pilot", logging_level=param.log_level, logger_filename="dbgpt_webserver.log" + "pilot", logging_level=param.log_level, logger_filename=param.log_file ) # Before start system_app.before_start() @@ -133,14 +133,16 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None): model_start_listener = _create_model_start_listener(system_app) initialize_components(param, system_app, embedding_model_name, embedding_model_path) - model_path = LLM_MODEL_CONFIG.get(CFG.LLM_MODEL) + model_name = param.model_name or CFG.LLM_MODEL + + model_path = LLM_MODEL_CONFIG.get(model_name) if not param.light: print("Model Unified Deployment Mode!") if not param.remote_embedding: embedding_model_name, embedding_model_path = None, None initialize_worker_manager_in_client( app=app, - model_name=CFG.LLM_MODEL, + model_name=model_name, model_path=model_path, local_port=param.port, embedding_model_name=embedding_model_name, @@ -152,12 +154,13 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None): CFG.NEW_SERVER_MODE = True else: # MODEL_SERVER is controller address now + controller_addr = param.controller_addr or CFG.MODEL_SERVER initialize_worker_manager_in_client( app=app, - model_name=CFG.LLM_MODEL, + model_name=model_name, model_path=model_path, run_locally=False, - controller_addr=CFG.MODEL_SERVER, + controller_addr=controller_addr, local_port=param.port, start_listener=model_start_listener, system_app=system_app, @@ -182,7 +185,7 @@ def run_uvicorn(param: WebWerverParameters): def run_webserver(param: WebWerverParameters = None): if not param: param = _get_webserver_params() - initialize_tracer(system_app, os.path.join(LOGDIR, "dbgpt_webserver_tracer.jsonl")) + initialize_tracer(system_app, os.path.join(LOGDIR, param.tracer_file)) with root_tracer.start_span( "run_webserver", diff --git a/pilot/utils/tracer/span_storage.py b/pilot/utils/tracer/span_storage.py index 914aa0126..3070fb834 100644 --- a/pilot/utils/tracer/span_storage.py +++ b/pilot/utils/tracer/span_storage.py @@ -1,6 +1,7 @@ import os import json import time +import datetime import threading import queue import logging @@ -27,6 +28,13 @@ class FileSpanStorage(SpanStorage): def __init__(self, filename: str, batch_size=10, flush_interval=10): super().__init__() self.filename = filename + # Split filename into prefix and suffix + self.filename_prefix, self.filename_suffix = os.path.splitext(filename) + if not self.filename_suffix: + self.filename_suffix = ".log" + self.last_date = ( + datetime.datetime.now().date() + ) # Store the current date for checking date changes self.queue = queue.Queue() self.batch_size = batch_size self.flush_interval = flush_interval @@ -52,7 +60,21 @@ class FileSpanStorage(SpanStorage): except queue.Full: pass # If the signal queue is full, it's okay. The flush thread will handle it. + def _get_dated_filename(self, date: datetime.date) -> str: + """Return the filename based on a specific date.""" + date_str = date.strftime("%Y-%m-%d") + return f"{self.filename_prefix}_{date_str}{self.filename_suffix}" + + def _roll_over_if_needed(self): + """Checks if a day has changed since the last write, and if so, renames the current file.""" + current_date = datetime.datetime.now().date() + if current_date != self.last_date: + if os.path.exists(self.filename): + os.rename(self.filename, self._get_dated_filename(self.last_date)) + self.last_date = current_date + def _write_to_file(self): + self._roll_over_if_needed() spans_to_write = [] while not self.queue.empty(): spans_to_write.append(self.queue.get()) diff --git a/pilot/utils/tracer/tests/test_span_storage.py b/pilot/utils/tracer/tests/test_span_storage.py index 0c63992a6..9ca727995 100644 --- a/pilot/utils/tracer/tests/test_span_storage.py +++ b/pilot/utils/tracer/tests/test_span_storage.py @@ -4,6 +4,8 @@ import asyncio import json import tempfile import time +from unittest.mock import patch +from datetime import datetime, timedelta from pilot.utils.tracer import SpanStorage, FileSpanStorage, Span, SpanType @@ -122,3 +124,46 @@ def test_non_existent_file(storage: SpanStorage): assert len(spans_in_file) == 2 assert spans_in_file[0]["trace_id"] == "1" assert spans_in_file[1]["trace_id"] == "2" + + +@pytest.mark.parametrize( + "storage", [{"batch_size": 1, "file_does_not_exist": True}], indirect=True +) +def test_log_rollover(storage: SpanStorage): + # mock start date + mock_start_date = datetime(2023, 10, 18, 23, 59) + + with patch("datetime.datetime") as mock_datetime: + mock_datetime.now.return_value = mock_start_date + + span1 = Span("1", "a", SpanType.BASE, "b", "op1") + storage.append_span(span1) + time.sleep(0.1) + + # mock new day + mock_datetime.now.return_value = mock_start_date + timedelta(minutes=1) + + span2 = Span("2", "c", SpanType.BASE, "d", "op2") + storage.append_span(span2) + time.sleep(0.1) + + # origin filename need exists + assert os.path.exists(storage.filename) + + # get roll over filename + dated_filename = os.path.join( + os.path.dirname(storage.filename), + f"{os.path.basename(storage.filename).split('.')[0]}_2023-10-18.jsonl", + ) + + assert os.path.exists(dated_filename) + + # check origin filename just include the second span + spans_in_original_file = read_spans_from_file(storage.filename) + assert len(spans_in_original_file) == 1 + assert spans_in_original_file[0]["trace_id"] == "2" + + # check the roll over filename just include the first span + spans_in_dated_file = read_spans_from_file(dated_filename) + assert len(spans_in_dated_file) == 1 + assert spans_in_dated_file[0]["trace_id"] == "1"