mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-19 08:47:32 +00:00
test(model): Add unit test for worker manager (#681)
- Add unit test for worker manager. - Fix model start error after start failed last time.
This commit is contained in:
commit
9fa0f45264
@ -64,6 +64,20 @@ class WorkerApplyOutput:
|
||||
# The seconds cost to apply some action to worker instances
|
||||
timecost: Optional[int] = -1
|
||||
|
||||
@staticmethod
|
||||
def reduce(outs: List["WorkerApplyOutput"]) -> "WorkerApplyOutput":
|
||||
"""Merge all outputs
|
||||
|
||||
Args:
|
||||
outs (List["WorkerApplyOutput"]): The list of WorkerApplyOutput
|
||||
"""
|
||||
if not outs:
|
||||
return WorkerApplyOutput("Not outputs")
|
||||
combined_success = all(out.success for out in outs)
|
||||
max_timecost = max(out.timecost for out in outs)
|
||||
combined_message = ", ".join(out.message for out in outs)
|
||||
return WorkerApplyOutput(combined_message, combined_success, max_timecost)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SupportedModel:
|
||||
|
@ -26,11 +26,22 @@ class WorkerRunData:
|
||||
_heartbeat_future: Optional[Future] = None
|
||||
_last_heartbeat: Optional[datetime] = None
|
||||
|
||||
def _to_print_key(self):
|
||||
model_name = self.model_params.model_name
|
||||
model_type = self.model_params.model_type
|
||||
host = self.host
|
||||
port = self.port
|
||||
return f"model {model_name}@{model_type}({host}:{port})"
|
||||
|
||||
|
||||
class WorkerManager(ABC):
|
||||
@abstractmethod
|
||||
async def start(self):
|
||||
"""Start worker manager"""
|
||||
"""Start worker manager
|
||||
|
||||
Raises:
|
||||
Exception: if start worker manager not successfully
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self, ignore_exception: bool = False):
|
||||
@ -69,11 +80,11 @@ class WorkerManager(ABC):
|
||||
"""List supported models"""
|
||||
|
||||
@abstractmethod
|
||||
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
|
||||
async def model_startup(self, startup_req: WorkerStartupRequest):
|
||||
"""Create and start a model instance"""
|
||||
|
||||
@abstractmethod
|
||||
async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
|
||||
async def model_shutdown(self, shutdown_req: WorkerStartupRequest):
|
||||
"""Shutdown model instance"""
|
||||
|
||||
@abstractmethod
|
||||
|
@ -104,12 +104,16 @@ class LocalWorkerManager(WorkerManager):
|
||||
return f"{model_name}@{worker_type}"
|
||||
|
||||
async def run_blocking_func(self, func, *args):
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
raise ValueError(f"The function {func} is not blocking function")
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(self.executor, func, *args)
|
||||
|
||||
async def start(self):
|
||||
if len(self.workers) > 0:
|
||||
await self._start_all_worker(apply_req=None)
|
||||
out = await self._start_all_worker(apply_req=None)
|
||||
if not out.success:
|
||||
raise Exception(out.message)
|
||||
if self.register_func:
|
||||
await self.register_func(self.run_data)
|
||||
if self.send_heartbeat_func:
|
||||
@ -143,7 +147,9 @@ class LocalWorkerManager(WorkerManager):
|
||||
else:
|
||||
stop_tasks.append(self.deregister_func(self.run_data))
|
||||
|
||||
await asyncio.gather(*stop_tasks)
|
||||
results = await asyncio.gather(*stop_tasks)
|
||||
if not results[0].success and not ignore_exception:
|
||||
raise Exception(results[0].message)
|
||||
|
||||
def after_start(self, listener: Callable[["WorkerManager"], None]):
|
||||
self.start_listeners.append(listener)
|
||||
@ -193,7 +199,15 @@ class LocalWorkerManager(WorkerManager):
|
||||
logger.warn(f"Instance {worker_key} exist")
|
||||
return False
|
||||
|
||||
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
|
||||
def _remove_worker(self, worker_params: ModelWorkerParameters) -> None:
|
||||
worker_key = self._worker_key(
|
||||
worker_params.worker_type, worker_params.model_name
|
||||
)
|
||||
instances = self.workers.get(worker_key)
|
||||
if instances:
|
||||
del self.workers[worker_key]
|
||||
|
||||
async def model_startup(self, startup_req: WorkerStartupRequest):
|
||||
"""Start model"""
|
||||
model_name = startup_req.model
|
||||
worker_type = startup_req.worker_type
|
||||
@ -213,22 +227,30 @@ class LocalWorkerManager(WorkerManager):
|
||||
self.add_worker, worker, worker_params, command_args
|
||||
)
|
||||
if not success:
|
||||
logger.warn(
|
||||
f"Add worker failed, worker instances is exist, worker_params: {worker_params}"
|
||||
)
|
||||
return False
|
||||
msg = f"Add worker {model_name}@{worker_type}, worker instances is exist"
|
||||
logger.warn(f"{msg}, worker_params: {worker_params}")
|
||||
self._remove_worker(worker_params)
|
||||
raise Exception(msg)
|
||||
supported_types = WorkerType.values()
|
||||
if worker_type not in supported_types:
|
||||
self._remove_worker(worker_params)
|
||||
raise ValueError(
|
||||
f"Unsupported worker type: {worker_type}, now supported worker type: {supported_types}"
|
||||
)
|
||||
start_apply_req = WorkerApplyRequest(
|
||||
model=model_name, apply_type=WorkerApplyType.START, worker_type=worker_type
|
||||
)
|
||||
await self.worker_apply(start_apply_req)
|
||||
return True
|
||||
out: WorkerApplyOutput = None
|
||||
try:
|
||||
out = await self.worker_apply(start_apply_req)
|
||||
except Exception as e:
|
||||
self._remove_worker(worker_params)
|
||||
raise e
|
||||
if not out.success:
|
||||
self._remove_worker(worker_params)
|
||||
raise Exception(out.message)
|
||||
|
||||
async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
|
||||
async def model_shutdown(self, shutdown_req: WorkerStartupRequest):
|
||||
logger.info(f"Begin shutdown model, shutdown_req: {shutdown_req}")
|
||||
apply_req = WorkerApplyRequest(
|
||||
model=shutdown_req.model,
|
||||
@ -236,8 +258,7 @@ class LocalWorkerManager(WorkerManager):
|
||||
worker_type=shutdown_req.worker_type,
|
||||
)
|
||||
out = await self._stop_all_worker(apply_req)
|
||||
if out.success:
|
||||
return True
|
||||
if not out.success:
|
||||
raise Exception(out.message)
|
||||
|
||||
async def supported_models(self) -> List[WorkerSupportedModel]:
|
||||
@ -253,7 +274,7 @@ class LocalWorkerManager(WorkerManager):
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
worker_key = self._worker_key(worker_type, model_name)
|
||||
return self.workers.get(worker_key)
|
||||
return self.workers.get(worker_key, [])
|
||||
|
||||
def _simple_select(
|
||||
self, worker_type: str, model_name: str, worker_instances: List[WorkerRunData]
|
||||
@ -424,10 +445,15 @@ class LocalWorkerManager(WorkerManager):
|
||||
async def _start_all_worker(
|
||||
self, apply_req: WorkerApplyRequest
|
||||
) -> WorkerApplyOutput:
|
||||
# TODO avoid start twice
|
||||
start_time = time.time()
|
||||
logger.info(f"Begin start all worker, apply_req: {apply_req}")
|
||||
|
||||
async def _start_worker(worker_run_data: WorkerRunData):
|
||||
_start_time = time.time()
|
||||
info = worker_run_data._to_print_key()
|
||||
out = WorkerApplyOutput("")
|
||||
try:
|
||||
await self.run_blocking_func(
|
||||
worker_run_data.worker.start,
|
||||
worker_run_data.model_params,
|
||||
@ -448,12 +474,18 @@ class LocalWorkerManager(WorkerManager):
|
||||
self.send_heartbeat_func,
|
||||
)
|
||||
)
|
||||
out.message = f"{info} start successfully"
|
||||
except Exception as e:
|
||||
out.success = False
|
||||
out.message = f"{info} start failed, {str(e)}"
|
||||
finally:
|
||||
out.timecost = time.time() - _start_time
|
||||
return out
|
||||
|
||||
await self._apply_worker(apply_req, _start_worker)
|
||||
timecost = time.time() - start_time
|
||||
return WorkerApplyOutput(
|
||||
message=f"Worker started successfully", timecost=timecost
|
||||
)
|
||||
outs = await self._apply_worker(apply_req, _start_worker)
|
||||
out = WorkerApplyOutput.reduce(outs)
|
||||
out.timecost = time.time() - start_time
|
||||
return out
|
||||
|
||||
async def _stop_all_worker(
|
||||
self, apply_req: WorkerApplyRequest, ignore_exception: bool = False
|
||||
@ -461,6 +493,10 @@ class LocalWorkerManager(WorkerManager):
|
||||
start_time = time.time()
|
||||
|
||||
async def _stop_worker(worker_run_data: WorkerRunData):
|
||||
_start_time = time.time()
|
||||
info = worker_run_data._to_print_key()
|
||||
out = WorkerApplyOutput("")
|
||||
try:
|
||||
await self.run_blocking_func(worker_run_data.worker.stop)
|
||||
# Set stop event
|
||||
worker_run_data.stop_event.set()
|
||||
@ -486,17 +522,27 @@ class LocalWorkerManager(WorkerManager):
|
||||
|
||||
_deregister_func = safe_deregister_func
|
||||
await _deregister_func(worker_run_data)
|
||||
# Remove metadata
|
||||
self._remove_worker(worker_run_data.worker_params)
|
||||
out.message = f"{info} stop successfully"
|
||||
except Exception as e:
|
||||
out.success = False
|
||||
out.message = f"{info} stop failed, {str(e)}"
|
||||
finally:
|
||||
out.timecost = time.time() - _start_time
|
||||
return out
|
||||
|
||||
await self._apply_worker(apply_req, _stop_worker)
|
||||
timecost = time.time() - start_time
|
||||
return WorkerApplyOutput(
|
||||
message=f"Worker stopped successfully", timecost=timecost
|
||||
)
|
||||
outs = await self._apply_worker(apply_req, _stop_worker)
|
||||
out = WorkerApplyOutput.reduce(outs)
|
||||
out.timecost = time.time() - start_time
|
||||
return out
|
||||
|
||||
async def _restart_all_worker(
|
||||
self, apply_req: WorkerApplyRequest
|
||||
) -> WorkerApplyOutput:
|
||||
await self._stop_all_worker(apply_req)
|
||||
out = await self._stop_all_worker(apply_req, ignore_exception=True)
|
||||
if not out.success:
|
||||
return out
|
||||
return await self._start_all_worker(apply_req)
|
||||
|
||||
async def _update_all_worker_params(
|
||||
@ -541,10 +587,10 @@ class WorkerManagerAdapter(WorkerManager):
|
||||
async def supported_models(self) -> List[WorkerSupportedModel]:
|
||||
return await self.worker_manager.supported_models()
|
||||
|
||||
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
|
||||
async def model_startup(self, startup_req: WorkerStartupRequest):
|
||||
return await self.worker_manager.model_startup(startup_req)
|
||||
|
||||
async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
|
||||
async def model_shutdown(self, shutdown_req: WorkerStartupRequest):
|
||||
return await self.worker_manager.model_shutdown(shutdown_req)
|
||||
|
||||
async def get_model_instances(
|
||||
|
@ -96,7 +96,7 @@ class RemoteWorkerManager(LocalWorkerManager):
|
||||
raise Exception(error_msg)
|
||||
return worker_instances
|
||||
|
||||
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
|
||||
async def model_startup(self, startup_req: WorkerStartupRequest):
|
||||
worker_instances = await self._get_worker_service_instance(
|
||||
startup_req.host, startup_req.port
|
||||
)
|
||||
@ -107,10 +107,10 @@ class RemoteWorkerManager(LocalWorkerManager):
|
||||
"/models/startup",
|
||||
method="POST",
|
||||
json=startup_req.dict(),
|
||||
success_handler=lambda x: True,
|
||||
success_handler=lambda x: None,
|
||||
)
|
||||
|
||||
async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
|
||||
async def model_shutdown(self, shutdown_req: WorkerStartupRequest):
|
||||
worker_instances = await self._get_worker_service_instance(
|
||||
shutdown_req.host, shutdown_req.port
|
||||
)
|
||||
@ -121,7 +121,7 @@ class RemoteWorkerManager(LocalWorkerManager):
|
||||
"/models/shutdown",
|
||||
method="POST",
|
||||
json=shutdown_req.dict(),
|
||||
success_handler=lambda x: True,
|
||||
success_handler=lambda x: None,
|
||||
)
|
||||
|
||||
def _build_worker_instances(
|
||||
|
0
pilot/model/cluster/worker/tests/__init__.py
Normal file
0
pilot/model/cluster/worker/tests/__init__.py
Normal file
168
pilot/model/cluster/worker/tests/base_tests.py
Normal file
168
pilot/model/cluster/worker/tests/base_tests.py
Normal file
@ -0,0 +1,168 @@
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from contextlib import contextmanager, asynccontextmanager
|
||||
from typing import List, Iterator, Dict, Tuple
|
||||
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
|
||||
from pilot.model.base import ModelOutput
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.model.cluster.worker.manager import (
|
||||
LocalWorkerManager,
|
||||
RegisterFunc,
|
||||
DeregisterFunc,
|
||||
SendHeartbeatFunc,
|
||||
ApplyFunction,
|
||||
)
|
||||
|
||||
|
||||
class MockModelWorker(ModelWorker):
|
||||
def __init__(
|
||||
self,
|
||||
model_parameters: ModelParameters,
|
||||
error_worker: bool = False,
|
||||
stop_error: bool = False,
|
||||
stream_messags: List[str] = None,
|
||||
embeddings: List[List[float]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if not stream_messags:
|
||||
stream_messags = []
|
||||
if not embeddings:
|
||||
embeddings = []
|
||||
self.model_parameters = model_parameters
|
||||
self.error_worker = error_worker
|
||||
self.stop_error = stop_error
|
||||
self.stream_messags = stream_messags
|
||||
self._embeddings = embeddings
|
||||
|
||||
def parse_parameters(self, command_args: List[str] = None) -> ModelParameters:
|
||||
return self.model_parameters
|
||||
|
||||
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def start(
|
||||
self, model_params: ModelParameters = None, command_args: List[str] = None
|
||||
) -> None:
|
||||
if self.error_worker:
|
||||
raise Exception("Start worker error for mock")
|
||||
|
||||
def stop(self) -> None:
|
||||
if self.stop_error:
|
||||
raise Exception("Stop worker error for mock")
|
||||
|
||||
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||
for msg in self.stream_messags:
|
||||
yield ModelOutput(text=msg, error_code=0)
|
||||
|
||||
def generate(self, params: Dict) -> ModelOutput:
|
||||
output = None
|
||||
for out in self.generate_stream(params):
|
||||
output = out
|
||||
return output
|
||||
|
||||
def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
return self._embeddings
|
||||
|
||||
|
||||
_TEST_MODEL_NAME = "vicuna-13b-v1.5"
|
||||
_TEST_MODEL_PATH = "/app/models/vicuna-13b-v1.5"
|
||||
|
||||
|
||||
def _new_worker_params(
|
||||
model_name: str = _TEST_MODEL_NAME,
|
||||
model_path: str = _TEST_MODEL_PATH,
|
||||
worker_type: str = WorkerType.LLM.value,
|
||||
) -> ModelWorkerParameters:
|
||||
return ModelWorkerParameters(
|
||||
model_name=model_name, model_path=model_path, worker_type=worker_type
|
||||
)
|
||||
|
||||
|
||||
def _create_workers(
|
||||
num_workers: int,
|
||||
error_worker: bool = False,
|
||||
stop_error: bool = False,
|
||||
worker_type: str = WorkerType.LLM.value,
|
||||
stream_messags: List[str] = None,
|
||||
embeddings: List[List[float]] = None,
|
||||
) -> List[Tuple[ModelWorker, ModelWorkerParameters]]:
|
||||
workers = []
|
||||
for i in range(num_workers):
|
||||
model_name = f"test-model-name-{i}"
|
||||
model_path = f"test-model-path-{i}"
|
||||
model_parameters = ModelParameters(model_name=model_name, model_path=model_path)
|
||||
worker = MockModelWorker(
|
||||
model_parameters,
|
||||
error_worker=error_worker,
|
||||
stop_error=stop_error,
|
||||
stream_messags=stream_messags,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
worker_params = _new_worker_params(
|
||||
model_name, model_path, worker_type=worker_type
|
||||
)
|
||||
workers.append((worker, worker_params))
|
||||
return workers
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _start_worker_manager(**kwargs):
|
||||
register_func = kwargs.get("register_func")
|
||||
deregister_func = kwargs.get("deregister_func")
|
||||
send_heartbeat_func = kwargs.get("send_heartbeat_func")
|
||||
model_registry = kwargs.get("model_registry")
|
||||
workers = kwargs.get("workers")
|
||||
num_workers = int(kwargs.get("num_workers", 0))
|
||||
start = kwargs.get("start", True)
|
||||
stop = kwargs.get("stop", True)
|
||||
error_worker = kwargs.get("error_worker", False)
|
||||
stop_error = kwargs.get("stop_error", False)
|
||||
stream_messags = kwargs.get("stream_messags", [])
|
||||
embeddings = kwargs.get("embeddings", [])
|
||||
|
||||
worker_manager = LocalWorkerManager(
|
||||
register_func=register_func,
|
||||
deregister_func=deregister_func,
|
||||
send_heartbeat_func=send_heartbeat_func,
|
||||
model_registry=model_registry,
|
||||
)
|
||||
|
||||
for worker, worker_params in _create_workers(
|
||||
num_workers, error_worker, stop_error, stream_messags, embeddings
|
||||
):
|
||||
worker_manager.add_worker(worker, worker_params)
|
||||
if workers:
|
||||
for worker, worker_params in workers:
|
||||
worker_manager.add_worker(worker, worker_params)
|
||||
|
||||
if start:
|
||||
await worker_manager.start()
|
||||
|
||||
yield worker_manager
|
||||
if stop:
|
||||
await worker_manager.stop()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def manager_2_workers(request):
|
||||
param = getattr(request, "param", {})
|
||||
async with _start_worker_manager(num_workers=2, **param) as worker_manager:
|
||||
yield worker_manager
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def manager_with_2_workers(request):
|
||||
param = getattr(request, "param", {})
|
||||
workers = _create_workers(2, stream_messags=param.get("stream_messags", []))
|
||||
async with _start_worker_manager(workers=workers, **param) as worker_manager:
|
||||
yield (worker_manager, workers)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def manager_2_embedding_workers(request):
|
||||
param = getattr(request, "param", {})
|
||||
workers = _create_workers(
|
||||
2, worker_type=WorkerType.TEXT2VEC.value, embeddings=param.get("embeddings", [])
|
||||
)
|
||||
async with _start_worker_manager(workers=workers, **param) as worker_manager:
|
||||
yield (worker_manager, workers)
|
486
pilot/model/cluster/worker/tests/test_manager.py
Normal file
486
pilot/model/cluster/worker/tests/test_manager.py
Normal file
@ -0,0 +1,486 @@
|
||||
from unittest.mock import patch, AsyncMock
|
||||
import pytest
|
||||
from typing import List, Iterator, Dict, Tuple
|
||||
from dataclasses import asdict
|
||||
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
|
||||
from pilot.model.base import ModelOutput, WorkerApplyType
|
||||
from pilot.model.cluster.base import WorkerApplyRequest, WorkerStartupRequest
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.model.cluster.manager_base import WorkerRunData
|
||||
from pilot.model.cluster.worker.manager import (
|
||||
LocalWorkerManager,
|
||||
RegisterFunc,
|
||||
DeregisterFunc,
|
||||
SendHeartbeatFunc,
|
||||
ApplyFunction,
|
||||
)
|
||||
from pilot.model.cluster.worker.tests.base_tests import (
|
||||
MockModelWorker,
|
||||
manager_2_workers,
|
||||
manager_with_2_workers,
|
||||
manager_2_embedding_workers,
|
||||
_create_workers,
|
||||
_start_worker_manager,
|
||||
_new_worker_params,
|
||||
)
|
||||
|
||||
|
||||
_TEST_MODEL_NAME = "vicuna-13b-v1.5"
|
||||
_TEST_MODEL_PATH = "/app/models/vicuna-13b-v1.5"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def worker():
|
||||
mock_worker = _create_workers(1)
|
||||
yield mock_worker[0][0]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def worker_param():
|
||||
return _new_worker_params()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager(request):
|
||||
if not request or not hasattr(request, "param"):
|
||||
register_func = None
|
||||
deregister_func = None
|
||||
send_heartbeat_func = None
|
||||
model_registry = None
|
||||
workers = []
|
||||
else:
|
||||
register_func = request.param.get("register_func")
|
||||
deregister_func = request.param.get("deregister_func")
|
||||
send_heartbeat_func = request.param.get("send_heartbeat_func")
|
||||
model_registry = request.param.get("model_registry")
|
||||
workers = request.param.get("model_registry")
|
||||
|
||||
worker_manager = LocalWorkerManager(
|
||||
register_func=register_func,
|
||||
deregister_func=deregister_func,
|
||||
send_heartbeat_func=send_heartbeat_func,
|
||||
model_registry=model_registry,
|
||||
)
|
||||
yield worker_manager
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_blocking_func(manager: LocalWorkerManager):
|
||||
def f1() -> int:
|
||||
return 0
|
||||
|
||||
def f2(a: int, b: int) -> int:
|
||||
return a + b
|
||||
|
||||
async def error_f3() -> None:
|
||||
return 0
|
||||
|
||||
assert await manager.run_blocking_func(f1) == 0
|
||||
assert await manager.run_blocking_func(f2, 1, 2) == 3
|
||||
with pytest.raises(ValueError):
|
||||
await manager.run_blocking_func(error_f3)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_worker(
|
||||
manager: LocalWorkerManager,
|
||||
worker: ModelWorker,
|
||||
worker_param: ModelWorkerParameters,
|
||||
):
|
||||
# TODO test with register function
|
||||
assert manager.add_worker(worker, worker_param)
|
||||
# Add again
|
||||
assert manager.add_worker(worker, worker_param) == False
|
||||
key = manager._worker_key(worker_param.worker_type, worker_param.model_name)
|
||||
assert len(manager.workers) == 1
|
||||
assert len(manager.workers[key]) == 1
|
||||
assert manager.workers[key][0].worker == worker
|
||||
|
||||
assert manager.add_worker(
|
||||
worker,
|
||||
_new_worker_params(
|
||||
model_name="chatglm2-6b", model_path="/app/models/chatglm2-6b"
|
||||
),
|
||||
)
|
||||
assert (
|
||||
manager.add_worker(
|
||||
worker,
|
||||
_new_worker_params(
|
||||
model_name="chatglm2-6b", model_path="/app/models/chatglm2-6b"
|
||||
),
|
||||
)
|
||||
== False
|
||||
)
|
||||
assert len(manager.workers) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test__apply_worker(manager_2_workers: LocalWorkerManager):
|
||||
manager = manager_2_workers
|
||||
|
||||
async def f1(wr: WorkerRunData) -> int:
|
||||
return 0
|
||||
|
||||
# Apply to all workers
|
||||
assert await manager._apply_worker(None, apply_func=f1) == [0, 0]
|
||||
|
||||
workers = _create_workers(4)
|
||||
async with _start_worker_manager(workers=workers) as manager:
|
||||
# Apply to single model
|
||||
req = WorkerApplyRequest(
|
||||
model=workers[0][1].model_name,
|
||||
apply_type=WorkerApplyType.START,
|
||||
worker_type=WorkerType.LLM,
|
||||
)
|
||||
assert await manager._apply_worker(req, apply_func=f1) == [0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("manager_2_workers", [{"start": False}], indirect=True)
|
||||
async def test__start_all_worker(manager_2_workers: LocalWorkerManager):
|
||||
manager = manager_2_workers
|
||||
out = await manager._start_all_worker(None)
|
||||
assert out.success
|
||||
assert len(manager.workers) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"manager_2_workers, is_error_worker",
|
||||
[
|
||||
({"start": False, "error_worker": False}, False),
|
||||
({"start": False, "error_worker": True}, True),
|
||||
],
|
||||
indirect=["manager_2_workers"],
|
||||
)
|
||||
async def test_start_worker_manager(
|
||||
manager_2_workers: LocalWorkerManager, is_error_worker: bool
|
||||
):
|
||||
manager = manager_2_workers
|
||||
if is_error_worker:
|
||||
with pytest.raises(Exception):
|
||||
await manager.start()
|
||||
else:
|
||||
await manager.start()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"manager_2_workers, is_stop_error",
|
||||
[
|
||||
({"stop": False, "stop_error": False}, False),
|
||||
({"stop": False, "stop_error": True}, True),
|
||||
],
|
||||
indirect=["manager_2_workers"],
|
||||
)
|
||||
async def test__stop_all_worker(
|
||||
manager_2_workers: LocalWorkerManager, is_stop_error: bool
|
||||
):
|
||||
manager = manager_2_workers
|
||||
out = await manager._stop_all_worker(None)
|
||||
if is_stop_error:
|
||||
assert not out.success
|
||||
else:
|
||||
assert out.success
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test__restart_all_worker(manager_2_workers: LocalWorkerManager):
|
||||
manager = manager_2_workers
|
||||
out = await manager._restart_all_worker(None)
|
||||
assert out.success
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"manager_2_workers, is_stop_error",
|
||||
[
|
||||
({"stop": False, "stop_error": False}, False),
|
||||
({"stop": False, "stop_error": True}, True),
|
||||
],
|
||||
indirect=["manager_2_workers"],
|
||||
)
|
||||
async def test_stop_worker_manager(
|
||||
manager_2_workers: LocalWorkerManager, is_stop_error: bool
|
||||
):
|
||||
manager = manager_2_workers
|
||||
if is_stop_error:
|
||||
with pytest.raises(Exception):
|
||||
await manager.stop()
|
||||
else:
|
||||
await manager.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test__remove_worker():
|
||||
workers = _create_workers(3)
|
||||
async with _start_worker_manager(workers=workers, stop=False) as manager:
|
||||
assert len(manager.workers) == 3
|
||||
for _, worker_params in workers:
|
||||
manager._remove_worker(worker_params)
|
||||
not_exist_parmas = _new_worker_params(
|
||||
model_name="this is a not exist worker params"
|
||||
)
|
||||
manager._remove_worker(not_exist_parmas)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("pilot.model.cluster.worker.manager._build_worker")
|
||||
async def test_model_startup(mock_build_worker):
|
||||
async with _start_worker_manager() as manager:
|
||||
workers = _create_workers(1)
|
||||
worker, worker_params = workers[0]
|
||||
mock_build_worker.return_value = worker
|
||||
|
||||
req = WorkerStartupRequest(
|
||||
host="127.0.0.1",
|
||||
port=8001,
|
||||
model=worker_params.model_name,
|
||||
worker_type=WorkerType.LLM,
|
||||
params=asdict(worker_params),
|
||||
)
|
||||
await manager.model_startup(req)
|
||||
with pytest.raises(Exception):
|
||||
await manager.model_startup(req)
|
||||
|
||||
async with _start_worker_manager() as manager:
|
||||
workers = _create_workers(1, error_worker=True)
|
||||
worker, worker_params = workers[0]
|
||||
mock_build_worker.return_value = worker
|
||||
req = WorkerStartupRequest(
|
||||
host="127.0.0.1",
|
||||
port=8001,
|
||||
model=worker_params.model_name,
|
||||
worker_type=WorkerType.LLM,
|
||||
params=asdict(worker_params),
|
||||
)
|
||||
with pytest.raises(Exception):
|
||||
await manager.model_startup(req)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("pilot.model.cluster.worker.manager._build_worker")
|
||||
async def test_model_shutdown(mock_build_worker):
|
||||
async with _start_worker_manager(start=False, stop=False) as manager:
|
||||
workers = _create_workers(1)
|
||||
worker, worker_params = workers[0]
|
||||
mock_build_worker.return_value = worker
|
||||
|
||||
req = WorkerStartupRequest(
|
||||
host="127.0.0.1",
|
||||
port=8001,
|
||||
model=worker_params.model_name,
|
||||
worker_type=WorkerType.LLM,
|
||||
params=asdict(worker_params),
|
||||
)
|
||||
await manager.model_startup(req)
|
||||
await manager.model_shutdown(req)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_supported_models(manager_2_workers: LocalWorkerManager):
|
||||
manager = manager_2_workers
|
||||
models = await manager.supported_models()
|
||||
assert len(models) == 1
|
||||
models = models[0].models
|
||||
assert len(models) > 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"is_async",
|
||||
[
|
||||
True,
|
||||
False,
|
||||
],
|
||||
)
|
||||
async def test_get_model_instances(is_async):
|
||||
workers = _create_workers(3)
|
||||
async with _start_worker_manager(workers=workers, stop=False) as manager:
|
||||
assert len(manager.workers) == 3
|
||||
for _, worker_params in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
if is_async:
|
||||
assert (
|
||||
len(await manager.get_model_instances(worker_type, model_name)) == 1
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
len(manager.sync_get_model_instances(worker_type, model_name)) == 1
|
||||
)
|
||||
if is_async:
|
||||
assert not await manager.get_model_instances(
|
||||
worker_type, "this is not exist model instances"
|
||||
)
|
||||
else:
|
||||
assert not manager.sync_get_model_instances(
|
||||
worker_type, "this is not exist model instances"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test__simple_select(
|
||||
manager_with_2_workers: Tuple[
|
||||
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
|
||||
]
|
||||
):
|
||||
manager, workers = manager_with_2_workers
|
||||
for _, worker_params in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
instances = await manager.get_model_instances(worker_type, model_name)
|
||||
assert instances
|
||||
inst = manager._simple_select(worker_params.worker_type, model_name, instances)
|
||||
assert inst is not None
|
||||
assert inst.worker_params == worker_params
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"is_async",
|
||||
[
|
||||
True,
|
||||
False,
|
||||
],
|
||||
)
|
||||
async def test_select_one_instance(
|
||||
is_async: bool,
|
||||
manager_with_2_workers: Tuple[
|
||||
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
|
||||
],
|
||||
):
|
||||
manager, workers = manager_with_2_workers
|
||||
for _, worker_params in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
if is_async:
|
||||
inst = await manager.select_one_instance(worker_type, model_name)
|
||||
else:
|
||||
inst = manager.sync_select_one_instance(worker_type, model_name)
|
||||
assert inst is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"is_async",
|
||||
[
|
||||
True,
|
||||
False,
|
||||
],
|
||||
)
|
||||
async def test__get_model(
|
||||
is_async: bool,
|
||||
manager_with_2_workers: Tuple[
|
||||
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
|
||||
],
|
||||
):
|
||||
manager, workers = manager_with_2_workers
|
||||
for _, worker_params in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
params = {"model": model_name}
|
||||
if is_async:
|
||||
wr = await manager._get_model(params, worker_type=worker_type)
|
||||
else:
|
||||
wr = manager._sync_get_model(params, worker_type=worker_type)
|
||||
assert wr is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"manager_with_2_workers, expected_messages",
|
||||
[
|
||||
({"stream_messags": ["Hello", " world."]}, "Hello world."),
|
||||
({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。"),
|
||||
],
|
||||
indirect=["manager_with_2_workers"],
|
||||
)
|
||||
async def test_generate_stream(
|
||||
manager_with_2_workers: Tuple[
|
||||
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
|
||||
],
|
||||
expected_messages: str,
|
||||
):
|
||||
manager, workers = manager_with_2_workers
|
||||
for _, worker_params in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
params = {"model": model_name}
|
||||
text = ""
|
||||
async for out in manager.generate_stream(params):
|
||||
text += out.text
|
||||
assert text == expected_messages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"manager_with_2_workers, expected_messages",
|
||||
[
|
||||
({"stream_messags": ["Hello", " world."]}, " world."),
|
||||
({"stream_messags": ["你好,我是", "张三。"]}, "张三。"),
|
||||
],
|
||||
indirect=["manager_with_2_workers"],
|
||||
)
|
||||
async def test_generate(
|
||||
manager_with_2_workers: Tuple[
|
||||
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
|
||||
],
|
||||
expected_messages: str,
|
||||
):
|
||||
manager, workers = manager_with_2_workers
|
||||
for _, worker_params in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
params = {"model": model_name}
|
||||
out = await manager.generate(params)
|
||||
assert out.text == expected_messages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"manager_2_embedding_workers, expected_embedding, is_async",
|
||||
[
|
||||
({"embeddings": [[1, 2, 3], [4, 5, 6]]}, [[1, 2, 3], [4, 5, 6]], True),
|
||||
({"embeddings": [[0, 0, 0], [1, 1, 1]]}, [[0, 0, 0], [1, 1, 1]], False),
|
||||
],
|
||||
indirect=["manager_2_embedding_workers"],
|
||||
)
|
||||
async def test_embeddings(
|
||||
manager_2_embedding_workers: Tuple[
|
||||
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
|
||||
],
|
||||
expected_embedding: List[List[int]],
|
||||
is_async: bool,
|
||||
):
|
||||
manager, workers = manager_2_embedding_workers
|
||||
for _, worker_params in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
params = {"model": model_name, "input": ["hello", "world"]}
|
||||
if is_async:
|
||||
out = await manager.embeddings(params)
|
||||
else:
|
||||
out = manager.sync_embeddings(params)
|
||||
assert out == expected_embedding
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parameter_descriptions(
|
||||
manager_with_2_workers: Tuple[
|
||||
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
|
||||
]
|
||||
):
|
||||
manager, workers = manager_with_2_workers
|
||||
for _, worker_params in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
params = await manager.parameter_descriptions(worker_type, model_name)
|
||||
assert params is not None
|
||||
assert len(params) > 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test__update_all_worker_params():
|
||||
# TODO
|
||||
pass
|
Loading…
Reference in New Issue
Block a user