DB-GPT/dbgpt/model/cluster/worker/tests/test_manager.py
2024-01-10 10:39:04 +08:00

488 lines
15 KiB
Python

from dataclasses import asdict
from typing import Dict, Iterator, List, Tuple
from unittest.mock import AsyncMock, patch
import pytest
from dbgpt.model.base import ModelInstance, WorkerApplyType
from dbgpt.model.cluster.base import WorkerApplyRequest, WorkerStartupRequest
from dbgpt.model.cluster.manager_base import WorkerRunData
from dbgpt.model.cluster.tests.conftest import (
MockModelWorker,
_create_workers,
_new_worker_params,
_start_worker_manager,
manager_2_embedding_workers,
manager_2_workers,
manager_with_2_workers,
)
from dbgpt.model.cluster.worker.manager import (
ApplyFunction,
DeregisterFunc,
LocalWorkerManager,
RegisterFunc,
SendHeartbeatFunc,
)
from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
_TEST_MODEL_NAME = "vicuna-13b-v1.5"
_TEST_MODEL_PATH = "/app/models/vicuna-13b-v1.5"
@pytest.fixture
def worker():
mock_worker = _create_workers(1)
yield mock_worker[0][0]
@pytest.fixture
def worker_param():
return _new_worker_params()
@pytest.fixture
def manager(request):
if not request or not hasattr(request, "param"):
register_func = None
deregister_func = None
send_heartbeat_func = None
model_registry = None
workers = []
else:
register_func = request.param.get("register_func")
deregister_func = request.param.get("deregister_func")
send_heartbeat_func = request.param.get("send_heartbeat_func")
model_registry = request.param.get("model_registry")
workers = request.param.get("model_registry")
worker_manager = LocalWorkerManager(
register_func=register_func,
deregister_func=deregister_func,
send_heartbeat_func=send_heartbeat_func,
model_registry=model_registry,
)
yield worker_manager
@pytest.mark.asyncio
async def test_run_blocking_func(manager: LocalWorkerManager):
def f1() -> int:
return 0
def f2(a: int, b: int) -> int:
return a + b
async def error_f3() -> None:
return 0
assert await manager.run_blocking_func(f1) == 0
assert await manager.run_blocking_func(f2, 1, 2) == 3
with pytest.raises(ValueError):
await manager.run_blocking_func(error_f3)
@pytest.mark.asyncio
async def test_add_worker(
manager: LocalWorkerManager,
worker: ModelWorker,
worker_param: ModelWorkerParameters,
):
# TODO test with register function
assert manager.add_worker(worker, worker_param)
# Add again
assert manager.add_worker(worker, worker_param) == False
key = manager._worker_key(worker_param.worker_type, worker_param.model_name)
assert len(manager.workers) == 1
assert len(manager.workers[key]) == 1
assert manager.workers[key][0].worker == worker
assert manager.add_worker(
worker,
_new_worker_params(
model_name="chatglm2-6b", model_path="/app/models/chatglm2-6b"
),
)
assert (
manager.add_worker(
worker,
_new_worker_params(
model_name="chatglm2-6b", model_path="/app/models/chatglm2-6b"
),
)
== False
)
assert len(manager.workers) == 2
@pytest.mark.asyncio
async def test__apply_worker(manager_2_workers: LocalWorkerManager):
manager = manager_2_workers
async def f1(wr: WorkerRunData) -> int:
return 0
# Apply to all workers
assert await manager._apply_worker(None, apply_func=f1) == [0, 0]
workers = _create_workers(4)
async with _start_worker_manager(workers=workers) as manager:
# Apply to single model
req = WorkerApplyRequest(
model=workers[0][1].model_name,
apply_type=WorkerApplyType.START,
worker_type=WorkerType.LLM,
)
assert await manager._apply_worker(req, apply_func=f1) == [0]
@pytest.mark.asyncio
@pytest.mark.parametrize("manager_2_workers", [{"start": False}], indirect=True)
async def test__start_all_worker(manager_2_workers: LocalWorkerManager):
manager = manager_2_workers
out = await manager._start_all_worker(None)
assert out.success
assert len(manager.workers) == 2
@pytest.mark.asyncio
@pytest.mark.parametrize(
"manager_2_workers, is_error_worker",
[
({"start": False, "error_worker": False}, False),
({"start": False, "error_worker": True}, True),
],
indirect=["manager_2_workers"],
)
async def test_start_worker_manager(
manager_2_workers: LocalWorkerManager, is_error_worker: bool
):
manager = manager_2_workers
if is_error_worker:
with pytest.raises(Exception):
await manager.start()
else:
await manager.start()
@pytest.mark.asyncio
@pytest.mark.parametrize(
"manager_2_workers, is_stop_error",
[
({"stop": False, "stop_error": False}, False),
({"stop": False, "stop_error": True}, True),
],
indirect=["manager_2_workers"],
)
async def test__stop_all_worker(
manager_2_workers: LocalWorkerManager, is_stop_error: bool
):
manager = manager_2_workers
out = await manager._stop_all_worker(None)
if is_stop_error:
assert not out.success
else:
assert out.success
@pytest.mark.asyncio
async def test__restart_all_worker(manager_2_workers: LocalWorkerManager):
manager = manager_2_workers
out = await manager._restart_all_worker(None)
assert out.success
@pytest.mark.asyncio
@pytest.mark.parametrize(
"manager_2_workers, is_stop_error",
[
({"stop": False, "stop_error": False}, False),
({"stop": False, "stop_error": True}, True),
],
indirect=["manager_2_workers"],
)
async def test_stop_worker_manager(
manager_2_workers: LocalWorkerManager, is_stop_error: bool
):
manager = manager_2_workers
if is_stop_error:
with pytest.raises(Exception):
await manager.stop()
else:
await manager.stop()
@pytest.mark.asyncio
async def test__remove_worker():
workers = _create_workers(3)
async with _start_worker_manager(workers=workers, stop=False) as manager:
assert len(manager.workers) == 3
for _, worker_params, _ in workers:
manager._remove_worker(worker_params)
not_exist_parmas = _new_worker_params(
model_name="this is a not exist worker params"
)
manager._remove_worker(not_exist_parmas)
@pytest.mark.asyncio
@patch("dbgpt.model.cluster.worker.manager._build_worker")
async def test_model_startup(mock_build_worker):
async with _start_worker_manager() as manager:
workers = _create_workers(1)
worker, worker_params, model_instance = workers[0]
mock_build_worker.return_value = worker
req = WorkerStartupRequest(
host="127.0.0.1",
port=8001,
model=worker_params.model_name,
worker_type=WorkerType.LLM,
params=asdict(worker_params),
)
await manager.model_startup(req)
with pytest.raises(Exception):
await manager.model_startup(req)
async with _start_worker_manager() as manager:
workers = _create_workers(1, error_worker=True)
worker, worker_params, model_instance = workers[0]
mock_build_worker.return_value = worker
req = WorkerStartupRequest(
host="127.0.0.1",
port=8001,
model=worker_params.model_name,
worker_type=WorkerType.LLM,
params=asdict(worker_params),
)
with pytest.raises(Exception):
await manager.model_startup(req)
@pytest.mark.asyncio
@patch("dbgpt.model.cluster.worker.manager._build_worker")
async def test_model_shutdown(mock_build_worker):
async with _start_worker_manager(start=False, stop=False) as manager:
workers = _create_workers(1)
worker, worker_params, model_instance = workers[0]
mock_build_worker.return_value = worker
req = WorkerStartupRequest(
host="127.0.0.1",
port=8001,
model=worker_params.model_name,
worker_type=WorkerType.LLM,
params=asdict(worker_params),
)
await manager.model_startup(req)
await manager.model_shutdown(req)
@pytest.mark.asyncio
async def test_supported_models(manager_2_workers: LocalWorkerManager):
manager = manager_2_workers
models = await manager.supported_models()
assert len(models) == 1
models = models[0].models
assert len(models) > 10
@pytest.mark.asyncio
@pytest.mark.parametrize(
"is_async",
[
True,
False,
],
)
async def test_get_model_instances(is_async):
workers = _create_workers(3)
async with _start_worker_manager(workers=workers, stop=False) as manager:
assert len(manager.workers) == 3
for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
if is_async:
assert (
len(await manager.get_model_instances(worker_type, model_name)) == 1
)
else:
assert (
len(manager.sync_get_model_instances(worker_type, model_name)) == 1
)
if is_async:
assert not await manager.get_model_instances(
worker_type, "this is not exist model instances"
)
else:
assert not manager.sync_get_model_instances(
worker_type, "this is not exist model instances"
)
@pytest.mark.asyncio
async def test__simple_select(
manager_with_2_workers: Tuple[
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
]
):
manager, workers = manager_with_2_workers
for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
instances = await manager.get_model_instances(worker_type, model_name)
assert instances
inst = manager._simple_select(worker_params.worker_type, model_name, instances)
assert inst is not None
assert inst.worker_params == worker_params
@pytest.mark.asyncio
@pytest.mark.parametrize(
"is_async",
[
True,
False,
],
)
async def test_select_one_instance(
is_async: bool,
manager_with_2_workers: Tuple[
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
],
):
manager, workers = manager_with_2_workers
for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
if is_async:
inst = await manager.select_one_instance(worker_type, model_name)
else:
inst = manager.sync_select_one_instance(worker_type, model_name)
assert inst is not None
@pytest.mark.asyncio
@pytest.mark.parametrize(
"is_async",
[
True,
False,
],
)
async def test__get_model(
is_async: bool,
manager_with_2_workers: Tuple[
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
],
):
manager, workers = manager_with_2_workers
for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = {"model": model_name}
if is_async:
wr = await manager._get_model(params, worker_type=worker_type)
else:
wr = manager._sync_get_model(params, worker_type=worker_type)
assert wr is not None
@pytest.mark.asyncio
@pytest.mark.parametrize(
"manager_with_2_workers, expected_messages",
[
({"stream_messags": ["Hello", " world."]}, "Hello world."),
({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。"),
],
indirect=["manager_with_2_workers"],
)
async def test_generate_stream(
manager_with_2_workers: Tuple[
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
],
expected_messages: str,
):
manager, workers = manager_with_2_workers
for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = {"model": model_name}
text = ""
async for out in manager.generate_stream(params):
text = out.text
assert text == expected_messages
@pytest.mark.asyncio
@pytest.mark.parametrize(
"manager_with_2_workers, expected_messages",
[
({"stream_messags": ["Hello", " world."]}, "Hello world."),
({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。"),
],
indirect=["manager_with_2_workers"],
)
async def test_generate(
manager_with_2_workers: Tuple[
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
],
expected_messages: str,
):
manager, workers = manager_with_2_workers
for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = {"model": model_name}
out = await manager.generate(params)
assert out.text == expected_messages
@pytest.mark.asyncio
@pytest.mark.parametrize(
"manager_2_embedding_workers, expected_embedding, is_async",
[
({"embeddings": [[1, 2, 3], [4, 5, 6]]}, [[1, 2, 3], [4, 5, 6]], True),
({"embeddings": [[0, 0, 0], [1, 1, 1]]}, [[0, 0, 0], [1, 1, 1]], False),
],
indirect=["manager_2_embedding_workers"],
)
async def test_embeddings(
manager_2_embedding_workers: Tuple[
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
],
expected_embedding: List[List[int]],
is_async: bool,
):
manager, workers = manager_2_embedding_workers
for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = {"model": model_name, "input": ["hello", "world"]}
if is_async:
out = await manager.embeddings(params)
else:
out = manager.sync_embeddings(params)
assert out == expected_embedding
@pytest.mark.asyncio
async def test_parameter_descriptions(
manager_with_2_workers: Tuple[
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
]
):
manager, workers = manager_with_2_workers
for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = await manager.parameter_descriptions(worker_type, model_name)
assert params is not None
assert len(params) > 5
@pytest.mark.asyncio
async def test__update_all_worker_params():
# TODO
pass