Merge remote-tracking branch 'origin/main' into feat_rag_graph

This commit is contained in:
aries_ckt 2023-10-18 20:37:08 +08:00
commit f6694d95ec
13 changed files with 929 additions and 85 deletions

View File

@ -64,6 +64,20 @@ class WorkerApplyOutput:
# The seconds cost to apply some action to worker instances
timecost: Optional[int] = -1
@staticmethod
def reduce(outs: List["WorkerApplyOutput"]) -> "WorkerApplyOutput":
"""Merge all outputs
Args:
outs (List["WorkerApplyOutput"]): The list of WorkerApplyOutput
"""
if not outs:
return WorkerApplyOutput("Not outputs")
combined_success = all(out.success for out in outs)
max_timecost = max(out.timecost for out in outs)
combined_message = ", ".join(out.message for out in outs)
return WorkerApplyOutput(combined_message, combined_success, max_timecost)
@dataclass
class SupportedModel:

View File

@ -185,7 +185,7 @@ def run_model_controller():
setup_logging(
"pilot",
logging_level=controller_params.log_level,
logger_filename="dbgpt_model_controller.log",
logger_filename=controller_params.log_file,
)
initialize_controller(host=controller_params.host, port=controller_params.port)

View File

@ -26,11 +26,22 @@ class WorkerRunData:
_heartbeat_future: Optional[Future] = None
_last_heartbeat: Optional[datetime] = None
def _to_print_key(self):
model_name = self.model_params.model_name
model_type = self.model_params.model_type
host = self.host
port = self.port
return f"model {model_name}@{model_type}({host}:{port})"
class WorkerManager(ABC):
@abstractmethod
async def start(self):
"""Start worker manager"""
"""Start worker manager
Raises:
Exception: if start worker manager not successfully
"""
@abstractmethod
async def stop(self, ignore_exception: bool = False):
@ -69,11 +80,11 @@ class WorkerManager(ABC):
"""List supported models"""
@abstractmethod
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
async def model_startup(self, startup_req: WorkerStartupRequest):
"""Create and start a model instance"""
@abstractmethod
async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
async def model_shutdown(self, shutdown_req: WorkerStartupRequest):
"""Shutdown model instance"""
@abstractmethod

View File

@ -104,12 +104,16 @@ class LocalWorkerManager(WorkerManager):
return f"{model_name}@{worker_type}"
async def run_blocking_func(self, func, *args):
if asyncio.iscoroutinefunction(func):
raise ValueError(f"The function {func} is not blocking function")
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self.executor, func, *args)
async def start(self):
if len(self.workers) > 0:
await self._start_all_worker(apply_req=None)
out = await self._start_all_worker(apply_req=None)
if not out.success:
raise Exception(out.message)
if self.register_func:
await self.register_func(self.run_data)
if self.send_heartbeat_func:
@ -143,7 +147,9 @@ class LocalWorkerManager(WorkerManager):
else:
stop_tasks.append(self.deregister_func(self.run_data))
await asyncio.gather(*stop_tasks)
results = await asyncio.gather(*stop_tasks)
if not results[0].success and not ignore_exception:
raise Exception(results[0].message)
def after_start(self, listener: Callable[["WorkerManager"], None]):
self.start_listeners.append(listener)
@ -193,7 +199,15 @@ class LocalWorkerManager(WorkerManager):
logger.warn(f"Instance {worker_key} exist")
return False
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
def _remove_worker(self, worker_params: ModelWorkerParameters) -> None:
worker_key = self._worker_key(
worker_params.worker_type, worker_params.model_name
)
instances = self.workers.get(worker_key)
if instances:
del self.workers[worker_key]
async def model_startup(self, startup_req: WorkerStartupRequest):
"""Start model"""
model_name = startup_req.model
worker_type = startup_req.worker_type
@ -213,22 +227,30 @@ class LocalWorkerManager(WorkerManager):
self.add_worker, worker, worker_params, command_args
)
if not success:
logger.warn(
f"Add worker failed, worker instances is exist, worker_params: {worker_params}"
)
return False
msg = f"Add worker {model_name}@{worker_type}, worker instances is exist"
logger.warn(f"{msg}, worker_params: {worker_params}")
self._remove_worker(worker_params)
raise Exception(msg)
supported_types = WorkerType.values()
if worker_type not in supported_types:
self._remove_worker(worker_params)
raise ValueError(
f"Unsupported worker type: {worker_type}, now supported worker type: {supported_types}"
)
start_apply_req = WorkerApplyRequest(
model=model_name, apply_type=WorkerApplyType.START, worker_type=worker_type
)
await self.worker_apply(start_apply_req)
return True
out: WorkerApplyOutput = None
try:
out = await self.worker_apply(start_apply_req)
except Exception as e:
self._remove_worker(worker_params)
raise e
if not out.success:
self._remove_worker(worker_params)
raise Exception(out.message)
async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
async def model_shutdown(self, shutdown_req: WorkerStartupRequest):
logger.info(f"Begin shutdown model, shutdown_req: {shutdown_req}")
apply_req = WorkerApplyRequest(
model=shutdown_req.model,
@ -236,8 +258,7 @@ class LocalWorkerManager(WorkerManager):
worker_type=shutdown_req.worker_type,
)
out = await self._stop_all_worker(apply_req)
if out.success:
return True
if not out.success:
raise Exception(out.message)
async def supported_models(self) -> List[WorkerSupportedModel]:
@ -253,7 +274,7 @@ class LocalWorkerManager(WorkerManager):
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
worker_key = self._worker_key(worker_type, model_name)
return self.workers.get(worker_key)
return self.workers.get(worker_key, [])
def _simple_select(
self, worker_type: str, model_name: str, worker_instances: List[WorkerRunData]
@ -424,10 +445,15 @@ class LocalWorkerManager(WorkerManager):
async def _start_all_worker(
self, apply_req: WorkerApplyRequest
) -> WorkerApplyOutput:
# TODO avoid start twice
start_time = time.time()
logger.info(f"Begin start all worker, apply_req: {apply_req}")
async def _start_worker(worker_run_data: WorkerRunData):
_start_time = time.time()
info = worker_run_data._to_print_key()
out = WorkerApplyOutput("")
try:
await self.run_blocking_func(
worker_run_data.worker.start,
worker_run_data.model_params,
@ -448,12 +474,18 @@ class LocalWorkerManager(WorkerManager):
self.send_heartbeat_func,
)
)
out.message = f"{info} start successfully"
except Exception as e:
out.success = False
out.message = f"{info} start failed, {str(e)}"
finally:
out.timecost = time.time() - _start_time
return out
await self._apply_worker(apply_req, _start_worker)
timecost = time.time() - start_time
return WorkerApplyOutput(
message=f"Worker started successfully", timecost=timecost
)
outs = await self._apply_worker(apply_req, _start_worker)
out = WorkerApplyOutput.reduce(outs)
out.timecost = time.time() - start_time
return out
async def _stop_all_worker(
self, apply_req: WorkerApplyRequest, ignore_exception: bool = False
@ -461,6 +493,10 @@ class LocalWorkerManager(WorkerManager):
start_time = time.time()
async def _stop_worker(worker_run_data: WorkerRunData):
_start_time = time.time()
info = worker_run_data._to_print_key()
out = WorkerApplyOutput("")
try:
await self.run_blocking_func(worker_run_data.worker.stop)
# Set stop event
worker_run_data.stop_event.set()
@ -486,17 +522,27 @@ class LocalWorkerManager(WorkerManager):
_deregister_func = safe_deregister_func
await _deregister_func(worker_run_data)
# Remove metadata
self._remove_worker(worker_run_data.worker_params)
out.message = f"{info} stop successfully"
except Exception as e:
out.success = False
out.message = f"{info} stop failed, {str(e)}"
finally:
out.timecost = time.time() - _start_time
return out
await self._apply_worker(apply_req, _stop_worker)
timecost = time.time() - start_time
return WorkerApplyOutput(
message=f"Worker stopped successfully", timecost=timecost
)
outs = await self._apply_worker(apply_req, _stop_worker)
out = WorkerApplyOutput.reduce(outs)
out.timecost = time.time() - start_time
return out
async def _restart_all_worker(
self, apply_req: WorkerApplyRequest
) -> WorkerApplyOutput:
await self._stop_all_worker(apply_req)
out = await self._stop_all_worker(apply_req, ignore_exception=True)
if not out.success:
return out
return await self._start_all_worker(apply_req)
async def _update_all_worker_params(
@ -541,10 +587,10 @@ class WorkerManagerAdapter(WorkerManager):
async def supported_models(self) -> List[WorkerSupportedModel]:
return await self.worker_manager.supported_models()
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
async def model_startup(self, startup_req: WorkerStartupRequest):
return await self.worker_manager.model_startup(startup_req)
async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
async def model_shutdown(self, shutdown_req: WorkerStartupRequest):
return await self.worker_manager.model_shutdown(shutdown_req)
async def get_model_instances(
@ -960,7 +1006,7 @@ def run_worker_manager(
setup_logging(
"pilot",
logging_level=worker_params.log_level,
logger_filename="dbgpt_model_worker_manager.log",
logger_filename=worker_params.log_file,
)
embedded_mod = True
@ -973,7 +1019,7 @@ def run_worker_manager(
system_app = SystemApp(app)
initialize_tracer(
system_app,
os.path.join(LOGDIR, "dbgpt_model_worker_manager_tracer.jsonl"),
os.path.join(LOGDIR, worker_params.tracer_file),
root_operation_name="DB-GPT-WorkerManager-Entry",
)

View File

@ -96,7 +96,7 @@ class RemoteWorkerManager(LocalWorkerManager):
raise Exception(error_msg)
return worker_instances
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
async def model_startup(self, startup_req: WorkerStartupRequest):
worker_instances = await self._get_worker_service_instance(
startup_req.host, startup_req.port
)
@ -107,10 +107,10 @@ class RemoteWorkerManager(LocalWorkerManager):
"/models/startup",
method="POST",
json=startup_req.dict(),
success_handler=lambda x: True,
success_handler=lambda x: None,
)
async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
async def model_shutdown(self, shutdown_req: WorkerStartupRequest):
worker_instances = await self._get_worker_service_instance(
shutdown_req.host, shutdown_req.port
)
@ -121,7 +121,7 @@ class RemoteWorkerManager(LocalWorkerManager):
"/models/shutdown",
method="POST",
json=shutdown_req.dict(),
success_handler=lambda x: True,
success_handler=lambda x: None,
)
def _build_worker_instances(

View File

@ -0,0 +1,168 @@
import pytest
import pytest_asyncio
from contextlib import contextmanager, asynccontextmanager
from typing import List, Iterator, Dict, Tuple
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
from pilot.model.base import ModelOutput
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.worker.manager import (
LocalWorkerManager,
RegisterFunc,
DeregisterFunc,
SendHeartbeatFunc,
ApplyFunction,
)
class MockModelWorker(ModelWorker):
def __init__(
self,
model_parameters: ModelParameters,
error_worker: bool = False,
stop_error: bool = False,
stream_messags: List[str] = None,
embeddings: List[List[float]] = None,
) -> None:
super().__init__()
if not stream_messags:
stream_messags = []
if not embeddings:
embeddings = []
self.model_parameters = model_parameters
self.error_worker = error_worker
self.stop_error = stop_error
self.stream_messags = stream_messags
self._embeddings = embeddings
def parse_parameters(self, command_args: List[str] = None) -> ModelParameters:
return self.model_parameters
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
pass
def start(
self, model_params: ModelParameters = None, command_args: List[str] = None
) -> None:
if self.error_worker:
raise Exception("Start worker error for mock")
def stop(self) -> None:
if self.stop_error:
raise Exception("Stop worker error for mock")
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
for msg in self.stream_messags:
yield ModelOutput(text=msg, error_code=0)
def generate(self, params: Dict) -> ModelOutput:
output = None
for out in self.generate_stream(params):
output = out
return output
def embeddings(self, params: Dict) -> List[List[float]]:
return self._embeddings
_TEST_MODEL_NAME = "vicuna-13b-v1.5"
_TEST_MODEL_PATH = "/app/models/vicuna-13b-v1.5"
def _new_worker_params(
model_name: str = _TEST_MODEL_NAME,
model_path: str = _TEST_MODEL_PATH,
worker_type: str = WorkerType.LLM.value,
) -> ModelWorkerParameters:
return ModelWorkerParameters(
model_name=model_name, model_path=model_path, worker_type=worker_type
)
def _create_workers(
num_workers: int,
error_worker: bool = False,
stop_error: bool = False,
worker_type: str = WorkerType.LLM.value,
stream_messags: List[str] = None,
embeddings: List[List[float]] = None,
) -> List[Tuple[ModelWorker, ModelWorkerParameters]]:
workers = []
for i in range(num_workers):
model_name = f"test-model-name-{i}"
model_path = f"test-model-path-{i}"
model_parameters = ModelParameters(model_name=model_name, model_path=model_path)
worker = MockModelWorker(
model_parameters,
error_worker=error_worker,
stop_error=stop_error,
stream_messags=stream_messags,
embeddings=embeddings,
)
worker_params = _new_worker_params(
model_name, model_path, worker_type=worker_type
)
workers.append((worker, worker_params))
return workers
@asynccontextmanager
async def _start_worker_manager(**kwargs):
register_func = kwargs.get("register_func")
deregister_func = kwargs.get("deregister_func")
send_heartbeat_func = kwargs.get("send_heartbeat_func")
model_registry = kwargs.get("model_registry")
workers = kwargs.get("workers")
num_workers = int(kwargs.get("num_workers", 0))
start = kwargs.get("start", True)
stop = kwargs.get("stop", True)
error_worker = kwargs.get("error_worker", False)
stop_error = kwargs.get("stop_error", False)
stream_messags = kwargs.get("stream_messags", [])
embeddings = kwargs.get("embeddings", [])
worker_manager = LocalWorkerManager(
register_func=register_func,
deregister_func=deregister_func,
send_heartbeat_func=send_heartbeat_func,
model_registry=model_registry,
)
for worker, worker_params in _create_workers(
num_workers, error_worker, stop_error, stream_messags, embeddings
):
worker_manager.add_worker(worker, worker_params)
if workers:
for worker, worker_params in workers:
worker_manager.add_worker(worker, worker_params)
if start:
await worker_manager.start()
yield worker_manager
if stop:
await worker_manager.stop()
@pytest_asyncio.fixture
async def manager_2_workers(request):
param = getattr(request, "param", {})
async with _start_worker_manager(num_workers=2, **param) as worker_manager:
yield worker_manager
@pytest_asyncio.fixture
async def manager_with_2_workers(request):
param = getattr(request, "param", {})
workers = _create_workers(2, stream_messags=param.get("stream_messags", []))
async with _start_worker_manager(workers=workers, **param) as worker_manager:
yield (worker_manager, workers)
@pytest_asyncio.fixture
async def manager_2_embedding_workers(request):
param = getattr(request, "param", {})
workers = _create_workers(
2, worker_type=WorkerType.TEXT2VEC.value, embeddings=param.get("embeddings", [])
)
async with _start_worker_manager(workers=workers, **param) as worker_manager:
yield (worker_manager, workers)

View File

@ -0,0 +1,486 @@
from unittest.mock import patch, AsyncMock
import pytest
from typing import List, Iterator, Dict, Tuple
from dataclasses import asdict
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
from pilot.model.base import ModelOutput, WorkerApplyType
from pilot.model.cluster.base import WorkerApplyRequest, WorkerStartupRequest
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.manager_base import WorkerRunData
from pilot.model.cluster.worker.manager import (
LocalWorkerManager,
RegisterFunc,
DeregisterFunc,
SendHeartbeatFunc,
ApplyFunction,
)
from pilot.model.cluster.worker.tests.base_tests import (
MockModelWorker,
manager_2_workers,
manager_with_2_workers,
manager_2_embedding_workers,
_create_workers,
_start_worker_manager,
_new_worker_params,
)
_TEST_MODEL_NAME = "vicuna-13b-v1.5"
_TEST_MODEL_PATH = "/app/models/vicuna-13b-v1.5"
@pytest.fixture
def worker():
mock_worker = _create_workers(1)
yield mock_worker[0][0]
@pytest.fixture
def worker_param():
return _new_worker_params()
@pytest.fixture
def manager(request):
if not request or not hasattr(request, "param"):
register_func = None
deregister_func = None
send_heartbeat_func = None
model_registry = None
workers = []
else:
register_func = request.param.get("register_func")
deregister_func = request.param.get("deregister_func")
send_heartbeat_func = request.param.get("send_heartbeat_func")
model_registry = request.param.get("model_registry")
workers = request.param.get("model_registry")
worker_manager = LocalWorkerManager(
register_func=register_func,
deregister_func=deregister_func,
send_heartbeat_func=send_heartbeat_func,
model_registry=model_registry,
)
yield worker_manager
@pytest.mark.asyncio
async def test_run_blocking_func(manager: LocalWorkerManager):
def f1() -> int:
return 0
def f2(a: int, b: int) -> int:
return a + b
async def error_f3() -> None:
return 0
assert await manager.run_blocking_func(f1) == 0
assert await manager.run_blocking_func(f2, 1, 2) == 3
with pytest.raises(ValueError):
await manager.run_blocking_func(error_f3)
@pytest.mark.asyncio
async def test_add_worker(
manager: LocalWorkerManager,
worker: ModelWorker,
worker_param: ModelWorkerParameters,
):
# TODO test with register function
assert manager.add_worker(worker, worker_param)
# Add again
assert manager.add_worker(worker, worker_param) == False
key = manager._worker_key(worker_param.worker_type, worker_param.model_name)
assert len(manager.workers) == 1
assert len(manager.workers[key]) == 1
assert manager.workers[key][0].worker == worker
assert manager.add_worker(
worker,
_new_worker_params(
model_name="chatglm2-6b", model_path="/app/models/chatglm2-6b"
),
)
assert (
manager.add_worker(
worker,
_new_worker_params(
model_name="chatglm2-6b", model_path="/app/models/chatglm2-6b"
),
)
== False
)
assert len(manager.workers) == 2
@pytest.mark.asyncio
async def test__apply_worker(manager_2_workers: LocalWorkerManager):
manager = manager_2_workers
async def f1(wr: WorkerRunData) -> int:
return 0
# Apply to all workers
assert await manager._apply_worker(None, apply_func=f1) == [0, 0]
workers = _create_workers(4)
async with _start_worker_manager(workers=workers) as manager:
# Apply to single model
req = WorkerApplyRequest(
model=workers[0][1].model_name,
apply_type=WorkerApplyType.START,
worker_type=WorkerType.LLM,
)
assert await manager._apply_worker(req, apply_func=f1) == [0]
@pytest.mark.asyncio
@pytest.mark.parametrize("manager_2_workers", [{"start": False}], indirect=True)
async def test__start_all_worker(manager_2_workers: LocalWorkerManager):
manager = manager_2_workers
out = await manager._start_all_worker(None)
assert out.success
assert len(manager.workers) == 2
@pytest.mark.asyncio
@pytest.mark.parametrize(
"manager_2_workers, is_error_worker",
[
({"start": False, "error_worker": False}, False),
({"start": False, "error_worker": True}, True),
],
indirect=["manager_2_workers"],
)
async def test_start_worker_manager(
manager_2_workers: LocalWorkerManager, is_error_worker: bool
):
manager = manager_2_workers
if is_error_worker:
with pytest.raises(Exception):
await manager.start()
else:
await manager.start()
@pytest.mark.asyncio
@pytest.mark.parametrize(
"manager_2_workers, is_stop_error",
[
({"stop": False, "stop_error": False}, False),
({"stop": False, "stop_error": True}, True),
],
indirect=["manager_2_workers"],
)
async def test__stop_all_worker(
manager_2_workers: LocalWorkerManager, is_stop_error: bool
):
manager = manager_2_workers
out = await manager._stop_all_worker(None)
if is_stop_error:
assert not out.success
else:
assert out.success
@pytest.mark.asyncio
async def test__restart_all_worker(manager_2_workers: LocalWorkerManager):
manager = manager_2_workers
out = await manager._restart_all_worker(None)
assert out.success
@pytest.mark.asyncio
@pytest.mark.parametrize(
"manager_2_workers, is_stop_error",
[
({"stop": False, "stop_error": False}, False),
({"stop": False, "stop_error": True}, True),
],
indirect=["manager_2_workers"],
)
async def test_stop_worker_manager(
manager_2_workers: LocalWorkerManager, is_stop_error: bool
):
manager = manager_2_workers
if is_stop_error:
with pytest.raises(Exception):
await manager.stop()
else:
await manager.stop()
@pytest.mark.asyncio
async def test__remove_worker():
workers = _create_workers(3)
async with _start_worker_manager(workers=workers, stop=False) as manager:
assert len(manager.workers) == 3
for _, worker_params in workers:
manager._remove_worker(worker_params)
not_exist_parmas = _new_worker_params(
model_name="this is a not exist worker params"
)
manager._remove_worker(not_exist_parmas)
@pytest.mark.asyncio
@patch("pilot.model.cluster.worker.manager._build_worker")
async def test_model_startup(mock_build_worker):
async with _start_worker_manager() as manager:
workers = _create_workers(1)
worker, worker_params = workers[0]
mock_build_worker.return_value = worker
req = WorkerStartupRequest(
host="127.0.0.1",
port=8001,
model=worker_params.model_name,
worker_type=WorkerType.LLM,
params=asdict(worker_params),
)
await manager.model_startup(req)
with pytest.raises(Exception):
await manager.model_startup(req)
async with _start_worker_manager() as manager:
workers = _create_workers(1, error_worker=True)
worker, worker_params = workers[0]
mock_build_worker.return_value = worker
req = WorkerStartupRequest(
host="127.0.0.1",
port=8001,
model=worker_params.model_name,
worker_type=WorkerType.LLM,
params=asdict(worker_params),
)
with pytest.raises(Exception):
await manager.model_startup(req)
@pytest.mark.asyncio
@patch("pilot.model.cluster.worker.manager._build_worker")
async def test_model_shutdown(mock_build_worker):
async with _start_worker_manager(start=False, stop=False) as manager:
workers = _create_workers(1)
worker, worker_params = workers[0]
mock_build_worker.return_value = worker
req = WorkerStartupRequest(
host="127.0.0.1",
port=8001,
model=worker_params.model_name,
worker_type=WorkerType.LLM,
params=asdict(worker_params),
)
await manager.model_startup(req)
await manager.model_shutdown(req)
@pytest.mark.asyncio
async def test_supported_models(manager_2_workers: LocalWorkerManager):
manager = manager_2_workers
models = await manager.supported_models()
assert len(models) == 1
models = models[0].models
assert len(models) > 10
@pytest.mark.asyncio
@pytest.mark.parametrize(
"is_async",
[
True,
False,
],
)
async def test_get_model_instances(is_async):
workers = _create_workers(3)
async with _start_worker_manager(workers=workers, stop=False) as manager:
assert len(manager.workers) == 3
for _, worker_params in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
if is_async:
assert (
len(await manager.get_model_instances(worker_type, model_name)) == 1
)
else:
assert (
len(manager.sync_get_model_instances(worker_type, model_name)) == 1
)
if is_async:
assert not await manager.get_model_instances(
worker_type, "this is not exist model instances"
)
else:
assert not manager.sync_get_model_instances(
worker_type, "this is not exist model instances"
)
@pytest.mark.asyncio
async def test__simple_select(
manager_with_2_workers: Tuple[
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
]
):
manager, workers = manager_with_2_workers
for _, worker_params in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
instances = await manager.get_model_instances(worker_type, model_name)
assert instances
inst = manager._simple_select(worker_params.worker_type, model_name, instances)
assert inst is not None
assert inst.worker_params == worker_params
@pytest.mark.asyncio
@pytest.mark.parametrize(
"is_async",
[
True,
False,
],
)
async def test_select_one_instance(
is_async: bool,
manager_with_2_workers: Tuple[
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
],
):
manager, workers = manager_with_2_workers
for _, worker_params in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
if is_async:
inst = await manager.select_one_instance(worker_type, model_name)
else:
inst = manager.sync_select_one_instance(worker_type, model_name)
assert inst is not None
@pytest.mark.asyncio
@pytest.mark.parametrize(
"is_async",
[
True,
False,
],
)
async def test__get_model(
is_async: bool,
manager_with_2_workers: Tuple[
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
],
):
manager, workers = manager_with_2_workers
for _, worker_params in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = {"model": model_name}
if is_async:
wr = await manager._get_model(params, worker_type=worker_type)
else:
wr = manager._sync_get_model(params, worker_type=worker_type)
assert wr is not None
@pytest.mark.asyncio
@pytest.mark.parametrize(
"manager_with_2_workers, expected_messages",
[
({"stream_messags": ["Hello", " world."]}, "Hello world."),
({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。"),
],
indirect=["manager_with_2_workers"],
)
async def test_generate_stream(
manager_with_2_workers: Tuple[
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
],
expected_messages: str,
):
manager, workers = manager_with_2_workers
for _, worker_params in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = {"model": model_name}
text = ""
async for out in manager.generate_stream(params):
text += out.text
assert text == expected_messages
@pytest.mark.asyncio
@pytest.mark.parametrize(
"manager_with_2_workers, expected_messages",
[
({"stream_messags": ["Hello", " world."]}, " world."),
({"stream_messags": ["你好,我是", "张三。"]}, "张三。"),
],
indirect=["manager_with_2_workers"],
)
async def test_generate(
manager_with_2_workers: Tuple[
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
],
expected_messages: str,
):
manager, workers = manager_with_2_workers
for _, worker_params in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = {"model": model_name}
out = await manager.generate(params)
assert out.text == expected_messages
@pytest.mark.asyncio
@pytest.mark.parametrize(
"manager_2_embedding_workers, expected_embedding, is_async",
[
({"embeddings": [[1, 2, 3], [4, 5, 6]]}, [[1, 2, 3], [4, 5, 6]], True),
({"embeddings": [[0, 0, 0], [1, 1, 1]]}, [[0, 0, 0], [1, 1, 1]], False),
],
indirect=["manager_2_embedding_workers"],
)
async def test_embeddings(
manager_2_embedding_workers: Tuple[
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
],
expected_embedding: List[List[int]],
is_async: bool,
):
manager, workers = manager_2_embedding_workers
for _, worker_params in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = {"model": model_name, "input": ["hello", "world"]}
if is_async:
out = await manager.embeddings(params)
else:
out = manager.sync_embeddings(params)
assert out == expected_embedding
@pytest.mark.asyncio
async def test_parameter_descriptions(
manager_with_2_workers: Tuple[
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
]
):
manager, workers = manager_with_2_workers
for _, worker_params in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = await manager.parameter_descriptions(worker_type, model_name)
assert params is not None
assert len(params) > 5
@pytest.mark.asyncio
async def test__update_all_worker_params():
# TODO
pass

View File

@ -46,6 +46,18 @@ class ModelControllerParameters(BaseParameters):
],
},
)
log_file: Optional[str] = field(
default="dbgpt_model_controller.log",
metadata={
"help": "The filename to store log",
},
)
tracer_file: Optional[str] = field(
default="dbgpt_model_controller_tracer.jsonl",
metadata={
"help": "The filename to store tracer span records",
},
)
@dataclass
@ -122,6 +134,18 @@ class ModelWorkerParameters(BaseModelParameters):
],
},
)
log_file: Optional[str] = field(
default="dbgpt_model_worker_manager.log",
metadata={
"help": "The filename to store log",
},
)
tracer_file: Optional[str] = field(
default="dbgpt_model_worker_manager_tracer.jsonl",
metadata={
"help": "The filename to store tracer span records",
},
)
@dataclass

View File

@ -95,6 +95,19 @@ class WebWerverParameters(BaseParameters):
daemon: Optional[bool] = field(
default=False, metadata={"help": "Run Webserver in background"}
)
controller_addr: Optional[str] = field(
default=None,
metadata={
"help": "The Model controller address to connect. If None, read model controller address from environment key `MODEL_SERVER`."
},
)
model_name: str = field(
default=None,
metadata={
"help": "The default model name to use. If None, read model name from environment key `LLM_MODEL`.",
"tags": "fixed",
},
)
share: Optional[bool] = field(
default=False,
metadata={
@ -123,3 +136,15 @@ class WebWerverParameters(BaseParameters):
},
)
light: Optional[bool] = field(default=False, metadata={"help": "enable light mode"})
log_file: Optional[str] = field(
default="dbgpt_webserver.log",
metadata={
"help": "The filename to store log",
},
)
tracer_file: Optional[str] = field(
default="dbgpt_webserver_tracer.jsonl",
metadata={
"help": "The filename to store tracer span records",
},
)

View File

@ -119,7 +119,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
if not param.log_level:
param.log_level = _get_logging_level()
setup_logging(
"pilot", logging_level=param.log_level, logger_filename="dbgpt_webserver.log"
"pilot", logging_level=param.log_level, logger_filename=param.log_file
)
# Before start
system_app.before_start()
@ -133,14 +133,16 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
model_start_listener = _create_model_start_listener(system_app)
initialize_components(param, system_app, embedding_model_name, embedding_model_path)
model_path = LLM_MODEL_CONFIG.get(CFG.LLM_MODEL)
model_name = param.model_name or CFG.LLM_MODEL
model_path = LLM_MODEL_CONFIG.get(model_name)
if not param.light:
print("Model Unified Deployment Mode!")
if not param.remote_embedding:
embedding_model_name, embedding_model_path = None, None
initialize_worker_manager_in_client(
app=app,
model_name=CFG.LLM_MODEL,
model_name=model_name,
model_path=model_path,
local_port=param.port,
embedding_model_name=embedding_model_name,
@ -152,12 +154,13 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
CFG.NEW_SERVER_MODE = True
else:
# MODEL_SERVER is controller address now
controller_addr = param.controller_addr or CFG.MODEL_SERVER
initialize_worker_manager_in_client(
app=app,
model_name=CFG.LLM_MODEL,
model_name=model_name,
model_path=model_path,
run_locally=False,
controller_addr=CFG.MODEL_SERVER,
controller_addr=controller_addr,
local_port=param.port,
start_listener=model_start_listener,
system_app=system_app,
@ -182,7 +185,7 @@ def run_uvicorn(param: WebWerverParameters):
def run_webserver(param: WebWerverParameters = None):
if not param:
param = _get_webserver_params()
initialize_tracer(system_app, os.path.join(LOGDIR, "dbgpt_webserver_tracer.jsonl"))
initialize_tracer(system_app, os.path.join(LOGDIR, param.tracer_file))
with root_tracer.start_span(
"run_webserver",

View File

@ -1,6 +1,7 @@
import os
import json
import time
import datetime
import threading
import queue
import logging
@ -27,6 +28,13 @@ class FileSpanStorage(SpanStorage):
def __init__(self, filename: str, batch_size=10, flush_interval=10):
super().__init__()
self.filename = filename
# Split filename into prefix and suffix
self.filename_prefix, self.filename_suffix = os.path.splitext(filename)
if not self.filename_suffix:
self.filename_suffix = ".log"
self.last_date = (
datetime.datetime.now().date()
) # Store the current date for checking date changes
self.queue = queue.Queue()
self.batch_size = batch_size
self.flush_interval = flush_interval
@ -52,7 +60,21 @@ class FileSpanStorage(SpanStorage):
except queue.Full:
pass # If the signal queue is full, it's okay. The flush thread will handle it.
def _get_dated_filename(self, date: datetime.date) -> str:
"""Return the filename based on a specific date."""
date_str = date.strftime("%Y-%m-%d")
return f"{self.filename_prefix}_{date_str}{self.filename_suffix}"
def _roll_over_if_needed(self):
"""Checks if a day has changed since the last write, and if so, renames the current file."""
current_date = datetime.datetime.now().date()
if current_date != self.last_date:
if os.path.exists(self.filename):
os.rename(self.filename, self._get_dated_filename(self.last_date))
self.last_date = current_date
def _write_to_file(self):
self._roll_over_if_needed()
spans_to_write = []
while not self.queue.empty():
spans_to_write.append(self.queue.get())

View File

@ -4,6 +4,8 @@ import asyncio
import json
import tempfile
import time
from unittest.mock import patch
from datetime import datetime, timedelta
from pilot.utils.tracer import SpanStorage, FileSpanStorage, Span, SpanType
@ -122,3 +124,46 @@ def test_non_existent_file(storage: SpanStorage):
assert len(spans_in_file) == 2
assert spans_in_file[0]["trace_id"] == "1"
assert spans_in_file[1]["trace_id"] == "2"
@pytest.mark.parametrize(
"storage", [{"batch_size": 1, "file_does_not_exist": True}], indirect=True
)
def test_log_rollover(storage: SpanStorage):
# mock start date
mock_start_date = datetime(2023, 10, 18, 23, 59)
with patch("datetime.datetime") as mock_datetime:
mock_datetime.now.return_value = mock_start_date
span1 = Span("1", "a", SpanType.BASE, "b", "op1")
storage.append_span(span1)
time.sleep(0.1)
# mock new day
mock_datetime.now.return_value = mock_start_date + timedelta(minutes=1)
span2 = Span("2", "c", SpanType.BASE, "d", "op2")
storage.append_span(span2)
time.sleep(0.1)
# origin filename need exists
assert os.path.exists(storage.filename)
# get roll over filename
dated_filename = os.path.join(
os.path.dirname(storage.filename),
f"{os.path.basename(storage.filename).split('.')[0]}_2023-10-18.jsonl",
)
assert os.path.exists(dated_filename)
# check origin filename just include the second span
spans_in_original_file = read_spans_from_file(storage.filename)
assert len(spans_in_original_file) == 1
assert spans_in_original_file[0]["trace_id"] == "2"
# check the roll over filename just include the first span
spans_in_dated_file = read_spans_from_file(dated_filename)
assert len(spans_in_dated_file) == 1
assert spans_in_dated_file[0]["trace_id"] == "1"